Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    OpenRouter,
12    Ollama,
13    ZAI,
14    Moonshot,
15    HuggingFace,
16    Minimax,
17    OpenCodeZen,
18    OpenCodeGo,
19}
20
21#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
22pub struct Usage {
23    pub prompt_tokens: u32,
24    pub completion_tokens: u32,
25    pub total_tokens: u32,
26    pub cached_prompt_tokens: Option<u32>,
27    pub cache_creation_tokens: Option<u32>,
28    pub cache_read_tokens: Option<u32>,
29}
30
31impl Usage {
32    #[inline]
33    fn has_cache_read_metric(&self) -> bool {
34        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
35    }
36
37    #[inline]
38    fn has_any_cache_metrics(&self) -> bool {
39        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
40    }
41
42    #[inline]
43    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
44        self.cache_read_tokens
45            .or(self.cached_prompt_tokens)
46            .unwrap_or(0)
47    }
48
49    #[inline]
50    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
51        self.cache_creation_tokens.unwrap_or(0)
52    }
53
54    #[inline]
55    pub fn cache_hit_rate(&self) -> Option<f64> {
56        if !self.has_any_cache_metrics() {
57            return None;
58        }
59        let read = self.cache_read_tokens_or_fallback() as f64;
60        let creation = self.cache_creation_tokens_or_zero() as f64;
61        let total = read + creation;
62        if total > 0.0 {
63            Some((read / total) * 100.0)
64        } else {
65            None
66        }
67    }
68
69    #[inline]
70    pub fn is_cache_hit(&self) -> Option<bool> {
71        self.has_any_cache_metrics()
72            .then(|| self.cache_read_tokens_or_fallback() > 0)
73    }
74
75    #[inline]
76    pub fn is_cache_miss(&self) -> Option<bool> {
77        self.has_any_cache_metrics().then(|| {
78            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
79        })
80    }
81
82    #[inline]
83    pub fn total_cache_tokens(&self) -> u32 {
84        let read = self.cache_read_tokens_or_fallback();
85        let creation = self.cache_creation_tokens_or_zero();
86        read + creation
87    }
88
89    #[inline]
90    pub fn cache_savings_ratio(&self) -> Option<f64> {
91        if !self.has_cache_read_metric() {
92            return None;
93        }
94        let read = self.cache_read_tokens_or_fallback() as f64;
95        let prompt = self.prompt_tokens as f64;
96        if prompt > 0.0 {
97            Some(read / prompt)
98        } else {
99            None
100        }
101    }
102}
103
104#[cfg(test)]
105mod usage_tests {
106    use super::Usage;
107
108    #[test]
109    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
110        let usage = Usage {
111            prompt_tokens: 1_000,
112            completion_tokens: 200,
113            total_tokens: 1_200,
114            cached_prompt_tokens: Some(600),
115            cache_creation_tokens: Some(150),
116            cache_read_tokens: None,
117        };
118
119        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
120        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
121        assert_eq!(usage.total_cache_tokens(), 750);
122        assert_eq!(usage.is_cache_hit(), Some(true));
123        assert_eq!(usage.is_cache_miss(), Some(false));
124        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
125        assert_eq!(usage.cache_hit_rate(), Some(80.0));
126    }
127
128    #[test]
129    fn cache_helpers_preserve_unknown_without_metrics() {
130        let usage = Usage {
131            prompt_tokens: 1_000,
132            completion_tokens: 200,
133            total_tokens: 1_200,
134            cached_prompt_tokens: None,
135            cache_creation_tokens: None,
136            cache_read_tokens: None,
137        };
138
139        assert_eq!(usage.total_cache_tokens(), 0);
140        assert_eq!(usage.is_cache_hit(), None);
141        assert_eq!(usage.is_cache_miss(), None);
142        assert_eq!(usage.cache_savings_ratio(), None);
143        assert_eq!(usage.cache_hit_rate(), None);
144    }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
148pub enum FinishReason {
149    #[default]
150    Stop,
151    Length,
152    ToolCalls,
153    ContentFilter,
154    Pause,
155    Refusal,
156    Error(String),
157}
158
159/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
160#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
161pub struct ToolCall {
162    /// Unique identifier for this tool call (e.g., "call_123")
163    pub id: String,
164
165    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
166    #[serde(rename = "type")]
167    pub call_type: String,
168
169    /// Function call details (for function-type tools)
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub function: Option<FunctionCall>,
172
173    /// Raw text payload (for custom freeform tools in GPT-5)
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub text: Option<String>,
176
177    /// Gemini-specific thought signature for maintaining reasoning context
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub thought_signature: Option<String>,
180}
181
182/// Function call within a tool call
183#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
184pub struct FunctionCall {
185    /// Optional namespace for grouped or deferred tools.
186    #[serde(default, skip_serializing_if = "Option::is_none")]
187    pub namespace: Option<String>,
188
189    /// The name of the function to call
190    pub name: String,
191
192    /// The arguments to pass to the function, as a JSON string
193    pub arguments: String,
194}
195
196impl ToolCall {
197    /// Create a new function tool call
198    pub fn function(id: String, name: String, arguments: String) -> Self {
199        Self::function_with_namespace(id, None, name, arguments)
200    }
201
202    /// Create a new function tool call with an optional namespace.
203    pub fn function_with_namespace(
204        id: String,
205        namespace: Option<String>,
206        name: String,
207        arguments: String,
208    ) -> Self {
209        Self {
210            id,
211            call_type: "function".to_owned(),
212            function: Some(FunctionCall {
213                namespace,
214                name,
215                arguments,
216            }),
217            text: None,
218            thought_signature: None,
219        }
220    }
221
222    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
223    pub fn custom(id: String, name: String, text: String) -> Self {
224        Self {
225            id,
226            call_type: "custom".to_owned(),
227            function: Some(FunctionCall {
228                namespace: None,
229                name,
230                arguments: text.clone(),
231            }),
232            text: Some(text),
233            thought_signature: None,
234        }
235    }
236
237    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
238    pub fn is_custom(&self) -> bool {
239        self.call_type == "custom"
240    }
241
242    /// Returns the tool name when the call includes function details.
243    pub fn tool_name(&self) -> Option<&str> {
244        self.function
245            .as_ref()
246            .map(|function| function.name.as_str())
247    }
248
249    /// Returns the raw payload text exactly as emitted by the model.
250    pub fn raw_input(&self) -> Option<&str> {
251        self.text.as_deref().or_else(|| {
252            self.function
253                .as_ref()
254                .map(|function| function.arguments.as_str())
255        })
256    }
257
258    /// Parse the arguments as JSON Value (for function-type tools)
259    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
260        if let Some(ref func) = self.function {
261            parse_tool_arguments(&func.arguments)
262        } else {
263            // Return an error by trying to parse invalid JSON
264            serde_json::from_str("")
265        }
266    }
267
268    /// Returns the execution payload for this tool call.
269    ///
270    /// Function tools keep their JSON semantics. Custom tools execute with their
271    /// raw text payload wrapped as a JSON string value so freeform inputs can
272    /// flow through the existing tool pipeline.
273    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
274        if self.is_custom() {
275            return Ok(serde_json::Value::String(
276                self.raw_input().unwrap_or_default().to_string(),
277            ));
278        }
279
280        self.parsed_arguments()
281    }
282
283    /// Validate that this tool call is properly formed
284    pub fn validate(&self) -> Result<(), String> {
285        if self.id.is_empty() {
286            return Err("Tool call ID cannot be empty".to_owned());
287        }
288
289        match self.call_type.as_str() {
290            "function" => {
291                if let Some(func) = &self.function {
292                    if func.name.is_empty() {
293                        return Err("Function name cannot be empty".to_owned());
294                    }
295                    // Validate that arguments is valid JSON for function tools
296                    if let Err(e) = self.parsed_arguments() {
297                        return Err(format!("Invalid JSON in function arguments: {}", e));
298                    }
299                } else {
300                    return Err("Function tool call missing function details".to_owned());
301                }
302            }
303            "custom" => {
304                // For custom tools, we allow raw text payload without JSON validation
305                if let Some(func) = &self.function {
306                    if func.name.is_empty() {
307                        return Err("Custom tool name cannot be empty".to_owned());
308                    }
309                } else {
310                    return Err("Custom tool call missing function details".to_owned());
311                }
312            }
313            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
314        }
315
316        Ok(())
317    }
318}
319
320fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
321    let trimmed = raw_arguments.trim();
322    match serde_json::from_str(trimmed) {
323        Ok(parsed) => Ok(parsed),
324        Err(primary_error) => {
325            if let Some(candidate) = extract_balanced_json(trimmed)
326                && let Ok(parsed) = serde_json::from_str(candidate)
327            {
328                return Ok(parsed);
329            }
330            if let Some(candidate) = repair_tag_polluted_json(trimmed)
331                && let Ok(parsed) = serde_json::from_str(&candidate)
332            {
333                return Ok(parsed);
334            }
335            Err(primary_error)
336        }
337    }
338}
339
340fn extract_balanced_json(input: &str) -> Option<&str> {
341    let start = input.find(['{', '['])?;
342    let opening = input.as_bytes().get(start).copied()?;
343    let closing = match opening {
344        b'{' => b'}',
345        b'[' => b']',
346        _ => return None,
347    };
348
349    let mut depth = 0usize;
350    let mut in_string = false;
351    let mut escaped = false;
352
353    for (offset, ch) in input[start..].char_indices() {
354        if in_string {
355            if escaped {
356                escaped = false;
357                continue;
358            }
359            if ch == '\\' {
360                escaped = true;
361                continue;
362            }
363            if ch == '"' {
364                in_string = false;
365            }
366            continue;
367        }
368
369        match ch {
370            '"' => in_string = true,
371            _ if ch as u32 == opening as u32 => depth += 1,
372            _ if ch as u32 == closing as u32 => {
373                depth = depth.saturating_sub(1);
374                if depth == 0 {
375                    let end = start + offset + ch.len_utf8();
376                    return input.get(start..end);
377                }
378            }
379            _ => {}
380        }
381    }
382
383    None
384}
385
386fn repair_tag_polluted_json(input: &str) -> Option<String> {
387    let start = input.find(['{', '['])?;
388    let candidate = input.get(start..)?;
389    let boundary = find_provider_markup_boundary(candidate)?;
390    if boundary == 0 {
391        return None;
392    }
393
394    close_incomplete_json_prefix(candidate[..boundary].trim_end())
395}
396
397fn find_provider_markup_boundary(input: &str) -> Option<usize> {
398    const PROVIDER_MARKERS: &[&str] = &[
399        "<</",
400        "</parameter>",
401        "</invoke>",
402        "</minimax:tool_call>",
403        "<minimax:tool_call>",
404        "<parameter name=\"",
405        "<invoke name=\"",
406        "<tool_call>",
407        "</tool_call>",
408    ];
409
410    input.char_indices().find_map(|(offset, _)| {
411        let rest = input.get(offset..)?;
412        PROVIDER_MARKERS
413            .iter()
414            .any(|marker| rest.starts_with(marker))
415            .then_some(offset)
416    })
417}
418
419fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
420    if prefix.is_empty() {
421        return None;
422    }
423
424    let mut repaired = String::with_capacity(prefix.len() + 8);
425    let mut expected_closers = Vec::new();
426    let mut in_string = false;
427    let mut escaped = false;
428
429    for ch in prefix.chars() {
430        repaired.push(ch);
431
432        if in_string {
433            if escaped {
434                escaped = false;
435                continue;
436            }
437
438            match ch {
439                '\\' => escaped = true,
440                '"' => in_string = false,
441                _ => {}
442            }
443            continue;
444        }
445
446        match ch {
447            '"' => in_string = true,
448            '{' => expected_closers.push('}'),
449            '[' => expected_closers.push(']'),
450            '}' | ']' => {
451                if expected_closers.pop() != Some(ch) {
452                    return None;
453                }
454            }
455            _ => {}
456        }
457    }
458
459    if in_string {
460        repaired.push('"');
461    }
462    while let Some(closer) = expected_closers.pop() {
463        repaired.push(closer);
464    }
465
466    Some(repaired)
467}
468
469/// Universal LLM response structure
470#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
471pub struct LLMResponse {
472    /// The response content text
473    pub content: Option<String>,
474
475    /// Tool calls made by the model
476    pub tool_calls: Option<Vec<ToolCall>>,
477
478    /// The model that generated this response
479    pub model: String,
480
481    /// Token usage statistics
482    pub usage: Option<Usage>,
483
484    /// Why the response finished
485    pub finish_reason: FinishReason,
486
487    /// Reasoning content (for models that support it)
488    pub reasoning: Option<String>,
489
490    /// Detailed reasoning traces (for models that support it)
491    pub reasoning_details: Option<Vec<String>>,
492
493    /// Tool references for context
494    pub tool_references: Vec<String>,
495
496    /// Request ID from the provider
497    pub request_id: Option<String>,
498
499    /// Organization ID from the provider
500    pub organization_id: Option<String>,
501}
502
503impl LLMResponse {
504    /// Create a new LLM response with mandatory fields
505    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
506        Self {
507            content: Some(content.into()),
508            tool_calls: None,
509            model: model.into(),
510            usage: None,
511            finish_reason: FinishReason::Stop,
512            reasoning: None,
513            reasoning_details: None,
514            tool_references: Vec::new(),
515            request_id: None,
516            organization_id: None,
517        }
518    }
519
520    /// Get content or empty string
521    pub fn content_text(&self) -> &str {
522        self.content.as_deref().unwrap_or("")
523    }
524
525    /// Get content as String (clone)
526    pub fn content_string(&self) -> String {
527        self.content.clone().unwrap_or_default()
528    }
529}
530
531#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
532pub struct LLMErrorMetadata {
533    pub provider: Option<String>,
534    pub status: Option<u16>,
535    pub code: Option<String>,
536    pub request_id: Option<String>,
537    pub organization_id: Option<String>,
538    pub retry_after: Option<String>,
539    pub message: Option<String>,
540}
541
542impl LLMErrorMetadata {
543    pub fn new(
544        provider: impl Into<String>,
545        status: Option<u16>,
546        code: Option<String>,
547        request_id: Option<String>,
548        organization_id: Option<String>,
549        retry_after: Option<String>,
550        message: Option<String>,
551    ) -> Box<Self> {
552        Box::new(Self {
553            provider: Some(provider.into()),
554            status,
555            code,
556            request_id,
557            organization_id,
558            retry_after,
559            message,
560        })
561    }
562}
563
564/// LLM error types with optional provider metadata
565#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
566#[serde(tag = "type", rename_all = "snake_case")]
567pub enum LLMError {
568    #[error("Authentication failed: {message}")]
569    Authentication {
570        message: String,
571        metadata: Option<Box<LLMErrorMetadata>>,
572    },
573    #[error("Rate limit exceeded")]
574    RateLimit {
575        metadata: Option<Box<LLMErrorMetadata>>,
576    },
577    #[error("Invalid request: {message}")]
578    InvalidRequest {
579        message: String,
580        metadata: Option<Box<LLMErrorMetadata>>,
581    },
582    #[error("Network error: {message}")]
583    Network {
584        message: String,
585        metadata: Option<Box<LLMErrorMetadata>>,
586    },
587    #[error("Provider error: {message}")]
588    Provider {
589        message: String,
590        metadata: Option<Box<LLMErrorMetadata>>,
591    },
592}
593
594#[cfg(test)]
595mod tests {
596    use super::ToolCall;
597    use serde_json::json;
598
599    #[test]
600    fn parsed_arguments_accepts_trailing_characters() {
601        let call = ToolCall::function(
602            "call_read".to_string(),
603            "read_file".to_string(),
604            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
605        );
606
607        let parsed = call
608            .parsed_arguments()
609            .expect("arguments with trailing text should recover");
610        assert_eq!(parsed, json!({"path":"src/main.rs"}));
611    }
612
613    #[test]
614    fn parsed_arguments_accepts_code_fenced_json() {
615        let call = ToolCall::function(
616            "call_read".to_string(),
617            "read_file".to_string(),
618            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
619        );
620
621        let parsed = call
622            .parsed_arguments()
623            .expect("code-fenced arguments should recover");
624        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
625    }
626
627    #[test]
628    fn parsed_arguments_rejects_incomplete_json() {
629        let call = ToolCall::function(
630            "call_read".to_string(),
631            "read_file".to_string(),
632            r#"{"path":"src/main.rs""#.to_string(),
633        );
634
635        assert!(call.parsed_arguments().is_err());
636    }
637
638    #[test]
639    fn parsed_arguments_recovers_truncated_minimax_markup() {
640        let call = ToolCall::function(
641            "call_search".to_string(),
642            "unified_search".to_string(),
643            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
644        );
645
646        let parsed = call
647            .parsed_arguments()
648            .expect("minimax markup spillover should recover");
649        assert_eq!(
650            parsed,
651            json!({
652                "action": "grep",
653                "pattern": "persistent_memory",
654                "path": "vtcode-core/src"
655            })
656        );
657    }
658
659    #[test]
660    fn function_call_serializes_optional_namespace() {
661        let call = ToolCall::function_with_namespace(
662            "call_read".to_string(),
663            Some("workspace".to_string()),
664            "read_file".to_string(),
665            r#"{"path":"src/main.rs"}"#.to_string(),
666        );
667
668        let json = serde_json::to_value(&call).expect("tool call should serialize");
669        assert_eq!(json["function"]["namespace"], "workspace");
670        assert_eq!(json["function"]["name"], "read_file");
671    }
672
673    #[test]
674    fn custom_tool_call_exposes_raw_execution_arguments() {
675        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
676        let call = ToolCall::custom(
677            "call_patch".to_string(),
678            "apply_patch".to_string(),
679            patch.clone(),
680        );
681
682        assert!(call.is_custom());
683        assert_eq!(call.tool_name(), Some("apply_patch"));
684        assert_eq!(call.raw_input(), Some(patch.as_str()));
685        assert_eq!(
686            call.execution_arguments().expect("custom arguments"),
687            json!(patch)
688        );
689        assert!(
690            call.parsed_arguments().is_err(),
691            "custom tool payload should stay freeform rather than JSON"
692        );
693    }
694}