1use crate::config::OxiosConfig;
10use crate::credential::CredentialStore;
11use chrono::{DateTime, Utc};
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::PathBuf;
16use std::sync::Arc;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct RoutingConfigSnapshot {
24 pub routing_enabled: bool,
26 pub prefer_cost_efficient: bool,
28 pub fallback_models: Vec<String>,
30 pub excluded_models: Vec<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct RoutingStatsSnapshot {
38 pub model_calls: HashMap<String, u64>,
40 pub model_cost: HashMap<String, f64>,
42 pub total_requests: u64,
44 pub total_cost: f64,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct FallbackEvent {
52 pub timestamp: DateTime<Utc>,
54 pub from_model: String,
56 pub to_model: String,
58 pub reason: String,
60 pub success: bool,
62}
63
64#[derive(Debug, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct RoutingUpdate {
68 pub routing_enabled: Option<bool>,
69 pub prefer_cost_efficient: Option<bool>,
70 pub fallback_models: Option<Vec<String>>,
71 pub excluded_models: Option<Vec<String>>,
72}
73
74pub struct RoutingStats {
79 calls: RwLock<HashMap<String, u64>>,
80 costs: RwLock<HashMap<String, f64>>,
81 fallbacks: RwLock<Vec<FallbackEvent>>,
83}
84
85impl Default for RoutingStats {
86 fn default() -> Self {
87 Self {
88 calls: RwLock::new(HashMap::new()),
89 costs: RwLock::new(HashMap::new()),
90 fallbacks: RwLock::new(Vec::new()),
91 }
92 }
93}
94
95impl RoutingStats {
96 pub fn new() -> Self {
98 Self::default()
99 }
100
101 pub fn record_model_usage(&self, model_id: &str, cost_usd: f64) {
103 let mut calls = self.calls.write();
104 *calls.entry(model_id.to_string()).or_insert(0) += 1;
105 if cost_usd > 0.0 {
106 let mut costs = self.costs.write();
107 *costs.entry(model_id.to_string()).or_insert(0.0) += cost_usd;
108 }
109 }
110
111 pub fn record_fallback(&self, event: FallbackEvent) {
113 let mut fb = self.fallbacks.write();
114 fb.push(event);
115 let keep = fb.len().saturating_sub(200);
116 if keep > 0 {
117 fb.drain(0..keep);
118 }
119 }
120
121 pub fn snapshot(&self) -> RoutingStatsSnapshot {
123 let calls = self.calls.read();
124 let costs = self.costs.read();
125 let total_requests: u64 = calls.values().sum();
126 let total_cost: f64 = costs.values().sum();
127 RoutingStatsSnapshot {
128 model_calls: calls.clone(),
129 model_cost: costs.clone(),
130 total_requests,
131 total_cost,
132 }
133 }
134
135 pub fn fallback_history(&self, limit: usize) -> Vec<FallbackEvent> {
137 let fb = self.fallbacks.read();
138 fb.iter().rev().take(limit).cloned().collect()
139 }
140}
141
142pub fn estimate_cost(model_id: &str, input_tokens: u64, output_tokens: u64) -> f64 {
147 let entries = oxi_sdk::get_provider_models(model_id.split('/').next().unwrap_or(model_id));
148 let entry = entries
149 .iter()
150 .find(|e| format!("{}/{}", e.provider, e.id) == model_id);
151 match entry {
152 Some(e) => {
153 (e.cost_input * input_tokens as f64 / 1_000_000.0)
154 + (e.cost_output * output_tokens as f64 / 1_000_000.0)
155 }
156 None => {
157 (0.003 * input_tokens as f64 / 1_000_000.0)
159 + (0.015 * output_tokens as f64 / 1_000_000.0)
160 }
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
168#[serde(rename_all = "lowercase")]
169pub enum ProviderCategory {
170 Major,
172 Open,
174 Regional,
176 Local,
178}
179
180#[derive(Debug, Clone, Copy)]
193struct ProviderMeta {
194 id: &'static str,
196 display_name: &'static str,
198 category: ProviderCategory,
200 hidden: bool,
204 description: &'static str,
206 env_key: &'static str,
210 aliases: &'static [&'static str],
214}
215
216const PROVIDER_META: &[ProviderMeta] = &[
222 ProviderMeta {
224 id: "anthropic",
225 display_name: "Anthropic",
226 category: ProviderCategory::Major,
227 hidden: false,
228 description: "Claude models with extended thinking",
229 env_key: "ANTHROPIC_API_KEY",
230 aliases: &["anthropic"],
231 },
232 ProviderMeta {
233 id: "openai",
234 display_name: "OpenAI",
235 category: ProviderCategory::Major,
236 hidden: false,
237 description: "GPT, o-series, and Codex models",
238 env_key: "OPENAI_API_KEY",
239 aliases: &["openai"],
240 },
241 ProviderMeta {
242 id: "google",
243 display_name: "Google Gemini",
244 category: ProviderCategory::Major,
245 hidden: false,
246 description: "Gemini models with thinking and tool use",
247 env_key: "GOOGLE_API_KEY",
248 aliases: &["google"],
249 },
250 ProviderMeta {
252 id: "groq",
253 display_name: "Groq",
254 category: ProviderCategory::Open,
255 hidden: false,
256 description: "Fast Llama, Mixtral, and Gemma inference",
257 env_key: "GROQ_API_KEY",
258 aliases: &["groq"],
259 },
260 ProviderMeta {
261 id: "openrouter",
262 display_name: "OpenRouter",
263 category: ProviderCategory::Open,
264 hidden: false,
265 description: "Unified gateway to 200+ models",
266 env_key: "OPENROUTER_API_KEY",
267 aliases: &["openrouter"],
268 },
269 ProviderMeta {
270 id: "deepseek",
271 display_name: "DeepSeek",
272 category: ProviderCategory::Open,
273 hidden: false,
274 description: "DeepSeek-V3 and DeepSeek-R1",
275 env_key: "DEEPSEEK_API_KEY",
276 aliases: &["deepseek"],
277 },
278 ProviderMeta {
279 id: "mistral",
280 display_name: "Mistral",
281 category: ProviderCategory::Open,
282 hidden: false,
283 description: "Mistral and Codestral models",
284 env_key: "MISTRAL_API_KEY",
285 aliases: &["mistral"],
286 },
287 ProviderMeta {
288 id: "xai",
289 display_name: "xAI (Grok)",
290 category: ProviderCategory::Open,
291 hidden: false,
292 description: "Grok models from xAI",
293 env_key: "XAI_API_KEY",
294 aliases: &["xai", "grok"],
295 },
296 ProviderMeta {
297 id: "cerebras",
298 display_name: "Cerebras",
299 category: ProviderCategory::Open,
300 hidden: false,
301 description: "Ultra-fast open model inference",
302 env_key: "CEREBRAS_API_KEY",
303 aliases: &["cerebras"],
304 },
305 ProviderMeta {
306 id: "fireworks",
307 display_name: "Fireworks",
308 category: ProviderCategory::Open,
309 hidden: false,
310 description: "Fast open-source model serving",
311 env_key: "FIREWORKS_API_KEY",
312 aliases: &["fireworks"],
313 },
314 ProviderMeta {
315 id: "github-copilot",
316 display_name: "GitHub Copilot",
317 category: ProviderCategory::Open,
318 hidden: false,
319 description: "GitHub Copilot models (GPT-4, Claude)",
320 env_key: "GITHUB_COPILOT_TOKEN",
321 aliases: &["github-copilot", "copilot"],
322 },
323 ProviderMeta {
324 id: "huggingface",
325 display_name: "Hugging Face",
326 category: ProviderCategory::Open,
327 hidden: false,
328 description: "Open model inference hub",
329 env_key: "HUGGINGFACE_API_KEY",
330 aliases: &["huggingface", "hf"],
331 },
332 ProviderMeta {
333 id: "together",
334 display_name: "Together AI",
335 category: ProviderCategory::Open,
336 hidden: false,
337 description: "Open-source model hosting (Llama, Mixtral, ...)",
338 env_key: "TOGETHER_API_KEY",
339 aliases: &["together", "togetherai"],
340 },
341 ProviderMeta {
342 id: "opencode",
343 display_name: "OpenCode",
344 category: ProviderCategory::Open,
345 hidden: false,
346 description: "OpenCode coding agent gateway",
347 env_key: "",
348 aliases: &["opencode"],
349 },
350 ProviderMeta {
351 id: "perplexity",
352 display_name: "Perplexity",
353 category: ProviderCategory::Open,
354 hidden: false,
355 description: "Search-augmented answer models",
356 env_key: "PERPLEXITY_API_KEY",
357 aliases: &["perplexity"],
358 },
359 ProviderMeta {
360 id: "cohere",
361 display_name: "Cohere",
362 category: ProviderCategory::Open,
363 hidden: false,
364 description: "Cohere Command and Embed models",
365 env_key: "COHERE_API_KEY",
366 aliases: &["cohere"],
367 },
368 ProviderMeta {
370 id: "minimax",
371 display_name: "MiniMax",
372 category: ProviderCategory::Regional,
373 hidden: false,
374 description: "MiniMax-M2.7, abab models",
375 env_key: "MINIMAX_API_KEY",
376 aliases: &["minimax"],
377 },
378 ProviderMeta {
379 id: "moonshotai",
380 display_name: "Moonshot AI (Kimi)",
381 category: ProviderCategory::Regional,
382 hidden: false,
383 description: "Kimi models from Moonshot AI",
384 env_key: "MOONSHOT_API_KEY",
385 aliases: &["moonshotai", "moonshot", "kimi"],
386 },
387 ProviderMeta {
388 id: "kimi-coding",
389 display_name: "Kimi Coding",
390 category: ProviderCategory::Regional,
391 hidden: false,
392 description: "Kimi Coding Plan — optimized for coding",
393 env_key: "KIMI_CODING_API_KEY",
394 aliases: &["kimi-coding"],
395 },
396 ProviderMeta {
397 id: "zai",
398 display_name: "Z.AI (GLM)",
399 category: ProviderCategory::Regional,
400 hidden: false,
401 description: "Z.AI GLM models (coding plan)",
402 env_key: "ZAI_API_KEY",
403 aliases: &["zai"],
404 },
405 ProviderMeta {
411 id: "amazon-bedrock",
412 display_name: "Amazon Bedrock",
413 category: ProviderCategory::Open,
414 hidden: true,
415 description: "Multi-model via AWS Bedrock ConverseStream",
416 env_key: "AWS_ACCESS_KEY_ID",
417 aliases: &["amazon-bedrock", "aws-bedrock", "bedrock"],
418 },
419 ProviderMeta {
420 id: "azure-openai-responses",
421 display_name: "Azure OpenAI (Responses)",
422 category: ProviderCategory::Open,
423 hidden: true,
424 description: "OpenAI models via Azure Cognitive Services",
425 env_key: "AZURE_OPENAI_API_KEY",
426 aliases: &["azure-openai-responses", "azure"],
427 },
428 ProviderMeta {
429 id: "cloudflare-ai-gateway",
430 display_name: "Cloudflare AI Gateway",
431 category: ProviderCategory::Open,
432 hidden: true,
433 description: "Serverless AI via Cloudflare AI Gateway",
434 env_key: "CLOUDFLARE_API_TOKEN",
435 aliases: &["cloudflare-ai-gateway", "cf-ai-gateway"],
436 },
437 ProviderMeta {
438 id: "cloudflare-workers-ai",
439 display_name: "Cloudflare Workers AI",
440 category: ProviderCategory::Open,
441 hidden: true,
442 description: "Serverless AI via Cloudflare Workers",
443 env_key: "CLOUDFLARE_API_KEY",
444 aliases: &["cloudflare-workers-ai", "cloudflare", "workers-ai"],
445 },
446 ProviderMeta {
447 id: "google-vertex",
448 display_name: "Google Vertex AI",
449 category: ProviderCategory::Open,
450 hidden: true,
451 description: "Gemini via Google Cloud Vertex AI",
452 env_key: "GOOGLE_APPLICATION_CREDENTIALS",
453 aliases: &["google-vertex", "vertex"],
454 },
455 ProviderMeta {
456 id: "minimax-cn",
457 display_name: "MiniMax (China)",
458 category: ProviderCategory::Regional,
459 hidden: true,
460 description: "MiniMax China region endpoint",
461 env_key: "MINIMAX_CN_API_KEY",
462 aliases: &["minimax-cn"],
463 },
464 ProviderMeta {
465 id: "moonshotai-cn",
466 display_name: "Moonshot AI (China)",
467 category: ProviderCategory::Regional,
468 hidden: true,
469 description: "Kimi models — China region endpoint",
470 env_key: "MOONSHOT_CN_API_KEY",
471 aliases: &["moonshotai-cn", "moonshot-cn"],
472 },
473 ProviderMeta {
474 id: "openai-codex",
475 display_name: "OpenAI Codex",
476 category: ProviderCategory::Open,
477 hidden: true,
478 description: "OpenAI Codex coding agent (Responses API)",
479 env_key: "OPENAI_API_KEY",
480 aliases: &["openai-codex"],
481 },
482 ProviderMeta {
483 id: "opencode-go",
484 display_name: "OpenCode Go",
485 category: ProviderCategory::Open,
486 hidden: true,
487 description: "OpenCode Go Gateway",
488 env_key: "OPENCODE_GO_API_KEY",
489 aliases: &["opencode-go"],
490 },
491 ProviderMeta {
492 id: "vercel-ai-gateway",
493 display_name: "Vercel AI Gateway",
494 category: ProviderCategory::Open,
495 hidden: true,
496 description: "Vercel AI Gateway",
497 env_key: "VERCEL_API_KEY",
498 aliases: &["vercel-ai-gateway", "vercel"],
499 },
500 ProviderMeta {
501 id: "xiaomi",
502 display_name: "Xiaomi MiMo",
503 category: ProviderCategory::Regional,
504 hidden: true,
505 description: "Xiaomi MiMo models",
506 env_key: "XIAOMI_API_KEY",
507 aliases: &["xiaomi"],
508 },
509];
510
511fn provider_meta(id: &str) -> Option<&'static ProviderMeta> {
513 PROVIDER_META
514 .iter()
515 .find(|m| m.id == id || m.aliases.contains(&id))
516}
517
518fn provider_category(id: &str) -> ProviderCategory {
519 provider_meta(id)
520 .map(|m| m.category)
521 .unwrap_or(ProviderCategory::Open)
522}
523
524fn provider_display_name(id: &str) -> String {
530 provider_meta(id)
531 .map(|m| m.display_name.to_string())
532 .unwrap_or_else(|| fallback_display_name(id))
533}
534
535fn fallback_display_name(id: &str) -> String {
541 id.split(['-', '_'])
542 .filter(|s| !s.is_empty())
543 .map(|segment| {
544 let mut chars = segment.chars();
545 match chars.next() {
546 Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
547 None => String::new(),
548 }
549 })
550 .collect::<Vec<_>>()
551 .join(" ")
552}
553
554#[derive(Debug, Clone, Serialize, Deserialize)]
556#[serde(rename_all = "camelCase")]
557pub struct ProviderInfo {
558 pub id: String,
560 pub name: String,
562 pub category: ProviderCategory,
564 pub model_count: usize,
566 pub has_key: bool,
568 #[serde(default)]
571 pub description: String,
572 #[serde(default)]
576 pub env_key: String,
577}
578
579#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
581#[serde(rename_all = "lowercase")]
582pub enum InputModality {
583 Text,
585 Image,
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
591#[serde(rename_all = "camelCase")]
592pub struct ModelInfo {
593 pub id: String,
595 pub name: String,
597 pub api: String,
599 pub provider: String,
601 pub reasoning: bool,
603 pub input: Vec<InputModality>,
605 pub context_window: u32,
607 pub max_tokens: u32,
609 pub cost_input: f64,
611 pub cost_output: f64,
613 pub cost_cache_read: f64,
615 pub cost_cache_write: f64,
617}
618
619impl From<&oxi_sdk::ModelEntry> for ModelInfo {
620 fn from(entry: &oxi_sdk::ModelEntry) -> Self {
621 Self {
622 id: format!("{}/{}", entry.provider, entry.id),
623 name: entry.name.to_string(),
624 api: entry.api.to_string(),
625 provider: entry.provider.to_string(),
626 reasoning: entry.reasoning,
627 input: entry
628 .input
629 .iter()
630 .map(|m| match m {
631 oxi_sdk::InputModality::Text => InputModality::Text,
632 oxi_sdk::InputModality::Image => InputModality::Image,
633 _ => InputModality::Text,
634 })
635 .collect(),
636 context_window: entry.context_window,
637 max_tokens: entry.max_tokens,
638 cost_input: entry.cost_input,
639 cost_output: entry.cost_output,
640 cost_cache_read: entry.cost_cache_read,
641 cost_cache_write: entry.cost_cache_write,
642 }
643 }
644}
645
646impl From<&oxi_sdk::CatalogModelEntry> for ModelInfo {
647 fn from(entry: &oxi_sdk::CatalogModelEntry) -> Self {
653 Self {
654 id: format!("{}/{}", entry.provider, entry.model_id),
655 name: entry.name.clone(),
656 api: entry.protocol.as_str().to_string(),
657 provider: entry.provider.clone(),
658 reasoning: entry.reasoning,
659 input: entry
660 .input_modalities
661 .iter()
662 .map(|m| match m.as_str() {
663 "image" => InputModality::Image,
664 _ => InputModality::Text,
665 })
666 .collect(),
667 context_window: entry.context_window,
668 max_tokens: entry.max_tokens,
669 cost_input: entry.cost_input,
670 cost_output: entry.cost_output,
671 cost_cache_read: entry.cost_cache_read,
672 cost_cache_write: entry.cost_cache_write,
673 }
674 }
675}
676
677#[derive(Debug, Clone, Serialize, Deserialize)]
679pub struct EngineConfigResponse {
680 pub default_model: String,
682 pub api_key_set: bool,
684 pub api_key_source: Option<String>,
686 pub provider: Option<String>,
688 pub routing: RoutingConfigSnapshot,
690}
691
692#[derive(Debug, Clone, Serialize, Deserialize)]
694pub struct ValidateKeyResult {
695 pub valid: bool,
697 pub provider: String,
699 pub message: Option<String>,
701}
702
703pub struct EngineApi {
715 config: Arc<RwLock<OxiosConfig>>,
716 config_path: PathBuf,
717 routing_stats: Arc<RoutingStats>,
718 engine_handle: Arc<crate::engine::EngineHandle>,
720}
721
722impl EngineApi {
723 pub fn new(
730 config: Arc<RwLock<OxiosConfig>>,
731 config_path: PathBuf,
732 routing_stats: Arc<RoutingStats>,
733 engine_handle: Arc<crate::engine::EngineHandle>,
734 ) -> Self {
735 Self {
736 config,
737 config_path,
738 routing_stats,
739 engine_handle,
740 }
741 }
742
743 pub fn routing_stats(&self) -> Arc<RoutingStats> {
745 Arc::clone(&self.routing_stats)
746 }
747
748 pub fn engine_handle(&self) -> &Arc<crate::engine::EngineHandle> {
750 &self.engine_handle
751 }
752
753 pub fn providers(&self) -> Vec<ProviderInfo> {
769 let catalog = self.engine_handle.get().oxi().catalog().clone();
770 let use_catalog = catalog.model_count_sync() > 0;
771 let all: Vec<String> = if use_catalog {
772 catalog.list_providers_sync()
773 } else {
774 oxi_sdk::get_providers()
775 .into_iter()
776 .map(|s| s.to_string())
777 .collect()
778 };
779
780 all.into_iter()
781 .filter(|p| provider_meta(p).map(|m| !m.hidden).unwrap_or(true))
782 .map(|p| {
783 let model_count = if use_catalog {
784 catalog.list_models_sync(&p).len()
785 } else {
786 oxi_sdk::get_provider_models(&p).len()
787 };
788 let has_key = CredentialStore::has_credential(
789 &p,
790 self.config
791 .read()
792 .engine
793 .api_key
794 .as_deref()
795 .filter(|k| !k.is_empty()),
796 );
797 let meta = provider_meta(&p);
798 ProviderInfo {
799 id: p.clone(),
800 name: provider_display_name(&p),
801 category: provider_category(&p),
802 model_count,
803 has_key,
804 description: meta.map(|m| m.description.to_string()).unwrap_or_default(),
805 env_key: meta.map(|m| m.env_key.to_string()).unwrap_or_default(),
806 }
807 })
808 .collect()
809 }
810
811 pub fn models(&self, provider: &str, query: Option<&str>) -> Vec<ModelInfo> {
817 let catalog = self.engine_handle.get().oxi().catalog().clone();
818 let live = catalog.list_models_sync(provider);
819 let models: Vec<ModelInfo> = if !live.is_empty() {
820 live.iter().map(ModelInfo::from).collect()
821 } else {
822 oxi_sdk::get_provider_models(provider)
823 .iter()
824 .map(ModelInfo::from)
825 .collect()
826 };
827 models
828 .into_iter()
829 .filter(|m| !m.name.contains("latest"))
830 .filter(|m| {
831 if let Some(q) = query {
832 let q = q.to_lowercase();
833 m.name.to_lowercase().contains(&q)
834 || m.id.to_lowercase().contains(&q)
835 || m.provider.to_lowercase().contains(&q)
836 } else {
837 true
838 }
839 })
840 .collect()
841 }
842
843 pub fn search_models(&self, query: &str) -> Vec<ModelInfo> {
848 let catalog = self.engine_handle.get().oxi().catalog().clone();
849 let live = catalog.search_sync(query);
850 if !live.is_empty() {
851 live.iter().map(ModelInfo::from).collect()
852 } else {
853 oxi_sdk::search_models(query)
854 .into_iter()
855 .map(ModelInfo::from)
856 .collect()
857 }
858 }
859
860 pub fn config(&self) -> EngineConfigResponse {
862 let cfg = self.config.read();
863 let provider =
864 CredentialStore::provider_from_model(&cfg.engine.default_model).map(|s| s.to_string());
865 let api_key_source = provider.as_deref().and_then(|p| {
866 CredentialStore::resolve(p, cfg.api_key().as_deref()).map(|(_, src)| {
867 match src {
868 crate::credential::CredentialSource::EnvVar => "env",
869 crate::credential::CredentialSource::Config => "config",
870 crate::credential::CredentialSource::OxiAuthStore => "auth_store",
871 }
872 .to_string()
873 })
874 });
875 let api_key_set = provider
876 .as_deref()
877 .map(|p| CredentialStore::has_credential(p, cfg.api_key().as_deref()))
878 .unwrap_or(false);
879
880 EngineConfigResponse {
881 default_model: cfg.engine.default_model.clone(),
882 api_key_set,
883 api_key_source,
884 provider,
885 routing: RoutingConfigSnapshot {
886 routing_enabled: cfg.engine.routing_enabled,
887 prefer_cost_efficient: cfg.engine.prefer_cost_efficient,
888 fallback_models: cfg.engine.fallback_models.clone(),
889 excluded_models: cfg.engine.excluded_models.clone(),
890 },
891 }
892 }
893
894 pub fn routing_stats_snapshot(&self) -> RoutingStatsSnapshot {
896 self.routing_stats.snapshot()
897 }
898
899 pub fn fallback_history(&self, limit: usize) -> Vec<FallbackEvent> {
901 self.routing_stats.fallback_history(limit)
902 }
903
904 pub fn set_model(&self, model_id: &str) -> anyhow::Result<()> {
911 {
912 let mut cfg = self.config.write();
913 cfg.engine.default_model = model_id.to_string();
914 self.persist(&cfg)?;
915 }
916 tracing::info!(model = %model_id, "Default model updated in config");
917 self.rebuild_and_swap();
918 Ok(())
919 }
920
921 pub fn set_api_key(&self, provider: &str, key: &str) -> anyhow::Result<()> {
927 CredentialStore::store(provider, key)?;
928
929 let cfg = self.config.read();
931 if let Some(current_provider) =
932 CredentialStore::provider_from_model(&cfg.engine.default_model)
933 && current_provider == provider
934 {
935 drop(cfg);
936 let mut cfg = self.config.write();
937 cfg.engine.api_key = Some(key.to_string());
938 self.persist(&cfg)?;
939 }
940 tracing::info!(provider = %provider, "API key stored");
941 self.rebuild_and_swap();
942 Ok(())
943 }
944
945 pub fn set_provider_options(&self, opts: &oxi_sdk::ProviderOptions) -> anyhow::Result<()> {
950 {
951 let mut cfg = self.config.write();
952 cfg.engine.provider_options = Some(opts.clone());
953 self.persist(&cfg)?;
954 }
955 tracing::info!("Provider options updated and persisted");
956 Ok(())
959 }
960
961 pub fn set_routing(&self, update: RoutingUpdate) -> anyhow::Result<()> {
966 {
967 let mut cfg = self.config.write();
968 if let Some(v) = update.routing_enabled {
969 cfg.engine.routing_enabled = v;
970 }
971 if let Some(v) = update.prefer_cost_efficient {
972 cfg.engine.prefer_cost_efficient = v;
973 }
974 if let Some(v) = update.fallback_models {
975 cfg.engine.fallback_models = v;
976 }
977 if let Some(v) = update.excluded_models {
978 cfg.engine.excluded_models = v;
979 }
980 self.persist(&cfg)?;
981 }
982 tracing::info!("Routing configuration updated via API");
983 self.rebuild_and_swap();
984 Ok(())
985 }
986
987 pub fn validate_key(&self, provider: &str, api_key: &str) -> ValidateKeyResult {
992 let result = self.try_validate(provider, api_key);
994 match result {
995 Ok(()) => ValidateKeyResult {
996 valid: true,
997 provider: provider.to_string(),
998 message: Some("API key is valid".to_string()),
999 },
1000 Err(e) => ValidateKeyResult {
1001 valid: false,
1002 provider: provider.to_string(),
1003 message: Some(format!("Validation failed: {e}")),
1004 },
1005 }
1006 }
1007
1008 fn try_validate(&self, provider: &str, api_key: &str) -> anyhow::Result<()> {
1010 let builder = oxi_sdk::OxiBuilder::new()
1012 .with_builtins()
1013 .api_key(provider, api_key);
1014 let oxi = builder.build();
1015
1016 let models = oxi_sdk::get_provider_models(provider);
1018 if models.is_empty() {
1019 anyhow::bail!("No models found for provider '{provider}'");
1020 }
1021
1022 let model_id = format!("{}/{}", provider, models[0].id);
1023 let _model = oxi.resolve_model(&model_id)?;
1024
1025 let _provider = oxi.create_provider(provider)?;
1027
1028 if api_key.is_empty() {
1032 anyhow::bail!("API key is empty");
1033 }
1034 if api_key.len() < 8 {
1035 anyhow::bail!("API key appears too short");
1036 }
1037
1038 tracing::debug!(
1039 provider = %provider,
1040 model = %model_id,
1041 "Key validation: provider resolved with injected key"
1042 );
1043 Ok(())
1044 }
1045
1046 pub fn estimate_cost(model_id: &str, input_tokens: u64, output_tokens: u64) -> f64 {
1050 estimate_cost(model_id, input_tokens, output_tokens)
1051 }
1052
1053 fn persist(&self, config: &OxiosConfig) -> anyhow::Result<()> {
1055 let content = toml::to_string_pretty(config)
1056 .map_err(|e| anyhow::anyhow!("Failed to serialize config: {e}"))?;
1057 std::fs::write(&self.config_path, content)?;
1058 Ok(())
1059 }
1060
1061 fn rebuild_and_swap(&self) {
1068 let cfg = self.config.read();
1069 let model_id = &cfg.engine.default_model;
1070 let catalog = self.engine_handle.get().oxi().catalog().clone();
1072 let new_engine = crate::engine::OxiosEngine::from_config_with_catalog(
1073 model_id,
1074 cfg.api_key().as_deref(),
1075 catalog,
1076 );
1077 drop(cfg);
1078 self.engine_handle.swap(new_engine);
1079 }
1080}
1081
1082impl std::fmt::Debug for EngineApi {
1083 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1084 f.debug_struct("EngineApi")
1085 .field("config_path", &self.config_path)
1086 .finish()
1087 }
1088}
1089
1090pub fn record_usage_to_stats(
1093 stats: &Option<Arc<RoutingStats>>,
1094 model_id: &str,
1095 input_tokens: u64,
1096 output_tokens: u64,
1097) {
1098 if let Some(s) = stats {
1099 let cost = estimate_cost(model_id, input_tokens, output_tokens);
1100 s.record_model_usage(model_id, cost);
1101 }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106 use super::*;
1107
1108 #[test]
1109 fn test_provider_category_known() {
1110 assert_eq!(provider_category("anthropic"), ProviderCategory::Major);
1112 assert_eq!(provider_category("openai"), ProviderCategory::Major);
1113 assert_eq!(provider_category("google"), ProviderCategory::Major);
1114 assert_eq!(provider_category("groq"), ProviderCategory::Open);
1116 assert_eq!(provider_category("opencode"), ProviderCategory::Open);
1117 assert_eq!(provider_category("minimax"), ProviderCategory::Regional);
1119 assert_eq!(provider_category("moonshotai"), ProviderCategory::Regional);
1120 assert_eq!(provider_category("kimi-coding"), ProviderCategory::Regional);
1121 assert_eq!(provider_category("zai"), ProviderCategory::Regional);
1122 assert_eq!(provider_category("minimax-cn"), ProviderCategory::Regional);
1123 assert_eq!(provider_category("xiaomi"), ProviderCategory::Regional);
1124 }
1125
1126 #[test]
1127 fn test_provider_category_fallback() {
1128 assert_eq!(
1130 provider_category("not-a-real-provider"),
1131 ProviderCategory::Open
1132 );
1133 assert_eq!(provider_category(""), ProviderCategory::Open);
1134 }
1135
1136 #[test]
1137 fn test_provider_display_name_known() {
1138 assert_eq!(provider_display_name("anthropic"), "Anthropic");
1139 assert_eq!(provider_display_name("minimax"), "MiniMax");
1140 assert_eq!(provider_display_name("moonshotai"), "Moonshot AI (Kimi)");
1141 assert_eq!(provider_display_name("kimi-coding"), "Kimi Coding");
1142 assert_eq!(provider_display_name("zai"), "Z.AI (GLM)");
1143 assert_eq!(provider_display_name("opencode"), "OpenCode");
1144 assert_eq!(provider_display_name("amazon-bedrock"), "Amazon Bedrock");
1145 }
1146
1147 #[test]
1148 fn test_provider_display_name_fallback() {
1149 assert_eq!(
1151 provider_display_name("some-new-provider"),
1152 "Some New Provider"
1153 );
1154 assert_eq!(provider_display_name("kimi-coding"), "Kimi Coding");
1155 assert_eq!(provider_display_name("some_id"), "Some Id");
1156 assert_eq!(provider_display_name(""), "");
1158 }
1159
1160 #[test]
1161 fn test_provider_meta_lookup_by_alias() {
1162 let by_id = provider_meta("github-copilot").unwrap();
1164 let by_alias = provider_meta("copilot").unwrap();
1165 assert_eq!(by_id.id, by_alias.id);
1166
1167 let bedrock_id = provider_meta("amazon-bedrock").unwrap();
1168 let bedrock_alias = provider_meta("aws-bedrock").unwrap();
1169 let bedrock_canonical = provider_meta("bedrock").unwrap();
1170 assert_eq!(bedrock_id.id, bedrock_alias.id);
1171 assert_eq!(bedrock_id.id, bedrock_canonical.id);
1172 }
1173
1174 #[test]
1175 fn test_provider_meta_unknown_is_none() {
1176 assert!(provider_meta("not-a-real-provider").is_none());
1177 assert!(provider_meta("").is_none());
1178 }
1179
1180 #[test]
1181 fn test_provider_info_serialization() {
1182 let info = ProviderInfo {
1183 id: "anthropic".to_string(),
1184 name: "Anthropic".to_string(),
1185 category: ProviderCategory::Major,
1186 model_count: 15,
1187 has_key: true,
1188 description: "Claude models with extended thinking".to_string(),
1189 env_key: "ANTHROPIC_API_KEY".to_string(),
1190 };
1191 let json = serde_json::to_string(&info).unwrap();
1192 assert!(json.contains("\"modelCount\":15"));
1194 assert!(json.contains("\"hasKey\":true"));
1195 assert!(json.contains("\"envKey\":\"ANTHROPIC_API_KEY\""));
1196 let restored: ProviderInfo = serde_json::from_str(&json).unwrap();
1197 assert_eq!(restored.id, "anthropic");
1198 assert_eq!(restored.name, "Anthropic");
1199 assert_eq!(restored.model_count, 15);
1200 assert!(restored.has_key);
1201 assert_eq!(restored.env_key, "ANTHROPIC_API_KEY");
1202 }
1203
1204 #[test]
1205 fn test_provider_info_serialization_missing_optional() {
1206 let json = r#"{
1209 "id": "anthropic",
1210 "name": "Anthropic",
1211 "category": "major",
1212 "modelCount": 15,
1213 "hasKey": true
1214 }"#;
1215 let info: ProviderInfo = serde_json::from_str(json).unwrap();
1216 assert_eq!(info.id, "anthropic");
1217 assert_eq!(info.description, "");
1218 assert_eq!(info.env_key, "");
1219 }
1220
1221 #[test]
1222 fn test_model_info_serialization() {
1223 let info = ModelInfo {
1224 id: "anthropic/claude-sonnet-4".to_string(),
1225 name: "Claude Sonnet 4".to_string(),
1226 api: "anthropic-messages".to_string(),
1227 provider: "anthropic".to_string(),
1228 reasoning: true,
1229 input: vec![InputModality::Text, InputModality::Image],
1230 context_window: 200000,
1231 max_tokens: 16000,
1232 cost_input: 3.0,
1233 cost_output: 15.0,
1234 cost_cache_read: 0.3,
1235 cost_cache_write: 3.75,
1236 };
1237 let json = serde_json::to_string(&info).unwrap();
1238 let restored: ModelInfo = serde_json::from_str(&json).unwrap();
1239 assert_eq!(restored.id, "anthropic/claude-sonnet-4");
1240 assert!(restored.reasoning);
1241 assert_eq!(restored.context_window, 200000);
1242 assert!(restored.input.contains(&InputModality::Image));
1243 assert_eq!(restored.api, "anthropic-messages");
1244 }
1245
1246 #[test]
1247 fn test_engine_config_response_serialization() {
1248 let resp = EngineConfigResponse {
1249 default_model: "anthropic/claude-sonnet-4".to_string(),
1250 api_key_set: true,
1251 api_key_source: Some("config.toml".to_string()),
1252 provider: Some("anthropic".to_string()),
1253 routing: RoutingConfigSnapshot {
1254 routing_enabled: false,
1255 prefer_cost_efficient: false,
1256 fallback_models: vec![],
1257 excluded_models: vec![],
1258 },
1259 };
1260 let json = serde_json::to_string(&resp).unwrap();
1261 let restored: EngineConfigResponse = serde_json::from_str(&json).unwrap();
1262 assert_eq!(restored.default_model, "anthropic/claude-sonnet-4");
1263 assert!(restored.api_key_set);
1264 assert_eq!(restored.api_key_source.as_deref(), Some("config.toml"));
1265 assert!(!restored.routing.routing_enabled);
1266 }
1267
1268 #[test]
1269 fn test_validate_key_result_serialization() {
1270 let result = ValidateKeyResult {
1271 valid: true,
1272 provider: "openai".to_string(),
1273 message: Some("API key is valid".to_string()),
1274 };
1275 let json = serde_json::to_string(&result).unwrap();
1276 let restored: ValidateKeyResult = serde_json::from_str(&json).unwrap();
1277 assert!(restored.valid);
1278 assert_eq!(restored.provider, "openai");
1279 }
1280
1281 #[test]
1282 fn test_validate_key_result_invalid() {
1283 let result = ValidateKeyResult {
1284 valid: false,
1285 provider: "anthropic".to_string(),
1286 message: Some("Validation failed: key too short".to_string()),
1287 };
1288 assert!(!result.valid);
1289 assert!(result.message.as_ref().unwrap().contains("failed"));
1290 }
1291
1292 #[test]
1293 fn test_routing_stats_snapshot() {
1294 let stats = RoutingStats::new();
1295 stats.record_model_usage("anthropic/claude-sonnet-4", 0.05);
1296 stats.record_model_usage("anthropic/claude-sonnet-4", 0.03);
1297 stats.record_model_usage("openai/gpt-4o-mini", 0.01);
1298
1299 let snap = stats.snapshot();
1300 assert_eq!(snap.total_requests, 3);
1301 assert_eq!(snap.model_calls["anthropic/claude-sonnet-4"], 2);
1302 assert_eq!(snap.model_calls["openai/gpt-4o-mini"], 1);
1303 assert!((snap.total_cost - 0.09).abs() < 0.001);
1304 }
1305
1306 #[test]
1307 fn test_fallback_history_circular() {
1308 let stats = RoutingStats::new();
1309 for i in 0..210 {
1310 stats.record_fallback(FallbackEvent {
1311 timestamp: DateTime::from_timestamp(i as i64, 0).unwrap(),
1312 from_model: format!("model-{}", i),
1313 to_model: "fallback".to_string(),
1314 reason: "test".to_string(),
1315 success: true,
1316 });
1317 }
1318 let history = stats.fallback_history(200);
1319 assert_eq!(history.len(), 200);
1320 assert_eq!(history[0].from_model, "model-209");
1322 assert_eq!(history[199].from_model, "model-10");
1323 }
1324}