1use crate::config::OxiosConfig;
10use crate::credential::CredentialStore;
11use anyhow::Context;
12use chrono::{DateTime, Utc};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(rename_all = "camelCase")]
24pub struct RoutingConfigSnapshot {
25 pub routing_enabled: bool,
27 pub prefer_cost_efficient: bool,
29 pub fallback_models: Vec<String>,
31 pub excluded_models: Vec<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct RoutingStatsSnapshot {
39 pub model_calls: HashMap<String, u64>,
41 pub model_cost: HashMap<String, f64>,
43 pub total_requests: u64,
45 pub total_cost: f64,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct FallbackEvent {
53 pub timestamp: DateTime<Utc>,
55 pub from_model: String,
57 pub to_model: String,
59 pub reason: String,
61 pub success: bool,
63}
64
65#[derive(Debug, Deserialize)]
67#[serde(rename_all = "camelCase")]
68pub struct RoutingUpdate {
69 pub routing_enabled: Option<bool>,
70 pub prefer_cost_efficient: Option<bool>,
71 pub fallback_models: Option<Vec<String>>,
72 pub excluded_models: Option<Vec<String>>,
73}
74
75pub struct RoutingStats {
80 calls: RwLock<HashMap<String, u64>>,
81 costs: RwLock<HashMap<String, f64>>,
82 fallbacks: RwLock<std::collections::VecDeque<FallbackEvent>>,
84}
85
86impl Default for RoutingStats {
87 fn default() -> Self {
88 Self {
89 calls: RwLock::new(HashMap::new()),
90 costs: RwLock::new(HashMap::new()),
91 fallbacks: RwLock::new(std::collections::VecDeque::new()),
92 }
93 }
94}
95
96impl RoutingStats {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn record_model_usage(&self, model_id: &str, cost_usd: f64) {
104 let mut calls = self.calls.write();
105 *calls.entry(model_id.to_string()).or_insert(0) += 1;
106 if cost_usd > 0.0 {
107 let mut costs = self.costs.write();
108 *costs.entry(model_id.to_string()).or_insert(0.0) += cost_usd;
109 }
110 }
111
112 pub fn record_fallback(&self, event: FallbackEvent) {
117 let mut fb = self.fallbacks.write();
118 fb.push_back(event);
119 while fb.len() > 200 {
120 fb.pop_front();
121 }
122 }
123
124 pub fn snapshot(&self) -> RoutingStatsSnapshot {
126 let calls = self.calls.read();
127 let costs = self.costs.read();
128 let total_requests: u64 = calls.values().sum();
129 let total_cost: f64 = costs.values().sum();
130 RoutingStatsSnapshot {
131 model_calls: calls.clone(),
132 model_cost: costs.clone(),
133 total_requests,
134 total_cost,
135 }
136 }
137
138 pub fn fallback_history(&self, limit: usize) -> Vec<FallbackEvent> {
140 let fb = self.fallbacks.read();
141 fb.iter().rev().take(limit).cloned().collect()
142 }
143}
144
145pub fn estimate_cost(model_id: &str, input_tokens: u64, output_tokens: u64) -> f64 {
150 let entries = oxi_sdk::get_provider_models(model_id.split('/').next().unwrap_or(model_id));
151 let entry = entries
152 .iter()
153 .find(|e| format!("{}/{}", e.provider, e.id) == model_id);
154 match entry {
155 Some(e) => {
156 (e.cost_input * input_tokens as f64 / 1_000_000.0)
157 + (e.cost_output * output_tokens as f64 / 1_000_000.0)
158 }
159 None => {
160 (0.003 * input_tokens as f64 / 1_000_000.0)
162 + (0.015 * output_tokens as f64 / 1_000_000.0)
163 }
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum ProviderCategory {
173 Major,
175 Open,
177 Regional,
179 Local,
181}
182
183#[derive(Debug, Clone, Copy)]
196struct ProviderMeta {
197 id: &'static str,
199 display_name: &'static str,
201 category: ProviderCategory,
203 hidden: bool,
207 description: &'static str,
209 env_key: &'static str,
213 aliases: &'static [&'static str],
217}
218
219const PROVIDER_META: &[ProviderMeta] = &[
225 ProviderMeta {
227 id: "anthropic",
228 display_name: "Anthropic",
229 category: ProviderCategory::Major,
230 hidden: false,
231 description: "Claude models with extended thinking",
232 env_key: "ANTHROPIC_API_KEY",
233 aliases: &["anthropic"],
234 },
235 ProviderMeta {
236 id: "openai",
237 display_name: "OpenAI",
238 category: ProviderCategory::Major,
239 hidden: false,
240 description: "GPT, o-series, and Codex models",
241 env_key: "OPENAI_API_KEY",
242 aliases: &["openai"],
243 },
244 ProviderMeta {
245 id: "google",
246 display_name: "Google Gemini",
247 category: ProviderCategory::Major,
248 hidden: false,
249 description: "Gemini models with thinking and tool use",
250 env_key: "GOOGLE_API_KEY",
251 aliases: &["google"],
252 },
253 ProviderMeta {
255 id: "groq",
256 display_name: "Groq",
257 category: ProviderCategory::Open,
258 hidden: false,
259 description: "Fast Llama, Mixtral, and Gemma inference",
260 env_key: "GROQ_API_KEY",
261 aliases: &["groq"],
262 },
263 ProviderMeta {
264 id: "openrouter",
265 display_name: "OpenRouter",
266 category: ProviderCategory::Open,
267 hidden: false,
268 description: "Unified gateway to 200+ models",
269 env_key: "OPENROUTER_API_KEY",
270 aliases: &["openrouter"],
271 },
272 ProviderMeta {
273 id: "deepseek",
274 display_name: "DeepSeek",
275 category: ProviderCategory::Open,
276 hidden: false,
277 description: "DeepSeek-V3 and DeepSeek-R1",
278 env_key: "DEEPSEEK_API_KEY",
279 aliases: &["deepseek"],
280 },
281 ProviderMeta {
282 id: "mistral",
283 display_name: "Mistral",
284 category: ProviderCategory::Open,
285 hidden: false,
286 description: "Mistral and Codestral models",
287 env_key: "MISTRAL_API_KEY",
288 aliases: &["mistral"],
289 },
290 ProviderMeta {
291 id: "xai",
292 display_name: "xAI (Grok)",
293 category: ProviderCategory::Open,
294 hidden: false,
295 description: "Grok models from xAI",
296 env_key: "XAI_API_KEY",
297 aliases: &["xai", "grok"],
298 },
299 ProviderMeta {
300 id: "cerebras",
301 display_name: "Cerebras",
302 category: ProviderCategory::Open,
303 hidden: false,
304 description: "Ultra-fast open model inference",
305 env_key: "CEREBRAS_API_KEY",
306 aliases: &["cerebras"],
307 },
308 ProviderMeta {
309 id: "fireworks",
310 display_name: "Fireworks",
311 category: ProviderCategory::Open,
312 hidden: false,
313 description: "Fast open-source model serving",
314 env_key: "FIREWORKS_API_KEY",
315 aliases: &["fireworks"],
316 },
317 ProviderMeta {
318 id: "github-copilot",
319 display_name: "GitHub Copilot",
320 category: ProviderCategory::Open,
321 hidden: false,
322 description: "GitHub Copilot models (GPT-4, Claude)",
323 env_key: "GITHUB_COPILOT_TOKEN",
324 aliases: &["github-copilot", "copilot"],
325 },
326 ProviderMeta {
327 id: "huggingface",
328 display_name: "Hugging Face",
329 category: ProviderCategory::Open,
330 hidden: false,
331 description: "Open model inference hub",
332 env_key: "HUGGINGFACE_API_KEY",
333 aliases: &["huggingface", "hf"],
334 },
335 ProviderMeta {
336 id: "together",
337 display_name: "Together AI",
338 category: ProviderCategory::Open,
339 hidden: false,
340 description: "Open-source model hosting (Llama, Mixtral, ...)",
341 env_key: "TOGETHER_API_KEY",
342 aliases: &["together", "togetherai"],
343 },
344 ProviderMeta {
345 id: "opencode",
346 display_name: "OpenCode",
347 category: ProviderCategory::Open,
348 hidden: false,
349 description: "OpenCode coding agent gateway",
350 env_key: "",
351 aliases: &["opencode"],
352 },
353 ProviderMeta {
354 id: "perplexity",
355 display_name: "Perplexity",
356 category: ProviderCategory::Open,
357 hidden: false,
358 description: "Search-augmented answer models",
359 env_key: "PERPLEXITY_API_KEY",
360 aliases: &["perplexity"],
361 },
362 ProviderMeta {
363 id: "cohere",
364 display_name: "Cohere",
365 category: ProviderCategory::Open,
366 hidden: false,
367 description: "Cohere Command and Embed models",
368 env_key: "COHERE_API_KEY",
369 aliases: &["cohere"],
370 },
371 ProviderMeta {
373 id: "minimax",
374 display_name: "MiniMax",
375 category: ProviderCategory::Regional,
376 hidden: false,
377 description: "MiniMax-M2.7, abab models",
378 env_key: "MINIMAX_API_KEY",
379 aliases: &["minimax"],
380 },
381 ProviderMeta {
382 id: "moonshotai",
383 display_name: "Moonshot AI (Kimi)",
384 category: ProviderCategory::Regional,
385 hidden: false,
386 description: "Kimi models from Moonshot AI",
387 env_key: "MOONSHOT_API_KEY",
388 aliases: &["moonshotai", "moonshot", "kimi"],
389 },
390 ProviderMeta {
391 id: "kimi-coding",
392 display_name: "Kimi Coding",
393 category: ProviderCategory::Regional,
394 hidden: false,
395 description: "Kimi Coding Plan — optimized for coding",
396 env_key: "KIMI_CODING_API_KEY",
397 aliases: &["kimi-coding"],
398 },
399 ProviderMeta {
400 id: "zai",
401 display_name: "Z.AI (GLM)",
402 category: ProviderCategory::Regional,
403 hidden: false,
404 description: "Z.AI GLM models (coding plan)",
405 env_key: "ZAI_API_KEY",
406 aliases: &["zai"],
407 },
408 ProviderMeta {
414 id: "amazon-bedrock",
415 display_name: "Amazon Bedrock",
416 category: ProviderCategory::Open,
417 hidden: true,
418 description: "Multi-model via AWS Bedrock ConverseStream",
419 env_key: "AWS_ACCESS_KEY_ID",
420 aliases: &["amazon-bedrock", "aws-bedrock", "bedrock"],
421 },
422 ProviderMeta {
423 id: "azure-openai-responses",
424 display_name: "Azure OpenAI (Responses)",
425 category: ProviderCategory::Open,
426 hidden: true,
427 description: "OpenAI models via Azure Cognitive Services",
428 env_key: "AZURE_OPENAI_API_KEY",
429 aliases: &["azure-openai-responses", "azure"],
430 },
431 ProviderMeta {
432 id: "cloudflare-ai-gateway",
433 display_name: "Cloudflare AI Gateway",
434 category: ProviderCategory::Open,
435 hidden: true,
436 description: "Serverless AI via Cloudflare AI Gateway",
437 env_key: "CLOUDFLARE_API_TOKEN",
438 aliases: &["cloudflare-ai-gateway", "cf-ai-gateway"],
439 },
440 ProviderMeta {
441 id: "cloudflare-workers-ai",
442 display_name: "Cloudflare Workers AI",
443 category: ProviderCategory::Open,
444 hidden: true,
445 description: "Serverless AI via Cloudflare Workers",
446 env_key: "CLOUDFLARE_API_KEY",
447 aliases: &["cloudflare-workers-ai", "cloudflare", "workers-ai"],
448 },
449 ProviderMeta {
450 id: "google-vertex",
451 display_name: "Google Vertex AI",
452 category: ProviderCategory::Open,
453 hidden: true,
454 description: "Gemini via Google Cloud Vertex AI",
455 env_key: "GOOGLE_APPLICATION_CREDENTIALS",
456 aliases: &["google-vertex", "vertex"],
457 },
458 ProviderMeta {
459 id: "minimax-cn",
460 display_name: "MiniMax (China)",
461 category: ProviderCategory::Regional,
462 hidden: true,
463 description: "MiniMax China region endpoint",
464 env_key: "MINIMAX_CN_API_KEY",
465 aliases: &["minimax-cn"],
466 },
467 ProviderMeta {
468 id: "moonshotai-cn",
469 display_name: "Moonshot AI (China)",
470 category: ProviderCategory::Regional,
471 hidden: true,
472 description: "Kimi models — China region endpoint",
473 env_key: "MOONSHOT_CN_API_KEY",
474 aliases: &["moonshotai-cn", "moonshot-cn"],
475 },
476 ProviderMeta {
477 id: "openai-codex",
478 display_name: "OpenAI Codex",
479 category: ProviderCategory::Open,
480 hidden: true,
481 description: "OpenAI Codex coding agent (Responses API)",
482 env_key: "OPENAI_API_KEY",
483 aliases: &["openai-codex"],
484 },
485 ProviderMeta {
486 id: "opencode-go",
487 display_name: "OpenCode Go",
488 category: ProviderCategory::Open,
489 hidden: true,
490 description: "OpenCode Go Gateway",
491 env_key: "OPENCODE_GO_API_KEY",
492 aliases: &["opencode-go"],
493 },
494 ProviderMeta {
495 id: "vercel-ai-gateway",
496 display_name: "Vercel AI Gateway",
497 category: ProviderCategory::Open,
498 hidden: true,
499 description: "Vercel AI Gateway",
500 env_key: "VERCEL_API_KEY",
501 aliases: &["vercel-ai-gateway", "vercel"],
502 },
503 ProviderMeta {
504 id: "xiaomi",
505 display_name: "Xiaomi MiMo",
506 category: ProviderCategory::Regional,
507 hidden: true,
508 description: "Xiaomi MiMo models",
509 env_key: "XIAOMI_API_KEY",
510 aliases: &["xiaomi"],
511 },
512];
513
514fn provider_meta(id: &str) -> Option<&'static ProviderMeta> {
516 PROVIDER_META
517 .iter()
518 .find(|m| m.id == id || m.aliases.contains(&id))
519}
520
521fn provider_category(id: &str) -> ProviderCategory {
522 provider_meta(id)
523 .map(|m| m.category)
524 .unwrap_or(ProviderCategory::Open)
525}
526
527fn provider_display_name(id: &str) -> String {
533 provider_meta(id)
534 .map(|m| m.display_name.to_string())
535 .unwrap_or_else(|| fallback_display_name(id))
536}
537
538fn fallback_display_name(id: &str) -> String {
544 id.split(['-', '_'])
545 .filter(|s| !s.is_empty())
546 .map(|segment| {
547 let mut chars = segment.chars();
548 match chars.next() {
549 Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
550 None => String::new(),
551 }
552 })
553 .collect::<Vec<_>>()
554 .join(" ")
555}
556
557#[derive(Debug, Clone, Serialize, Deserialize)]
559#[serde(rename_all = "camelCase")]
560pub struct ProviderInfo {
561 pub id: String,
563 pub name: String,
565 pub category: ProviderCategory,
567 pub model_count: usize,
569 pub has_key: bool,
571 #[serde(default)]
574 pub description: String,
575 #[serde(default)]
579 pub env_key: String,
580}
581
582#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
584#[serde(rename_all = "lowercase")]
585pub enum InputModality {
586 Text,
588 Image,
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
594#[serde(rename_all = "camelCase")]
595pub struct ModelInfo {
596 pub id: String,
598 pub name: String,
600 pub api: String,
602 pub provider: String,
604 pub reasoning: bool,
606 pub input: Vec<InputModality>,
608 pub context_window: u32,
610 pub max_tokens: u32,
612 pub cost_input: f64,
614 pub cost_output: f64,
616 pub cost_cache_read: f64,
618 pub cost_cache_write: f64,
620}
621
622impl From<&oxi_sdk::ModelEntry> for ModelInfo {
623 fn from(entry: &oxi_sdk::ModelEntry) -> Self {
624 Self {
625 id: format!("{}/{}", entry.provider, entry.id),
626 name: entry.name.to_string(),
627 api: entry.api.to_string(),
628 provider: entry.provider.to_string(),
629 reasoning: entry.reasoning,
630 input: entry
631 .input
632 .iter()
633 .map(|m| match m {
634 oxi_sdk::InputModality::Text => InputModality::Text,
635 oxi_sdk::InputModality::Image => InputModality::Image,
636 _ => InputModality::Text,
637 })
638 .collect(),
639 context_window: entry.context_window,
640 max_tokens: entry.max_tokens,
641 cost_input: entry.cost_input,
642 cost_output: entry.cost_output,
643 cost_cache_read: entry.cost_cache_read,
644 cost_cache_write: entry.cost_cache_write,
645 }
646 }
647}
648
649impl From<&oxi_sdk::CatalogModelEntry> for ModelInfo {
650 fn from(entry: &oxi_sdk::CatalogModelEntry) -> Self {
656 Self {
657 id: format!("{}/{}", entry.provider, entry.model_id),
658 name: entry.name.clone(),
659 api: entry.protocol.as_str().to_string(),
660 provider: entry.provider.clone(),
661 reasoning: entry.reasoning,
662 input: entry
663 .input_modalities
664 .iter()
665 .map(|m| match m.as_str() {
666 "image" => InputModality::Image,
667 _ => InputModality::Text,
668 })
669 .collect(),
670 context_window: entry.context_window,
671 max_tokens: entry.max_tokens,
672 cost_input: entry.cost_input,
673 cost_output: entry.cost_output,
674 cost_cache_read: entry.cost_cache_read,
675 cost_cache_write: entry.cost_cache_write,
676 }
677 }
678}
679
680#[derive(Debug, Clone, Serialize, Deserialize)]
682pub struct EngineConfigResponse {
683 pub default_model: String,
685 pub api_key_set: bool,
687 pub api_key_source: Option<String>,
689 pub provider: Option<String>,
691 pub routing: RoutingConfigSnapshot,
693}
694
695#[derive(Debug, Clone, Serialize, Deserialize)]
697pub struct ValidateKeyResult {
698 pub valid: bool,
700 pub provider: String,
702 pub message: Option<String>,
704}
705
706pub struct EngineApi {
718 config: Arc<RwLock<OxiosConfig>>,
719 config_path: PathBuf,
720 routing_stats: Arc<RoutingStats>,
721 engine_handle: Arc<crate::engine::EngineHandle>,
723}
724
725impl EngineApi {
726 pub fn new(
733 config: Arc<RwLock<OxiosConfig>>,
734 config_path: PathBuf,
735 routing_stats: Arc<RoutingStats>,
736 engine_handle: Arc<crate::engine::EngineHandle>,
737 ) -> Self {
738 Self {
739 config,
740 config_path,
741 routing_stats,
742 engine_handle,
743 }
744 }
745
746 pub fn routing_stats(&self) -> Arc<RoutingStats> {
748 Arc::clone(&self.routing_stats)
749 }
750
751 pub fn engine_handle(&self) -> &Arc<crate::engine::EngineHandle> {
753 &self.engine_handle
754 }
755
756 pub fn providers(&self) -> Vec<ProviderInfo> {
772 let catalog = self.engine_handle.get().oxi().catalog().clone();
773 let use_catalog = catalog.model_count_sync() > 0;
774 let all: Vec<String> = if use_catalog {
775 catalog.list_providers_sync()
776 } else {
777 oxi_sdk::get_providers()
778 .into_iter()
779 .map(|s| s.to_string())
780 .collect()
781 };
782
783 let api_key_override = {
787 let cfg = self.config.read();
788 cfg.engine
789 .api_key
790 .as_deref()
791 .filter(|k| !k.is_empty())
792 .map(str::to_owned)
793 };
794 all.into_iter()
795 .filter(|p| provider_meta(p).map(|m| !m.hidden).unwrap_or(true))
796 .map(|p| {
797 let model_count = if use_catalog {
798 catalog.list_models_sync(&p).len()
799 } else {
800 oxi_sdk::get_provider_models(&p).len()
801 };
802 let has_key = CredentialStore::has_credential(&p, api_key_override.as_deref());
803 let meta = provider_meta(&p);
804 ProviderInfo {
805 id: p.clone(),
806 name: provider_display_name(&p),
807 category: provider_category(&p),
808 model_count,
809 has_key,
810 description: meta.map(|m| m.description.to_string()).unwrap_or_default(),
811 env_key: meta.map(|m| m.env_key.to_string()).unwrap_or_default(),
812 }
813 })
814 .collect()
815 }
816
817 pub fn models(&self, provider: &str, query: Option<&str>) -> Vec<ModelInfo> {
823 let catalog = self.engine_handle.get().oxi().catalog().clone();
824 let live = catalog.list_models_sync(provider);
825 let models: Vec<ModelInfo> = if !live.is_empty() {
826 live.iter().map(ModelInfo::from).collect()
827 } else {
828 oxi_sdk::get_provider_models(provider)
829 .iter()
830 .map(ModelInfo::from)
831 .collect()
832 };
833 models
834 .into_iter()
835 .filter(|m| !m.name.contains("latest"))
836 .filter(|m| {
837 if let Some(q) = query {
838 let q = q.to_lowercase();
839 m.name.to_lowercase().contains(&q)
840 || m.id.to_lowercase().contains(&q)
841 || m.provider.to_lowercase().contains(&q)
842 } else {
843 true
844 }
845 })
846 .collect()
847 }
848
849 pub fn search_models(&self, query: &str) -> Vec<ModelInfo> {
854 let catalog = self.engine_handle.get().oxi().catalog().clone();
855 let live = catalog.search_sync(query);
856 if !live.is_empty() {
857 live.iter().map(ModelInfo::from).collect()
858 } else {
859 oxi_sdk::search_models(query)
860 .into_iter()
861 .map(ModelInfo::from)
862 .collect()
863 }
864 }
865
866 pub fn config(&self) -> EngineConfigResponse {
868 let cfg = self.config.read();
869 let provider =
870 CredentialStore::provider_from_model(&cfg.engine.default_model).map(|s| s.to_string());
871 let api_key_source = provider.as_deref().and_then(|p| {
872 CredentialStore::resolve(p, cfg.api_key().as_deref()).map(|(_, src)| {
873 match src {
874 crate::credential::CredentialSource::EnvVar => "env",
875 crate::credential::CredentialSource::Config => "config",
876 crate::credential::CredentialSource::OxiAuthStore => "auth_store",
877 }
878 .to_string()
879 })
880 });
881 let api_key_set = provider
882 .as_deref()
883 .map(|p| CredentialStore::has_credential(p, cfg.api_key().as_deref()))
884 .unwrap_or(false);
885
886 EngineConfigResponse {
887 default_model: cfg.engine.default_model.clone(),
888 api_key_set,
889 api_key_source,
890 provider,
891 routing: RoutingConfigSnapshot {
892 routing_enabled: cfg.engine.routing_enabled,
893 prefer_cost_efficient: cfg.engine.prefer_cost_efficient,
894 fallback_models: cfg.engine.fallback_models.clone(),
895 excluded_models: cfg.engine.excluded_models.clone(),
896 },
897 }
898 }
899
900 pub fn routing_stats_snapshot(&self) -> RoutingStatsSnapshot {
902 self.routing_stats.snapshot()
903 }
904
905 pub fn fallback_history(&self, limit: usize) -> Vec<FallbackEvent> {
907 self.routing_stats.fallback_history(limit)
908 }
909
910 pub fn set_model(&self, model_id: &str) -> anyhow::Result<()> {
917 {
923 let engine = self.engine_handle.get();
924 let model = engine
925 .resolve_model(model_id)
926 .with_context(|| format!("Unknown model '{model_id}'"))?;
927 engine.create_provider(&model.provider).with_context(|| {
928 format!(
929 "Provider '{}' is not configured for '{model_id}'",
930 model.provider
931 )
932 })?;
933 }
934 let snapshot = {
935 let mut cfg = self.config.write();
936 cfg.engine.default_model = model_id.to_string();
937 cfg.clone()
938 };
939 self.persist(&snapshot)?;
942 tracing::info!(model = %model_id, "Default model updated in config");
943 self.rebuild_and_swap();
944 Ok(())
945 }
946
947 pub fn set_api_key(&self, provider: &str, key: &str) -> anyhow::Result<()> {
953 CredentialStore::store(provider, key)?;
954
955 let snapshot = {
961 let mut cfg = self.config.write();
962 let matches = CredentialStore::provider_from_model(&cfg.engine.default_model)
963 .is_some_and(|current_provider| current_provider == provider);
964 if matches {
965 cfg.engine.api_key = Some(key.to_string());
966 Some(cfg.clone())
967 } else {
968 None
969 }
970 };
971 if let Some(snap) = snapshot {
972 self.persist(&snap)?;
974 }
975 tracing::info!(provider = %provider, "API key stored");
976 self.rebuild_and_swap();
977 Ok(())
978 }
979
980 pub fn set_provider_options(&self, opts: &oxi_sdk::ProviderOptions) -> anyhow::Result<()> {
985 let snapshot = {
986 let mut cfg = self.config.write();
987 cfg.engine.provider_options = Some(opts.clone());
988 cfg.clone()
989 };
990 self.persist(&snapshot)?;
991 tracing::info!("Provider options updated and persisted");
992 Ok(())
995 }
996
997 pub fn set_routing(&self, update: RoutingUpdate) -> anyhow::Result<()> {
1002 let snapshot = {
1003 let mut cfg = self.config.write();
1004 if let Some(v) = update.routing_enabled {
1005 cfg.engine.routing_enabled = v;
1006 }
1007 if let Some(v) = update.prefer_cost_efficient {
1008 cfg.engine.prefer_cost_efficient = v;
1009 }
1010 if let Some(v) = update.fallback_models {
1011 cfg.engine.fallback_models = v;
1012 }
1013 if let Some(v) = update.excluded_models {
1014 cfg.engine.excluded_models = v;
1015 }
1016 cfg.clone()
1017 };
1018 self.persist(&snapshot)?;
1019 tracing::info!("Routing configuration updated via API");
1020 self.rebuild_and_swap();
1021 Ok(())
1022 }
1023
1024 pub fn validate_key(&self, provider: &str, api_key: &str) -> ValidateKeyResult {
1029 let result = self.try_validate(provider, api_key);
1031 match result {
1032 Ok(()) => ValidateKeyResult {
1033 valid: true,
1034 provider: provider.to_string(),
1035 message: Some("API key is valid".to_string()),
1036 },
1037 Err(e) => ValidateKeyResult {
1038 valid: false,
1039 provider: provider.to_string(),
1040 message: Some(format!("Validation failed: {e}")),
1041 },
1042 }
1043 }
1044
1045 fn try_validate(&self, provider: &str, api_key: &str) -> anyhow::Result<()> {
1047 let builder = oxi_sdk::OxiBuilder::new()
1049 .with_builtins()
1050 .api_key(provider, api_key);
1051 let oxi = builder.build();
1052
1053 let models = oxi_sdk::get_provider_models(provider);
1055 if models.is_empty() {
1056 anyhow::bail!("No models found for provider '{provider}'");
1057 }
1058
1059 let model_id = format!("{}/{}", provider, models[0].id);
1060 let _model = oxi.resolve_model(&model_id)?;
1061
1062 let _provider = oxi.create_provider(provider)?;
1064
1065 if api_key.is_empty() {
1069 anyhow::bail!("API key is empty");
1070 }
1071 if api_key.len() < 8 {
1072 anyhow::bail!("API key appears too short");
1073 }
1074
1075 tracing::debug!(
1076 provider = %provider,
1077 model = %model_id,
1078 "Key validation: provider resolved with injected key"
1079 );
1080 Ok(())
1081 }
1082
1083 pub fn estimate_cost(model_id: &str, input_tokens: u64, output_tokens: u64) -> f64 {
1087 estimate_cost(model_id, input_tokens, output_tokens)
1088 }
1089
1090 fn persist(&self, config: &OxiosConfig) -> anyhow::Result<()> {
1092 let content = toml::to_string_pretty(config)
1093 .map_err(|e| anyhow::anyhow!("Failed to serialize config: {e}"))?;
1094 std::fs::write(&self.config_path, content)?;
1095 Ok(())
1096 }
1097
1098 fn rebuild_and_swap(&self) {
1105 let (model_id, api_key, catalog) = {
1109 let cfg = self.config.read();
1110 let catalog = self.engine_handle.get().oxi().catalog().clone();
1111 (cfg.engine.default_model.clone(), cfg.api_key(), catalog)
1112 };
1113 let new_engine = crate::engine::OxiosEngine::from_config_with_catalog(
1114 &model_id,
1115 api_key.as_deref(),
1116 catalog,
1117 );
1118 self.engine_handle.swap(new_engine);
1119 }
1120}
1121
1122impl std::fmt::Debug for EngineApi {
1123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1124 f.debug_struct("EngineApi")
1125 .field("config_path", &self.config_path)
1126 .finish()
1127 }
1128}
1129
1130pub fn record_usage_to_stats(
1133 stats: &Option<Arc<RoutingStats>>,
1134 model_id: &str,
1135 input_tokens: u64,
1136 output_tokens: u64,
1137) {
1138 if let Some(s) = stats {
1139 let cost = estimate_cost(model_id, input_tokens, output_tokens);
1140 s.record_model_usage(model_id, cost);
1141 }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146 use super::*;
1147
1148 #[test]
1149 fn test_provider_category_known() {
1150 assert_eq!(provider_category("anthropic"), ProviderCategory::Major);
1152 assert_eq!(provider_category("openai"), ProviderCategory::Major);
1153 assert_eq!(provider_category("google"), ProviderCategory::Major);
1154 assert_eq!(provider_category("groq"), ProviderCategory::Open);
1156 assert_eq!(provider_category("opencode"), ProviderCategory::Open);
1157 assert_eq!(provider_category("minimax"), ProviderCategory::Regional);
1159 assert_eq!(provider_category("moonshotai"), ProviderCategory::Regional);
1160 assert_eq!(provider_category("kimi-coding"), ProviderCategory::Regional);
1161 assert_eq!(provider_category("zai"), ProviderCategory::Regional);
1162 assert_eq!(provider_category("minimax-cn"), ProviderCategory::Regional);
1163 assert_eq!(provider_category("xiaomi"), ProviderCategory::Regional);
1164 }
1165
1166 #[test]
1167 fn test_provider_category_fallback() {
1168 assert_eq!(
1170 provider_category("not-a-real-provider"),
1171 ProviderCategory::Open
1172 );
1173 assert_eq!(provider_category(""), ProviderCategory::Open);
1174 }
1175
1176 #[test]
1177 fn test_provider_display_name_known() {
1178 assert_eq!(provider_display_name("anthropic"), "Anthropic");
1179 assert_eq!(provider_display_name("minimax"), "MiniMax");
1180 assert_eq!(provider_display_name("moonshotai"), "Moonshot AI (Kimi)");
1181 assert_eq!(provider_display_name("kimi-coding"), "Kimi Coding");
1182 assert_eq!(provider_display_name("zai"), "Z.AI (GLM)");
1183 assert_eq!(provider_display_name("opencode"), "OpenCode");
1184 assert_eq!(provider_display_name("amazon-bedrock"), "Amazon Bedrock");
1185 }
1186
1187 #[test]
1188 fn test_provider_display_name_fallback() {
1189 assert_eq!(
1191 provider_display_name("some-new-provider"),
1192 "Some New Provider"
1193 );
1194 assert_eq!(provider_display_name("kimi-coding"), "Kimi Coding");
1195 assert_eq!(provider_display_name("some_id"), "Some Id");
1196 assert_eq!(provider_display_name(""), "");
1198 }
1199
1200 #[test]
1201 fn test_provider_meta_lookup_by_alias() {
1202 let by_id = provider_meta("github-copilot").unwrap();
1204 let by_alias = provider_meta("copilot").unwrap();
1205 assert_eq!(by_id.id, by_alias.id);
1206
1207 let bedrock_id = provider_meta("amazon-bedrock").unwrap();
1208 let bedrock_alias = provider_meta("aws-bedrock").unwrap();
1209 let bedrock_canonical = provider_meta("bedrock").unwrap();
1210 assert_eq!(bedrock_id.id, bedrock_alias.id);
1211 assert_eq!(bedrock_id.id, bedrock_canonical.id);
1212 }
1213
1214 #[test]
1215 fn test_provider_meta_unknown_is_none() {
1216 assert!(provider_meta("not-a-real-provider").is_none());
1217 assert!(provider_meta("").is_none());
1218 }
1219
1220 #[test]
1221 fn test_provider_info_serialization() {
1222 let info = ProviderInfo {
1223 id: "anthropic".to_string(),
1224 name: "Anthropic".to_string(),
1225 category: ProviderCategory::Major,
1226 model_count: 15,
1227 has_key: true,
1228 description: "Claude models with extended thinking".to_string(),
1229 env_key: "ANTHROPIC_API_KEY".to_string(),
1230 };
1231 let json = serde_json::to_string(&info).unwrap();
1232 assert!(json.contains("\"modelCount\":15"));
1234 assert!(json.contains("\"hasKey\":true"));
1235 assert!(json.contains("\"envKey\":\"ANTHROPIC_API_KEY\""));
1236 let restored: ProviderInfo = serde_json::from_str(&json).unwrap();
1237 assert_eq!(restored.id, "anthropic");
1238 assert_eq!(restored.name, "Anthropic");
1239 assert_eq!(restored.model_count, 15);
1240 assert!(restored.has_key);
1241 assert_eq!(restored.env_key, "ANTHROPIC_API_KEY");
1242 }
1243
1244 #[test]
1245 fn test_provider_info_serialization_missing_optional() {
1246 let json = r#"{
1249 "id": "anthropic",
1250 "name": "Anthropic",
1251 "category": "major",
1252 "modelCount": 15,
1253 "hasKey": true
1254 }"#;
1255 let info: ProviderInfo = serde_json::from_str(json).unwrap();
1256 assert_eq!(info.id, "anthropic");
1257 assert_eq!(info.description, "");
1258 assert_eq!(info.env_key, "");
1259 }
1260
1261 #[test]
1262 fn test_model_info_serialization() {
1263 let info = ModelInfo {
1264 id: "anthropic/claude-sonnet-4".to_string(),
1265 name: "Claude Sonnet 4".to_string(),
1266 api: "anthropic-messages".to_string(),
1267 provider: "anthropic".to_string(),
1268 reasoning: true,
1269 input: vec![InputModality::Text, InputModality::Image],
1270 context_window: 200000,
1271 max_tokens: 16000,
1272 cost_input: 3.0,
1273 cost_output: 15.0,
1274 cost_cache_read: 0.3,
1275 cost_cache_write: 3.75,
1276 };
1277 let json = serde_json::to_string(&info).unwrap();
1278 let restored: ModelInfo = serde_json::from_str(&json).unwrap();
1279 assert_eq!(restored.id, "anthropic/claude-sonnet-4");
1280 assert!(restored.reasoning);
1281 assert_eq!(restored.context_window, 200000);
1282 assert!(restored.input.contains(&InputModality::Image));
1283 assert_eq!(restored.api, "anthropic-messages");
1284 }
1285
1286 #[test]
1287 fn test_engine_config_response_serialization() {
1288 let resp = EngineConfigResponse {
1289 default_model: "anthropic/claude-sonnet-4".to_string(),
1290 api_key_set: true,
1291 api_key_source: Some("config.toml".to_string()),
1292 provider: Some("anthropic".to_string()),
1293 routing: RoutingConfigSnapshot {
1294 routing_enabled: false,
1295 prefer_cost_efficient: false,
1296 fallback_models: vec![],
1297 excluded_models: vec![],
1298 },
1299 };
1300 let json = serde_json::to_string(&resp).unwrap();
1301 let restored: EngineConfigResponse = serde_json::from_str(&json).unwrap();
1302 assert_eq!(restored.default_model, "anthropic/claude-sonnet-4");
1303 assert!(restored.api_key_set);
1304 assert_eq!(restored.api_key_source.as_deref(), Some("config.toml"));
1305 assert!(!restored.routing.routing_enabled);
1306 }
1307
1308 #[test]
1309 fn test_validate_key_result_serialization() {
1310 let result = ValidateKeyResult {
1311 valid: true,
1312 provider: "openai".to_string(),
1313 message: Some("API key is valid".to_string()),
1314 };
1315 let json = serde_json::to_string(&result).unwrap();
1316 let restored: ValidateKeyResult = serde_json::from_str(&json).unwrap();
1317 assert!(restored.valid);
1318 assert_eq!(restored.provider, "openai");
1319 }
1320
1321 #[test]
1322 fn test_validate_key_result_invalid() {
1323 let result = ValidateKeyResult {
1324 valid: false,
1325 provider: "anthropic".to_string(),
1326 message: Some("Validation failed: key too short".to_string()),
1327 };
1328 assert!(!result.valid);
1329 assert!(result.message.as_ref().unwrap().contains("failed"));
1330 }
1331
1332 #[test]
1333 fn test_routing_stats_snapshot() {
1334 let stats = RoutingStats::new();
1335 stats.record_model_usage("anthropic/claude-sonnet-4", 0.05);
1336 stats.record_model_usage("anthropic/claude-sonnet-4", 0.03);
1337 stats.record_model_usage("openai/gpt-4o-mini", 0.01);
1338
1339 let snap = stats.snapshot();
1340 assert_eq!(snap.total_requests, 3);
1341 assert_eq!(snap.model_calls["anthropic/claude-sonnet-4"], 2);
1342 assert_eq!(snap.model_calls["openai/gpt-4o-mini"], 1);
1343 assert!((snap.total_cost - 0.09).abs() < 0.001);
1344 }
1345
1346 #[test]
1347 fn test_fallback_history_circular() {
1348 let stats = RoutingStats::new();
1349 for i in 0..210 {
1350 stats.record_fallback(FallbackEvent {
1351 timestamp: DateTime::from_timestamp(i as i64, 0).unwrap(),
1352 from_model: format!("model-{}", i),
1353 to_model: "fallback".to_string(),
1354 reason: "test".to_string(),
1355 success: true,
1356 });
1357 }
1358 let history = stats.fallback_history(200);
1359 assert_eq!(history.len(), 200);
1360 assert_eq!(history[0].from_model, "model-209");
1362 assert_eq!(history[199].from_model, "model-10");
1363 }
1364
1365 #[test]
1366 fn set_model_rejects_unknown_model_before_persist() {
1367 use crate::engine::{EngineHandle, OxiosEngine};
1368
1369 let engine = Arc::new(OxiosEngine::new("anthropic/claude-sonnet-4-20250514"));
1370 let handle = Arc::new(EngineHandle::new(engine));
1371 let config = Arc::new(parking_lot::RwLock::new(OxiosConfig::default()));
1372 let path = PathBuf::from("/tmp/oxios-set-model-test-NONEXISTENT.toml");
1375 let api = EngineApi::new(config, path, Arc::new(RoutingStats::new()), handle);
1376
1377 let before = api.config.read().engine.default_model.clone();
1380 let err = api.set_model("zai-coding-plan/glm-5-turbo").unwrap_err();
1381 assert!(
1382 err.to_string().contains("Unknown model"),
1383 "expected unknown-model error, got: {err}"
1384 );
1385 assert_eq!(api.config.read().engine.default_model, before);
1387 }
1388
1389 #[test]
1390 fn set_model_accepts_known_builtin_model() {
1391 use crate::engine::{EngineHandle, OxiosEngine};
1392
1393 let engine = Arc::new(OxiosEngine::new("anthropic/claude-sonnet-4-20250514"));
1394 let handle = Arc::new(EngineHandle::new(engine));
1395 let config = Arc::new(parking_lot::RwLock::new(OxiosConfig::default()));
1396 let tmp =
1397 std::env::temp_dir().join(format!("oxios-set-model-ok-{}.toml", std::process::id()));
1398 let api = EngineApi::new(config, tmp.clone(), Arc::new(RoutingStats::new()), handle);
1399
1400 let result = api.set_model("openai/gpt-4o");
1403 match result {
1407 Ok(()) => assert_eq!(api.config.read().engine.default_model, "openai/gpt-4o"),
1408 Err(e) => assert!(
1409 !e.to_string().contains("Unknown model"),
1410 "known model rejected as unknown: {e}"
1411 ),
1412 }
1413 let _ = std::fs::remove_file(&tmp);
1414 }
1415}