1use crate::config::OxiosConfig;
10use crate::credential::CredentialStore;
11use chrono::{DateTime, Utc};
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::PathBuf;
16use std::sync::Arc;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct RoutingConfigSnapshot {
24 pub routing_enabled: bool,
26 pub prefer_cost_efficient: bool,
28 pub fallback_models: Vec<String>,
30 pub excluded_models: Vec<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct RoutingStatsSnapshot {
38 pub model_calls: HashMap<String, u64>,
40 pub model_cost: HashMap<String, f64>,
42 pub total_requests: u64,
44 pub total_cost: f64,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct FallbackEvent {
52 pub timestamp: DateTime<Utc>,
54 pub from_model: String,
56 pub to_model: String,
58 pub reason: String,
60 pub success: bool,
62}
63
64#[derive(Debug, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct RoutingUpdate {
68 pub routing_enabled: Option<bool>,
69 pub prefer_cost_efficient: Option<bool>,
70 pub fallback_models: Option<Vec<String>>,
71 pub excluded_models: Option<Vec<String>>,
72}
73
74pub struct RoutingStats {
79 calls: RwLock<HashMap<String, u64>>,
80 costs: RwLock<HashMap<String, f64>>,
81 fallbacks: RwLock<Vec<FallbackEvent>>,
83}
84
85impl Default for RoutingStats {
86 fn default() -> Self {
87 Self {
88 calls: RwLock::new(HashMap::new()),
89 costs: RwLock::new(HashMap::new()),
90 fallbacks: RwLock::new(Vec::new()),
91 }
92 }
93}
94
95impl RoutingStats {
96 pub fn new() -> Self {
98 Self::default()
99 }
100
101 pub fn record_model_usage(&self, model_id: &str, cost_usd: f64) {
103 let mut calls = self.calls.write();
104 *calls.entry(model_id.to_string()).or_insert(0) += 1;
105 if cost_usd > 0.0 {
106 let mut costs = self.costs.write();
107 *costs.entry(model_id.to_string()).or_insert(0.0) += cost_usd;
108 }
109 }
110
111 pub fn record_fallback(&self, event: FallbackEvent) {
113 let mut fb = self.fallbacks.write();
114 fb.push(event);
115 let keep = fb.len().saturating_sub(200);
116 if keep > 0 {
117 fb.drain(0..keep);
118 }
119 }
120
121 pub fn snapshot(&self) -> RoutingStatsSnapshot {
123 let calls = self.calls.read();
124 let costs = self.costs.read();
125 let total_requests: u64 = calls.values().sum();
126 let total_cost: f64 = costs.values().sum();
127 RoutingStatsSnapshot {
128 model_calls: calls.clone(),
129 model_cost: costs.clone(),
130 total_requests,
131 total_cost,
132 }
133 }
134
135 pub fn fallback_history(&self, limit: usize) -> Vec<FallbackEvent> {
137 let fb = self.fallbacks.read();
138 fb.iter().rev().take(limit).cloned().collect()
139 }
140}
141
142pub fn estimate_cost(model_id: &str, input_tokens: u64, output_tokens: u64) -> f64 {
147 let entries = oxi_sdk::get_provider_models(model_id.split('/').next().unwrap_or(model_id));
148 let entry = entries
149 .iter()
150 .find(|e| format!("{}/{}", e.provider, e.id) == model_id);
151 match entry {
152 Some(e) => {
153 (e.cost_input * input_tokens as f64 / 1_000_000.0)
154 + (e.cost_output * output_tokens as f64 / 1_000_000.0)
155 }
156 None => {
157 (0.003 * input_tokens as f64 / 1_000_000.0)
159 + (0.015 * output_tokens as f64 / 1_000_000.0)
160 }
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
168#[serde(rename_all = "lowercase")]
169pub enum ProviderCategory {
170 Major,
172 Open,
174 Regional,
176 Local,
178}
179
180#[derive(Debug, Clone, Copy)]
193struct ProviderMeta {
194 id: &'static str,
196 display_name: &'static str,
198 category: ProviderCategory,
200 hidden: bool,
204 description: &'static str,
206 env_key: &'static str,
210 aliases: &'static [&'static str],
214}
215
216const PROVIDER_META: &[ProviderMeta] = &[
222 ProviderMeta {
224 id: "anthropic",
225 display_name: "Anthropic",
226 category: ProviderCategory::Major,
227 hidden: false,
228 description: "Claude models with extended thinking",
229 env_key: "ANTHROPIC_API_KEY",
230 aliases: &["anthropic"],
231 },
232 ProviderMeta {
233 id: "openai",
234 display_name: "OpenAI",
235 category: ProviderCategory::Major,
236 hidden: false,
237 description: "GPT, o-series, and Codex models",
238 env_key: "OPENAI_API_KEY",
239 aliases: &["openai"],
240 },
241 ProviderMeta {
242 id: "google",
243 display_name: "Google Gemini",
244 category: ProviderCategory::Major,
245 hidden: false,
246 description: "Gemini models with thinking and tool use",
247 env_key: "GOOGLE_API_KEY",
248 aliases: &["google"],
249 },
250 ProviderMeta {
252 id: "groq",
253 display_name: "Groq",
254 category: ProviderCategory::Open,
255 hidden: false,
256 description: "Fast Llama, Mixtral, and Gemma inference",
257 env_key: "GROQ_API_KEY",
258 aliases: &["groq"],
259 },
260 ProviderMeta {
261 id: "openrouter",
262 display_name: "OpenRouter",
263 category: ProviderCategory::Open,
264 hidden: false,
265 description: "Unified gateway to 200+ models",
266 env_key: "OPENROUTER_API_KEY",
267 aliases: &["openrouter"],
268 },
269 ProviderMeta {
270 id: "deepseek",
271 display_name: "DeepSeek",
272 category: ProviderCategory::Open,
273 hidden: false,
274 description: "DeepSeek-V3 and DeepSeek-R1",
275 env_key: "DEEPSEEK_API_KEY",
276 aliases: &["deepseek"],
277 },
278 ProviderMeta {
279 id: "mistral",
280 display_name: "Mistral",
281 category: ProviderCategory::Open,
282 hidden: false,
283 description: "Mistral and Codestral models",
284 env_key: "MISTRAL_API_KEY",
285 aliases: &["mistral"],
286 },
287 ProviderMeta {
288 id: "xai",
289 display_name: "xAI (Grok)",
290 category: ProviderCategory::Open,
291 hidden: false,
292 description: "Grok models from xAI",
293 env_key: "XAI_API_KEY",
294 aliases: &["xai", "grok"],
295 },
296 ProviderMeta {
297 id: "cerebras",
298 display_name: "Cerebras",
299 category: ProviderCategory::Open,
300 hidden: false,
301 description: "Ultra-fast open model inference",
302 env_key: "CEREBRAS_API_KEY",
303 aliases: &["cerebras"],
304 },
305 ProviderMeta {
306 id: "fireworks",
307 display_name: "Fireworks",
308 category: ProviderCategory::Open,
309 hidden: false,
310 description: "Fast open-source model serving",
311 env_key: "FIREWORKS_API_KEY",
312 aliases: &["fireworks"],
313 },
314 ProviderMeta {
315 id: "github-copilot",
316 display_name: "GitHub Copilot",
317 category: ProviderCategory::Open,
318 hidden: false,
319 description: "GitHub Copilot models (GPT-4, Claude)",
320 env_key: "GITHUB_COPILOT_TOKEN",
321 aliases: &["github-copilot", "copilot"],
322 },
323 ProviderMeta {
324 id: "huggingface",
325 display_name: "Hugging Face",
326 category: ProviderCategory::Open,
327 hidden: false,
328 description: "Open model inference hub",
329 env_key: "HUGGINGFACE_API_KEY",
330 aliases: &["huggingface", "hf"],
331 },
332 ProviderMeta {
333 id: "together",
334 display_name: "Together AI",
335 category: ProviderCategory::Open,
336 hidden: false,
337 description: "Open-source model hosting (Llama, Mixtral, ...)",
338 env_key: "TOGETHER_API_KEY",
339 aliases: &["together", "togetherai"],
340 },
341 ProviderMeta {
342 id: "opencode",
343 display_name: "OpenCode",
344 category: ProviderCategory::Open,
345 hidden: false,
346 description: "OpenCode coding agent gateway",
347 env_key: "",
348 aliases: &["opencode"],
349 },
350 ProviderMeta {
351 id: "perplexity",
352 display_name: "Perplexity",
353 category: ProviderCategory::Open,
354 hidden: false,
355 description: "Search-augmented answer models",
356 env_key: "PERPLEXITY_API_KEY",
357 aliases: &["perplexity"],
358 },
359 ProviderMeta {
360 id: "cohere",
361 display_name: "Cohere",
362 category: ProviderCategory::Open,
363 hidden: false,
364 description: "Cohere Command and Embed models",
365 env_key: "COHERE_API_KEY",
366 aliases: &["cohere"],
367 },
368 ProviderMeta {
370 id: "minimax",
371 display_name: "MiniMax",
372 category: ProviderCategory::Regional,
373 hidden: false,
374 description: "MiniMax-M2.7, abab models",
375 env_key: "MINIMAX_API_KEY",
376 aliases: &["minimax"],
377 },
378 ProviderMeta {
379 id: "moonshotai",
380 display_name: "Moonshot AI (Kimi)",
381 category: ProviderCategory::Regional,
382 hidden: false,
383 description: "Kimi models from Moonshot AI",
384 env_key: "MOONSHOT_API_KEY",
385 aliases: &["moonshotai", "moonshot", "kimi"],
386 },
387 ProviderMeta {
388 id: "kimi-coding",
389 display_name: "Kimi Coding",
390 category: ProviderCategory::Regional,
391 hidden: false,
392 description: "Kimi Coding Plan — optimized for coding",
393 env_key: "KIMI_CODING_API_KEY",
394 aliases: &["kimi-coding"],
395 },
396 ProviderMeta {
397 id: "zai",
398 display_name: "Z.AI (GLM)",
399 category: ProviderCategory::Regional,
400 hidden: false,
401 description: "Z.AI GLM models (coding plan)",
402 env_key: "ZAI_API_KEY",
403 aliases: &["zai"],
404 },
405 ProviderMeta {
411 id: "amazon-bedrock",
412 display_name: "Amazon Bedrock",
413 category: ProviderCategory::Open,
414 hidden: true,
415 description: "Multi-model via AWS Bedrock ConverseStream",
416 env_key: "AWS_ACCESS_KEY_ID",
417 aliases: &["amazon-bedrock", "aws-bedrock", "bedrock"],
418 },
419 ProviderMeta {
420 id: "azure-openai-responses",
421 display_name: "Azure OpenAI (Responses)",
422 category: ProviderCategory::Open,
423 hidden: true,
424 description: "OpenAI models via Azure Cognitive Services",
425 env_key: "AZURE_OPENAI_API_KEY",
426 aliases: &["azure-openai-responses", "azure"],
427 },
428 ProviderMeta {
429 id: "cloudflare-ai-gateway",
430 display_name: "Cloudflare AI Gateway",
431 category: ProviderCategory::Open,
432 hidden: true,
433 description: "Serverless AI via Cloudflare AI Gateway",
434 env_key: "CLOUDFLARE_API_TOKEN",
435 aliases: &["cloudflare-ai-gateway", "cf-ai-gateway"],
436 },
437 ProviderMeta {
438 id: "cloudflare-workers-ai",
439 display_name: "Cloudflare Workers AI",
440 category: ProviderCategory::Open,
441 hidden: true,
442 description: "Serverless AI via Cloudflare Workers",
443 env_key: "CLOUDFLARE_API_KEY",
444 aliases: &["cloudflare-workers-ai", "cloudflare", "workers-ai"],
445 },
446 ProviderMeta {
447 id: "google-vertex",
448 display_name: "Google Vertex AI",
449 category: ProviderCategory::Open,
450 hidden: true,
451 description: "Gemini via Google Cloud Vertex AI",
452 env_key: "GOOGLE_APPLICATION_CREDENTIALS",
453 aliases: &["google-vertex", "vertex"],
454 },
455 ProviderMeta {
456 id: "minimax-cn",
457 display_name: "MiniMax (China)",
458 category: ProviderCategory::Regional,
459 hidden: true,
460 description: "MiniMax China region endpoint",
461 env_key: "MINIMAX_CN_API_KEY",
462 aliases: &["minimax-cn"],
463 },
464 ProviderMeta {
465 id: "moonshotai-cn",
466 display_name: "Moonshot AI (China)",
467 category: ProviderCategory::Regional,
468 hidden: true,
469 description: "Kimi models — China region endpoint",
470 env_key: "MOONSHOT_CN_API_KEY",
471 aliases: &["moonshotai-cn", "moonshot-cn"],
472 },
473 ProviderMeta {
474 id: "openai-codex",
475 display_name: "OpenAI Codex",
476 category: ProviderCategory::Open,
477 hidden: true,
478 description: "OpenAI Codex coding agent (Responses API)",
479 env_key: "OPENAI_API_KEY",
480 aliases: &["openai-codex"],
481 },
482 ProviderMeta {
483 id: "opencode-go",
484 display_name: "OpenCode Go",
485 category: ProviderCategory::Open,
486 hidden: true,
487 description: "OpenCode Go Gateway",
488 env_key: "OPENCODE_GO_API_KEY",
489 aliases: &["opencode-go"],
490 },
491 ProviderMeta {
492 id: "vercel-ai-gateway",
493 display_name: "Vercel AI Gateway",
494 category: ProviderCategory::Open,
495 hidden: true,
496 description: "Vercel AI Gateway",
497 env_key: "VERCEL_API_KEY",
498 aliases: &["vercel-ai-gateway", "vercel"],
499 },
500 ProviderMeta {
501 id: "xiaomi",
502 display_name: "Xiaomi MiMo",
503 category: ProviderCategory::Regional,
504 hidden: true,
505 description: "Xiaomi MiMo models",
506 env_key: "XIAOMI_API_KEY",
507 aliases: &["xiaomi"],
508 },
509];
510
511fn provider_meta(id: &str) -> Option<&'static ProviderMeta> {
513 PROVIDER_META
514 .iter()
515 .find(|m| m.id == id || m.aliases.contains(&id))
516}
517
518fn provider_category(id: &str) -> ProviderCategory {
519 provider_meta(id)
520 .map(|m| m.category)
521 .unwrap_or(ProviderCategory::Open)
522}
523
524fn provider_display_name(id: &str) -> String {
530 provider_meta(id)
531 .map(|m| m.display_name.to_string())
532 .unwrap_or_else(|| fallback_display_name(id))
533}
534
535fn fallback_display_name(id: &str) -> String {
541 id.split(['-', '_'])
542 .filter(|s| !s.is_empty())
543 .map(|segment| {
544 let mut chars = segment.chars();
545 match chars.next() {
546 Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
547 None => String::new(),
548 }
549 })
550 .collect::<Vec<_>>()
551 .join(" ")
552}
553
554#[derive(Debug, Clone, Serialize, Deserialize)]
556#[serde(rename_all = "camelCase")]
557pub struct ProviderInfo {
558 pub id: String,
560 pub name: String,
562 pub category: ProviderCategory,
564 pub model_count: usize,
566 pub has_key: bool,
568 #[serde(default)]
571 pub description: String,
572 #[serde(default)]
576 pub env_key: String,
577}
578
579#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
581#[serde(rename_all = "lowercase")]
582pub enum InputModality {
583 Text,
585 Image,
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
591#[serde(rename_all = "camelCase")]
592pub struct ModelInfo {
593 pub id: String,
595 pub name: String,
597 pub api: String,
599 pub provider: String,
601 pub reasoning: bool,
603 pub input: Vec<InputModality>,
605 pub context_window: u32,
607 pub max_tokens: u32,
609 pub cost_input: f64,
611 pub cost_output: f64,
613 pub cost_cache_read: f64,
615 pub cost_cache_write: f64,
617}
618
619impl From<&oxi_sdk::ModelEntry> for ModelInfo {
620 fn from(entry: &oxi_sdk::ModelEntry) -> Self {
621 Self {
622 id: format!("{}/{}", entry.provider, entry.id),
623 name: entry.name.to_string(),
624 api: entry.api.to_string(),
625 provider: entry.provider.to_string(),
626 reasoning: entry.reasoning,
627 input: entry
628 .input
629 .iter()
630 .map(|m| match m {
631 oxi_sdk::InputModality::Text => InputModality::Text,
632 oxi_sdk::InputModality::Image => InputModality::Image,
633 _ => InputModality::Text,
634 })
635 .collect(),
636 context_window: entry.context_window,
637 max_tokens: entry.max_tokens,
638 cost_input: entry.cost_input,
639 cost_output: entry.cost_output,
640 cost_cache_read: entry.cost_cache_read,
641 cost_cache_write: entry.cost_cache_write,
642 }
643 }
644}
645
646#[derive(Debug, Clone, Serialize, Deserialize)]
648pub struct EngineConfigResponse {
649 pub default_model: String,
651 pub api_key_set: bool,
653 pub api_key_source: Option<String>,
655 pub provider: Option<String>,
657 pub routing: RoutingConfigSnapshot,
659}
660
661#[derive(Debug, Clone, Serialize, Deserialize)]
663pub struct ValidateKeyResult {
664 pub valid: bool,
666 pub provider: String,
668 pub message: Option<String>,
670}
671
672pub struct EngineApi {
684 config: Arc<RwLock<OxiosConfig>>,
685 config_path: PathBuf,
686 routing_stats: Arc<RoutingStats>,
687 engine_handle: Arc<crate::engine::EngineHandle>,
689}
690
691impl EngineApi {
692 pub fn new(
699 config: Arc<RwLock<OxiosConfig>>,
700 config_path: PathBuf,
701 routing_stats: Arc<RoutingStats>,
702 engine_handle: Arc<crate::engine::EngineHandle>,
703 ) -> Self {
704 Self {
705 config,
706 config_path,
707 routing_stats,
708 engine_handle,
709 }
710 }
711
712 pub fn routing_stats(&self) -> Arc<RoutingStats> {
714 Arc::clone(&self.routing_stats)
715 }
716
717 pub fn engine_handle(&self) -> &Arc<crate::engine::EngineHandle> {
719 &self.engine_handle
720 }
721
722 pub fn providers(&self) -> Vec<ProviderInfo> {
734 let all = oxi_sdk::get_providers();
735
736 all.into_iter()
737 .filter(|p| provider_meta(p).map(|m| !m.hidden).unwrap_or(true))
738 .map(|p| {
739 let model_count = oxi_sdk::get_provider_models(p).len();
740 let has_key = CredentialStore::has_credential(
741 p,
742 self.config
743 .read()
744 .engine
745 .api_key
746 .as_deref()
747 .filter(|k| !k.is_empty()),
748 );
749 let meta = provider_meta(p);
750 ProviderInfo {
751 id: p.to_string(),
752 name: provider_display_name(p),
753 category: provider_category(p),
754 model_count,
755 has_key,
756 description: meta.map(|m| m.description.to_string()).unwrap_or_default(),
757 env_key: meta.map(|m| m.env_key.to_string()).unwrap_or_default(),
758 }
759 })
760 .collect()
761 }
762
763 pub fn models(&self, provider: &str, query: Option<&str>) -> Vec<ModelInfo> {
765 let entries = oxi_sdk::get_provider_models(provider);
766 entries
767 .iter()
768 .filter(|e| {
769 !e.name.contains("latest")
771 })
772 .filter(|e| {
773 if let Some(q) = query {
774 let q = q.to_lowercase();
775 e.name.to_lowercase().contains(&q)
776 || e.id.to_lowercase().contains(&q)
777 || e.provider.to_lowercase().contains(&q)
778 } else {
779 true
780 }
781 })
782 .map(ModelInfo::from)
783 .collect()
784 }
785
786 pub fn search_models(&self, query: &str) -> Vec<ModelInfo> {
788 oxi_sdk::search_models(query)
789 .into_iter()
790 .map(ModelInfo::from)
791 .collect()
792 }
793
794 pub fn config(&self) -> EngineConfigResponse {
796 let cfg = self.config.read();
797 let provider =
798 CredentialStore::provider_from_model(&cfg.engine.default_model).map(|s| s.to_string());
799 let api_key_source = provider.as_deref().and_then(|p| {
800 CredentialStore::resolve(p, cfg.api_key().as_deref()).map(|(_, src)| {
801 match src {
802 crate::credential::CredentialSource::EnvVar => "env",
803 crate::credential::CredentialSource::Config => "config",
804 crate::credential::CredentialSource::OxiAuthStore => "auth_store",
805 }
806 .to_string()
807 })
808 });
809 let api_key_set = provider
810 .as_deref()
811 .map(|p| CredentialStore::has_credential(p, cfg.api_key().as_deref()))
812 .unwrap_or(false);
813
814 EngineConfigResponse {
815 default_model: cfg.engine.default_model.clone(),
816 api_key_set,
817 api_key_source,
818 provider,
819 routing: RoutingConfigSnapshot {
820 routing_enabled: cfg.engine.routing_enabled,
821 prefer_cost_efficient: cfg.engine.prefer_cost_efficient,
822 fallback_models: cfg.engine.fallback_models.clone(),
823 excluded_models: cfg.engine.excluded_models.clone(),
824 },
825 }
826 }
827
828 pub fn routing_stats_snapshot(&self) -> RoutingStatsSnapshot {
830 self.routing_stats.snapshot()
831 }
832
833 pub fn fallback_history(&self, limit: usize) -> Vec<FallbackEvent> {
835 self.routing_stats.fallback_history(limit)
836 }
837
838 pub fn set_model(&self, model_id: &str) -> anyhow::Result<()> {
845 {
846 let mut cfg = self.config.write();
847 cfg.engine.default_model = model_id.to_string();
848 self.persist(&cfg)?;
849 }
850 tracing::info!(model = %model_id, "Default model updated in config");
851 self.rebuild_and_swap();
852 Ok(())
853 }
854
855 pub fn set_api_key(&self, provider: &str, key: &str) -> anyhow::Result<()> {
861 CredentialStore::store(provider, key)?;
862
863 let cfg = self.config.read();
865 if let Some(current_provider) =
866 CredentialStore::provider_from_model(&cfg.engine.default_model)
867 {
868 if current_provider == provider {
869 drop(cfg);
870 let mut cfg = self.config.write();
871 cfg.engine.api_key = Some(key.to_string());
872 self.persist(&cfg)?;
873 }
874 }
875 tracing::info!(provider = %provider, "API key stored");
876 self.rebuild_and_swap();
877 Ok(())
878 }
879
880 pub fn set_provider_options(&self, opts: &oxi_sdk::ProviderOptions) -> anyhow::Result<()> {
885 {
886 let mut cfg = self.config.write();
887 cfg.engine.provider_options = Some(opts.clone());
888 self.persist(&cfg)?;
889 }
890 tracing::info!("Provider options updated and persisted");
891 Ok(())
894 }
895
896 pub fn set_routing(&self, update: RoutingUpdate) -> anyhow::Result<()> {
901 {
902 let mut cfg = self.config.write();
903 if let Some(v) = update.routing_enabled {
904 cfg.engine.routing_enabled = v;
905 }
906 if let Some(v) = update.prefer_cost_efficient {
907 cfg.engine.prefer_cost_efficient = v;
908 }
909 if let Some(v) = update.fallback_models {
910 cfg.engine.fallback_models = v;
911 }
912 if let Some(v) = update.excluded_models {
913 cfg.engine.excluded_models = v;
914 }
915 self.persist(&cfg)?;
916 }
917 tracing::info!("Routing configuration updated via API");
918 self.rebuild_and_swap();
919 Ok(())
920 }
921
922 pub fn validate_key(&self, provider: &str, api_key: &str) -> ValidateKeyResult {
927 let result = self.try_validate(provider, api_key);
929 match result {
930 Ok(()) => ValidateKeyResult {
931 valid: true,
932 provider: provider.to_string(),
933 message: Some("API key is valid".to_string()),
934 },
935 Err(e) => ValidateKeyResult {
936 valid: false,
937 provider: provider.to_string(),
938 message: Some(format!("Validation failed: {e}")),
939 },
940 }
941 }
942
943 fn try_validate(&self, provider: &str, api_key: &str) -> anyhow::Result<()> {
945 let builder = oxi_sdk::OxiBuilder::new()
947 .with_builtins()
948 .api_key(provider, api_key);
949 let oxi = builder.build();
950
951 let models = oxi_sdk::get_provider_models(provider);
953 if models.is_empty() {
954 anyhow::bail!("No models found for provider '{provider}'");
955 }
956
957 let model_id = format!("{}/{}", provider, models[0].id);
958 let _model = oxi.resolve_model(&model_id)?;
959
960 let _provider = oxi.create_provider(provider)?;
962
963 if api_key.is_empty() {
967 anyhow::bail!("API key is empty");
968 }
969 if api_key.len() < 8 {
970 anyhow::bail!("API key appears too short");
971 }
972
973 tracing::debug!(
974 provider = %provider,
975 model = %model_id,
976 "Key validation: provider resolved with injected key"
977 );
978 Ok(())
979 }
980
981 pub fn estimate_cost(model_id: &str, input_tokens: u64, output_tokens: u64) -> f64 {
985 estimate_cost(model_id, input_tokens, output_tokens)
986 }
987
988 fn persist(&self, config: &OxiosConfig) -> anyhow::Result<()> {
990 let content = toml::to_string_pretty(config)
991 .map_err(|e| anyhow::anyhow!("Failed to serialize config: {e}"))?;
992 std::fs::write(&self.config_path, content)?;
993 Ok(())
994 }
995
996 fn rebuild_and_swap(&self) {
1002 let cfg = self.config.read();
1003 let model_id = &cfg.engine.default_model;
1004 let new_engine =
1005 crate::engine::OxiosEngine::from_config(model_id, cfg.api_key().as_deref());
1006 drop(cfg);
1007 self.engine_handle.swap(new_engine);
1008 }
1009}
1010
1011impl std::fmt::Debug for EngineApi {
1012 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1013 f.debug_struct("EngineApi")
1014 .field("config_path", &self.config_path)
1015 .finish()
1016 }
1017}
1018
1019pub fn record_usage_to_stats(
1022 stats: &Option<Arc<RoutingStats>>,
1023 model_id: &str,
1024 input_tokens: u64,
1025 output_tokens: u64,
1026) {
1027 if let Some(s) = stats {
1028 let cost = estimate_cost(model_id, input_tokens, output_tokens);
1029 s.record_model_usage(model_id, cost);
1030 }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035 use super::*;
1036
1037 #[test]
1038 fn test_provider_category_known() {
1039 assert_eq!(provider_category("anthropic"), ProviderCategory::Major);
1041 assert_eq!(provider_category("openai"), ProviderCategory::Major);
1042 assert_eq!(provider_category("google"), ProviderCategory::Major);
1043 assert_eq!(provider_category("groq"), ProviderCategory::Open);
1045 assert_eq!(provider_category("opencode"), ProviderCategory::Open);
1046 assert_eq!(provider_category("minimax"), ProviderCategory::Regional);
1048 assert_eq!(provider_category("moonshotai"), ProviderCategory::Regional);
1049 assert_eq!(provider_category("kimi-coding"), ProviderCategory::Regional);
1050 assert_eq!(provider_category("zai"), ProviderCategory::Regional);
1051 assert_eq!(provider_category("minimax-cn"), ProviderCategory::Regional);
1052 assert_eq!(provider_category("xiaomi"), ProviderCategory::Regional);
1053 }
1054
1055 #[test]
1056 fn test_provider_category_fallback() {
1057 assert_eq!(
1059 provider_category("not-a-real-provider"),
1060 ProviderCategory::Open
1061 );
1062 assert_eq!(provider_category(""), ProviderCategory::Open);
1063 }
1064
1065 #[test]
1066 fn test_provider_display_name_known() {
1067 assert_eq!(provider_display_name("anthropic"), "Anthropic");
1068 assert_eq!(provider_display_name("minimax"), "MiniMax");
1069 assert_eq!(provider_display_name("moonshotai"), "Moonshot AI (Kimi)");
1070 assert_eq!(provider_display_name("kimi-coding"), "Kimi Coding");
1071 assert_eq!(provider_display_name("zai"), "Z.AI (GLM)");
1072 assert_eq!(provider_display_name("opencode"), "OpenCode");
1073 assert_eq!(provider_display_name("amazon-bedrock"), "Amazon Bedrock");
1074 }
1075
1076 #[test]
1077 fn test_provider_display_name_fallback() {
1078 assert_eq!(
1080 provider_display_name("some-new-provider"),
1081 "Some New Provider"
1082 );
1083 assert_eq!(provider_display_name("kimi-coding"), "Kimi Coding");
1084 assert_eq!(provider_display_name("some_id"), "Some Id");
1085 assert_eq!(provider_display_name(""), "");
1087 }
1088
1089 #[test]
1090 fn test_provider_meta_lookup_by_alias() {
1091 let by_id = provider_meta("github-copilot").unwrap();
1093 let by_alias = provider_meta("copilot").unwrap();
1094 assert_eq!(by_id.id, by_alias.id);
1095
1096 let bedrock_id = provider_meta("amazon-bedrock").unwrap();
1097 let bedrock_alias = provider_meta("aws-bedrock").unwrap();
1098 let bedrock_canonical = provider_meta("bedrock").unwrap();
1099 assert_eq!(bedrock_id.id, bedrock_alias.id);
1100 assert_eq!(bedrock_id.id, bedrock_canonical.id);
1101 }
1102
1103 #[test]
1104 fn test_provider_meta_unknown_is_none() {
1105 assert!(provider_meta("not-a-real-provider").is_none());
1106 assert!(provider_meta("").is_none());
1107 }
1108
1109 #[test]
1110 fn test_provider_info_serialization() {
1111 let info = ProviderInfo {
1112 id: "anthropic".to_string(),
1113 name: "Anthropic".to_string(),
1114 category: ProviderCategory::Major,
1115 model_count: 15,
1116 has_key: true,
1117 description: "Claude models with extended thinking".to_string(),
1118 env_key: "ANTHROPIC_API_KEY".to_string(),
1119 };
1120 let json = serde_json::to_string(&info).unwrap();
1121 assert!(json.contains("\"modelCount\":15"));
1123 assert!(json.contains("\"hasKey\":true"));
1124 assert!(json.contains("\"envKey\":\"ANTHROPIC_API_KEY\""));
1125 let restored: ProviderInfo = serde_json::from_str(&json).unwrap();
1126 assert_eq!(restored.id, "anthropic");
1127 assert_eq!(restored.name, "Anthropic");
1128 assert_eq!(restored.model_count, 15);
1129 assert!(restored.has_key);
1130 assert_eq!(restored.env_key, "ANTHROPIC_API_KEY");
1131 }
1132
1133 #[test]
1134 fn test_provider_info_serialization_missing_optional() {
1135 let json = r#"{
1138 "id": "anthropic",
1139 "name": "Anthropic",
1140 "category": "major",
1141 "modelCount": 15,
1142 "hasKey": true
1143 }"#;
1144 let info: ProviderInfo = serde_json::from_str(json).unwrap();
1145 assert_eq!(info.id, "anthropic");
1146 assert_eq!(info.description, "");
1147 assert_eq!(info.env_key, "");
1148 }
1149
1150 #[test]
1151 fn test_model_info_serialization() {
1152 let info = ModelInfo {
1153 id: "anthropic/claude-sonnet-4".to_string(),
1154 name: "Claude Sonnet 4".to_string(),
1155 api: "anthropic-messages".to_string(),
1156 provider: "anthropic".to_string(),
1157 reasoning: true,
1158 input: vec![InputModality::Text, InputModality::Image],
1159 context_window: 200000,
1160 max_tokens: 16000,
1161 cost_input: 3.0,
1162 cost_output: 15.0,
1163 cost_cache_read: 0.3,
1164 cost_cache_write: 3.75,
1165 };
1166 let json = serde_json::to_string(&info).unwrap();
1167 let restored: ModelInfo = serde_json::from_str(&json).unwrap();
1168 assert_eq!(restored.id, "anthropic/claude-sonnet-4");
1169 assert!(restored.reasoning);
1170 assert_eq!(restored.context_window, 200000);
1171 assert!(restored.input.contains(&InputModality::Image));
1172 assert_eq!(restored.api, "anthropic-messages");
1173 }
1174
1175 #[test]
1176 fn test_engine_config_response_serialization() {
1177 let resp = EngineConfigResponse {
1178 default_model: "anthropic/claude-sonnet-4".to_string(),
1179 api_key_set: true,
1180 api_key_source: Some("config.toml".to_string()),
1181 provider: Some("anthropic".to_string()),
1182 routing: RoutingConfigSnapshot {
1183 routing_enabled: false,
1184 prefer_cost_efficient: false,
1185 fallback_models: vec![],
1186 excluded_models: vec![],
1187 },
1188 };
1189 let json = serde_json::to_string(&resp).unwrap();
1190 let restored: EngineConfigResponse = serde_json::from_str(&json).unwrap();
1191 assert_eq!(restored.default_model, "anthropic/claude-sonnet-4");
1192 assert!(restored.api_key_set);
1193 assert_eq!(restored.api_key_source.as_deref(), Some("config.toml"));
1194 assert!(!restored.routing.routing_enabled);
1195 }
1196
1197 #[test]
1198 fn test_validate_key_result_serialization() {
1199 let result = ValidateKeyResult {
1200 valid: true,
1201 provider: "openai".to_string(),
1202 message: Some("API key is valid".to_string()),
1203 };
1204 let json = serde_json::to_string(&result).unwrap();
1205 let restored: ValidateKeyResult = serde_json::from_str(&json).unwrap();
1206 assert!(restored.valid);
1207 assert_eq!(restored.provider, "openai");
1208 }
1209
1210 #[test]
1211 fn test_validate_key_result_invalid() {
1212 let result = ValidateKeyResult {
1213 valid: false,
1214 provider: "anthropic".to_string(),
1215 message: Some("Validation failed: key too short".to_string()),
1216 };
1217 assert!(!result.valid);
1218 assert!(result.message.as_ref().unwrap().contains("failed"));
1219 }
1220
1221 #[test]
1222 fn test_routing_stats_snapshot() {
1223 let stats = RoutingStats::new();
1224 stats.record_model_usage("anthropic/claude-sonnet-4", 0.05);
1225 stats.record_model_usage("anthropic/claude-sonnet-4", 0.03);
1226 stats.record_model_usage("openai/gpt-4o-mini", 0.01);
1227
1228 let snap = stats.snapshot();
1229 assert_eq!(snap.total_requests, 3);
1230 assert_eq!(snap.model_calls["anthropic/claude-sonnet-4"], 2);
1231 assert_eq!(snap.model_calls["openai/gpt-4o-mini"], 1);
1232 assert!((snap.total_cost - 0.09).abs() < 0.001);
1233 }
1234
1235 #[test]
1236 fn test_fallback_history_circular() {
1237 let stats = RoutingStats::new();
1238 for i in 0..210 {
1239 stats.record_fallback(FallbackEvent {
1240 timestamp: DateTime::from_timestamp(i as i64, 0).unwrap(),
1241 from_model: format!("model-{}", i),
1242 to_model: "fallback".to_string(),
1243 reason: "test".to_string(),
1244 success: true,
1245 });
1246 }
1247 let history = stats.fallback_history(200);
1248 assert_eq!(history.len(), 200);
1249 assert_eq!(history[0].from_model, "model-209");
1251 assert_eq!(history[199].from_model, "model-10");
1252 }
1253}