Skip to main content

codetether_agent/provider/
glm5.rs

1//! GLM-5 FP8 provider for Vast.ai serverless deployments
2//!
3//! This provider connects to a self-hosted vLLM endpoint running GLM-5 FP8
4//! on Vast.ai serverless infrastructure. The endpoint exposes an OpenAI-compatible
5//! chat completions API (vLLM default).
6//!
7//! ## Model: GLM-5-FP8
8//!
9//! - **Architecture**: 744B parameter Mixture-of-Experts (MoE)
10//! - **Active Parameters**: 40B per forward pass
11//! - **Quantization**: FP8 for efficient inference
12//! - **Hardware**: 8x A100 SXM4 80GB
13//! - **Features**: MTP speculative decoding enabled
14//!
15//! ## Configuration (Vault)
16//!
17//! Store under `secret/data/codetether/providers/glm5`:
18//! ```json
19//! {
20//!   "api_key": "<vast-endpoint-api-key>",
21//!   "base_url": "https://route.vast.ai/<endpoint-id>/<api-key>/v1",
22//!   "extra": {
23//!     "model_name": "glm-5-fp8"
24//!   }
25//! }
26//! ```
27//!
28//! ## Model reference format
29//!
30//! Use `glm5/glm-5-fp8`, `glm5/glm-5`, or just `glm-5-fp8` as the model string.
31//!
32//! ## Environment variable fallback
33//!
34//! - `GLM5_API_KEY` — API key for the Vast.ai endpoint
35//! - `GLM5_BASE_URL` — Base URL of the vLLM endpoint (required)
36//! - `GLM5_MODEL` — Model name override (default: glm-5-fp8)
37
38use super::{
39    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
40    Role, StreamChunk, ToolDefinition, Usage,
41};
42use anyhow::{Context, Result};
43use async_trait::async_trait;
44use futures::StreamExt;
45use reqwest::Client;
46use serde::Deserialize;
47use serde_json::{Value, json};
48
49/// Default model name served by the Vast.ai vLLM endpoint.
50pub const DEFAULT_MODEL: &str = "glm-5-fp8";
51
52/// GLM-5 FP8 provider targeting a Vast.ai vLLM serverless endpoint.
53pub struct Glm5Provider {
54    client: Client,
55    api_key: String,
56    base_url: String,
57    /// The actual model name to send in API requests (e.g. "glm-5-fp8").
58    model_name: String,
59}
60
61impl std::fmt::Debug for Glm5Provider {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("Glm5Provider")
64            .field("base_url", &self.base_url)
65            .field("model_name", &self.model_name)
66            .field("api_key", &"<REDACTED>")
67            .finish()
68    }
69}
70
71impl Glm5Provider {
72    /// Create a new Glm5Provider pointing at `base_url` with `api_key`.
73    ///
74    /// `base_url` should be the full base URL including `/v1`, e.g.
75    /// `https://<vast-endpoint>.vast.ai/v1`.
76    #[allow(dead_code)]
77    pub fn new(api_key: String, base_url: String) -> Result<Self> {
78        Self::with_model(api_key, base_url, DEFAULT_MODEL.to_string())
79    }
80
81    /// Create with an explicit model name override.
82    pub fn with_model(api_key: String, base_url: String, model_name: String) -> Result<Self> {
83        let base_url = base_url.trim_end_matches('/').to_string();
84        tracing::debug!(
85            provider = "glm5",
86            base_url = %base_url,
87            model = %model_name,
88            "Creating GLM-5 FP8 provider"
89        );
90        Ok(Self {
91            client: crate::provider::shared_http::shared_client().clone(),
92            api_key,
93            base_url,
94            model_name,
95        })
96    }
97
98    /// Normalize a model string: strip provider prefix (`glm5/`, `glm5:`) and
99    /// map aliases to the canonical model name served by vLLM.
100    ///
101    /// Examples:
102    /// - `"glm5/glm-5-fp8"` → `"glm-5-fp8"`
103    /// - `"glm5:glm-5-fp8"` → `"glm-5-fp8"`
104    /// - `"glm-5-fp8"`      → `"glm-5-fp8"`
105    /// - `"glm-5"`          → `"glm-5-fp8"` (alias to fp8 variant)
106    /// - `""`               → `DEFAULT_MODEL`
107    pub fn normalize_model(model: &str) -> String {
108        let stripped = model
109            .trim()
110            .trim_start_matches("glm5/")
111            .trim_start_matches("glm5:");
112
113        match stripped {
114            "" | "glm-5" | "glm5" => DEFAULT_MODEL.to_string(),
115            other => other.to_string(),
116        }
117    }
118
119    fn convert_messages(messages: &[Message]) -> Vec<Value> {
120        messages
121            .iter()
122            .map(|msg| {
123                let role = match msg.role {
124                    Role::System => "system",
125                    Role::User => "user",
126                    Role::Assistant => "assistant",
127                    Role::Tool => "tool",
128                };
129
130                match msg.role {
131                    Role::Tool => {
132                        if let Some(ContentPart::ToolResult {
133                            tool_call_id,
134                            content,
135                        }) = msg.content.first()
136                        {
137                            json!({
138                                "role": "tool",
139                                "tool_call_id": tool_call_id,
140                                "content": content
141                            })
142                        } else {
143                            json!({"role": role, "content": ""})
144                        }
145                    }
146                    Role::Assistant => {
147                        let text: String = msg
148                            .content
149                            .iter()
150                            .filter_map(|p| match p {
151                                ContentPart::Text { text } => Some(text.clone()),
152                                _ => None,
153                            })
154                            .collect::<Vec<_>>()
155                            .join("");
156
157                        let tool_calls: Vec<Value> = msg
158                            .content
159                            .iter()
160                            .filter_map(|p| match p {
161                                ContentPart::ToolCall {
162                                    id,
163                                    name,
164                                    arguments,
165                                    ..
166                                } => {
167                                    // vLLM expects arguments as a JSON string
168                                    let args_string = serde_json::from_str::<Value>(arguments)
169                                        .map(|parsed| {
170                                            serde_json::to_string(&parsed)
171                                                .unwrap_or_else(|_| "{}".to_string())
172                                        })
173                                        .unwrap_or_else(|_| {
174                                            json!({"input": arguments}).to_string()
175                                        });
176                                    Some(json!({
177                                        "id": id,
178                                        "type": "function",
179                                        "function": {
180                                            "name": name,
181                                            "arguments": args_string
182                                        }
183                                    }))
184                                }
185                                _ => None,
186                            })
187                            .collect();
188
189                        let mut msg_json = json!({
190                            "role": "assistant",
191                            "content": if text.is_empty() { Value::Null } else { json!(text) },
192                        });
193                        if !tool_calls.is_empty() {
194                            msg_json["tool_calls"] = json!(tool_calls);
195                        }
196                        msg_json
197                    }
198                    _ => {
199                        let text: String = msg
200                            .content
201                            .iter()
202                            .filter_map(|p| match p {
203                                ContentPart::Text { text } => Some(text.clone()),
204                                _ => None,
205                            })
206                            .collect::<Vec<_>>()
207                            .join("\n");
208
209                        json!({"role": role, "content": text})
210                    }
211                }
212            })
213            .collect()
214    }
215
216    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
217        tools
218            .iter()
219            .map(|t| {
220                json!({
221                    "type": "function",
222                    "function": {
223                        "name": t.name,
224                        "description": t.description,
225                        "parameters": t.parameters
226                    }
227                })
228            })
229            .collect()
230    }
231
232    fn preview_text(text: &str, max_chars: usize) -> &str {
233        if max_chars == 0 {
234            return "";
235        }
236        if let Some((idx, _)) = text.char_indices().nth(max_chars) {
237            &text[..idx]
238        } else {
239            text
240        }
241    }
242}
243
244// ─── Response deserialization ────────────────────────────────────────────────
245
246#[derive(Debug, Deserialize)]
247struct Glm5Response {
248    choices: Vec<Glm5Choice>,
249    #[serde(default)]
250    usage: Option<Glm5Usage>,
251}
252
253#[derive(Debug, Deserialize)]
254struct Glm5Choice {
255    message: Glm5Message,
256    #[serde(default)]
257    finish_reason: Option<String>,
258}
259
260#[derive(Debug, Deserialize)]
261struct Glm5Message {
262    #[serde(default)]
263    content: Option<String>,
264    #[serde(default)]
265    tool_calls: Option<Vec<Glm5ToolCall>>,
266}
267
268#[derive(Debug, Deserialize)]
269struct Glm5ToolCall {
270    id: String,
271    function: Glm5Function,
272}
273
274#[derive(Debug, Deserialize)]
275struct Glm5Function {
276    name: String,
277    arguments: Value,
278}
279
280#[derive(Debug, Deserialize)]
281struct Glm5Usage {
282    #[serde(default)]
283    prompt_tokens: usize,
284    #[serde(default)]
285    completion_tokens: usize,
286    #[serde(default)]
287    total_tokens: usize,
288    /// OpenAI-compatible prompt-cache breakdown. GLM-5 surfaces
289    /// `prompt_tokens_details.cached_tokens` when the implicit KV-cache
290    /// hit on the input; we subtract it from `prompt_tokens` so the
291    /// cost estimator can price the cached portion at the discounted
292    /// rate (see [`crate::provider::pricing::cache_read_multiplier`]).
293    #[serde(default)]
294    prompt_tokens_details: Option<Glm5PromptTokenDetails>,
295}
296
297#[derive(Debug, Deserialize, Default)]
298struct Glm5PromptTokenDetails {
299    #[serde(default)]
300    cached_tokens: usize,
301}
302
303impl Glm5Usage {
304    fn to_usage(&self) -> Usage {
305        let cached = self
306            .prompt_tokens_details
307            .as_ref()
308            .map(|d| d.cached_tokens)
309            .unwrap_or(0);
310        Usage {
311            prompt_tokens: self.prompt_tokens.saturating_sub(cached),
312            completion_tokens: self.completion_tokens,
313            total_tokens: self.total_tokens,
314            cache_read_tokens: if cached > 0 { Some(cached) } else { None },
315            cache_write_tokens: None,
316        }
317    }
318}
319
320#[derive(Debug, Deserialize)]
321struct Glm5Error {
322    error: Glm5ErrorDetail,
323}
324
325#[derive(Debug, Deserialize)]
326struct Glm5ErrorDetail {
327    message: String,
328    #[serde(default, rename = "type")]
329    error_type: Option<String>,
330}
331
332// ─── SSE stream types ────────────────────────────────────────────────────────
333
334#[derive(Debug, Deserialize)]
335struct Glm5StreamResponse {
336    choices: Vec<Glm5StreamChoice>,
337    #[serde(default)]
338    usage: Option<Glm5Usage>,
339}
340
341#[derive(Debug, Deserialize)]
342struct Glm5StreamChoice {
343    delta: Glm5StreamDelta,
344    #[serde(default)]
345    finish_reason: Option<String>,
346}
347
348#[derive(Debug, Deserialize)]
349struct Glm5StreamDelta {
350    #[serde(default)]
351    content: Option<String>,
352    #[serde(default)]
353    tool_calls: Option<Vec<Glm5StreamToolCall>>,
354}
355
356#[derive(Debug, Deserialize)]
357struct Glm5StreamToolCall {
358    #[serde(default)]
359    id: Option<String>,
360    function: Option<Glm5StreamFunction>,
361}
362
363#[derive(Debug, Deserialize)]
364struct Glm5StreamFunction {
365    #[serde(default)]
366    name: Option<String>,
367    #[serde(default)]
368    arguments: Option<Value>,
369}
370
371// ─── Provider impl ───────────────────────────────────────────────────────────
372
373#[async_trait]
374impl Provider for Glm5Provider {
375    fn name(&self) -> &str {
376        "glm5"
377    }
378
379    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
380        Ok(vec![
381            ModelInfo {
382                id: "glm-5-fp8".to_string(),
383                name: "GLM-5 FP8 (744B MoE, 40B active)".to_string(),
384                provider: "glm5".to_string(),
385                context_window: 128_000,
386                max_output_tokens: Some(16_384),
387                supports_vision: false,
388                supports_tools: true,
389                supports_streaming: true,
390                input_cost_per_million: None,
391                output_cost_per_million: None,
392            },
393            ModelInfo {
394                id: "glm-5".to_string(),
395                name: "GLM-5 (alias → FP8)".to_string(),
396                provider: "glm5".to_string(),
397                context_window: 128_000,
398                max_output_tokens: Some(16_384),
399                supports_vision: false,
400                supports_tools: true,
401                supports_streaming: true,
402                input_cost_per_million: None,
403                output_cost_per_million: None,
404            },
405        ])
406    }
407
408    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
409        let messages = Self::convert_messages(&request.messages);
410        let tools = Self::convert_tools(&request.tools);
411
412        // GLM-5 performs well at temperature 1.0
413        let temperature = request.temperature.unwrap_or(1.0);
414        // Resolve model alias
415        let model = Self::normalize_model(&request.model);
416
417        let mut body = json!({
418            "model": model,
419            "messages": messages,
420            "temperature": temperature,
421        });
422
423        if !tools.is_empty() {
424            body["tools"] = json!(tools);
425        }
426        if let Some(max) = request.max_tokens {
427            body["max_tokens"] = json!(max);
428        }
429
430        tracing::debug!(model = %model, endpoint = %self.base_url, "GLM-5 FP8 request");
431
432        let response = self
433            .client
434            .post(format!("{}/chat/completions", self.base_url))
435            .header("Authorization", format!("Bearer {}", self.api_key))
436            .header("Content-Type", "application/json")
437            .json(&body)
438            .send()
439            .await
440            .context("Failed to send request to GLM-5 FP8 endpoint")?;
441
442        let status = response.status();
443        let text = response
444            .text()
445            .await
446            .context("Failed to read GLM-5 FP8 response")?;
447
448        if !status.is_success() {
449            if let Ok(err) = serde_json::from_str::<Glm5Error>(&text) {
450                anyhow::bail!(
451                    "GLM-5 FP8 API error: {} ({:?})",
452                    err.error.message,
453                    err.error.error_type
454                );
455            }
456            anyhow::bail!("GLM-5 FP8 API error: {} {}", status, text);
457        }
458
459        let parsed: Glm5Response = serde_json::from_str(&text).context(format!(
460            "Failed to parse GLM-5 FP8 response: {}",
461            Self::preview_text(&text, 200)
462        ))?;
463
464        let choice = parsed
465            .choices
466            .first()
467            .ok_or_else(|| anyhow::anyhow!("No choices in GLM-5 FP8 response"))?;
468
469        let mut content = Vec::new();
470        let mut has_tool_calls = false;
471
472        if let Some(text) = &choice.message.content
473            && !text.is_empty()
474        {
475            content.push(ContentPart::Text { text: text.clone() });
476        }
477
478        if let Some(tool_calls) = &choice.message.tool_calls {
479            has_tool_calls = !tool_calls.is_empty();
480            for tc in tool_calls {
481                let arguments = match &tc.function.arguments {
482                    Value::String(s) => s.clone(),
483                    other => serde_json::to_string(other).unwrap_or_default(),
484                };
485                content.push(ContentPart::ToolCall {
486                    id: tc.id.clone(),
487                    name: tc.function.name.clone(),
488                    arguments,
489                    thought_signature: None,
490                });
491            }
492        }
493
494        let finish_reason = if has_tool_calls {
495            FinishReason::ToolCalls
496        } else {
497            match choice.finish_reason.as_deref() {
498                Some("stop") => FinishReason::Stop,
499                Some("length") => FinishReason::Length,
500                Some("tool_calls") => FinishReason::ToolCalls,
501                _ => FinishReason::Stop,
502            }
503        };
504
505        Ok(CompletionResponse {
506            message: Message {
507                role: Role::Assistant,
508                content,
509            },
510            usage: parsed
511                .usage
512                .as_ref()
513                .map(Glm5Usage::to_usage)
514                .unwrap_or_default(),
515            finish_reason,
516        })
517    }
518
519    async fn complete_stream(
520        &self,
521        request: CompletionRequest,
522    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
523        let messages = Self::convert_messages(&request.messages);
524        let tools = Self::convert_tools(&request.tools);
525
526        let temperature = request.temperature.unwrap_or(1.0);
527        let model = Self::normalize_model(&request.model);
528
529        let mut body = json!({
530            "model": model,
531            "messages": messages,
532            "temperature": temperature,
533            "stream": true,
534            "stream_options": { "include_usage": true },
535        });
536
537        if !tools.is_empty() {
538            body["tools"] = json!(tools);
539        }
540        if let Some(max) = request.max_tokens {
541            body["max_tokens"] = json!(max);
542        }
543
544        tracing::debug!(
545            model = %model,
546            endpoint = %self.base_url,
547            "GLM-5 FP8 streaming request"
548        );
549
550        let response = self
551            .client
552            .post(format!("{}/chat/completions", self.base_url))
553            .header("Authorization", format!("Bearer {}", self.api_key))
554            .header("Content-Type", "application/json")
555            .json(&body)
556            .send()
557            .await
558            .context("Failed to send streaming request to GLM-5 FP8 endpoint")?;
559
560        if !response.status().is_success() {
561            let status = response.status();
562            let text = response.text().await.unwrap_or_default();
563            if let Ok(err) = serde_json::from_str::<Glm5Error>(&text) {
564                anyhow::bail!(
565                    "GLM-5 FP8 API error: {} ({:?})",
566                    err.error.message,
567                    err.error.error_type
568                );
569            }
570            anyhow::bail!("GLM-5 FP8 streaming error: {} {}", status, text);
571        }
572
573        let stream = response.bytes_stream();
574        let mut buffer = String::new();
575
576        Ok(stream
577            .flat_map(move |chunk_result| {
578                let mut chunks: Vec<StreamChunk> = Vec::new();
579                match chunk_result {
580                    Ok(bytes) => {
581                        let text = String::from_utf8_lossy(&bytes);
582                        buffer.push_str(&text);
583
584                        let mut text_buf = String::new();
585                        while let Some(line_end) = buffer.find('\n') {
586                            let line = buffer[..line_end].trim().to_string();
587                            buffer = buffer[line_end + 1..].to_string();
588
589                            if line == "data: [DONE]" {
590                                if !text_buf.is_empty() {
591                                    chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
592                                }
593                                chunks.push(StreamChunk::Done { usage: None });
594                                continue;
595                            }
596
597                            if let Some(data) = line.strip_prefix("data: ")
598                                && let Ok(parsed) = serde_json::from_str::<Glm5StreamResponse>(data)
599                            {
600                                // Capture usage from the final chunk (stream_options)
601                                let usage = parsed.usage.as_ref().map(Glm5Usage::to_usage);
602
603                                if let Some(choice) = parsed.choices.first() {
604                                    if let Some(ref content) = choice.delta.content {
605                                        text_buf.push_str(content);
606                                    }
607
608                                    // Streaming tool calls
609                                    if let Some(ref tool_calls) = choice.delta.tool_calls {
610                                        if !text_buf.is_empty() {
611                                            chunks.push(StreamChunk::Text(std::mem::take(
612                                                &mut text_buf,
613                                            )));
614                                        }
615                                        for tc in tool_calls {
616                                            if let Some(ref func) = tc.function {
617                                                if let Some(ref name) = func.name {
618                                                    chunks.push(StreamChunk::ToolCallStart {
619                                                        id: tc.id.clone().unwrap_or_default(),
620                                                        name: name.clone(),
621                                                    });
622                                                }
623                                                if let Some(ref args) = func.arguments {
624                                                    let delta = match args {
625                                                        Value::String(s) => s.clone(),
626                                                        other => serde_json::to_string(other)
627                                                            .unwrap_or_default(),
628                                                    };
629                                                    if !delta.is_empty() {
630                                                        chunks.push(StreamChunk::ToolCallDelta {
631                                                            id: tc.id.clone().unwrap_or_default(),
632                                                            arguments_delta: delta,
633                                                        });
634                                                    }
635                                                }
636                                            }
637                                        }
638                                    }
639
640                                    if let Some(ref reason) = choice.finish_reason {
641                                        if !text_buf.is_empty() {
642                                            chunks.push(StreamChunk::Text(std::mem::take(
643                                                &mut text_buf,
644                                            )));
645                                        }
646                                        if reason == "tool_calls"
647                                            && let Some(ref tcs) = choice.delta.tool_calls
648                                            && let Some(tc) = tcs.last()
649                                        {
650                                            chunks.push(StreamChunk::ToolCallEnd {
651                                                id: tc.id.clone().unwrap_or_default(),
652                                            });
653                                        }
654                                        chunks.push(StreamChunk::Done { usage });
655                                    }
656                                } else if usage.is_some() {
657                                    // Usage-only chunk (empty choices)
658                                    if !text_buf.is_empty() {
659                                        chunks
660                                            .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
661                                    }
662                                    chunks.push(StreamChunk::Done { usage });
663                                }
664                            }
665                        }
666                        if !text_buf.is_empty() {
667                            chunks.push(StreamChunk::Text(text_buf));
668                        }
669                    }
670                    Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
671                }
672                futures::stream::iter(chunks)
673            })
674            .boxed())
675    }
676}
677
678// ─── Tests ───────────────────────────────────────────────────────────────────
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683
684    #[test]
685    fn normalize_model_strips_prefix() {
686        assert_eq!(Glm5Provider::normalize_model("glm5/glm-5-fp8"), "glm-5-fp8");
687        assert_eq!(Glm5Provider::normalize_model("glm5:glm-5-fp8"), "glm-5-fp8");
688        assert_eq!(Glm5Provider::normalize_model("glm-5-fp8"), "glm-5-fp8");
689    }
690
691    #[test]
692    fn normalize_model_aliases_glm5_to_fp8() {
693        assert_eq!(Glm5Provider::normalize_model("glm-5"), DEFAULT_MODEL);
694        assert_eq!(Glm5Provider::normalize_model("glm5"), DEFAULT_MODEL);
695        assert_eq!(Glm5Provider::normalize_model(""), DEFAULT_MODEL);
696        assert_eq!(Glm5Provider::normalize_model("glm5/glm-5"), DEFAULT_MODEL);
697        assert_eq!(Glm5Provider::normalize_model("glm5:glm-5"), DEFAULT_MODEL);
698    }
699
700    #[test]
701    fn normalize_model_preserves_other_variants() {
702        assert_eq!(
703            Glm5Provider::normalize_model("glm5/glm-5-int4"),
704            "glm-5-int4"
705        );
706    }
707
708    #[test]
709    fn convert_messages_serializes_tool_arguments_as_json_string() {
710        let messages = vec![Message {
711            role: Role::Assistant,
712            content: vec![ContentPart::ToolCall {
713                id: "call_1".to_string(),
714                name: "bash".to_string(),
715                arguments: "{\"command\":\"ls\"}".to_string(),
716                thought_signature: None,
717            }],
718        }];
719
720        let converted = Glm5Provider::convert_messages(&messages);
721        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
722            .as_str()
723            .expect("arguments must be a string");
724
725        assert_eq!(args, r#"{"command":"ls"}"#);
726    }
727
728    #[test]
729    fn convert_messages_wraps_invalid_tool_arguments() {
730        let messages = vec![Message {
731            role: Role::Assistant,
732            content: vec![ContentPart::ToolCall {
733                id: "call_1".to_string(),
734                name: "bash".to_string(),
735                arguments: "command=ls".to_string(),
736                thought_signature: None,
737            }],
738        }];
739
740        let converted = Glm5Provider::convert_messages(&messages);
741        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
742            .as_str()
743            .expect("arguments must be a string");
744        let parsed: Value =
745            serde_json::from_str(args).expect("wrapped arguments must be valid JSON");
746
747        assert_eq!(parsed, json!({"input": "command=ls"}));
748    }
749}