Skip to main content

codetether_agent/provider/
glm5.rs

1//! GLM-5 FP8 provider for Vast.ai serverless deployments
2//!
3//! This provider connects to a self-hosted vLLM endpoint running GLM-5 FP8
4//! on Vast.ai serverless infrastructure. The endpoint exposes an OpenAI-compatible
5//! chat completions API (vLLM default).
6//!
7//! ## Model: GLM-5-FP8
8//!
9//! - **Architecture**: 744B parameter Mixture-of-Experts (MoE)
10//! - **Active Parameters**: 40B per forward pass
11//! - **Quantization**: FP8 for efficient inference
12//! - **Hardware**: 8x A100 SXM4 80GB
13//! - **Features**: MTP speculative decoding enabled
14//!
15//! ## Configuration (Vault)
16//!
17//! Store under `secret/data/codetether/providers/glm5`:
18//! ```json
19//! {
20//!   "api_key": "<vast-endpoint-api-key>",
21//!   "base_url": "https://route.vast.ai/<endpoint-id>/<api-key>/v1",
22//!   "extra": {
23//!     "model_name": "glm-5-fp8"
24//!   }
25//! }
26//! ```
27//!
28//! ## Model reference format
29//!
30//! Use `glm5/glm-5-fp8`, `glm5/glm-5`, or just `glm-5-fp8` as the model string.
31//!
32//! ## Environment variable fallback
33//!
34//! - `GLM5_API_KEY` — API key for the Vast.ai endpoint
35//! - `GLM5_BASE_URL` — Base URL of the vLLM endpoint (required)
36//! - `GLM5_MODEL` — Model name override (default: glm-5-fp8)
37
38use super::{
39    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
40    Role, StreamChunk, ToolDefinition, Usage,
41};
42use anyhow::{Context, Result};
43use async_trait::async_trait;
44use futures::StreamExt;
45use reqwest::Client;
46use serde::Deserialize;
47use serde_json::{Value, json};
48
49/// Default model name served by the Vast.ai vLLM endpoint.
50pub const DEFAULT_MODEL: &str = "glm-5-fp8";
51
52/// GLM-5 FP8 provider targeting a Vast.ai vLLM serverless endpoint.
53pub struct Glm5Provider {
54    client: Client,
55    api_key: String,
56    base_url: String,
57    /// The actual model name to send in API requests (e.g. "glm-5-fp8").
58    model_name: String,
59}
60
61impl std::fmt::Debug for Glm5Provider {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("Glm5Provider")
64            .field("base_url", &self.base_url)
65            .field("model_name", &self.model_name)
66            .field("api_key", &"<REDACTED>")
67            .finish()
68    }
69}
70
71impl Glm5Provider {
72    /// Create a new Glm5Provider pointing at `base_url` with `api_key`.
73    ///
74    /// `base_url` should be the full base URL including `/v1`, e.g.
75    /// `https://<vast-endpoint>.vast.ai/v1`.
76    #[allow(dead_code)]
77    pub fn new(api_key: String, base_url: String) -> Result<Self> {
78        Self::with_model(api_key, base_url, DEFAULT_MODEL.to_string())
79    }
80
81    /// Create with an explicit model name override.
82    pub fn with_model(api_key: String, base_url: String, model_name: String) -> Result<Self> {
83        let base_url = base_url.trim_end_matches('/').to_string();
84        tracing::debug!(
85            provider = "glm5",
86            base_url = %base_url,
87            model = %model_name,
88            "Creating GLM-5 FP8 provider"
89        );
90        Ok(Self {
91            client: Client::new(),
92            api_key,
93            base_url,
94            model_name,
95        })
96    }
97
98    /// Normalize a model string: strip provider prefix (`glm5/`, `glm5:`) and
99    /// map aliases to the canonical model name served by vLLM.
100    ///
101    /// Examples:
102    /// - `"glm5/glm-5-fp8"` → `"glm-5-fp8"`
103    /// - `"glm5:glm-5-fp8"` → `"glm-5-fp8"`
104    /// - `"glm-5-fp8"`      → `"glm-5-fp8"`
105    /// - `"glm-5"`          → `"glm-5-fp8"` (alias to fp8 variant)
106    /// - `""`               → `DEFAULT_MODEL`
107    pub fn normalize_model(model: &str) -> String {
108        let stripped = model
109            .trim()
110            .trim_start_matches("glm5/")
111            .trim_start_matches("glm5:");
112
113        match stripped {
114            "" | "glm-5" | "glm5" => DEFAULT_MODEL.to_string(),
115            other => other.to_string(),
116        }
117    }
118
119    fn convert_messages(messages: &[Message]) -> Vec<Value> {
120        messages
121            .iter()
122            .map(|msg| {
123                let role = match msg.role {
124                    Role::System => "system",
125                    Role::User => "user",
126                    Role::Assistant => "assistant",
127                    Role::Tool => "tool",
128                };
129
130                match msg.role {
131                    Role::Tool => {
132                        if let Some(ContentPart::ToolResult {
133                            tool_call_id,
134                            content,
135                        }) = msg.content.first()
136                        {
137                            json!({
138                                "role": "tool",
139                                "tool_call_id": tool_call_id,
140                                "content": content
141                            })
142                        } else {
143                            json!({"role": role, "content": ""})
144                        }
145                    }
146                    Role::Assistant => {
147                        let text: String = msg
148                            .content
149                            .iter()
150                            .filter_map(|p| match p {
151                                ContentPart::Text { text } => Some(text.clone()),
152                                _ => None,
153                            })
154                            .collect::<Vec<_>>()
155                            .join("");
156
157                        let tool_calls: Vec<Value> = msg
158                            .content
159                            .iter()
160                            .filter_map(|p| match p {
161                                ContentPart::ToolCall {
162                                    id,
163                                    name,
164                                    arguments,
165                                    ..
166                                } => {
167                                    // vLLM expects arguments as a JSON string
168                                    let args_string = serde_json::from_str::<Value>(arguments)
169                                        .map(|parsed| {
170                                            serde_json::to_string(&parsed)
171                                                .unwrap_or_else(|_| "{}".to_string())
172                                        })
173                                        .unwrap_or_else(|_| {
174                                            json!({"input": arguments}).to_string()
175                                        });
176                                    Some(json!({
177                                        "id": id,
178                                        "type": "function",
179                                        "function": {
180                                            "name": name,
181                                            "arguments": args_string
182                                        }
183                                    }))
184                                }
185                                _ => None,
186                            })
187                            .collect();
188
189                        let mut msg_json = json!({
190                            "role": "assistant",
191                            "content": if text.is_empty() { Value::Null } else { json!(text) },
192                        });
193                        if !tool_calls.is_empty() {
194                            msg_json["tool_calls"] = json!(tool_calls);
195                        }
196                        msg_json
197                    }
198                    _ => {
199                        let text: String = msg
200                            .content
201                            .iter()
202                            .filter_map(|p| match p {
203                                ContentPart::Text { text } => Some(text.clone()),
204                                _ => None,
205                            })
206                            .collect::<Vec<_>>()
207                            .join("\n");
208
209                        json!({"role": role, "content": text})
210                    }
211                }
212            })
213            .collect()
214    }
215
216    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
217        tools
218            .iter()
219            .map(|t| {
220                json!({
221                    "type": "function",
222                    "function": {
223                        "name": t.name,
224                        "description": t.description,
225                        "parameters": t.parameters
226                    }
227                })
228            })
229            .collect()
230    }
231
232    fn preview_text(text: &str, max_chars: usize) -> &str {
233        if max_chars == 0 {
234            return "";
235        }
236        if let Some((idx, _)) = text.char_indices().nth(max_chars) {
237            &text[..idx]
238        } else {
239            text
240        }
241    }
242}
243
244// ─── Response deserialization ────────────────────────────────────────────────
245
246#[derive(Debug, Deserialize)]
247struct Glm5Response {
248    choices: Vec<Glm5Choice>,
249    #[serde(default)]
250    usage: Option<Glm5Usage>,
251}
252
253#[derive(Debug, Deserialize)]
254struct Glm5Choice {
255    message: Glm5Message,
256    #[serde(default)]
257    finish_reason: Option<String>,
258}
259
260#[derive(Debug, Deserialize)]
261struct Glm5Message {
262    #[serde(default)]
263    content: Option<String>,
264    #[serde(default)]
265    tool_calls: Option<Vec<Glm5ToolCall>>,
266}
267
268#[derive(Debug, Deserialize)]
269struct Glm5ToolCall {
270    id: String,
271    function: Glm5Function,
272}
273
274#[derive(Debug, Deserialize)]
275struct Glm5Function {
276    name: String,
277    arguments: Value,
278}
279
280#[derive(Debug, Deserialize)]
281struct Glm5Usage {
282    #[serde(default)]
283    prompt_tokens: usize,
284    #[serde(default)]
285    completion_tokens: usize,
286    #[serde(default)]
287    total_tokens: usize,
288}
289
290#[derive(Debug, Deserialize)]
291struct Glm5Error {
292    error: Glm5ErrorDetail,
293}
294
295#[derive(Debug, Deserialize)]
296struct Glm5ErrorDetail {
297    message: String,
298    #[serde(default, rename = "type")]
299    error_type: Option<String>,
300}
301
302// ─── SSE stream types ────────────────────────────────────────────────────────
303
304#[derive(Debug, Deserialize)]
305struct Glm5StreamResponse {
306    choices: Vec<Glm5StreamChoice>,
307    #[serde(default)]
308    usage: Option<Glm5Usage>,
309}
310
311#[derive(Debug, Deserialize)]
312struct Glm5StreamChoice {
313    delta: Glm5StreamDelta,
314    #[serde(default)]
315    finish_reason: Option<String>,
316}
317
318#[derive(Debug, Deserialize)]
319struct Glm5StreamDelta {
320    #[serde(default)]
321    content: Option<String>,
322    #[serde(default)]
323    tool_calls: Option<Vec<Glm5StreamToolCall>>,
324}
325
326#[derive(Debug, Deserialize)]
327struct Glm5StreamToolCall {
328    #[serde(default)]
329    id: Option<String>,
330    function: Option<Glm5StreamFunction>,
331}
332
333#[derive(Debug, Deserialize)]
334struct Glm5StreamFunction {
335    #[serde(default)]
336    name: Option<String>,
337    #[serde(default)]
338    arguments: Option<Value>,
339}
340
341// ─── Provider impl ───────────────────────────────────────────────────────────
342
343#[async_trait]
344impl Provider for Glm5Provider {
345    fn name(&self) -> &str {
346        "glm5"
347    }
348
349    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
350        Ok(vec![
351            ModelInfo {
352                id: "glm-5-fp8".to_string(),
353                name: "GLM-5 FP8 (744B MoE, 40B active)".to_string(),
354                provider: "glm5".to_string(),
355                context_window: 128_000,
356                max_output_tokens: Some(16_384),
357                supports_vision: false,
358                supports_tools: true,
359                supports_streaming: true,
360                input_cost_per_million: None,
361                output_cost_per_million: None,
362            },
363            ModelInfo {
364                id: "glm-5".to_string(),
365                name: "GLM-5 (alias → FP8)".to_string(),
366                provider: "glm5".to_string(),
367                context_window: 128_000,
368                max_output_tokens: Some(16_384),
369                supports_vision: false,
370                supports_tools: true,
371                supports_streaming: true,
372                input_cost_per_million: None,
373                output_cost_per_million: None,
374            },
375        ])
376    }
377
378    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
379        let messages = Self::convert_messages(&request.messages);
380        let tools = Self::convert_tools(&request.tools);
381
382        // GLM-5 performs well at temperature 1.0
383        let temperature = request.temperature.unwrap_or(1.0);
384        // Resolve model alias
385        let model = Self::normalize_model(&request.model);
386
387        let mut body = json!({
388            "model": model,
389            "messages": messages,
390            "temperature": temperature,
391        });
392
393        if !tools.is_empty() {
394            body["tools"] = json!(tools);
395        }
396        if let Some(max) = request.max_tokens {
397            body["max_tokens"] = json!(max);
398        }
399
400        tracing::debug!(model = %model, endpoint = %self.base_url, "GLM-5 FP8 request");
401
402        let response = self
403            .client
404            .post(format!("{}/chat/completions", self.base_url))
405            .header("Authorization", format!("Bearer {}", self.api_key))
406            .header("Content-Type", "application/json")
407            .json(&body)
408            .send()
409            .await
410            .context("Failed to send request to GLM-5 FP8 endpoint")?;
411
412        let status = response.status();
413        let text = response
414            .text()
415            .await
416            .context("Failed to read GLM-5 FP8 response")?;
417
418        if !status.is_success() {
419            if let Ok(err) = serde_json::from_str::<Glm5Error>(&text) {
420                anyhow::bail!(
421                    "GLM-5 FP8 API error: {} ({:?})",
422                    err.error.message,
423                    err.error.error_type
424                );
425            }
426            anyhow::bail!("GLM-5 FP8 API error: {} {}", status, text);
427        }
428
429        let parsed: Glm5Response = serde_json::from_str(&text).context(format!(
430            "Failed to parse GLM-5 FP8 response: {}",
431            Self::preview_text(&text, 200)
432        ))?;
433
434        let choice = parsed
435            .choices
436            .first()
437            .ok_or_else(|| anyhow::anyhow!("No choices in GLM-5 FP8 response"))?;
438
439        let mut content = Vec::new();
440        let mut has_tool_calls = false;
441
442        if let Some(text) = &choice.message.content
443            && !text.is_empty()
444        {
445            content.push(ContentPart::Text { text: text.clone() });
446        }
447
448        if let Some(tool_calls) = &choice.message.tool_calls {
449            has_tool_calls = !tool_calls.is_empty();
450            for tc in tool_calls {
451                let arguments = match &tc.function.arguments {
452                    Value::String(s) => s.clone(),
453                    other => serde_json::to_string(other).unwrap_or_default(),
454                };
455                content.push(ContentPart::ToolCall {
456                    id: tc.id.clone(),
457                    name: tc.function.name.clone(),
458                    arguments,
459                    thought_signature: None,
460                });
461            }
462        }
463
464        let finish_reason = if has_tool_calls {
465            FinishReason::ToolCalls
466        } else {
467            match choice.finish_reason.as_deref() {
468                Some("stop") => FinishReason::Stop,
469                Some("length") => FinishReason::Length,
470                Some("tool_calls") => FinishReason::ToolCalls,
471                _ => FinishReason::Stop,
472            }
473        };
474
475        Ok(CompletionResponse {
476            message: Message {
477                role: Role::Assistant,
478                content,
479            },
480            usage: Usage {
481                prompt_tokens: parsed.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
482                completion_tokens: parsed
483                    .usage
484                    .as_ref()
485                    .map(|u| u.completion_tokens)
486                    .unwrap_or(0),
487                total_tokens: parsed.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
488                cache_read_tokens: None,
489                cache_write_tokens: None,
490            },
491            finish_reason,
492        })
493    }
494
495    async fn complete_stream(
496        &self,
497        request: CompletionRequest,
498    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
499        let messages = Self::convert_messages(&request.messages);
500        let tools = Self::convert_tools(&request.tools);
501
502        let temperature = request.temperature.unwrap_or(1.0);
503        let model = Self::normalize_model(&request.model);
504
505        let mut body = json!({
506            "model": model,
507            "messages": messages,
508            "temperature": temperature,
509            "stream": true,
510            "stream_options": { "include_usage": true },
511        });
512
513        if !tools.is_empty() {
514            body["tools"] = json!(tools);
515        }
516        if let Some(max) = request.max_tokens {
517            body["max_tokens"] = json!(max);
518        }
519
520        tracing::debug!(
521            model = %model,
522            endpoint = %self.base_url,
523            "GLM-5 FP8 streaming request"
524        );
525
526        let response = self
527            .client
528            .post(format!("{}/chat/completions", self.base_url))
529            .header("Authorization", format!("Bearer {}", self.api_key))
530            .header("Content-Type", "application/json")
531            .json(&body)
532            .send()
533            .await
534            .context("Failed to send streaming request to GLM-5 FP8 endpoint")?;
535
536        if !response.status().is_success() {
537            let status = response.status();
538            let text = response.text().await.unwrap_or_default();
539            if let Ok(err) = serde_json::from_str::<Glm5Error>(&text) {
540                anyhow::bail!(
541                    "GLM-5 FP8 API error: {} ({:?})",
542                    err.error.message,
543                    err.error.error_type
544                );
545            }
546            anyhow::bail!("GLM-5 FP8 streaming error: {} {}", status, text);
547        }
548
549        let stream = response.bytes_stream();
550        let mut buffer = String::new();
551
552        Ok(stream
553            .flat_map(move |chunk_result| {
554                let mut chunks: Vec<StreamChunk> = Vec::new();
555                match chunk_result {
556                    Ok(bytes) => {
557                        let text = String::from_utf8_lossy(&bytes);
558                        buffer.push_str(&text);
559
560                        let mut text_buf = String::new();
561                        while let Some(line_end) = buffer.find('\n') {
562                            let line = buffer[..line_end].trim().to_string();
563                            buffer = buffer[line_end + 1..].to_string();
564
565                            if line == "data: [DONE]" {
566                                if !text_buf.is_empty() {
567                                    chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
568                                }
569                                chunks.push(StreamChunk::Done { usage: None });
570                                continue;
571                            }
572
573                            if let Some(data) = line.strip_prefix("data: ")
574                                && let Ok(parsed) = serde_json::from_str::<Glm5StreamResponse>(data)
575                            {
576                                // Capture usage from the final chunk (stream_options)
577                                let usage = parsed.usage.as_ref().map(|u| Usage {
578                                    prompt_tokens: u.prompt_tokens,
579                                    completion_tokens: u.completion_tokens,
580                                    total_tokens: u.total_tokens,
581                                    cache_read_tokens: None,
582                                    cache_write_tokens: None,
583                                });
584
585                                if let Some(choice) = parsed.choices.first() {
586                                    if let Some(ref content) = choice.delta.content {
587                                        text_buf.push_str(content);
588                                    }
589
590                                    // Streaming tool calls
591                                    if let Some(ref tool_calls) = choice.delta.tool_calls {
592                                        if !text_buf.is_empty() {
593                                            chunks.push(StreamChunk::Text(std::mem::take(
594                                                &mut text_buf,
595                                            )));
596                                        }
597                                        for tc in tool_calls {
598                                            if let Some(ref func) = tc.function {
599                                                if let Some(ref name) = func.name {
600                                                    chunks.push(StreamChunk::ToolCallStart {
601                                                        id: tc.id.clone().unwrap_or_default(),
602                                                        name: name.clone(),
603                                                    });
604                                                }
605                                                if let Some(ref args) = func.arguments {
606                                                    let delta = match args {
607                                                        Value::String(s) => s.clone(),
608                                                        other => serde_json::to_string(other)
609                                                            .unwrap_or_default(),
610                                                    };
611                                                    if !delta.is_empty() {
612                                                        chunks.push(StreamChunk::ToolCallDelta {
613                                                            id: tc.id.clone().unwrap_or_default(),
614                                                            arguments_delta: delta,
615                                                        });
616                                                    }
617                                                }
618                                            }
619                                        }
620                                    }
621
622                                    if let Some(ref reason) = choice.finish_reason {
623                                        if !text_buf.is_empty() {
624                                            chunks.push(StreamChunk::Text(std::mem::take(
625                                                &mut text_buf,
626                                            )));
627                                        }
628                                        if reason == "tool_calls"
629                                            && let Some(ref tcs) = choice.delta.tool_calls
630                                            && let Some(tc) = tcs.last()
631                                        {
632                                            chunks.push(StreamChunk::ToolCallEnd {
633                                                id: tc.id.clone().unwrap_or_default(),
634                                            });
635                                        }
636                                        chunks.push(StreamChunk::Done { usage });
637                                    }
638                                } else if usage.is_some() {
639                                    // Usage-only chunk (empty choices)
640                                    if !text_buf.is_empty() {
641                                        chunks
642                                            .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
643                                    }
644                                    chunks.push(StreamChunk::Done { usage });
645                                }
646                            }
647                        }
648                        if !text_buf.is_empty() {
649                            chunks.push(StreamChunk::Text(text_buf));
650                        }
651                    }
652                    Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
653                }
654                futures::stream::iter(chunks)
655            })
656            .boxed())
657    }
658}
659
660// ─── Tests ───────────────────────────────────────────────────────────────────
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn normalize_model_strips_prefix() {
668        assert_eq!(Glm5Provider::normalize_model("glm5/glm-5-fp8"), "glm-5-fp8");
669        assert_eq!(Glm5Provider::normalize_model("glm5:glm-5-fp8"), "glm-5-fp8");
670        assert_eq!(Glm5Provider::normalize_model("glm-5-fp8"), "glm-5-fp8");
671    }
672
673    #[test]
674    fn normalize_model_aliases_glm5_to_fp8() {
675        assert_eq!(Glm5Provider::normalize_model("glm-5"), DEFAULT_MODEL);
676        assert_eq!(Glm5Provider::normalize_model("glm5"), DEFAULT_MODEL);
677        assert_eq!(Glm5Provider::normalize_model(""), DEFAULT_MODEL);
678        assert_eq!(Glm5Provider::normalize_model("glm5/glm-5"), DEFAULT_MODEL);
679        assert_eq!(Glm5Provider::normalize_model("glm5:glm-5"), DEFAULT_MODEL);
680    }
681
682    #[test]
683    fn normalize_model_preserves_other_variants() {
684        assert_eq!(
685            Glm5Provider::normalize_model("glm5/glm-5-int4"),
686            "glm-5-int4"
687        );
688    }
689
690    #[test]
691    fn convert_messages_serializes_tool_arguments_as_json_string() {
692        let messages = vec![Message {
693            role: Role::Assistant,
694            content: vec![ContentPart::ToolCall {
695                id: "call_1".to_string(),
696                name: "bash".to_string(),
697                arguments: "{\"command\":\"ls\"}".to_string(),
698                thought_signature: None,
699            }],
700        }];
701
702        let converted = Glm5Provider::convert_messages(&messages);
703        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
704            .as_str()
705            .expect("arguments must be a string");
706
707        assert_eq!(args, r#"{"command":"ls"}"#);
708    }
709
710    #[test]
711    fn convert_messages_wraps_invalid_tool_arguments() {
712        let messages = vec![Message {
713            role: Role::Assistant,
714            content: vec![ContentPart::ToolCall {
715                id: "call_1".to_string(),
716                name: "bash".to_string(),
717                arguments: "command=ls".to_string(),
718                thought_signature: None,
719            }],
720        }];
721
722        let converted = Glm5Provider::convert_messages(&messages);
723        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
724            .as_str()
725            .expect("arguments must be a string");
726        let parsed: Value =
727            serde_json::from_str(args).expect("wrapped arguments must be valid JSON");
728
729        assert_eq!(parsed, json!({"input": "command=ls"}));
730    }
731}