Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    OpenRouter,
12    Ollama,
13    ZAI,
14    Moonshot,
15    HuggingFace,
16    Minimax,
17    OpenCodeZen,
18    OpenCodeGo,
19}
20
21#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
22pub struct Usage {
23    pub prompt_tokens: u32,
24    pub completion_tokens: u32,
25    pub total_tokens: u32,
26    pub cached_prompt_tokens: Option<u32>,
27    pub cache_creation_tokens: Option<u32>,
28    pub cache_read_tokens: Option<u32>,
29}
30
31impl Usage {
32    #[inline]
33    fn has_cache_read_metric(&self) -> bool {
34        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
35    }
36
37    #[inline]
38    fn has_any_cache_metrics(&self) -> bool {
39        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
40    }
41
42    #[inline]
43    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
44        self.cache_read_tokens
45            .or(self.cached_prompt_tokens)
46            .unwrap_or(0)
47    }
48
49    #[inline]
50    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
51        self.cache_creation_tokens.unwrap_or(0)
52    }
53
54    #[inline]
55    pub fn cache_hit_rate(&self) -> Option<f64> {
56        if !self.has_any_cache_metrics() {
57            return None;
58        }
59        let read = self.cache_read_tokens_or_fallback() as f64;
60        let creation = self.cache_creation_tokens_or_zero() as f64;
61        let total = read + creation;
62        if total > 0.0 {
63            Some((read / total) * 100.0)
64        } else {
65            None
66        }
67    }
68
69    #[inline]
70    pub fn is_cache_hit(&self) -> Option<bool> {
71        self.has_any_cache_metrics()
72            .then(|| self.cache_read_tokens_or_fallback() > 0)
73    }
74
75    #[inline]
76    pub fn is_cache_miss(&self) -> Option<bool> {
77        self.has_any_cache_metrics().then(|| {
78            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
79        })
80    }
81
82    #[inline]
83    pub fn total_cache_tokens(&self) -> u32 {
84        let read = self.cache_read_tokens_or_fallback();
85        let creation = self.cache_creation_tokens_or_zero();
86        read + creation
87    }
88
89    #[inline]
90    pub fn cache_savings_ratio(&self) -> Option<f64> {
91        if !self.has_cache_read_metric() {
92            return None;
93        }
94        let read = self.cache_read_tokens_or_fallback() as f64;
95        let prompt = self.prompt_tokens as f64;
96        if prompt > 0.0 {
97            Some(read / prompt)
98        } else {
99            None
100        }
101    }
102}
103
104/// Provider-agnostic balance information for account status display.
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106pub struct BalanceInfo {
107    /// Human-readable balance string (e.g. "100.00¥", "$50.00").
108    pub display: String,
109    /// Whether the account has sufficient balance for API calls.
110    pub is_available: bool,
111}
112
113/// DeepSeek-specific balance info from GET /user/balance
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct DeepSeekBalanceResponse {
116    pub is_available: bool,
117    pub balance_infos: Vec<DeepSeekCurrencyBalance>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct DeepSeekCurrencyBalance {
122    pub currency: String,
123    pub total_balance: String,
124    #[serde(default)]
125    pub granted_balance: String,
126    #[serde(default)]
127    pub topped_up_balance: String,
128}
129
130impl From<DeepSeekBalanceResponse> for BalanceInfo {
131    fn from(resp: DeepSeekBalanceResponse) -> Self {
132        let display = resp
133            .balance_infos
134            .first()
135            .map(|b| {
136                let symbol = match b.currency.as_str() {
137                    "CNY" => "¥",
138                    "USD" => "$",
139                    _ => &b.currency,
140                };
141                format!("{}{}", b.total_balance, symbol)
142            })
143            .unwrap_or_else(|| "N/A".to_string());
144        BalanceInfo {
145            display,
146            is_available: resp.is_available,
147        }
148    }
149}
150
151#[cfg(test)]
152mod usage_tests {
153    use super::Usage;
154
155    #[test]
156    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
157        let usage = Usage {
158            prompt_tokens: 1_000,
159            completion_tokens: 200,
160            total_tokens: 1_200,
161            cached_prompt_tokens: Some(600),
162            cache_creation_tokens: Some(150),
163            cache_read_tokens: None,
164        };
165
166        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
167        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
168        assert_eq!(usage.total_cache_tokens(), 750);
169        assert_eq!(usage.is_cache_hit(), Some(true));
170        assert_eq!(usage.is_cache_miss(), Some(false));
171        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
172        assert_eq!(usage.cache_hit_rate(), Some(80.0));
173    }
174
175    #[test]
176    fn cache_helpers_preserve_unknown_without_metrics() {
177        let usage = Usage {
178            prompt_tokens: 1_000,
179            completion_tokens: 200,
180            total_tokens: 1_200,
181            cached_prompt_tokens: None,
182            cache_creation_tokens: None,
183            cache_read_tokens: None,
184        };
185
186        assert_eq!(usage.total_cache_tokens(), 0);
187        assert_eq!(usage.is_cache_hit(), None);
188        assert_eq!(usage.is_cache_miss(), None);
189        assert_eq!(usage.cache_savings_ratio(), None);
190        assert_eq!(usage.cache_hit_rate(), None);
191    }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
195pub enum FinishReason {
196    #[default]
197    Stop,
198    Length,
199    ToolCalls,
200    ContentFilter,
201    Pause,
202    Refusal,
203    Error(String),
204}
205
206/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
207#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
208pub struct ToolCall {
209    /// Unique identifier for this tool call (e.g., "call_123")
210    pub id: String,
211
212    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
213    #[serde(rename = "type")]
214    pub call_type: String,
215
216    /// Function call details (for function-type tools)
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub function: Option<FunctionCall>,
219
220    /// Raw text payload (for custom freeform tools in GPT-5)
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub text: Option<String>,
223
224    /// Gemini-specific thought signature for maintaining reasoning context
225    #[serde(skip_serializing_if = "Option::is_none")]
226    pub thought_signature: Option<String>,
227}
228
229/// Function call within a tool call
230#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231pub struct FunctionCall {
232    /// Optional namespace for grouped or deferred tools.
233    #[serde(default, skip_serializing_if = "Option::is_none")]
234    pub namespace: Option<String>,
235
236    /// The name of the function to call
237    pub name: String,
238
239    /// The arguments to pass to the function, as a JSON string
240    pub arguments: String,
241}
242
243impl ToolCall {
244    /// Create a new function tool call
245    pub fn function(id: String, name: String, arguments: String) -> Self {
246        Self::function_with_namespace(id, None, name, arguments)
247    }
248
249    /// Create a new function tool call with an optional namespace.
250    pub fn function_with_namespace(
251        id: String,
252        namespace: Option<String>,
253        name: String,
254        arguments: String,
255    ) -> Self {
256        Self {
257            id,
258            call_type: "function".to_owned(),
259            function: Some(FunctionCall {
260                namespace,
261                name,
262                arguments,
263            }),
264            text: None,
265            thought_signature: None,
266        }
267    }
268
269    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
270    pub fn custom(id: String, name: String, text: String) -> Self {
271        Self {
272            id,
273            call_type: "custom".to_owned(),
274            function: Some(FunctionCall {
275                namespace: None,
276                name,
277                arguments: text.clone(),
278            }),
279            text: Some(text),
280            thought_signature: None,
281        }
282    }
283
284    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
285    pub fn is_custom(&self) -> bool {
286        self.call_type == "custom"
287    }
288
289    /// Returns the tool name when the call includes function details.
290    pub fn tool_name(&self) -> Option<&str> {
291        self.function
292            .as_ref()
293            .map(|function| function.name.as_str())
294    }
295
296    /// Returns the raw payload text exactly as emitted by the model.
297    pub fn raw_input(&self) -> Option<&str> {
298        self.text.as_deref().or_else(|| {
299            self.function
300                .as_ref()
301                .map(|function| function.arguments.as_str())
302        })
303    }
304
305    /// Parse the arguments as JSON Value (for function-type tools)
306    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
307        if let Some(ref func) = self.function {
308            parse_tool_arguments(&func.arguments)
309        } else {
310            // Return an error by trying to parse invalid JSON
311            serde_json::from_str("")
312        }
313    }
314
315    /// Returns the execution payload for this tool call.
316    ///
317    /// Function tools keep their JSON semantics. Custom tools execute with their
318    /// raw text payload wrapped as a JSON string value so freeform inputs can
319    /// flow through the existing tool pipeline.
320    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
321        if self.is_custom() {
322            return Ok(serde_json::Value::String(
323                self.raw_input().unwrap_or_default().to_string(),
324            ));
325        }
326
327        self.parsed_arguments()
328    }
329
330    /// Validate that this tool call is properly formed
331    pub fn validate(&self) -> Result<(), String> {
332        if self.id.is_empty() {
333            return Err("Tool call ID cannot be empty".to_owned());
334        }
335
336        match self.call_type.as_str() {
337            "function" => {
338                if let Some(func) = &self.function {
339                    if func.name.is_empty() {
340                        return Err("Function name cannot be empty".to_owned());
341                    }
342                    // Validate that arguments is valid JSON for function tools
343                    if let Err(e) = self.parsed_arguments() {
344                        return Err(format!("Invalid JSON in function arguments: {}", e));
345                    }
346                } else {
347                    return Err("Function tool call missing function details".to_owned());
348                }
349            }
350            "custom" => {
351                // For custom tools, we allow raw text payload without JSON validation
352                if let Some(func) = &self.function {
353                    if func.name.is_empty() {
354                        return Err("Custom tool name cannot be empty".to_owned());
355                    }
356                } else {
357                    return Err("Custom tool call missing function details".to_owned());
358                }
359            }
360            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
361        }
362
363        Ok(())
364    }
365}
366
367fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
368    let trimmed = raw_arguments.trim();
369    match serde_json::from_str(trimmed) {
370        Ok(parsed) => Ok(parsed),
371        Err(primary_error) => {
372            if let Some(candidate) = extract_balanced_json(trimmed)
373                && let Ok(parsed) = serde_json::from_str(candidate)
374            {
375                return Ok(parsed);
376            }
377            if let Some(candidate) = repair_tag_polluted_json(trimmed)
378                && let Ok(parsed) = serde_json::from_str(&candidate)
379            {
380                return Ok(parsed);
381            }
382            Err(primary_error)
383        }
384    }
385}
386
387fn extract_balanced_json(input: &str) -> Option<&str> {
388    let start = input.find(['{', '['])?;
389    let opening = input.as_bytes().get(start).copied()?;
390    let closing = match opening {
391        b'{' => b'}',
392        b'[' => b']',
393        _ => return None,
394    };
395
396    let mut depth = 0usize;
397    let mut in_string = false;
398    let mut escaped = false;
399
400    for (offset, ch) in input[start..].char_indices() {
401        if in_string {
402            if escaped {
403                escaped = false;
404                continue;
405            }
406            if ch == '\\' {
407                escaped = true;
408                continue;
409            }
410            if ch == '"' {
411                in_string = false;
412            }
413            continue;
414        }
415
416        match ch {
417            '"' => in_string = true,
418            _ if ch as u32 == opening as u32 => depth += 1,
419            _ if ch as u32 == closing as u32 => {
420                depth = depth.saturating_sub(1);
421                if depth == 0 {
422                    let end = start + offset + ch.len_utf8();
423                    return input.get(start..end);
424                }
425            }
426            _ => {}
427        }
428    }
429
430    None
431}
432
433fn repair_tag_polluted_json(input: &str) -> Option<String> {
434    let start = input.find(['{', '['])?;
435    let candidate = input.get(start..)?;
436    let boundary = find_provider_markup_boundary(candidate)?;
437    if boundary == 0 {
438        return None;
439    }
440
441    close_incomplete_json_prefix(candidate[..boundary].trim_end())
442}
443
444fn find_provider_markup_boundary(input: &str) -> Option<usize> {
445    const PROVIDER_MARKERS: &[&str] = &[
446        "<</",
447        "</parameter>",
448        "</invoke>",
449        "</minimax:tool_call>",
450        "<minimax:tool_call>",
451        "<parameter name=\"",
452        "<invoke name=\"",
453        "<tool_call>",
454        "</tool_call>",
455    ];
456
457    input.char_indices().find_map(|(offset, _)| {
458        let rest = input.get(offset..)?;
459        PROVIDER_MARKERS
460            .iter()
461            .any(|marker| rest.starts_with(marker))
462            .then_some(offset)
463    })
464}
465
466fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
467    if prefix.is_empty() {
468        return None;
469    }
470
471    let mut repaired = String::with_capacity(prefix.len() + 8);
472    let mut expected_closers = Vec::new();
473    let mut in_string = false;
474    let mut escaped = false;
475
476    for ch in prefix.chars() {
477        repaired.push(ch);
478
479        if in_string {
480            if escaped {
481                escaped = false;
482                continue;
483            }
484
485            match ch {
486                '\\' => escaped = true,
487                '"' => in_string = false,
488                _ => {}
489            }
490            continue;
491        }
492
493        match ch {
494            '"' => in_string = true,
495            '{' => expected_closers.push('}'),
496            '[' => expected_closers.push(']'),
497            '}' | ']' => {
498                if expected_closers.pop() != Some(ch) {
499                    return None;
500                }
501            }
502            _ => {}
503        }
504    }
505
506    if in_string {
507        repaired.push('"');
508    }
509    for closer in expected_closers.drain(..) {
510        repaired.push(closer);
511    }
512
513    Some(repaired)
514}
515
516/// Universal LLM response structure
517#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
518pub struct LLMResponse {
519    /// The response content text
520    pub content: Option<String>,
521
522    /// Tool calls made by the model
523    pub tool_calls: Option<Vec<ToolCall>>,
524
525    /// The model that generated this response
526    pub model: String,
527
528    /// Token usage statistics
529    pub usage: Option<Usage>,
530
531    /// Why the response finished
532    pub finish_reason: FinishReason,
533
534    /// Reasoning content (for models that support it)
535    pub reasoning: Option<String>,
536
537    /// Detailed reasoning traces (for models that support it)
538    pub reasoning_details: Option<Vec<String>>,
539
540    /// Tool references for context
541    pub tool_references: Vec<String>,
542
543    /// Request ID from the provider
544    pub request_id: Option<String>,
545
546    /// Organization ID from the provider
547    pub organization_id: Option<String>,
548}
549
550impl LLMResponse {
551    /// Create a new LLM response with mandatory fields
552    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
553        Self {
554            content: Some(content.into()),
555            tool_calls: None,
556            model: model.into(),
557            usage: None,
558            finish_reason: FinishReason::Stop,
559            reasoning: None,
560            reasoning_details: None,
561            tool_references: Vec::new(),
562            request_id: None,
563            organization_id: None,
564        }
565    }
566
567    /// Get content or empty string
568    pub fn content_text(&self) -> &str {
569        self.content.as_deref().unwrap_or("")
570    }
571
572    /// Get content as String (clone)
573    pub fn content_string(&self) -> String {
574        self.content.clone().unwrap_or_default()
575    }
576}
577
578#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
579pub struct LLMErrorMetadata {
580    pub provider: Option<String>,
581    pub status: Option<u16>,
582    pub code: Option<String>,
583    pub request_id: Option<String>,
584    pub organization_id: Option<String>,
585    pub retry_after: Option<String>,
586    pub message: Option<String>,
587}
588
589impl LLMErrorMetadata {
590    /// Boxed constructor because metadata is always stored inside `Option<Box<LLMErrorMetadata>>`
591    /// in the LLMError enum variants.
592    #[must_use]
593    pub fn new(
594        provider: impl Into<String>,
595        status: Option<u16>,
596        code: Option<String>,
597        request_id: Option<String>,
598        organization_id: Option<String>,
599        retry_after: Option<String>,
600        message: Option<String>,
601    ) -> Box<Self> {
602        Box::new(Self {
603            provider: Some(provider.into()),
604            status,
605            code,
606            request_id,
607            organization_id,
608            retry_after,
609            message,
610        })
611    }
612}
613
614/// LLM error types with optional provider metadata
615#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
616#[serde(tag = "type", rename_all = "snake_case")]
617pub enum LLMError {
618    #[error("Authentication failed: {message}")]
619    Authentication {
620        message: String,
621        metadata: Option<Box<LLMErrorMetadata>>,
622    },
623    #[error("Rate limit exceeded")]
624    RateLimit {
625        metadata: Option<Box<LLMErrorMetadata>>,
626    },
627    #[error("Invalid request: {message}")]
628    InvalidRequest {
629        message: String,
630        metadata: Option<Box<LLMErrorMetadata>>,
631    },
632    #[error("Network error: {message}")]
633    Network {
634        message: String,
635        metadata: Option<Box<LLMErrorMetadata>>,
636    },
637    #[error("Provider error: {message}")]
638    Provider {
639        message: String,
640        metadata: Option<Box<LLMErrorMetadata>>,
641    },
642}
643
644#[cfg(test)]
645mod tests {
646    use super::ToolCall;
647    use serde_json::json;
648
649    #[test]
650    fn parsed_arguments_accepts_trailing_characters() {
651        let call = ToolCall::function(
652            "call_read".to_string(),
653            "read_file".to_string(),
654            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
655        );
656
657        let parsed = call
658            .parsed_arguments()
659            .expect("arguments with trailing text should recover");
660        assert_eq!(parsed, json!({"path":"src/main.rs"}));
661    }
662
663    #[test]
664    fn parsed_arguments_accepts_code_fenced_json() {
665        let call = ToolCall::function(
666            "call_read".to_string(),
667            "read_file".to_string(),
668            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
669        );
670
671        let parsed = call
672            .parsed_arguments()
673            .expect("code-fenced arguments should recover");
674        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
675    }
676
677    #[test]
678    fn parsed_arguments_rejects_incomplete_json() {
679        let call = ToolCall::function(
680            "call_read".to_string(),
681            "read_file".to_string(),
682            r#"{"path":"src/main.rs""#.to_string(),
683        );
684
685        assert!(call.parsed_arguments().is_err());
686    }
687
688    #[test]
689    fn parsed_arguments_recovers_truncated_minimax_markup() {
690        let call = ToolCall::function(
691            "call_search".to_string(),
692            "unified_search".to_string(),
693            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
694        );
695
696        let parsed = call
697            .parsed_arguments()
698            .expect("minimax markup spillover should recover");
699        assert_eq!(
700            parsed,
701            json!({
702                "action": "grep",
703                "pattern": "persistent_memory",
704                "path": "vtcode-core/src"
705            })
706        );
707    }
708
709    #[test]
710    fn function_call_serializes_optional_namespace() {
711        let call = ToolCall::function_with_namespace(
712            "call_read".to_string(),
713            Some("workspace".to_string()),
714            "read_file".to_string(),
715            r#"{"path":"src/main.rs"}"#.to_string(),
716        );
717
718        let json = serde_json::to_value(&call).expect("tool call should serialize");
719        assert_eq!(json["function"]["namespace"], "workspace");
720        assert_eq!(json["function"]["name"], "read_file");
721    }
722
723    #[test]
724    fn custom_tool_call_exposes_raw_execution_arguments() {
725        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
726        let call = ToolCall::custom(
727            "call_patch".to_string(),
728            "apply_patch".to_string(),
729            patch.clone(),
730        );
731
732        assert!(call.is_custom());
733        assert_eq!(call.tool_name(), Some("apply_patch"));
734        assert_eq!(call.raw_input(), Some(patch.as_str()));
735        assert_eq!(
736            call.execution_arguments().expect("custom arguments"),
737            json!(patch)
738        );
739        assert!(
740            call.parsed_arguments().is_err(),
741            "custom tool payload should stay freeform rather than JSON"
742        );
743    }
744}