1use crate::config::OxiosConfig;
10use crate::credential::CredentialStore;
11use anyhow::Context;
12use chrono::{DateTime, Utc};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(rename_all = "camelCase")]
24pub struct RoutingConfigSnapshot {
25 pub routing_enabled: bool,
27 pub prefer_cost_efficient: bool,
29 pub fallback_models: Vec<String>,
31 pub excluded_models: Vec<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct RoutingStatsSnapshot {
39 pub model_calls: HashMap<String, u64>,
41 pub model_cost: HashMap<String, f64>,
43 pub total_requests: u64,
45 pub total_cost: f64,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct FallbackEvent {
53 pub timestamp: DateTime<Utc>,
55 pub from_model: String,
57 pub to_model: String,
59 pub reason: String,
61 pub success: bool,
63}
64
65#[derive(Debug, Deserialize)]
67#[serde(rename_all = "camelCase")]
68pub struct RoutingUpdate {
69 pub routing_enabled: Option<bool>,
70 pub prefer_cost_efficient: Option<bool>,
71 pub fallback_models: Option<Vec<String>>,
72 pub excluded_models: Option<Vec<String>>,
73}
74
75pub struct RoutingStats {
80 calls: RwLock<HashMap<String, u64>>,
81 costs: RwLock<HashMap<String, f64>>,
82 fallbacks: RwLock<Vec<FallbackEvent>>,
84}
85
86impl Default for RoutingStats {
87 fn default() -> Self {
88 Self {
89 calls: RwLock::new(HashMap::new()),
90 costs: RwLock::new(HashMap::new()),
91 fallbacks: RwLock::new(Vec::new()),
92 }
93 }
94}
95
96impl RoutingStats {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn record_model_usage(&self, model_id: &str, cost_usd: f64) {
104 let mut calls = self.calls.write();
105 *calls.entry(model_id.to_string()).or_insert(0) += 1;
106 if cost_usd > 0.0 {
107 let mut costs = self.costs.write();
108 *costs.entry(model_id.to_string()).or_insert(0.0) += cost_usd;
109 }
110 }
111
112 pub fn record_fallback(&self, event: FallbackEvent) {
114 let mut fb = self.fallbacks.write();
115 fb.push(event);
116 let keep = fb.len().saturating_sub(200);
117 if keep > 0 {
118 fb.drain(0..keep);
119 }
120 }
121
122 pub fn snapshot(&self) -> RoutingStatsSnapshot {
124 let calls = self.calls.read();
125 let costs = self.costs.read();
126 let total_requests: u64 = calls.values().sum();
127 let total_cost: f64 = costs.values().sum();
128 RoutingStatsSnapshot {
129 model_calls: calls.clone(),
130 model_cost: costs.clone(),
131 total_requests,
132 total_cost,
133 }
134 }
135
136 pub fn fallback_history(&self, limit: usize) -> Vec<FallbackEvent> {
138 let fb = self.fallbacks.read();
139 fb.iter().rev().take(limit).cloned().collect()
140 }
141}
142
143pub fn estimate_cost(model_id: &str, input_tokens: u64, output_tokens: u64) -> f64 {
148 let entries = oxi_sdk::get_provider_models(model_id.split('/').next().unwrap_or(model_id));
149 let entry = entries
150 .iter()
151 .find(|e| format!("{}/{}", e.provider, e.id) == model_id);
152 match entry {
153 Some(e) => {
154 (e.cost_input * input_tokens as f64 / 1_000_000.0)
155 + (e.cost_output * output_tokens as f64 / 1_000_000.0)
156 }
157 None => {
158 (0.003 * input_tokens as f64 / 1_000_000.0)
160 + (0.015 * output_tokens as f64 / 1_000_000.0)
161 }
162 }
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
169#[serde(rename_all = "lowercase")]
170pub enum ProviderCategory {
171 Major,
173 Open,
175 Regional,
177 Local,
179}
180
181#[derive(Debug, Clone, Copy)]
194struct ProviderMeta {
195 id: &'static str,
197 display_name: &'static str,
199 category: ProviderCategory,
201 hidden: bool,
205 description: &'static str,
207 env_key: &'static str,
211 aliases: &'static [&'static str],
215}
216
217const PROVIDER_META: &[ProviderMeta] = &[
223 ProviderMeta {
225 id: "anthropic",
226 display_name: "Anthropic",
227 category: ProviderCategory::Major,
228 hidden: false,
229 description: "Claude models with extended thinking",
230 env_key: "ANTHROPIC_API_KEY",
231 aliases: &["anthropic"],
232 },
233 ProviderMeta {
234 id: "openai",
235 display_name: "OpenAI",
236 category: ProviderCategory::Major,
237 hidden: false,
238 description: "GPT, o-series, and Codex models",
239 env_key: "OPENAI_API_KEY",
240 aliases: &["openai"],
241 },
242 ProviderMeta {
243 id: "google",
244 display_name: "Google Gemini",
245 category: ProviderCategory::Major,
246 hidden: false,
247 description: "Gemini models with thinking and tool use",
248 env_key: "GOOGLE_API_KEY",
249 aliases: &["google"],
250 },
251 ProviderMeta {
253 id: "groq",
254 display_name: "Groq",
255 category: ProviderCategory::Open,
256 hidden: false,
257 description: "Fast Llama, Mixtral, and Gemma inference",
258 env_key: "GROQ_API_KEY",
259 aliases: &["groq"],
260 },
261 ProviderMeta {
262 id: "openrouter",
263 display_name: "OpenRouter",
264 category: ProviderCategory::Open,
265 hidden: false,
266 description: "Unified gateway to 200+ models",
267 env_key: "OPENROUTER_API_KEY",
268 aliases: &["openrouter"],
269 },
270 ProviderMeta {
271 id: "deepseek",
272 display_name: "DeepSeek",
273 category: ProviderCategory::Open,
274 hidden: false,
275 description: "DeepSeek-V3 and DeepSeek-R1",
276 env_key: "DEEPSEEK_API_KEY",
277 aliases: &["deepseek"],
278 },
279 ProviderMeta {
280 id: "mistral",
281 display_name: "Mistral",
282 category: ProviderCategory::Open,
283 hidden: false,
284 description: "Mistral and Codestral models",
285 env_key: "MISTRAL_API_KEY",
286 aliases: &["mistral"],
287 },
288 ProviderMeta {
289 id: "xai",
290 display_name: "xAI (Grok)",
291 category: ProviderCategory::Open,
292 hidden: false,
293 description: "Grok models from xAI",
294 env_key: "XAI_API_KEY",
295 aliases: &["xai", "grok"],
296 },
297 ProviderMeta {
298 id: "cerebras",
299 display_name: "Cerebras",
300 category: ProviderCategory::Open,
301 hidden: false,
302 description: "Ultra-fast open model inference",
303 env_key: "CEREBRAS_API_KEY",
304 aliases: &["cerebras"],
305 },
306 ProviderMeta {
307 id: "fireworks",
308 display_name: "Fireworks",
309 category: ProviderCategory::Open,
310 hidden: false,
311 description: "Fast open-source model serving",
312 env_key: "FIREWORKS_API_KEY",
313 aliases: &["fireworks"],
314 },
315 ProviderMeta {
316 id: "github-copilot",
317 display_name: "GitHub Copilot",
318 category: ProviderCategory::Open,
319 hidden: false,
320 description: "GitHub Copilot models (GPT-4, Claude)",
321 env_key: "GITHUB_COPILOT_TOKEN",
322 aliases: &["github-copilot", "copilot"],
323 },
324 ProviderMeta {
325 id: "huggingface",
326 display_name: "Hugging Face",
327 category: ProviderCategory::Open,
328 hidden: false,
329 description: "Open model inference hub",
330 env_key: "HUGGINGFACE_API_KEY",
331 aliases: &["huggingface", "hf"],
332 },
333 ProviderMeta {
334 id: "together",
335 display_name: "Together AI",
336 category: ProviderCategory::Open,
337 hidden: false,
338 description: "Open-source model hosting (Llama, Mixtral, ...)",
339 env_key: "TOGETHER_API_KEY",
340 aliases: &["together", "togetherai"],
341 },
342 ProviderMeta {
343 id: "opencode",
344 display_name: "OpenCode",
345 category: ProviderCategory::Open,
346 hidden: false,
347 description: "OpenCode coding agent gateway",
348 env_key: "",
349 aliases: &["opencode"],
350 },
351 ProviderMeta {
352 id: "perplexity",
353 display_name: "Perplexity",
354 category: ProviderCategory::Open,
355 hidden: false,
356 description: "Search-augmented answer models",
357 env_key: "PERPLEXITY_API_KEY",
358 aliases: &["perplexity"],
359 },
360 ProviderMeta {
361 id: "cohere",
362 display_name: "Cohere",
363 category: ProviderCategory::Open,
364 hidden: false,
365 description: "Cohere Command and Embed models",
366 env_key: "COHERE_API_KEY",
367 aliases: &["cohere"],
368 },
369 ProviderMeta {
371 id: "minimax",
372 display_name: "MiniMax",
373 category: ProviderCategory::Regional,
374 hidden: false,
375 description: "MiniMax-M2.7, abab models",
376 env_key: "MINIMAX_API_KEY",
377 aliases: &["minimax"],
378 },
379 ProviderMeta {
380 id: "moonshotai",
381 display_name: "Moonshot AI (Kimi)",
382 category: ProviderCategory::Regional,
383 hidden: false,
384 description: "Kimi models from Moonshot AI",
385 env_key: "MOONSHOT_API_KEY",
386 aliases: &["moonshotai", "moonshot", "kimi"],
387 },
388 ProviderMeta {
389 id: "kimi-coding",
390 display_name: "Kimi Coding",
391 category: ProviderCategory::Regional,
392 hidden: false,
393 description: "Kimi Coding Plan — optimized for coding",
394 env_key: "KIMI_CODING_API_KEY",
395 aliases: &["kimi-coding"],
396 },
397 ProviderMeta {
398 id: "zai",
399 display_name: "Z.AI (GLM)",
400 category: ProviderCategory::Regional,
401 hidden: false,
402 description: "Z.AI GLM models (coding plan)",
403 env_key: "ZAI_API_KEY",
404 aliases: &["zai"],
405 },
406 ProviderMeta {
412 id: "amazon-bedrock",
413 display_name: "Amazon Bedrock",
414 category: ProviderCategory::Open,
415 hidden: true,
416 description: "Multi-model via AWS Bedrock ConverseStream",
417 env_key: "AWS_ACCESS_KEY_ID",
418 aliases: &["amazon-bedrock", "aws-bedrock", "bedrock"],
419 },
420 ProviderMeta {
421 id: "azure-openai-responses",
422 display_name: "Azure OpenAI (Responses)",
423 category: ProviderCategory::Open,
424 hidden: true,
425 description: "OpenAI models via Azure Cognitive Services",
426 env_key: "AZURE_OPENAI_API_KEY",
427 aliases: &["azure-openai-responses", "azure"],
428 },
429 ProviderMeta {
430 id: "cloudflare-ai-gateway",
431 display_name: "Cloudflare AI Gateway",
432 category: ProviderCategory::Open,
433 hidden: true,
434 description: "Serverless AI via Cloudflare AI Gateway",
435 env_key: "CLOUDFLARE_API_TOKEN",
436 aliases: &["cloudflare-ai-gateway", "cf-ai-gateway"],
437 },
438 ProviderMeta {
439 id: "cloudflare-workers-ai",
440 display_name: "Cloudflare Workers AI",
441 category: ProviderCategory::Open,
442 hidden: true,
443 description: "Serverless AI via Cloudflare Workers",
444 env_key: "CLOUDFLARE_API_KEY",
445 aliases: &["cloudflare-workers-ai", "cloudflare", "workers-ai"],
446 },
447 ProviderMeta {
448 id: "google-vertex",
449 display_name: "Google Vertex AI",
450 category: ProviderCategory::Open,
451 hidden: true,
452 description: "Gemini via Google Cloud Vertex AI",
453 env_key: "GOOGLE_APPLICATION_CREDENTIALS",
454 aliases: &["google-vertex", "vertex"],
455 },
456 ProviderMeta {
457 id: "minimax-cn",
458 display_name: "MiniMax (China)",
459 category: ProviderCategory::Regional,
460 hidden: true,
461 description: "MiniMax China region endpoint",
462 env_key: "MINIMAX_CN_API_KEY",
463 aliases: &["minimax-cn"],
464 },
465 ProviderMeta {
466 id: "moonshotai-cn",
467 display_name: "Moonshot AI (China)",
468 category: ProviderCategory::Regional,
469 hidden: true,
470 description: "Kimi models — China region endpoint",
471 env_key: "MOONSHOT_CN_API_KEY",
472 aliases: &["moonshotai-cn", "moonshot-cn"],
473 },
474 ProviderMeta {
475 id: "openai-codex",
476 display_name: "OpenAI Codex",
477 category: ProviderCategory::Open,
478 hidden: true,
479 description: "OpenAI Codex coding agent (Responses API)",
480 env_key: "OPENAI_API_KEY",
481 aliases: &["openai-codex"],
482 },
483 ProviderMeta {
484 id: "opencode-go",
485 display_name: "OpenCode Go",
486 category: ProviderCategory::Open,
487 hidden: true,
488 description: "OpenCode Go Gateway",
489 env_key: "OPENCODE_GO_API_KEY",
490 aliases: &["opencode-go"],
491 },
492 ProviderMeta {
493 id: "vercel-ai-gateway",
494 display_name: "Vercel AI Gateway",
495 category: ProviderCategory::Open,
496 hidden: true,
497 description: "Vercel AI Gateway",
498 env_key: "VERCEL_API_KEY",
499 aliases: &["vercel-ai-gateway", "vercel"],
500 },
501 ProviderMeta {
502 id: "xiaomi",
503 display_name: "Xiaomi MiMo",
504 category: ProviderCategory::Regional,
505 hidden: true,
506 description: "Xiaomi MiMo models",
507 env_key: "XIAOMI_API_KEY",
508 aliases: &["xiaomi"],
509 },
510];
511
512fn provider_meta(id: &str) -> Option<&'static ProviderMeta> {
514 PROVIDER_META
515 .iter()
516 .find(|m| m.id == id || m.aliases.contains(&id))
517}
518
519fn provider_category(id: &str) -> ProviderCategory {
520 provider_meta(id)
521 .map(|m| m.category)
522 .unwrap_or(ProviderCategory::Open)
523}
524
525fn provider_display_name(id: &str) -> String {
531 provider_meta(id)
532 .map(|m| m.display_name.to_string())
533 .unwrap_or_else(|| fallback_display_name(id))
534}
535
536fn fallback_display_name(id: &str) -> String {
542 id.split(['-', '_'])
543 .filter(|s| !s.is_empty())
544 .map(|segment| {
545 let mut chars = segment.chars();
546 match chars.next() {
547 Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
548 None => String::new(),
549 }
550 })
551 .collect::<Vec<_>>()
552 .join(" ")
553}
554
555#[derive(Debug, Clone, Serialize, Deserialize)]
557#[serde(rename_all = "camelCase")]
558pub struct ProviderInfo {
559 pub id: String,
561 pub name: String,
563 pub category: ProviderCategory,
565 pub model_count: usize,
567 pub has_key: bool,
569 #[serde(default)]
572 pub description: String,
573 #[serde(default)]
577 pub env_key: String,
578}
579
580#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
582#[serde(rename_all = "lowercase")]
583pub enum InputModality {
584 Text,
586 Image,
588}
589
590#[derive(Debug, Clone, Serialize, Deserialize)]
592#[serde(rename_all = "camelCase")]
593pub struct ModelInfo {
594 pub id: String,
596 pub name: String,
598 pub api: String,
600 pub provider: String,
602 pub reasoning: bool,
604 pub input: Vec<InputModality>,
606 pub context_window: u32,
608 pub max_tokens: u32,
610 pub cost_input: f64,
612 pub cost_output: f64,
614 pub cost_cache_read: f64,
616 pub cost_cache_write: f64,
618}
619
620impl From<&oxi_sdk::ModelEntry> for ModelInfo {
621 fn from(entry: &oxi_sdk::ModelEntry) -> Self {
622 Self {
623 id: format!("{}/{}", entry.provider, entry.id),
624 name: entry.name.to_string(),
625 api: entry.api.to_string(),
626 provider: entry.provider.to_string(),
627 reasoning: entry.reasoning,
628 input: entry
629 .input
630 .iter()
631 .map(|m| match m {
632 oxi_sdk::InputModality::Text => InputModality::Text,
633 oxi_sdk::InputModality::Image => InputModality::Image,
634 _ => InputModality::Text,
635 })
636 .collect(),
637 context_window: entry.context_window,
638 max_tokens: entry.max_tokens,
639 cost_input: entry.cost_input,
640 cost_output: entry.cost_output,
641 cost_cache_read: entry.cost_cache_read,
642 cost_cache_write: entry.cost_cache_write,
643 }
644 }
645}
646
647impl From<&oxi_sdk::CatalogModelEntry> for ModelInfo {
648 fn from(entry: &oxi_sdk::CatalogModelEntry) -> Self {
654 Self {
655 id: format!("{}/{}", entry.provider, entry.model_id),
656 name: entry.name.clone(),
657 api: entry.protocol.as_str().to_string(),
658 provider: entry.provider.clone(),
659 reasoning: entry.reasoning,
660 input: entry
661 .input_modalities
662 .iter()
663 .map(|m| match m.as_str() {
664 "image" => InputModality::Image,
665 _ => InputModality::Text,
666 })
667 .collect(),
668 context_window: entry.context_window,
669 max_tokens: entry.max_tokens,
670 cost_input: entry.cost_input,
671 cost_output: entry.cost_output,
672 cost_cache_read: entry.cost_cache_read,
673 cost_cache_write: entry.cost_cache_write,
674 }
675 }
676}
677
678#[derive(Debug, Clone, Serialize, Deserialize)]
680pub struct EngineConfigResponse {
681 pub default_model: String,
683 pub api_key_set: bool,
685 pub api_key_source: Option<String>,
687 pub provider: Option<String>,
689 pub routing: RoutingConfigSnapshot,
691}
692
693#[derive(Debug, Clone, Serialize, Deserialize)]
695pub struct ValidateKeyResult {
696 pub valid: bool,
698 pub provider: String,
700 pub message: Option<String>,
702}
703
704pub struct EngineApi {
716 config: Arc<RwLock<OxiosConfig>>,
717 config_path: PathBuf,
718 routing_stats: Arc<RoutingStats>,
719 engine_handle: Arc<crate::engine::EngineHandle>,
721}
722
723impl EngineApi {
724 pub fn new(
731 config: Arc<RwLock<OxiosConfig>>,
732 config_path: PathBuf,
733 routing_stats: Arc<RoutingStats>,
734 engine_handle: Arc<crate::engine::EngineHandle>,
735 ) -> Self {
736 Self {
737 config,
738 config_path,
739 routing_stats,
740 engine_handle,
741 }
742 }
743
744 pub fn routing_stats(&self) -> Arc<RoutingStats> {
746 Arc::clone(&self.routing_stats)
747 }
748
749 pub fn engine_handle(&self) -> &Arc<crate::engine::EngineHandle> {
751 &self.engine_handle
752 }
753
754 pub fn providers(&self) -> Vec<ProviderInfo> {
770 let catalog = self.engine_handle.get().oxi().catalog().clone();
771 let use_catalog = catalog.model_count_sync() > 0;
772 let all: Vec<String> = if use_catalog {
773 catalog.list_providers_sync()
774 } else {
775 oxi_sdk::get_providers()
776 .into_iter()
777 .map(|s| s.to_string())
778 .collect()
779 };
780
781 all.into_iter()
782 .filter(|p| provider_meta(p).map(|m| !m.hidden).unwrap_or(true))
783 .map(|p| {
784 let model_count = if use_catalog {
785 catalog.list_models_sync(&p).len()
786 } else {
787 oxi_sdk::get_provider_models(&p).len()
788 };
789 let has_key = CredentialStore::has_credential(
790 &p,
791 self.config
792 .read()
793 .engine
794 .api_key
795 .as_deref()
796 .filter(|k| !k.is_empty()),
797 );
798 let meta = provider_meta(&p);
799 ProviderInfo {
800 id: p.clone(),
801 name: provider_display_name(&p),
802 category: provider_category(&p),
803 model_count,
804 has_key,
805 description: meta.map(|m| m.description.to_string()).unwrap_or_default(),
806 env_key: meta.map(|m| m.env_key.to_string()).unwrap_or_default(),
807 }
808 })
809 .collect()
810 }
811
812 pub fn models(&self, provider: &str, query: Option<&str>) -> Vec<ModelInfo> {
818 let catalog = self.engine_handle.get().oxi().catalog().clone();
819 let live = catalog.list_models_sync(provider);
820 let models: Vec<ModelInfo> = if !live.is_empty() {
821 live.iter().map(ModelInfo::from).collect()
822 } else {
823 oxi_sdk::get_provider_models(provider)
824 .iter()
825 .map(ModelInfo::from)
826 .collect()
827 };
828 models
829 .into_iter()
830 .filter(|m| !m.name.contains("latest"))
831 .filter(|m| {
832 if let Some(q) = query {
833 let q = q.to_lowercase();
834 m.name.to_lowercase().contains(&q)
835 || m.id.to_lowercase().contains(&q)
836 || m.provider.to_lowercase().contains(&q)
837 } else {
838 true
839 }
840 })
841 .collect()
842 }
843
844 pub fn search_models(&self, query: &str) -> Vec<ModelInfo> {
849 let catalog = self.engine_handle.get().oxi().catalog().clone();
850 let live = catalog.search_sync(query);
851 if !live.is_empty() {
852 live.iter().map(ModelInfo::from).collect()
853 } else {
854 oxi_sdk::search_models(query)
855 .into_iter()
856 .map(ModelInfo::from)
857 .collect()
858 }
859 }
860
861 pub fn config(&self) -> EngineConfigResponse {
863 let cfg = self.config.read();
864 let provider =
865 CredentialStore::provider_from_model(&cfg.engine.default_model).map(|s| s.to_string());
866 let api_key_source = provider.as_deref().and_then(|p| {
867 CredentialStore::resolve(p, cfg.api_key().as_deref()).map(|(_, src)| {
868 match src {
869 crate::credential::CredentialSource::EnvVar => "env",
870 crate::credential::CredentialSource::Config => "config",
871 crate::credential::CredentialSource::OxiAuthStore => "auth_store",
872 }
873 .to_string()
874 })
875 });
876 let api_key_set = provider
877 .as_deref()
878 .map(|p| CredentialStore::has_credential(p, cfg.api_key().as_deref()))
879 .unwrap_or(false);
880
881 EngineConfigResponse {
882 default_model: cfg.engine.default_model.clone(),
883 api_key_set,
884 api_key_source,
885 provider,
886 routing: RoutingConfigSnapshot {
887 routing_enabled: cfg.engine.routing_enabled,
888 prefer_cost_efficient: cfg.engine.prefer_cost_efficient,
889 fallback_models: cfg.engine.fallback_models.clone(),
890 excluded_models: cfg.engine.excluded_models.clone(),
891 },
892 }
893 }
894
895 pub fn routing_stats_snapshot(&self) -> RoutingStatsSnapshot {
897 self.routing_stats.snapshot()
898 }
899
900 pub fn fallback_history(&self, limit: usize) -> Vec<FallbackEvent> {
902 self.routing_stats.fallback_history(limit)
903 }
904
905 pub fn set_model(&self, model_id: &str) -> anyhow::Result<()> {
912 {
918 let engine = self.engine_handle.get();
919 let model = engine
920 .resolve_model(model_id)
921 .with_context(|| format!("Unknown model '{model_id}'"))?;
922 engine.create_provider(&model.provider).with_context(|| {
923 format!(
924 "Provider '{}' is not configured for '{model_id}'",
925 model.provider
926 )
927 })?;
928 }
929 {
930 let mut cfg = self.config.write();
931 cfg.engine.default_model = model_id.to_string();
932 self.persist(&cfg)?;
933 }
934 tracing::info!(model = %model_id, "Default model updated in config");
935 self.rebuild_and_swap();
936 Ok(())
937 }
938
939 pub fn set_api_key(&self, provider: &str, key: &str) -> anyhow::Result<()> {
945 CredentialStore::store(provider, key)?;
946
947 let cfg = self.config.read();
949 if let Some(current_provider) =
950 CredentialStore::provider_from_model(&cfg.engine.default_model)
951 && current_provider == provider
952 {
953 drop(cfg);
954 let mut cfg = self.config.write();
955 cfg.engine.api_key = Some(key.to_string());
956 self.persist(&cfg)?;
957 }
958 tracing::info!(provider = %provider, "API key stored");
959 self.rebuild_and_swap();
960 Ok(())
961 }
962
963 pub fn set_provider_options(&self, opts: &oxi_sdk::ProviderOptions) -> anyhow::Result<()> {
968 {
969 let mut cfg = self.config.write();
970 cfg.engine.provider_options = Some(opts.clone());
971 self.persist(&cfg)?;
972 }
973 tracing::info!("Provider options updated and persisted");
974 Ok(())
977 }
978
979 pub fn set_routing(&self, update: RoutingUpdate) -> anyhow::Result<()> {
984 {
985 let mut cfg = self.config.write();
986 if let Some(v) = update.routing_enabled {
987 cfg.engine.routing_enabled = v;
988 }
989 if let Some(v) = update.prefer_cost_efficient {
990 cfg.engine.prefer_cost_efficient = v;
991 }
992 if let Some(v) = update.fallback_models {
993 cfg.engine.fallback_models = v;
994 }
995 if let Some(v) = update.excluded_models {
996 cfg.engine.excluded_models = v;
997 }
998 self.persist(&cfg)?;
999 }
1000 tracing::info!("Routing configuration updated via API");
1001 self.rebuild_and_swap();
1002 Ok(())
1003 }
1004
1005 pub fn validate_key(&self, provider: &str, api_key: &str) -> ValidateKeyResult {
1010 let result = self.try_validate(provider, api_key);
1012 match result {
1013 Ok(()) => ValidateKeyResult {
1014 valid: true,
1015 provider: provider.to_string(),
1016 message: Some("API key is valid".to_string()),
1017 },
1018 Err(e) => ValidateKeyResult {
1019 valid: false,
1020 provider: provider.to_string(),
1021 message: Some(format!("Validation failed: {e}")),
1022 },
1023 }
1024 }
1025
1026 fn try_validate(&self, provider: &str, api_key: &str) -> anyhow::Result<()> {
1028 let builder = oxi_sdk::OxiBuilder::new()
1030 .with_builtins()
1031 .api_key(provider, api_key);
1032 let oxi = builder.build();
1033
1034 let models = oxi_sdk::get_provider_models(provider);
1036 if models.is_empty() {
1037 anyhow::bail!("No models found for provider '{provider}'");
1038 }
1039
1040 let model_id = format!("{}/{}", provider, models[0].id);
1041 let _model = oxi.resolve_model(&model_id)?;
1042
1043 let _provider = oxi.create_provider(provider)?;
1045
1046 if api_key.is_empty() {
1050 anyhow::bail!("API key is empty");
1051 }
1052 if api_key.len() < 8 {
1053 anyhow::bail!("API key appears too short");
1054 }
1055
1056 tracing::debug!(
1057 provider = %provider,
1058 model = %model_id,
1059 "Key validation: provider resolved with injected key"
1060 );
1061 Ok(())
1062 }
1063
1064 pub fn estimate_cost(model_id: &str, input_tokens: u64, output_tokens: u64) -> f64 {
1068 estimate_cost(model_id, input_tokens, output_tokens)
1069 }
1070
1071 fn persist(&self, config: &OxiosConfig) -> anyhow::Result<()> {
1073 let content = toml::to_string_pretty(config)
1074 .map_err(|e| anyhow::anyhow!("Failed to serialize config: {e}"))?;
1075 std::fs::write(&self.config_path, content)?;
1076 Ok(())
1077 }
1078
1079 fn rebuild_and_swap(&self) {
1086 let cfg = self.config.read();
1087 let model_id = &cfg.engine.default_model;
1088 let catalog = self.engine_handle.get().oxi().catalog().clone();
1090 let new_engine = crate::engine::OxiosEngine::from_config_with_catalog(
1091 model_id,
1092 cfg.api_key().as_deref(),
1093 catalog,
1094 );
1095 drop(cfg);
1096 self.engine_handle.swap(new_engine);
1097 }
1098}
1099
1100impl std::fmt::Debug for EngineApi {
1101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1102 f.debug_struct("EngineApi")
1103 .field("config_path", &self.config_path)
1104 .finish()
1105 }
1106}
1107
1108pub fn record_usage_to_stats(
1111 stats: &Option<Arc<RoutingStats>>,
1112 model_id: &str,
1113 input_tokens: u64,
1114 output_tokens: u64,
1115) {
1116 if let Some(s) = stats {
1117 let cost = estimate_cost(model_id, input_tokens, output_tokens);
1118 s.record_model_usage(model_id, cost);
1119 }
1120}
1121
1122#[cfg(test)]
1123mod tests {
1124 use super::*;
1125
1126 #[test]
1127 fn test_provider_category_known() {
1128 assert_eq!(provider_category("anthropic"), ProviderCategory::Major);
1130 assert_eq!(provider_category("openai"), ProviderCategory::Major);
1131 assert_eq!(provider_category("google"), ProviderCategory::Major);
1132 assert_eq!(provider_category("groq"), ProviderCategory::Open);
1134 assert_eq!(provider_category("opencode"), ProviderCategory::Open);
1135 assert_eq!(provider_category("minimax"), ProviderCategory::Regional);
1137 assert_eq!(provider_category("moonshotai"), ProviderCategory::Regional);
1138 assert_eq!(provider_category("kimi-coding"), ProviderCategory::Regional);
1139 assert_eq!(provider_category("zai"), ProviderCategory::Regional);
1140 assert_eq!(provider_category("minimax-cn"), ProviderCategory::Regional);
1141 assert_eq!(provider_category("xiaomi"), ProviderCategory::Regional);
1142 }
1143
1144 #[test]
1145 fn test_provider_category_fallback() {
1146 assert_eq!(
1148 provider_category("not-a-real-provider"),
1149 ProviderCategory::Open
1150 );
1151 assert_eq!(provider_category(""), ProviderCategory::Open);
1152 }
1153
1154 #[test]
1155 fn test_provider_display_name_known() {
1156 assert_eq!(provider_display_name("anthropic"), "Anthropic");
1157 assert_eq!(provider_display_name("minimax"), "MiniMax");
1158 assert_eq!(provider_display_name("moonshotai"), "Moonshot AI (Kimi)");
1159 assert_eq!(provider_display_name("kimi-coding"), "Kimi Coding");
1160 assert_eq!(provider_display_name("zai"), "Z.AI (GLM)");
1161 assert_eq!(provider_display_name("opencode"), "OpenCode");
1162 assert_eq!(provider_display_name("amazon-bedrock"), "Amazon Bedrock");
1163 }
1164
1165 #[test]
1166 fn test_provider_display_name_fallback() {
1167 assert_eq!(
1169 provider_display_name("some-new-provider"),
1170 "Some New Provider"
1171 );
1172 assert_eq!(provider_display_name("kimi-coding"), "Kimi Coding");
1173 assert_eq!(provider_display_name("some_id"), "Some Id");
1174 assert_eq!(provider_display_name(""), "");
1176 }
1177
1178 #[test]
1179 fn test_provider_meta_lookup_by_alias() {
1180 let by_id = provider_meta("github-copilot").unwrap();
1182 let by_alias = provider_meta("copilot").unwrap();
1183 assert_eq!(by_id.id, by_alias.id);
1184
1185 let bedrock_id = provider_meta("amazon-bedrock").unwrap();
1186 let bedrock_alias = provider_meta("aws-bedrock").unwrap();
1187 let bedrock_canonical = provider_meta("bedrock").unwrap();
1188 assert_eq!(bedrock_id.id, bedrock_alias.id);
1189 assert_eq!(bedrock_id.id, bedrock_canonical.id);
1190 }
1191
1192 #[test]
1193 fn test_provider_meta_unknown_is_none() {
1194 assert!(provider_meta("not-a-real-provider").is_none());
1195 assert!(provider_meta("").is_none());
1196 }
1197
1198 #[test]
1199 fn test_provider_info_serialization() {
1200 let info = ProviderInfo {
1201 id: "anthropic".to_string(),
1202 name: "Anthropic".to_string(),
1203 category: ProviderCategory::Major,
1204 model_count: 15,
1205 has_key: true,
1206 description: "Claude models with extended thinking".to_string(),
1207 env_key: "ANTHROPIC_API_KEY".to_string(),
1208 };
1209 let json = serde_json::to_string(&info).unwrap();
1210 assert!(json.contains("\"modelCount\":15"));
1212 assert!(json.contains("\"hasKey\":true"));
1213 assert!(json.contains("\"envKey\":\"ANTHROPIC_API_KEY\""));
1214 let restored: ProviderInfo = serde_json::from_str(&json).unwrap();
1215 assert_eq!(restored.id, "anthropic");
1216 assert_eq!(restored.name, "Anthropic");
1217 assert_eq!(restored.model_count, 15);
1218 assert!(restored.has_key);
1219 assert_eq!(restored.env_key, "ANTHROPIC_API_KEY");
1220 }
1221
1222 #[test]
1223 fn test_provider_info_serialization_missing_optional() {
1224 let json = r#"{
1227 "id": "anthropic",
1228 "name": "Anthropic",
1229 "category": "major",
1230 "modelCount": 15,
1231 "hasKey": true
1232 }"#;
1233 let info: ProviderInfo = serde_json::from_str(json).unwrap();
1234 assert_eq!(info.id, "anthropic");
1235 assert_eq!(info.description, "");
1236 assert_eq!(info.env_key, "");
1237 }
1238
1239 #[test]
1240 fn test_model_info_serialization() {
1241 let info = ModelInfo {
1242 id: "anthropic/claude-sonnet-4".to_string(),
1243 name: "Claude Sonnet 4".to_string(),
1244 api: "anthropic-messages".to_string(),
1245 provider: "anthropic".to_string(),
1246 reasoning: true,
1247 input: vec![InputModality::Text, InputModality::Image],
1248 context_window: 200000,
1249 max_tokens: 16000,
1250 cost_input: 3.0,
1251 cost_output: 15.0,
1252 cost_cache_read: 0.3,
1253 cost_cache_write: 3.75,
1254 };
1255 let json = serde_json::to_string(&info).unwrap();
1256 let restored: ModelInfo = serde_json::from_str(&json).unwrap();
1257 assert_eq!(restored.id, "anthropic/claude-sonnet-4");
1258 assert!(restored.reasoning);
1259 assert_eq!(restored.context_window, 200000);
1260 assert!(restored.input.contains(&InputModality::Image));
1261 assert_eq!(restored.api, "anthropic-messages");
1262 }
1263
1264 #[test]
1265 fn test_engine_config_response_serialization() {
1266 let resp = EngineConfigResponse {
1267 default_model: "anthropic/claude-sonnet-4".to_string(),
1268 api_key_set: true,
1269 api_key_source: Some("config.toml".to_string()),
1270 provider: Some("anthropic".to_string()),
1271 routing: RoutingConfigSnapshot {
1272 routing_enabled: false,
1273 prefer_cost_efficient: false,
1274 fallback_models: vec![],
1275 excluded_models: vec![],
1276 },
1277 };
1278 let json = serde_json::to_string(&resp).unwrap();
1279 let restored: EngineConfigResponse = serde_json::from_str(&json).unwrap();
1280 assert_eq!(restored.default_model, "anthropic/claude-sonnet-4");
1281 assert!(restored.api_key_set);
1282 assert_eq!(restored.api_key_source.as_deref(), Some("config.toml"));
1283 assert!(!restored.routing.routing_enabled);
1284 }
1285
1286 #[test]
1287 fn test_validate_key_result_serialization() {
1288 let result = ValidateKeyResult {
1289 valid: true,
1290 provider: "openai".to_string(),
1291 message: Some("API key is valid".to_string()),
1292 };
1293 let json = serde_json::to_string(&result).unwrap();
1294 let restored: ValidateKeyResult = serde_json::from_str(&json).unwrap();
1295 assert!(restored.valid);
1296 assert_eq!(restored.provider, "openai");
1297 }
1298
1299 #[test]
1300 fn test_validate_key_result_invalid() {
1301 let result = ValidateKeyResult {
1302 valid: false,
1303 provider: "anthropic".to_string(),
1304 message: Some("Validation failed: key too short".to_string()),
1305 };
1306 assert!(!result.valid);
1307 assert!(result.message.as_ref().unwrap().contains("failed"));
1308 }
1309
1310 #[test]
1311 fn test_routing_stats_snapshot() {
1312 let stats = RoutingStats::new();
1313 stats.record_model_usage("anthropic/claude-sonnet-4", 0.05);
1314 stats.record_model_usage("anthropic/claude-sonnet-4", 0.03);
1315 stats.record_model_usage("openai/gpt-4o-mini", 0.01);
1316
1317 let snap = stats.snapshot();
1318 assert_eq!(snap.total_requests, 3);
1319 assert_eq!(snap.model_calls["anthropic/claude-sonnet-4"], 2);
1320 assert_eq!(snap.model_calls["openai/gpt-4o-mini"], 1);
1321 assert!((snap.total_cost - 0.09).abs() < 0.001);
1322 }
1323
1324 #[test]
1325 fn test_fallback_history_circular() {
1326 let stats = RoutingStats::new();
1327 for i in 0..210 {
1328 stats.record_fallback(FallbackEvent {
1329 timestamp: DateTime::from_timestamp(i as i64, 0).unwrap(),
1330 from_model: format!("model-{}", i),
1331 to_model: "fallback".to_string(),
1332 reason: "test".to_string(),
1333 success: true,
1334 });
1335 }
1336 let history = stats.fallback_history(200);
1337 assert_eq!(history.len(), 200);
1338 assert_eq!(history[0].from_model, "model-209");
1340 assert_eq!(history[199].from_model, "model-10");
1341 }
1342
1343 #[test]
1344 fn set_model_rejects_unknown_model_before_persist() {
1345 use crate::engine::{EngineHandle, OxiosEngine};
1346
1347 let engine = Arc::new(OxiosEngine::new("anthropic/claude-sonnet-4-20250514"));
1348 let handle = Arc::new(EngineHandle::new(engine));
1349 let config = Arc::new(parking_lot::RwLock::new(OxiosConfig::default()));
1350 let path = PathBuf::from("/tmp/oxios-set-model-test-NONEXISTENT.toml");
1353 let api = EngineApi::new(config, path, Arc::new(RoutingStats::new()), handle);
1354
1355 let before = api.config.read().engine.default_model.clone();
1358 let err = api.set_model("zai-coding-plan/glm-5-turbo").unwrap_err();
1359 assert!(
1360 err.to_string().contains("Unknown model"),
1361 "expected unknown-model error, got: {err}"
1362 );
1363 assert_eq!(api.config.read().engine.default_model, before);
1365 }
1366
1367 #[test]
1368 fn set_model_accepts_known_builtin_model() {
1369 use crate::engine::{EngineHandle, OxiosEngine};
1370
1371 let engine = Arc::new(OxiosEngine::new("anthropic/claude-sonnet-4-20250514"));
1372 let handle = Arc::new(EngineHandle::new(engine));
1373 let config = Arc::new(parking_lot::RwLock::new(OxiosConfig::default()));
1374 let tmp =
1375 std::env::temp_dir().join(format!("oxios-set-model-ok-{}.toml", std::process::id()));
1376 let api = EngineApi::new(config, tmp.clone(), Arc::new(RoutingStats::new()), handle);
1377
1378 let result = api.set_model("openai/gpt-4o");
1381 match result {
1385 Ok(()) => assert_eq!(api.config.read().engine.default_model, "openai/gpt-4o"),
1386 Err(e) => assert!(
1387 !e.to_string().contains("Unknown model"),
1388 "known model rejected as unknown: {e}"
1389 ),
1390 }
1391 let _ = std::fs::remove_file(&tmp);
1392 }
1393}