Skip to main content

clawft_types/
provider.rs

1//! LLM provider types and the static provider registry.
2//!
3//! This module contains:
4//! - Response types returned by LLM providers ([`LlmResponse`], [`ContentBlock`], etc.)
5//! - The [`ProviderSpec`] metadata struct and static [`PROVIDERS`] registry
6//! - Lookup helpers: [`find_by_model`], [`find_gateway`], [`find_by_name`]
7
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11// ── LLM response types ──────────────────────────────────────────────────
12
13/// A complete response from an LLM provider.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LlmResponse {
16    /// Provider-assigned response identifier.
17    pub id: String,
18
19    /// Content blocks in the response.
20    pub content: Vec<ContentBlock>,
21
22    /// Why the model stopped generating.
23    pub stop_reason: StopReason,
24
25    /// Token usage for this request/response pair.
26    pub usage: Usage,
27
28    /// Arbitrary provider-specific metadata.
29    #[serde(default)]
30    pub metadata: HashMap<String, serde_json::Value>,
31}
32
33/// A single block of content in an LLM response.
34#[non_exhaustive]
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum ContentBlock {
38    /// A text block.
39    Text {
40        /// The text content.
41        text: String,
42    },
43    /// A tool-use request from the model.
44    ToolUse {
45        /// Tool call identifier (for correlating results).
46        id: String,
47        /// Name of the tool the model wants to invoke.
48        name: String,
49        /// JSON arguments to pass to the tool.
50        input: serde_json::Value,
51    },
52}
53
54/// The reason a model stopped generating tokens.
55#[non_exhaustive]
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum StopReason {
59    /// Natural end of response.
60    EndTurn,
61    /// Hit the `max_tokens` limit.
62    MaxTokens,
63    /// A stop sequence was encountered.
64    StopSequence,
65    /// The model wants to use a tool.
66    ToolUse,
67}
68
69/// Token usage statistics for a single LLM call.
70///
71/// This is the canonical usage type for the entire workspace. It stores
72/// token counts as `u32` (token counts are never negative). The fields
73/// use the clawft naming convention (`input_tokens`, `output_tokens`),
74/// but serde aliases allow deserialization from the OpenAI naming
75/// convention (`prompt_tokens`, `completion_tokens`) as well.
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
77pub struct Usage {
78    /// Tokens consumed by the prompt / input.
79    ///
80    /// Deserializes from either `"input_tokens"` or `"prompt_tokens"`.
81    #[serde(alias = "prompt_tokens")]
82    pub input_tokens: u32,
83
84    /// Tokens generated in the response.
85    ///
86    /// Deserializes from either `"output_tokens"` or `"completion_tokens"`.
87    #[serde(alias = "completion_tokens")]
88    pub output_tokens: u32,
89
90    /// Total tokens used (input + output).
91    ///
92    /// When deserializing from providers that include `total_tokens`, this
93    /// field is populated directly. Otherwise it defaults to 0 and callers
94    /// can use [`Usage::total`] to compute it.
95    #[serde(default)]
96    pub total_tokens: u32,
97}
98
99impl Usage {
100    /// Returns the total token count.
101    ///
102    /// If `total_tokens` was populated by the provider, returns that value.
103    /// Otherwise computes `input_tokens + output_tokens`.
104    pub fn total(&self) -> u32 {
105        if self.total_tokens > 0 {
106            self.total_tokens
107        } else {
108            self.input_tokens + self.output_tokens
109        }
110    }
111}
112
113/// A tool-call request extracted from a model response.
114///
115/// This is a convenience struct for pipeline stages that need to
116/// process tool calls without dealing with the full [`ContentBlock`].
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ToolCallRequest {
119    /// Tool call identifier.
120    pub id: String,
121    /// Name of the tool.
122    pub name: String,
123    /// JSON arguments.
124    pub input: serde_json::Value,
125}
126
127// ── Provider registry ────────────────────────────────────────────────────
128
129/// Metadata for a single LLM provider.
130///
131/// Used for model-name matching, API key detection, and URL prefixing.
132/// All string fields are `&'static str` because instances live in the
133/// static [`PROVIDERS`] array.
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub struct ProviderSpec {
136    /// Config field name (e.g. `"dashscope"`).
137    pub name: &'static str,
138
139    /// Model-name keywords for matching (lowercase).
140    pub keywords: &'static [&'static str],
141
142    /// Environment variable for the API key (e.g. `"DASHSCOPE_API_KEY"`).
143    pub env_key: &'static str,
144
145    /// Human-readable name shown in status output.
146    pub display_name: &'static str,
147
148    /// Prefix added to model names for routing (e.g. `"deepseek"` makes
149    /// `deepseek-chat` become `deepseek/deepseek-chat`).
150    pub litellm_prefix: &'static str,
151
152    /// Do not add prefix when model already starts with one of these.
153    pub skip_prefixes: &'static [&'static str],
154
155    /// Routes any model (e.g. OpenRouter, AiHubMix).
156    pub is_gateway: bool,
157
158    /// Local deployment (e.g. vLLM).
159    pub is_local: bool,
160
161    /// Uses OAuth flow instead of API key.
162    pub is_oauth: bool,
163
164    /// Fallback base URL for the provider.
165    pub default_api_base: &'static str,
166
167    /// Match `api_key` prefix for auto-detection (e.g. `"sk-or-"`).
168    pub detect_by_key_prefix: &'static str,
169
170    /// Match substring in `api_base` URL for auto-detection.
171    pub detect_by_base_keyword: &'static str,
172
173    /// Strip `"provider/"` prefix before re-prefixing.
174    pub strip_model_prefix: bool,
175}
176
177impl ProviderSpec {
178    /// Display label: `display_name` if non-empty, otherwise title-cased `name`.
179    pub fn label(&self) -> &str {
180        if self.display_name.is_empty() {
181            self.name
182        } else {
183            self.display_name
184        }
185    }
186}
187
188/// The provider registry. Order equals match priority (gateways first).
189///
190/// All 15 providers ported from the Python `nanobot/providers/registry.py`.
191pub static PROVIDERS: &[ProviderSpec] = &[
192    // === Custom (user-provided OpenAI-compatible endpoint) ===
193    ProviderSpec {
194        name: "custom",
195        keywords: &[],
196        env_key: "OPENAI_API_KEY",
197        display_name: "Custom",
198        litellm_prefix: "openai",
199        skip_prefixes: &["openai/"],
200        is_gateway: true,
201        is_local: false,
202        is_oauth: false,
203        default_api_base: "",
204        detect_by_key_prefix: "",
205        detect_by_base_keyword: "",
206        strip_model_prefix: true,
207    },
208    // === Gateways ===
209    ProviderSpec {
210        name: "openrouter",
211        keywords: &["openrouter"],
212        env_key: "OPENROUTER_API_KEY",
213        display_name: "OpenRouter",
214        litellm_prefix: "openrouter",
215        skip_prefixes: &[],
216        is_gateway: true,
217        is_local: false,
218        is_oauth: false,
219        default_api_base: "https://openrouter.ai/api/v1",
220        detect_by_key_prefix: "sk-or-",
221        detect_by_base_keyword: "openrouter",
222        strip_model_prefix: false,
223    },
224    ProviderSpec {
225        name: "aihubmix",
226        keywords: &["aihubmix"],
227        env_key: "OPENAI_API_KEY",
228        display_name: "AiHubMix",
229        litellm_prefix: "openai",
230        skip_prefixes: &[],
231        is_gateway: true,
232        is_local: false,
233        is_oauth: false,
234        default_api_base: "https://aihubmix.com/v1",
235        detect_by_key_prefix: "",
236        detect_by_base_keyword: "aihubmix",
237        strip_model_prefix: true,
238    },
239    // === Standard providers ===
240    ProviderSpec {
241        name: "anthropic",
242        keywords: &["anthropic", "claude"],
243        env_key: "ANTHROPIC_API_KEY",
244        display_name: "Anthropic",
245        litellm_prefix: "",
246        skip_prefixes: &[],
247        is_gateway: false,
248        is_local: false,
249        is_oauth: false,
250        default_api_base: "",
251        detect_by_key_prefix: "",
252        detect_by_base_keyword: "",
253        strip_model_prefix: false,
254    },
255    ProviderSpec {
256        name: "openai",
257        keywords: &["openai", "gpt"],
258        env_key: "OPENAI_API_KEY",
259        display_name: "OpenAI",
260        litellm_prefix: "",
261        skip_prefixes: &[],
262        is_gateway: false,
263        is_local: false,
264        is_oauth: false,
265        default_api_base: "",
266        detect_by_key_prefix: "",
267        detect_by_base_keyword: "",
268        strip_model_prefix: false,
269    },
270    ProviderSpec {
271        name: "openai_codex",
272        keywords: &["openai-codex", "codex"],
273        env_key: "",
274        display_name: "OpenAI Codex",
275        litellm_prefix: "",
276        skip_prefixes: &[],
277        is_gateway: false,
278        is_local: false,
279        is_oauth: true,
280        default_api_base: "https://chatgpt.com/backend-api",
281        detect_by_key_prefix: "",
282        detect_by_base_keyword: "codex",
283        strip_model_prefix: false,
284    },
285    ProviderSpec {
286        name: "deepseek",
287        keywords: &["deepseek"],
288        env_key: "DEEPSEEK_API_KEY",
289        display_name: "DeepSeek",
290        litellm_prefix: "deepseek",
291        skip_prefixes: &["deepseek/"],
292        is_gateway: false,
293        is_local: false,
294        is_oauth: false,
295        default_api_base: "",
296        detect_by_key_prefix: "",
297        detect_by_base_keyword: "",
298        strip_model_prefix: false,
299    },
300    ProviderSpec {
301        name: "gemini",
302        keywords: &["gemini"],
303        env_key: "GOOGLE_GEMINI_API_KEY",
304        display_name: "Gemini",
305        litellm_prefix: "gemini",
306        skip_prefixes: &["gemini/"],
307        is_gateway: false,
308        is_local: false,
309        is_oauth: false,
310        default_api_base: "",
311        detect_by_key_prefix: "",
312        detect_by_base_keyword: "",
313        strip_model_prefix: false,
314    },
315    ProviderSpec {
316        name: "zhipu",
317        keywords: &["zhipu", "glm", "zai"],
318        env_key: "ZAI_API_KEY",
319        display_name: "Zhipu AI",
320        litellm_prefix: "zai",
321        skip_prefixes: &["zhipu/", "zai/", "openrouter/", "hosted_vllm/"],
322        is_gateway: false,
323        is_local: false,
324        is_oauth: false,
325        default_api_base: "",
326        detect_by_key_prefix: "",
327        detect_by_base_keyword: "",
328        strip_model_prefix: false,
329    },
330    ProviderSpec {
331        name: "dashscope",
332        keywords: &["qwen", "dashscope"],
333        env_key: "DASHSCOPE_API_KEY",
334        display_name: "DashScope",
335        litellm_prefix: "dashscope",
336        skip_prefixes: &["dashscope/", "openrouter/"],
337        is_gateway: false,
338        is_local: false,
339        is_oauth: false,
340        default_api_base: "",
341        detect_by_key_prefix: "",
342        detect_by_base_keyword: "",
343        strip_model_prefix: false,
344    },
345    ProviderSpec {
346        name: "moonshot",
347        keywords: &["moonshot", "kimi"],
348        env_key: "MOONSHOT_API_KEY",
349        display_name: "Moonshot",
350        litellm_prefix: "moonshot",
351        skip_prefixes: &["moonshot/", "openrouter/"],
352        is_gateway: false,
353        is_local: false,
354        is_oauth: false,
355        default_api_base: "https://api.moonshot.ai/v1",
356        detect_by_key_prefix: "",
357        detect_by_base_keyword: "",
358        strip_model_prefix: false,
359    },
360    ProviderSpec {
361        name: "minimax",
362        keywords: &["minimax"],
363        env_key: "MINIMAX_API_KEY",
364        display_name: "MiniMax",
365        litellm_prefix: "minimax",
366        skip_prefixes: &["minimax/", "openrouter/"],
367        is_gateway: false,
368        is_local: false,
369        is_oauth: false,
370        default_api_base: "https://api.minimax.io/v1",
371        detect_by_key_prefix: "",
372        detect_by_base_keyword: "",
373        strip_model_prefix: false,
374    },
375    ProviderSpec {
376        name: "vllm",
377        keywords: &["vllm"],
378        env_key: "HOSTED_VLLM_API_KEY",
379        display_name: "vLLM/Local",
380        litellm_prefix: "hosted_vllm",
381        skip_prefixes: &[],
382        is_gateway: false,
383        is_local: true,
384        is_oauth: false,
385        default_api_base: "",
386        detect_by_key_prefix: "",
387        detect_by_base_keyword: "",
388        strip_model_prefix: false,
389    },
390    ProviderSpec {
391        name: "groq",
392        keywords: &["groq"],
393        env_key: "GROQ_API_KEY",
394        display_name: "Groq",
395        litellm_prefix: "groq",
396        skip_prefixes: &["groq/"],
397        is_gateway: false,
398        is_local: false,
399        is_oauth: false,
400        default_api_base: "",
401        detect_by_key_prefix: "",
402        detect_by_base_keyword: "",
403        strip_model_prefix: false,
404    },
405    ProviderSpec {
406        name: "xai",
407        keywords: &["xai", "grok"],
408        env_key: "XAI_API_KEY",
409        display_name: "xAI",
410        litellm_prefix: "xai",
411        skip_prefixes: &["xai/"],
412        is_gateway: false,
413        is_local: false,
414        is_oauth: false,
415        default_api_base: "https://api.x.ai/v1",
416        detect_by_key_prefix: "xai-",
417        detect_by_base_keyword: "x.ai",
418        strip_model_prefix: false,
419    },
420    // === Local / air-gapped providers ===
421    ProviderSpec {
422        name: "local",
423        keywords: &["local"],
424        env_key: "LOCAL_LLM_API_KEY",
425        display_name: "Local",
426        litellm_prefix: "openai",
427        skip_prefixes: &["local/"],
428        is_gateway: false,
429        is_local: true,
430        is_oauth: false,
431        default_api_base: "http://localhost:11434/v1",
432        detect_by_key_prefix: "",
433        detect_by_base_keyword: "",
434        strip_model_prefix: true,
435    },
436    ProviderSpec {
437        name: "ollama",
438        keywords: &["ollama"],
439        env_key: "LOCAL_LLM_API_KEY",
440        display_name: "Ollama",
441        litellm_prefix: "openai",
442        skip_prefixes: &["ollama/", "local/"],
443        is_gateway: false,
444        is_local: true,
445        is_oauth: false,
446        default_api_base: "http://localhost:11434/v1",
447        detect_by_key_prefix: "",
448        detect_by_base_keyword: ":11434",
449        strip_model_prefix: true,
450    },
451    ProviderSpec {
452        name: "lmstudio",
453        keywords: &["lmstudio", "lm-studio"],
454        env_key: "LOCAL_LLM_API_KEY",
455        display_name: "LM Studio",
456        litellm_prefix: "openai",
457        skip_prefixes: &["lmstudio/"],
458        is_gateway: false,
459        is_local: true,
460        is_oauth: false,
461        default_api_base: "http://localhost:1234/v1",
462        detect_by_key_prefix: "",
463        detect_by_base_keyword: ":1234",
464        strip_model_prefix: true,
465    },
466    ProviderSpec {
467        name: "llamacpp",
468        keywords: &["llamacpp", "llama-cpp", "llama.cpp"],
469        env_key: "LOCAL_LLM_API_KEY",
470        display_name: "llama.cpp",
471        litellm_prefix: "openai",
472        skip_prefixes: &["llamacpp/"],
473        is_gateway: false,
474        is_local: true,
475        is_oauth: false,
476        default_api_base: "http://localhost:8080/v1",
477        detect_by_key_prefix: "",
478        detect_by_base_keyword: ":8080",
479        strip_model_prefix: true,
480    },
481];
482
483/// Find a standard provider by model-name keyword (case-insensitive).
484///
485/// Skips gateways and local providers -- those are matched by
486/// API key prefix or base URL instead.
487pub fn find_by_model(model: &str) -> Option<&'static ProviderSpec> {
488    let model_lower = model.to_lowercase();
489    PROVIDERS.iter().find(|spec| {
490        !spec.is_gateway
491            && !spec.is_local
492            && spec.keywords.iter().any(|kw| model_lower.contains(kw))
493    })
494}
495
496/// Detect a gateway or local provider.
497///
498/// Priority:
499/// 1. `provider_name` -- if it maps to a gateway/local spec, use it directly.
500/// 2. `api_key` prefix -- e.g. `"sk-or-"` matches OpenRouter.
501/// 3. `api_base` keyword -- e.g. `"aihubmix"` in the URL matches AiHubMix.
502pub fn find_gateway(
503    provider_name: Option<&str>,
504    api_key: Option<&str>,
505    api_base: Option<&str>,
506) -> Option<&'static ProviderSpec> {
507    // 1. Direct match by config key
508    if let Some(name) = provider_name
509        && let Some(spec) = find_by_name(name)
510        && (spec.is_gateway || spec.is_local)
511    {
512        return Some(spec);
513    }
514
515    // 2. Auto-detect by api_key prefix / api_base keyword
516    for spec in PROVIDERS {
517        if !spec.detect_by_key_prefix.is_empty()
518            && let Some(key) = api_key
519            && key.starts_with(spec.detect_by_key_prefix)
520        {
521            return Some(spec);
522        }
523        if !spec.detect_by_base_keyword.is_empty()
524            && let Some(base) = api_base
525            && base.contains(spec.detect_by_base_keyword)
526        {
527            return Some(spec);
528        }
529    }
530
531    None
532}
533
534/// Find a provider spec by config field name (e.g. `"dashscope"`).
535pub fn find_by_name(name: &str) -> Option<&'static ProviderSpec> {
536    PROVIDERS.iter().find(|spec| spec.name == name)
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    #[test]
544    fn provider_count() {
545        assert_eq!(PROVIDERS.len(), 19);
546    }
547
548    #[test]
549    fn find_anthropic_by_model() {
550        let spec = find_by_model("anthropic/claude-opus-4-5").unwrap();
551        assert_eq!(spec.name, "anthropic");
552    }
553
554    #[test]
555    fn find_deepseek_by_model() {
556        let spec = find_by_model("deepseek-chat").unwrap();
557        assert_eq!(spec.name, "deepseek");
558    }
559
560    #[test]
561    fn find_by_model_skips_gateways() {
562        // "openrouter" is a keyword for the openrouter gateway but
563        // find_by_model should skip gateways.
564        let spec = find_by_model("openrouter/some-model");
565        assert!(spec.is_none());
566    }
567
568    #[test]
569    fn find_gateway_by_key_prefix() {
570        let spec = find_gateway(None, Some("sk-or-abc123"), None).unwrap();
571        assert_eq!(spec.name, "openrouter");
572    }
573
574    #[test]
575    fn find_gateway_by_base_keyword() {
576        let spec = find_gateway(None, None, Some("https://aihubmix.com/v1")).unwrap();
577        assert_eq!(spec.name, "aihubmix");
578    }
579
580    #[test]
581    fn find_gateway_by_name() {
582        let spec = find_gateway(Some("vllm"), None, None).unwrap();
583        assert_eq!(spec.name, "vllm");
584        assert!(spec.is_local);
585    }
586
587    #[test]
588    fn find_by_name_existing() {
589        let spec = find_by_name("moonshot").unwrap();
590        assert_eq!(spec.display_name, "Moonshot");
591        assert_eq!(spec.default_api_base, "https://api.moonshot.ai/v1");
592    }
593
594    #[test]
595    fn find_by_name_missing() {
596        assert!(find_by_name("nonexistent").is_none());
597    }
598
599    #[test]
600    fn provider_spec_label() {
601        let spec = find_by_name("anthropic").unwrap();
602        assert_eq!(spec.label(), "Anthropic");
603
604        // custom has a display_name
605        let spec = find_by_name("custom").unwrap();
606        assert_eq!(spec.label(), "Custom");
607    }
608
609    #[test]
610    fn openai_codex_is_oauth() {
611        let spec = find_by_name("openai_codex").unwrap();
612        assert!(spec.is_oauth);
613        assert!(spec.env_key.is_empty());
614    }
615
616    #[test]
617    fn find_local_by_name() {
618        let spec = find_by_name("local").unwrap();
619        assert!(spec.is_local);
620        assert_eq!(spec.display_name, "Local");
621        assert_eq!(spec.default_api_base, "http://localhost:11434/v1");
622    }
623
624    #[test]
625    fn find_ollama_by_name() {
626        let spec = find_by_name("ollama").unwrap();
627        assert!(spec.is_local);
628        assert_eq!(spec.display_name, "Ollama");
629        assert_eq!(spec.default_api_base, "http://localhost:11434/v1");
630    }
631
632    #[test]
633    fn find_lmstudio_by_name() {
634        let spec = find_by_name("lmstudio").unwrap();
635        assert!(spec.is_local);
636        assert_eq!(spec.display_name, "LM Studio");
637        assert_eq!(spec.default_api_base, "http://localhost:1234/v1");
638    }
639
640    #[test]
641    fn find_llamacpp_by_name() {
642        let spec = find_by_name("llamacpp").unwrap();
643        assert!(spec.is_local);
644        assert_eq!(spec.display_name, "llama.cpp");
645        assert_eq!(spec.default_api_base, "http://localhost:8080/v1");
646    }
647
648    #[test]
649    fn find_gateway_detects_local_by_name() {
650        let spec = find_gateway(Some("local"), None, None).unwrap();
651        assert_eq!(spec.name, "local");
652        assert!(spec.is_local);
653    }
654
655    #[test]
656    fn find_gateway_detects_ollama_by_name() {
657        let spec = find_gateway(Some("ollama"), None, None).unwrap();
658        assert_eq!(spec.name, "ollama");
659        assert!(spec.is_local);
660    }
661
662    #[test]
663    fn find_gateway_detects_ollama_by_port() {
664        let spec =
665            find_gateway(None, None, Some("http://192.168.1.5:11434/v1")).unwrap();
666        assert_eq!(spec.name, "ollama");
667    }
668
669    #[test]
670    fn find_gateway_detects_lmstudio_by_port() {
671        let spec =
672            find_gateway(None, None, Some("http://localhost:1234/v1")).unwrap();
673        assert_eq!(spec.name, "lmstudio");
674    }
675
676    #[test]
677    fn all_local_providers_are_local() {
678        for name in &["local", "ollama", "lmstudio", "llamacpp", "vllm"] {
679            let spec = find_by_name(name).unwrap();
680            assert!(
681                spec.is_local,
682                "provider {} should be marked as local",
683                name
684            );
685        }
686    }
687
688    #[test]
689    fn llm_response_serde_roundtrip() {
690        let resp = LlmResponse {
691            id: "resp-001".into(),
692            content: vec![ContentBlock::Text {
693                text: "Hello!".into(),
694            }],
695            stop_reason: StopReason::EndTurn,
696            usage: Usage {
697                input_tokens: 10,
698                output_tokens: 5,
699                total_tokens: 15,
700            },
701            metadata: HashMap::new(),
702        };
703        let json = serde_json::to_string(&resp).unwrap();
704        let restored: LlmResponse = serde_json::from_str(&json).unwrap();
705        assert_eq!(restored.id, "resp-001");
706        assert_eq!(restored.stop_reason, StopReason::EndTurn);
707        assert_eq!(restored.usage.input_tokens, 10);
708        assert_eq!(restored.usage.total(), 15);
709    }
710
711    #[test]
712    fn usage_total_computed() {
713        let usage = Usage {
714            input_tokens: 10,
715            output_tokens: 5,
716            total_tokens: 0,
717        };
718        assert_eq!(usage.total(), 15);
719    }
720
721    #[test]
722    fn usage_total_from_provider() {
723        let usage = Usage {
724            input_tokens: 10,
725            output_tokens: 5,
726            total_tokens: 20, // provider may count differently
727        };
728        assert_eq!(usage.total(), 20);
729    }
730
731    #[test]
732    fn usage_deserializes_from_openai_field_names() {
733        let json = r#"{"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}"#;
734        let usage: Usage = serde_json::from_str(json).unwrap();
735        assert_eq!(usage.input_tokens, 100);
736        assert_eq!(usage.output_tokens, 50);
737        assert_eq!(usage.total_tokens, 150);
738    }
739
740    #[test]
741    fn usage_deserializes_from_canonical_field_names() {
742        let json = r#"{"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}"#;
743        let usage: Usage = serde_json::from_str(json).unwrap();
744        assert_eq!(usage.input_tokens, 100);
745        assert_eq!(usage.output_tokens, 50);
746        assert_eq!(usage.total_tokens, 150);
747    }
748
749    #[test]
750    fn usage_deserializes_without_total() {
751        let json = r#"{"input_tokens": 100, "output_tokens": 50}"#;
752        let usage: Usage = serde_json::from_str(json).unwrap();
753        assert_eq!(usage.input_tokens, 100);
754        assert_eq!(usage.output_tokens, 50);
755        assert_eq!(usage.total_tokens, 0);
756        assert_eq!(usage.total(), 150);
757    }
758
759    #[test]
760    fn content_block_tool_use_serde() {
761        let block = ContentBlock::ToolUse {
762            id: "call-1".into(),
763            name: "web_search".into(),
764            input: serde_json::json!({"query": "rust"}),
765        };
766        let json = serde_json::to_string(&block).unwrap();
767        assert!(json.contains(r#""type":"tool_use""#));
768        let restored: ContentBlock = serde_json::from_str(&json).unwrap();
769        match restored {
770            ContentBlock::ToolUse { id, name, input } => {
771                assert_eq!(id, "call-1");
772                assert_eq!(name, "web_search");
773                assert_eq!(input["query"], "rust");
774            }
775            _ => panic!("expected ToolUse"),
776        }
777    }
778
779    #[test]
780    fn stop_reason_serde() {
781        let reasons = [
782            (StopReason::EndTurn, "\"end_turn\""),
783            (StopReason::MaxTokens, "\"max_tokens\""),
784            (StopReason::StopSequence, "\"stop_sequence\""),
785            (StopReason::ToolUse, "\"tool_use\""),
786        ];
787        for (reason, expected_json) in &reasons {
788            let json = serde_json::to_string(reason).unwrap();
789            assert_eq!(&json, expected_json);
790            let restored: StopReason = serde_json::from_str(&json).unwrap();
791            assert_eq!(restored, *reason);
792        }
793    }
794
795    #[test]
796    fn tool_call_request_serde() {
797        let req = ToolCallRequest {
798            id: "tc-1".into(),
799            name: "exec".into(),
800            input: serde_json::json!({"command": "ls"}),
801        };
802        let json = serde_json::to_string(&req).unwrap();
803        let restored: ToolCallRequest = serde_json::from_str(&json).unwrap();
804        assert_eq!(restored.name, "exec");
805    }
806}