Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    OpenRouter,
12    Ollama,
13    XAI,
14    ZAI,
15    Moonshot,
16    HuggingFace,
17    Minimax,
18}
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
21pub struct Usage {
22    pub prompt_tokens: u32,
23    pub completion_tokens: u32,
24    pub total_tokens: u32,
25    pub cached_prompt_tokens: Option<u32>,
26    pub cache_creation_tokens: Option<u32>,
27    pub cache_read_tokens: Option<u32>,
28}
29
30impl Usage {
31    #[inline]
32    pub fn cache_hit_rate(&self) -> Option<f64> {
33        let read = self.cache_read_tokens? as f64;
34        let creation = self.cache_creation_tokens? as f64;
35        let total = read + creation;
36        if total > 0.0 {
37            Some((read / total) * 100.0)
38        } else {
39            None
40        }
41    }
42
43    #[inline]
44    pub fn is_cache_hit(&self) -> Option<bool> {
45        Some(self.cache_read_tokens? > 0)
46    }
47
48    #[inline]
49    pub fn is_cache_miss(&self) -> Option<bool> {
50        Some(self.cache_creation_tokens? > 0 && self.cache_read_tokens? == 0)
51    }
52
53    #[inline]
54    pub fn total_cache_tokens(&self) -> u32 {
55        let read = self.cache_read_tokens.unwrap_or(0);
56        let creation = self.cache_creation_tokens.unwrap_or(0);
57        read + creation
58    }
59
60    #[inline]
61    pub fn cache_savings_ratio(&self) -> Option<f64> {
62        let read = self.cache_read_tokens? as f64;
63        let prompt = self.prompt_tokens as f64;
64        if prompt > 0.0 {
65            Some(read / prompt)
66        } else {
67            None
68        }
69    }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
73pub enum FinishReason {
74    #[default]
75    Stop,
76    Length,
77    ToolCalls,
78    ContentFilter,
79    Pause,
80    Refusal,
81    Error(String),
82}
83
84/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
85#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct ToolCall {
87    /// Unique identifier for this tool call (e.g., "call_123")
88    pub id: String,
89
90    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
91    #[serde(rename = "type")]
92    pub call_type: String,
93
94    /// Function call details (for function-type tools)
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub function: Option<FunctionCall>,
97
98    /// Raw text payload (for custom freeform tools in GPT-5)
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub text: Option<String>,
101
102    /// Gemini-specific thought signature for maintaining reasoning context
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub thought_signature: Option<String>,
105}
106
107/// Function call within a tool call
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109pub struct FunctionCall {
110    /// The name of the function to call
111    pub name: String,
112
113    /// The arguments to pass to the function, as a JSON string
114    pub arguments: String,
115}
116
117impl ToolCall {
118    /// Create a new function tool call
119    pub fn function(id: String, name: String, arguments: String) -> Self {
120        Self {
121            id,
122            call_type: "function".to_owned(),
123            function: Some(FunctionCall { name, arguments }),
124            text: None,
125            thought_signature: None,
126        }
127    }
128
129    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
130    pub fn custom(id: String, name: String, text: String) -> Self {
131        Self {
132            id,
133            call_type: "custom".to_owned(),
134            function: Some(FunctionCall {
135                name,
136                arguments: text.clone(),
137            }),
138            text: Some(text),
139            thought_signature: None,
140        }
141    }
142
143    /// Parse the arguments as JSON Value (for function-type tools)
144    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
145        if let Some(ref func) = self.function {
146            parse_tool_arguments(&func.arguments)
147        } else {
148            // Return an error by trying to parse invalid JSON
149            serde_json::from_str("")
150        }
151    }
152
153    /// Validate that this tool call is properly formed
154    pub fn validate(&self) -> Result<(), String> {
155        if self.id.is_empty() {
156            return Err("Tool call ID cannot be empty".to_owned());
157        }
158
159        match self.call_type.as_str() {
160            "function" => {
161                if let Some(func) = &self.function {
162                    if func.name.is_empty() {
163                        return Err("Function name cannot be empty".to_owned());
164                    }
165                    // Validate that arguments is valid JSON for function tools
166                    if let Err(e) = self.parsed_arguments() {
167                        return Err(format!("Invalid JSON in function arguments: {}", e));
168                    }
169                } else {
170                    return Err("Function tool call missing function details".to_owned());
171                }
172            }
173            "custom" => {
174                // For custom tools, we allow raw text payload without JSON validation
175                if let Some(func) = &self.function {
176                    if func.name.is_empty() {
177                        return Err("Custom tool name cannot be empty".to_owned());
178                    }
179                } else {
180                    return Err("Custom tool call missing function details".to_owned());
181                }
182            }
183            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
184        }
185
186        Ok(())
187    }
188}
189
190fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
191    let trimmed = raw_arguments.trim();
192    match serde_json::from_str(trimmed) {
193        Ok(parsed) => Ok(parsed),
194        Err(primary_error) => {
195            if let Some(candidate) = extract_balanced_json(trimmed)
196                && let Ok(parsed) = serde_json::from_str(candidate)
197            {
198                return Ok(parsed);
199            }
200            Err(primary_error)
201        }
202    }
203}
204
205fn extract_balanced_json(input: &str) -> Option<&str> {
206    let start = input.find(['{', '['])?;
207    let opening = input.as_bytes().get(start).copied()?;
208    let closing = match opening {
209        b'{' => b'}',
210        b'[' => b']',
211        _ => return None,
212    };
213
214    let mut depth = 0usize;
215    let mut in_string = false;
216    let mut escaped = false;
217
218    for (offset, ch) in input[start..].char_indices() {
219        if in_string {
220            if escaped {
221                escaped = false;
222                continue;
223            }
224            if ch == '\\' {
225                escaped = true;
226                continue;
227            }
228            if ch == '"' {
229                in_string = false;
230            }
231            continue;
232        }
233
234        match ch {
235            '"' => in_string = true,
236            _ if ch as u32 == opening as u32 => depth += 1,
237            _ if ch as u32 == closing as u32 => {
238                depth = depth.saturating_sub(1);
239                if depth == 0 {
240                    let end = start + offset + ch.len_utf8();
241                    return input.get(start..end);
242                }
243            }
244            _ => {}
245        }
246    }
247
248    None
249}
250
251/// Universal LLM response structure
252#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
253pub struct LLMResponse {
254    /// The response content text
255    pub content: Option<String>,
256
257    /// Tool calls made by the model
258    pub tool_calls: Option<Vec<ToolCall>>,
259
260    /// The model that generated this response
261    pub model: String,
262
263    /// Token usage statistics
264    pub usage: Option<Usage>,
265
266    /// Why the response finished
267    pub finish_reason: FinishReason,
268
269    /// Reasoning content (for models that support it)
270    pub reasoning: Option<String>,
271
272    /// Detailed reasoning traces (for models that support it)
273    pub reasoning_details: Option<Vec<String>>,
274
275    /// Tool references for context
276    pub tool_references: Vec<String>,
277
278    /// Request ID from the provider
279    pub request_id: Option<String>,
280
281    /// Organization ID from the provider
282    pub organization_id: Option<String>,
283}
284
285impl LLMResponse {
286    /// Create a new LLM response with mandatory fields
287    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
288        Self {
289            content: Some(content.into()),
290            tool_calls: None,
291            model: model.into(),
292            usage: None,
293            finish_reason: FinishReason::Stop,
294            reasoning: None,
295            reasoning_details: None,
296            tool_references: Vec::new(),
297            request_id: None,
298            organization_id: None,
299        }
300    }
301
302    /// Get content or empty string
303    pub fn content_text(&self) -> &str {
304        self.content.as_deref().unwrap_or("")
305    }
306
307    /// Get content as String (clone)
308    pub fn content_string(&self) -> String {
309        self.content.clone().unwrap_or_default()
310    }
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
314pub struct LLMErrorMetadata {
315    pub provider: Option<String>,
316    pub status: Option<u16>,
317    pub code: Option<String>,
318    pub request_id: Option<String>,
319    pub organization_id: Option<String>,
320    pub retry_after: Option<String>,
321    pub message: Option<String>,
322}
323
324impl LLMErrorMetadata {
325    pub fn new(
326        provider: impl Into<String>,
327        status: Option<u16>,
328        code: Option<String>,
329        request_id: Option<String>,
330        organization_id: Option<String>,
331        retry_after: Option<String>,
332        message: Option<String>,
333    ) -> Box<Self> {
334        Box::new(Self {
335            provider: Some(provider.into()),
336            status,
337            code,
338            request_id,
339            organization_id,
340            retry_after,
341            message,
342        })
343    }
344}
345
346/// LLM error types with optional provider metadata
347#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
348#[serde(tag = "type", rename_all = "snake_case")]
349pub enum LLMError {
350    #[error("Authentication failed: {message}")]
351    Authentication {
352        message: String,
353        metadata: Option<Box<LLMErrorMetadata>>,
354    },
355    #[error("Rate limit exceeded")]
356    RateLimit {
357        metadata: Option<Box<LLMErrorMetadata>>,
358    },
359    #[error("Invalid request: {message}")]
360    InvalidRequest {
361        message: String,
362        metadata: Option<Box<LLMErrorMetadata>>,
363    },
364    #[error("Network error: {message}")]
365    Network {
366        message: String,
367        metadata: Option<Box<LLMErrorMetadata>>,
368    },
369    #[error("Provider error: {message}")]
370    Provider {
371        message: String,
372        metadata: Option<Box<LLMErrorMetadata>>,
373    },
374}
375
376#[cfg(test)]
377mod tests {
378    use super::ToolCall;
379    use serde_json::json;
380
381    #[test]
382    fn parsed_arguments_accepts_trailing_characters() {
383        let call = ToolCall::function(
384            "call_read".to_string(),
385            "read_file".to_string(),
386            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
387        );
388
389        let parsed = call
390            .parsed_arguments()
391            .expect("arguments with trailing text should recover");
392        assert_eq!(parsed, json!({"path":"src/main.rs"}));
393    }
394
395    #[test]
396    fn parsed_arguments_accepts_code_fenced_json() {
397        let call = ToolCall::function(
398            "call_read".to_string(),
399            "read_file".to_string(),
400            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
401        );
402
403        let parsed = call
404            .parsed_arguments()
405            .expect("code-fenced arguments should recover");
406        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
407    }
408
409    #[test]
410    fn parsed_arguments_rejects_incomplete_json() {
411        let call = ToolCall::function(
412            "call_read".to_string(),
413            "read_file".to_string(),
414            r#"{"path":"src/main.rs""#.to_string(),
415        );
416
417        assert!(call.parsed_arguments().is_err());
418    }
419}