1use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LlmResponse {
16 pub id: String,
18
19 pub content: Vec<ContentBlock>,
21
22 pub stop_reason: StopReason,
24
25 pub usage: Usage,
27
28 #[serde(default)]
30 pub metadata: HashMap<String, serde_json::Value>,
31}
32
33#[non_exhaustive]
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum ContentBlock {
38 Text {
40 text: String,
42 },
43 ToolUse {
45 id: String,
47 name: String,
49 input: serde_json::Value,
51 },
52}
53
54#[non_exhaustive]
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum StopReason {
59 EndTurn,
61 MaxTokens,
63 StopSequence,
65 ToolUse,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
77pub struct Usage {
78 #[serde(alias = "prompt_tokens")]
82 pub input_tokens: u32,
83
84 #[serde(alias = "completion_tokens")]
88 pub output_tokens: u32,
89
90 #[serde(default)]
96 pub total_tokens: u32,
97}
98
99impl Usage {
100 pub fn total(&self) -> u32 {
105 if self.total_tokens > 0 {
106 self.total_tokens
107 } else {
108 self.input_tokens + self.output_tokens
109 }
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ToolCallRequest {
119 pub id: String,
121 pub name: String,
123 pub input: serde_json::Value,
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub struct ProviderSpec {
136 pub name: &'static str,
138
139 pub keywords: &'static [&'static str],
141
142 pub env_key: &'static str,
144
145 pub display_name: &'static str,
147
148 pub litellm_prefix: &'static str,
151
152 pub skip_prefixes: &'static [&'static str],
154
155 pub is_gateway: bool,
157
158 pub is_local: bool,
160
161 pub is_oauth: bool,
163
164 pub default_api_base: &'static str,
166
167 pub detect_by_key_prefix: &'static str,
169
170 pub detect_by_base_keyword: &'static str,
172
173 pub strip_model_prefix: bool,
175}
176
177impl ProviderSpec {
178 pub fn label(&self) -> &str {
180 if self.display_name.is_empty() {
181 self.name
182 } else {
183 self.display_name
184 }
185 }
186}
187
188pub static PROVIDERS: &[ProviderSpec] = &[
192 ProviderSpec {
194 name: "custom",
195 keywords: &[],
196 env_key: "OPENAI_API_KEY",
197 display_name: "Custom",
198 litellm_prefix: "openai",
199 skip_prefixes: &["openai/"],
200 is_gateway: true,
201 is_local: false,
202 is_oauth: false,
203 default_api_base: "",
204 detect_by_key_prefix: "",
205 detect_by_base_keyword: "",
206 strip_model_prefix: true,
207 },
208 ProviderSpec {
210 name: "openrouter",
211 keywords: &["openrouter"],
212 env_key: "OPENROUTER_API_KEY",
213 display_name: "OpenRouter",
214 litellm_prefix: "openrouter",
215 skip_prefixes: &[],
216 is_gateway: true,
217 is_local: false,
218 is_oauth: false,
219 default_api_base: "https://openrouter.ai/api/v1",
220 detect_by_key_prefix: "sk-or-",
221 detect_by_base_keyword: "openrouter",
222 strip_model_prefix: false,
223 },
224 ProviderSpec {
225 name: "aihubmix",
226 keywords: &["aihubmix"],
227 env_key: "OPENAI_API_KEY",
228 display_name: "AiHubMix",
229 litellm_prefix: "openai",
230 skip_prefixes: &[],
231 is_gateway: true,
232 is_local: false,
233 is_oauth: false,
234 default_api_base: "https://aihubmix.com/v1",
235 detect_by_key_prefix: "",
236 detect_by_base_keyword: "aihubmix",
237 strip_model_prefix: true,
238 },
239 ProviderSpec {
241 name: "anthropic",
242 keywords: &["anthropic", "claude"],
243 env_key: "ANTHROPIC_API_KEY",
244 display_name: "Anthropic",
245 litellm_prefix: "",
246 skip_prefixes: &[],
247 is_gateway: false,
248 is_local: false,
249 is_oauth: false,
250 default_api_base: "",
251 detect_by_key_prefix: "",
252 detect_by_base_keyword: "",
253 strip_model_prefix: false,
254 },
255 ProviderSpec {
256 name: "openai",
257 keywords: &["openai", "gpt"],
258 env_key: "OPENAI_API_KEY",
259 display_name: "OpenAI",
260 litellm_prefix: "",
261 skip_prefixes: &[],
262 is_gateway: false,
263 is_local: false,
264 is_oauth: false,
265 default_api_base: "",
266 detect_by_key_prefix: "",
267 detect_by_base_keyword: "",
268 strip_model_prefix: false,
269 },
270 ProviderSpec {
271 name: "openai_codex",
272 keywords: &["openai-codex", "codex"],
273 env_key: "",
274 display_name: "OpenAI Codex",
275 litellm_prefix: "",
276 skip_prefixes: &[],
277 is_gateway: false,
278 is_local: false,
279 is_oauth: true,
280 default_api_base: "https://chatgpt.com/backend-api",
281 detect_by_key_prefix: "",
282 detect_by_base_keyword: "codex",
283 strip_model_prefix: false,
284 },
285 ProviderSpec {
286 name: "deepseek",
287 keywords: &["deepseek"],
288 env_key: "DEEPSEEK_API_KEY",
289 display_name: "DeepSeek",
290 litellm_prefix: "deepseek",
291 skip_prefixes: &["deepseek/"],
292 is_gateway: false,
293 is_local: false,
294 is_oauth: false,
295 default_api_base: "",
296 detect_by_key_prefix: "",
297 detect_by_base_keyword: "",
298 strip_model_prefix: false,
299 },
300 ProviderSpec {
301 name: "gemini",
302 keywords: &["gemini"],
303 env_key: "GOOGLE_GEMINI_API_KEY",
304 display_name: "Gemini",
305 litellm_prefix: "gemini",
306 skip_prefixes: &["gemini/"],
307 is_gateway: false,
308 is_local: false,
309 is_oauth: false,
310 default_api_base: "",
311 detect_by_key_prefix: "",
312 detect_by_base_keyword: "",
313 strip_model_prefix: false,
314 },
315 ProviderSpec {
316 name: "zhipu",
317 keywords: &["zhipu", "glm", "zai"],
318 env_key: "ZAI_API_KEY",
319 display_name: "Zhipu AI",
320 litellm_prefix: "zai",
321 skip_prefixes: &["zhipu/", "zai/", "openrouter/", "hosted_vllm/"],
322 is_gateway: false,
323 is_local: false,
324 is_oauth: false,
325 default_api_base: "",
326 detect_by_key_prefix: "",
327 detect_by_base_keyword: "",
328 strip_model_prefix: false,
329 },
330 ProviderSpec {
331 name: "dashscope",
332 keywords: &["qwen", "dashscope"],
333 env_key: "DASHSCOPE_API_KEY",
334 display_name: "DashScope",
335 litellm_prefix: "dashscope",
336 skip_prefixes: &["dashscope/", "openrouter/"],
337 is_gateway: false,
338 is_local: false,
339 is_oauth: false,
340 default_api_base: "",
341 detect_by_key_prefix: "",
342 detect_by_base_keyword: "",
343 strip_model_prefix: false,
344 },
345 ProviderSpec {
346 name: "moonshot",
347 keywords: &["moonshot", "kimi"],
348 env_key: "MOONSHOT_API_KEY",
349 display_name: "Moonshot",
350 litellm_prefix: "moonshot",
351 skip_prefixes: &["moonshot/", "openrouter/"],
352 is_gateway: false,
353 is_local: false,
354 is_oauth: false,
355 default_api_base: "https://api.moonshot.ai/v1",
356 detect_by_key_prefix: "",
357 detect_by_base_keyword: "",
358 strip_model_prefix: false,
359 },
360 ProviderSpec {
361 name: "minimax",
362 keywords: &["minimax"],
363 env_key: "MINIMAX_API_KEY",
364 display_name: "MiniMax",
365 litellm_prefix: "minimax",
366 skip_prefixes: &["minimax/", "openrouter/"],
367 is_gateway: false,
368 is_local: false,
369 is_oauth: false,
370 default_api_base: "https://api.minimax.io/v1",
371 detect_by_key_prefix: "",
372 detect_by_base_keyword: "",
373 strip_model_prefix: false,
374 },
375 ProviderSpec {
376 name: "vllm",
377 keywords: &["vllm"],
378 env_key: "HOSTED_VLLM_API_KEY",
379 display_name: "vLLM/Local",
380 litellm_prefix: "hosted_vllm",
381 skip_prefixes: &[],
382 is_gateway: false,
383 is_local: true,
384 is_oauth: false,
385 default_api_base: "",
386 detect_by_key_prefix: "",
387 detect_by_base_keyword: "",
388 strip_model_prefix: false,
389 },
390 ProviderSpec {
391 name: "groq",
392 keywords: &["groq"],
393 env_key: "GROQ_API_KEY",
394 display_name: "Groq",
395 litellm_prefix: "groq",
396 skip_prefixes: &["groq/"],
397 is_gateway: false,
398 is_local: false,
399 is_oauth: false,
400 default_api_base: "",
401 detect_by_key_prefix: "",
402 detect_by_base_keyword: "",
403 strip_model_prefix: false,
404 },
405 ProviderSpec {
406 name: "xai",
407 keywords: &["xai", "grok"],
408 env_key: "XAI_API_KEY",
409 display_name: "xAI",
410 litellm_prefix: "xai",
411 skip_prefixes: &["xai/"],
412 is_gateway: false,
413 is_local: false,
414 is_oauth: false,
415 default_api_base: "https://api.x.ai/v1",
416 detect_by_key_prefix: "xai-",
417 detect_by_base_keyword: "x.ai",
418 strip_model_prefix: false,
419 },
420 ProviderSpec {
422 name: "local",
423 keywords: &["local"],
424 env_key: "LOCAL_LLM_API_KEY",
425 display_name: "Local",
426 litellm_prefix: "openai",
427 skip_prefixes: &["local/"],
428 is_gateway: false,
429 is_local: true,
430 is_oauth: false,
431 default_api_base: "http://localhost:11434/v1",
432 detect_by_key_prefix: "",
433 detect_by_base_keyword: "",
434 strip_model_prefix: true,
435 },
436 ProviderSpec {
437 name: "ollama",
438 keywords: &["ollama"],
439 env_key: "LOCAL_LLM_API_KEY",
440 display_name: "Ollama",
441 litellm_prefix: "openai",
442 skip_prefixes: &["ollama/", "local/"],
443 is_gateway: false,
444 is_local: true,
445 is_oauth: false,
446 default_api_base: "http://localhost:11434/v1",
447 detect_by_key_prefix: "",
448 detect_by_base_keyword: ":11434",
449 strip_model_prefix: true,
450 },
451 ProviderSpec {
452 name: "lmstudio",
453 keywords: &["lmstudio", "lm-studio"],
454 env_key: "LOCAL_LLM_API_KEY",
455 display_name: "LM Studio",
456 litellm_prefix: "openai",
457 skip_prefixes: &["lmstudio/"],
458 is_gateway: false,
459 is_local: true,
460 is_oauth: false,
461 default_api_base: "http://localhost:1234/v1",
462 detect_by_key_prefix: "",
463 detect_by_base_keyword: ":1234",
464 strip_model_prefix: true,
465 },
466 ProviderSpec {
467 name: "llamacpp",
468 keywords: &["llamacpp", "llama-cpp", "llama.cpp"],
469 env_key: "LOCAL_LLM_API_KEY",
470 display_name: "llama.cpp",
471 litellm_prefix: "openai",
472 skip_prefixes: &["llamacpp/"],
473 is_gateway: false,
474 is_local: true,
475 is_oauth: false,
476 default_api_base: "http://localhost:8080/v1",
477 detect_by_key_prefix: "",
478 detect_by_base_keyword: ":8080",
479 strip_model_prefix: true,
480 },
481];
482
483pub fn find_by_model(model: &str) -> Option<&'static ProviderSpec> {
488 let model_lower = model.to_lowercase();
489 PROVIDERS.iter().find(|spec| {
490 !spec.is_gateway
491 && !spec.is_local
492 && spec.keywords.iter().any(|kw| model_lower.contains(kw))
493 })
494}
495
496pub fn find_gateway(
503 provider_name: Option<&str>,
504 api_key: Option<&str>,
505 api_base: Option<&str>,
506) -> Option<&'static ProviderSpec> {
507 if let Some(name) = provider_name
509 && let Some(spec) = find_by_name(name)
510 && (spec.is_gateway || spec.is_local)
511 {
512 return Some(spec);
513 }
514
515 for spec in PROVIDERS {
517 if !spec.detect_by_key_prefix.is_empty()
518 && let Some(key) = api_key
519 && key.starts_with(spec.detect_by_key_prefix)
520 {
521 return Some(spec);
522 }
523 if !spec.detect_by_base_keyword.is_empty()
524 && let Some(base) = api_base
525 && base.contains(spec.detect_by_base_keyword)
526 {
527 return Some(spec);
528 }
529 }
530
531 None
532}
533
534pub fn find_by_name(name: &str) -> Option<&'static ProviderSpec> {
536 PROVIDERS.iter().find(|spec| spec.name == name)
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn provider_count() {
545 assert_eq!(PROVIDERS.len(), 19);
546 }
547
548 #[test]
549 fn find_anthropic_by_model() {
550 let spec = find_by_model("anthropic/claude-opus-4-5").unwrap();
551 assert_eq!(spec.name, "anthropic");
552 }
553
554 #[test]
555 fn find_deepseek_by_model() {
556 let spec = find_by_model("deepseek-chat").unwrap();
557 assert_eq!(spec.name, "deepseek");
558 }
559
560 #[test]
561 fn find_by_model_skips_gateways() {
562 let spec = find_by_model("openrouter/some-model");
565 assert!(spec.is_none());
566 }
567
568 #[test]
569 fn find_gateway_by_key_prefix() {
570 let spec = find_gateway(None, Some("sk-or-abc123"), None).unwrap();
571 assert_eq!(spec.name, "openrouter");
572 }
573
574 #[test]
575 fn find_gateway_by_base_keyword() {
576 let spec = find_gateway(None, None, Some("https://aihubmix.com/v1")).unwrap();
577 assert_eq!(spec.name, "aihubmix");
578 }
579
580 #[test]
581 fn find_gateway_by_name() {
582 let spec = find_gateway(Some("vllm"), None, None).unwrap();
583 assert_eq!(spec.name, "vllm");
584 assert!(spec.is_local);
585 }
586
587 #[test]
588 fn find_by_name_existing() {
589 let spec = find_by_name("moonshot").unwrap();
590 assert_eq!(spec.display_name, "Moonshot");
591 assert_eq!(spec.default_api_base, "https://api.moonshot.ai/v1");
592 }
593
594 #[test]
595 fn find_by_name_missing() {
596 assert!(find_by_name("nonexistent").is_none());
597 }
598
599 #[test]
600 fn provider_spec_label() {
601 let spec = find_by_name("anthropic").unwrap();
602 assert_eq!(spec.label(), "Anthropic");
603
604 let spec = find_by_name("custom").unwrap();
606 assert_eq!(spec.label(), "Custom");
607 }
608
609 #[test]
610 fn openai_codex_is_oauth() {
611 let spec = find_by_name("openai_codex").unwrap();
612 assert!(spec.is_oauth);
613 assert!(spec.env_key.is_empty());
614 }
615
616 #[test]
617 fn find_local_by_name() {
618 let spec = find_by_name("local").unwrap();
619 assert!(spec.is_local);
620 assert_eq!(spec.display_name, "Local");
621 assert_eq!(spec.default_api_base, "http://localhost:11434/v1");
622 }
623
624 #[test]
625 fn find_ollama_by_name() {
626 let spec = find_by_name("ollama").unwrap();
627 assert!(spec.is_local);
628 assert_eq!(spec.display_name, "Ollama");
629 assert_eq!(spec.default_api_base, "http://localhost:11434/v1");
630 }
631
632 #[test]
633 fn find_lmstudio_by_name() {
634 let spec = find_by_name("lmstudio").unwrap();
635 assert!(spec.is_local);
636 assert_eq!(spec.display_name, "LM Studio");
637 assert_eq!(spec.default_api_base, "http://localhost:1234/v1");
638 }
639
640 #[test]
641 fn find_llamacpp_by_name() {
642 let spec = find_by_name("llamacpp").unwrap();
643 assert!(spec.is_local);
644 assert_eq!(spec.display_name, "llama.cpp");
645 assert_eq!(spec.default_api_base, "http://localhost:8080/v1");
646 }
647
648 #[test]
649 fn find_gateway_detects_local_by_name() {
650 let spec = find_gateway(Some("local"), None, None).unwrap();
651 assert_eq!(spec.name, "local");
652 assert!(spec.is_local);
653 }
654
655 #[test]
656 fn find_gateway_detects_ollama_by_name() {
657 let spec = find_gateway(Some("ollama"), None, None).unwrap();
658 assert_eq!(spec.name, "ollama");
659 assert!(spec.is_local);
660 }
661
662 #[test]
663 fn find_gateway_detects_ollama_by_port() {
664 let spec =
665 find_gateway(None, None, Some("http://192.168.1.5:11434/v1")).unwrap();
666 assert_eq!(spec.name, "ollama");
667 }
668
669 #[test]
670 fn find_gateway_detects_lmstudio_by_port() {
671 let spec =
672 find_gateway(None, None, Some("http://localhost:1234/v1")).unwrap();
673 assert_eq!(spec.name, "lmstudio");
674 }
675
676 #[test]
677 fn all_local_providers_are_local() {
678 for name in &["local", "ollama", "lmstudio", "llamacpp", "vllm"] {
679 let spec = find_by_name(name).unwrap();
680 assert!(
681 spec.is_local,
682 "provider {} should be marked as local",
683 name
684 );
685 }
686 }
687
688 #[test]
689 fn llm_response_serde_roundtrip() {
690 let resp = LlmResponse {
691 id: "resp-001".into(),
692 content: vec![ContentBlock::Text {
693 text: "Hello!".into(),
694 }],
695 stop_reason: StopReason::EndTurn,
696 usage: Usage {
697 input_tokens: 10,
698 output_tokens: 5,
699 total_tokens: 15,
700 },
701 metadata: HashMap::new(),
702 };
703 let json = serde_json::to_string(&resp).unwrap();
704 let restored: LlmResponse = serde_json::from_str(&json).unwrap();
705 assert_eq!(restored.id, "resp-001");
706 assert_eq!(restored.stop_reason, StopReason::EndTurn);
707 assert_eq!(restored.usage.input_tokens, 10);
708 assert_eq!(restored.usage.total(), 15);
709 }
710
711 #[test]
712 fn usage_total_computed() {
713 let usage = Usage {
714 input_tokens: 10,
715 output_tokens: 5,
716 total_tokens: 0,
717 };
718 assert_eq!(usage.total(), 15);
719 }
720
721 #[test]
722 fn usage_total_from_provider() {
723 let usage = Usage {
724 input_tokens: 10,
725 output_tokens: 5,
726 total_tokens: 20, };
728 assert_eq!(usage.total(), 20);
729 }
730
731 #[test]
732 fn usage_deserializes_from_openai_field_names() {
733 let json = r#"{"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}"#;
734 let usage: Usage = serde_json::from_str(json).unwrap();
735 assert_eq!(usage.input_tokens, 100);
736 assert_eq!(usage.output_tokens, 50);
737 assert_eq!(usage.total_tokens, 150);
738 }
739
740 #[test]
741 fn usage_deserializes_from_canonical_field_names() {
742 let json = r#"{"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}"#;
743 let usage: Usage = serde_json::from_str(json).unwrap();
744 assert_eq!(usage.input_tokens, 100);
745 assert_eq!(usage.output_tokens, 50);
746 assert_eq!(usage.total_tokens, 150);
747 }
748
749 #[test]
750 fn usage_deserializes_without_total() {
751 let json = r#"{"input_tokens": 100, "output_tokens": 50}"#;
752 let usage: Usage = serde_json::from_str(json).unwrap();
753 assert_eq!(usage.input_tokens, 100);
754 assert_eq!(usage.output_tokens, 50);
755 assert_eq!(usage.total_tokens, 0);
756 assert_eq!(usage.total(), 150);
757 }
758
759 #[test]
760 fn content_block_tool_use_serde() {
761 let block = ContentBlock::ToolUse {
762 id: "call-1".into(),
763 name: "web_search".into(),
764 input: serde_json::json!({"query": "rust"}),
765 };
766 let json = serde_json::to_string(&block).unwrap();
767 assert!(json.contains(r#""type":"tool_use""#));
768 let restored: ContentBlock = serde_json::from_str(&json).unwrap();
769 match restored {
770 ContentBlock::ToolUse { id, name, input } => {
771 assert_eq!(id, "call-1");
772 assert_eq!(name, "web_search");
773 assert_eq!(input["query"], "rust");
774 }
775 _ => panic!("expected ToolUse"),
776 }
777 }
778
779 #[test]
780 fn stop_reason_serde() {
781 let reasons = [
782 (StopReason::EndTurn, "\"end_turn\""),
783 (StopReason::MaxTokens, "\"max_tokens\""),
784 (StopReason::StopSequence, "\"stop_sequence\""),
785 (StopReason::ToolUse, "\"tool_use\""),
786 ];
787 for (reason, expected_json) in &reasons {
788 let json = serde_json::to_string(reason).unwrap();
789 assert_eq!(&json, expected_json);
790 let restored: StopReason = serde_json::from_str(&json).unwrap();
791 assert_eq!(restored, *reason);
792 }
793 }
794
795 #[test]
796 fn tool_call_request_serde() {
797 let req = ToolCallRequest {
798 id: "tc-1".into(),
799 name: "exec".into(),
800 input: serde_json::json!({"command": "ls"}),
801 };
802 let json = serde_json::to_string(&req).unwrap();
803 let restored: ToolCallRequest = serde_json::from_str(&json).unwrap();
804 assert_eq!(restored.name, "exec");
805 }
806}