1use serde::{Deserialize, Serialize};
2use std::cell::RefCell;
3use std::collections::BTreeMap;
4use std::sync::OnceLock;
5
6static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
7static CONFIG_PATH: OnceLock<String> = OnceLock::new();
8
9thread_local! {
10 static USER_OVERRIDES: RefCell<Option<ProvidersConfig>> = const { RefCell::new(None) };
15}
16
17#[derive(Debug, Clone, Deserialize, Default)]
18pub struct ProvidersConfig {
19 #[serde(default)]
20 pub providers: BTreeMap<String, ProviderDef>,
21 #[serde(default)]
22 pub aliases: BTreeMap<String, AliasDef>,
23 #[serde(default)]
24 pub models: BTreeMap<String, ModelDef>,
25 #[serde(default)]
26 pub qc_defaults: BTreeMap<String, String>,
27 #[serde(default)]
28 pub inference_rules: Vec<InferenceRule>,
29 #[serde(default)]
30 pub tier_rules: Vec<TierRule>,
31 #[serde(default)]
32 pub tier_defaults: TierDefaults,
33 #[serde(default)]
34 pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
35}
36
37impl ProvidersConfig {
38 pub fn is_empty(&self) -> bool {
39 self.providers.is_empty()
40 && self.aliases.is_empty()
41 && self.models.is_empty()
42 && self.qc_defaults.is_empty()
43 && self.inference_rules.is_empty()
44 && self.tier_rules.is_empty()
45 && self.model_defaults.is_empty()
46 && self.tier_defaults.default == default_mid()
47 }
48
49 pub fn merge_from(&mut self, overlay: &ProvidersConfig) {
50 self.providers.extend(overlay.providers.clone());
51 self.aliases.extend(overlay.aliases.clone());
52 self.models.extend(overlay.models.clone());
53 self.qc_defaults.extend(overlay.qc_defaults.clone());
54
55 if !overlay.inference_rules.is_empty() {
56 let mut merged = overlay.inference_rules.clone();
57 merged.extend(self.inference_rules.clone());
58 self.inference_rules = merged;
59 }
60
61 if !overlay.tier_rules.is_empty() {
62 let mut merged = overlay.tier_rules.clone();
63 merged.extend(self.tier_rules.clone());
64 self.tier_rules = merged;
65 }
66
67 if overlay.tier_defaults.default != default_mid() {
68 self.tier_defaults = overlay.tier_defaults.clone();
69 }
70
71 for (pattern, defaults) in &overlay.model_defaults {
72 self.model_defaults
73 .entry(pattern.clone())
74 .or_default()
75 .extend(defaults.clone());
76 }
77 }
78}
79
80#[derive(Debug, Clone, Deserialize)]
81pub struct ProviderDef {
82 #[serde(default)]
83 pub display_name: Option<String>,
84 #[serde(default)]
85 pub icon: Option<String>,
86 pub base_url: String,
87 #[serde(default)]
88 pub base_url_env: Option<String>,
89 #[serde(default = "default_bearer")]
90 pub auth_style: String,
91 #[serde(default)]
92 pub auth_header: Option<String>,
93 #[serde(default)]
94 pub auth_env: AuthEnv,
95 #[serde(default)]
96 pub extra_headers: BTreeMap<String, String>,
97 #[serde(default)]
98 pub chat_endpoint: String,
99 #[serde(default)]
100 pub completion_endpoint: Option<String>,
101 #[serde(default)]
102 pub healthcheck: Option<HealthcheckDef>,
103 #[serde(default)]
104 pub features: Vec<String>,
105 #[serde(default)]
107 pub fallback: Option<String>,
108 #[serde(default)]
110 pub retry_count: Option<u32>,
111 #[serde(default)]
113 pub retry_delay_ms: Option<u64>,
114 #[serde(default)]
116 pub rpm: Option<u32>,
117 #[serde(default)]
119 pub cost_per_1k_in: Option<f64>,
120 #[serde(default)]
122 pub cost_per_1k_out: Option<f64>,
123 #[serde(default)]
125 pub latency_p50_ms: Option<u64>,
126}
127
128impl Default for ProviderDef {
129 fn default() -> Self {
130 Self {
131 display_name: None,
132 icon: None,
133 base_url: String::new(),
134 base_url_env: None,
135 auth_style: default_bearer(),
136 auth_header: None,
137 auth_env: AuthEnv::None,
138 extra_headers: BTreeMap::new(),
139 chat_endpoint: String::new(),
140 completion_endpoint: None,
141 healthcheck: None,
142 features: Vec::new(),
143 fallback: None,
144 retry_count: None,
145 retry_delay_ms: None,
146 rpm: None,
147 cost_per_1k_in: None,
148 cost_per_1k_out: None,
149 latency_p50_ms: None,
150 }
151 }
152}
153
154fn default_bearer() -> String {
155 "bearer".to_string()
156}
157
158#[derive(Debug, Clone, Deserialize, Default)]
161#[serde(untagged)]
162pub enum AuthEnv {
163 #[default]
164 None,
165 Single(String),
166 Multiple(Vec<String>),
167}
168
169#[derive(Debug, Clone, Deserialize)]
170pub struct HealthcheckDef {
171 pub method: String,
172 #[serde(default)]
173 pub path: Option<String>,
174 #[serde(default)]
175 pub url: Option<String>,
176 #[serde(default)]
177 pub body: Option<String>,
178}
179
180#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
181pub struct AliasDef {
182 pub id: String,
183 pub provider: String,
184 #[serde(default)]
189 pub tool_format: Option<String>,
190}
191
192#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
193pub struct ModelPricing {
194 pub input_per_mtok: f64,
195 pub output_per_mtok: f64,
196 #[serde(default)]
197 pub cache_read_per_mtok: Option<f64>,
198 #[serde(default)]
199 pub cache_write_per_mtok: Option<f64>,
200}
201
202#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
203pub struct ModelDef {
204 pub name: String,
205 pub provider: String,
206 pub context_window: u64,
207 #[serde(default)]
208 pub stream_timeout: Option<f64>,
209 #[serde(default)]
210 pub capabilities: Vec<String>,
211 #[serde(default)]
212 pub pricing: Option<ModelPricing>,
213}
214
215#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
216pub struct ResolvedModel {
217 pub id: String,
218 pub provider: String,
219 pub alias: Option<String>,
220 pub tool_format: String,
221 pub tier: String,
222}
223
224#[derive(Debug, Clone, Deserialize)]
225pub struct InferenceRule {
226 #[serde(default)]
227 pub pattern: Option<String>,
228 #[serde(default)]
229 pub contains: Option<String>,
230 #[serde(default)]
231 pub exact: Option<String>,
232 pub provider: String,
233}
234
235#[derive(Debug, Clone, Deserialize)]
236pub struct TierRule {
237 #[serde(default)]
238 pub pattern: Option<String>,
239 #[serde(default)]
240 pub contains: Option<String>,
241 #[serde(default)]
242 pub exact: Option<String>,
243 pub tier: String,
244}
245
246#[derive(Debug, Clone, Deserialize)]
247pub struct TierDefaults {
248 #[serde(default = "default_mid")]
249 pub default: String,
250}
251
252impl Default for TierDefaults {
253 fn default() -> Self {
254 Self {
255 default: default_mid(),
256 }
257 }
258}
259
260fn default_mid() -> String {
261 "mid".to_string()
262}
263
264pub fn load_config() -> &'static ProvidersConfig {
266 CONFIG.get_or_init(|| {
267 let verbose_config_logging = matches!(
268 std::env::var("HARN_VERBOSE_CONFIG").ok().as_deref(),
269 Some("1" | "true" | "TRUE" | "yes" | "YES")
270 ) || matches!(
271 std::env::var("HARN_ACP_VERBOSE").ok().as_deref(),
272 Some("1" | "true" | "TRUE" | "yes" | "YES")
273 );
274 if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
275 match std::fs::read_to_string(&path) {
276 Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
277 Ok(config) => {
278 if verbose_config_logging {
279 eprintln!(
280 "[llm_config] Loaded {} providers, {} aliases from {}",
281 config.providers.len(),
282 config.aliases.len(),
283 path
284 );
285 }
286 let _ = CONFIG_PATH.set(path);
287 return config;
288 }
289 Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
290 },
291 Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
292 }
293 }
294 if let Some(home) = dirs_or_home() {
295 let path = format!("{home}/.config/harn/providers.toml");
296 if let Ok(content) = std::fs::read_to_string(&path) {
297 if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
298 let _ = CONFIG_PATH.set(path);
299 return config;
300 }
301 }
302 }
303 default_config()
304 })
305}
306
307pub fn loaded_config_path() -> Option<std::path::PathBuf> {
310 let _ = load_config();
312 CONFIG_PATH.get().map(std::path::PathBuf::from)
313}
314
315pub fn set_user_overrides(config: Option<ProvidersConfig>) {
319 USER_OVERRIDES.with(|cell| *cell.borrow_mut() = config);
320}
321
322pub fn clear_user_overrides() {
324 set_user_overrides(None);
325}
326
327fn effective_config() -> ProvidersConfig {
328 let mut merged = load_config().clone();
329 USER_OVERRIDES.with(|cell| {
330 if let Some(overlay) = cell.borrow().as_ref() {
331 merged.merge_from(overlay);
332 }
333 });
334 merged
335}
336
337pub fn resolve_model(alias: &str) -> (String, Option<String>) {
339 let config = effective_config();
340 if let Some(a) = config.aliases.get(alias) {
341 return (a.id.clone(), Some(a.provider.clone()));
342 }
343 (normalize_model_id(alias), None)
344}
345
346pub fn normalize_model_id(raw: &str) -> String {
351 for prefix in ["ollama:", "local:", "huggingface:", "hf:"] {
352 if let Some(stripped) = raw.strip_prefix(prefix) {
353 return stripped.to_string();
354 }
355 }
356 raw.to_string()
357}
358
359pub fn resolve_model_info(selector: &str) -> ResolvedModel {
362 let config = effective_config();
363 if let Some(alias) = config.aliases.get(selector) {
364 let id = alias.id.clone();
365 let provider = alias.provider.clone();
366 let tool_format = alias
367 .tool_format
368 .clone()
369 .unwrap_or_else(|| default_tool_format_with_config(&config, &id, &provider));
370 return ResolvedModel {
371 tier: model_tier_with_config(&config, &id),
372 id,
373 provider,
374 alias: Some(selector.to_string()),
375 tool_format,
376 };
377 }
378
379 let provider = infer_provider_with_config(&config, selector);
380 let id = normalize_model_id(selector);
381 let tool_format = default_tool_format_with_config(&config, &id, &provider);
382 let tier = model_tier_with_config(&config, &id);
383 ResolvedModel {
384 id,
385 provider,
386 alias: None,
387 tool_format,
388 tier,
389 }
390}
391
392pub fn infer_provider(model_id: &str) -> String {
394 let config = effective_config();
395 infer_provider_with_config(&config, model_id)
396}
397
398fn infer_provider_with_config(config: &ProvidersConfig, model_id: &str) -> String {
399 if model_id.starts_with("local:") {
400 return "local".to_string();
401 }
402 if model_id.starts_with("ollama:") {
403 return "ollama".to_string();
404 }
405 if model_id.starts_with("huggingface:") || model_id.starts_with("hf:") {
406 return "huggingface".to_string();
407 }
408 for rule in &config.inference_rules {
409 if let Some(exact) = &rule.exact {
410 if model_id == exact {
411 return rule.provider.clone();
412 }
413 }
414 if let Some(pattern) = &rule.pattern {
415 if glob_match(pattern, model_id) {
416 return rule.provider.clone();
417 }
418 }
419 if let Some(substr) = &rule.contains {
420 if model_id.contains(substr.as_str()) {
421 return rule.provider.clone();
422 }
423 }
424 }
425 if model_id.starts_with("claude-") {
430 return "anthropic".to_string();
431 }
432 if model_id.to_lowercase().starts_with("or-") {
433 return "openrouter".to_string();
434 }
435 if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
436 return "openai".to_string();
437 }
438 if model_id.contains('/') {
439 return "openrouter".to_string();
440 }
441 if model_id.contains(':') {
442 return "ollama".to_string();
443 }
444 "anthropic".to_string()
445}
446
447pub fn model_tier(model_id: &str) -> String {
449 let config = effective_config();
450 model_tier_with_config(&config, model_id)
451}
452
453fn model_tier_with_config(config: &ProvidersConfig, model_id: &str) -> String {
454 for rule in &config.tier_rules {
455 if let Some(exact) = &rule.exact {
456 if model_id == exact {
457 return rule.tier.clone();
458 }
459 }
460 if let Some(pattern) = &rule.pattern {
461 if glob_match(pattern, model_id) {
462 return rule.tier.clone();
463 }
464 }
465 if let Some(substr) = &rule.contains {
466 if model_id.contains(substr.as_str()) {
467 return rule.tier.clone();
468 }
469 }
470 }
471 let lower = model_id.to_lowercase();
472 if lower.contains("9b") || lower.contains("a3b") {
473 return "small".to_string();
474 }
475 if lower.starts_with("claude-") || lower == "gpt-4o" {
476 return "frontier".to_string();
477 }
478 config.tier_defaults.default.clone()
479}
480
481pub fn provider_config(name: &str) -> Option<ProviderDef> {
483 effective_config().providers.get(name).cloned()
484}
485
486pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
489 let config = effective_config();
490 let mut params = BTreeMap::new();
491 for (pattern, defaults) in &config.model_defaults {
492 if glob_match(pattern, model_id) {
493 for (k, v) in defaults {
494 params.insert(k.clone(), v.clone());
495 }
496 }
497 }
498 params
499}
500
501pub fn provider_names() -> Vec<String> {
503 effective_config().providers.keys().cloned().collect()
504}
505
506pub fn known_model_names() -> Vec<String> {
508 effective_config().aliases.keys().cloned().collect()
509}
510
511pub fn alias_entries() -> Vec<(String, AliasDef)> {
512 effective_config().aliases.into_iter().collect()
513}
514
515pub fn model_catalog_entries() -> Vec<(String, ModelDef)> {
517 let mut entries: Vec<_> = effective_config().models.into_iter().collect();
518 entries.sort_by(|(id_a, model_a), (id_b, model_b)| {
519 model_a
520 .provider
521 .cmp(&model_b.provider)
522 .then_with(|| id_a.cmp(id_b))
523 });
524 entries
525}
526
527pub fn model_catalog_entry(model_id: &str) -> Option<ModelDef> {
528 effective_config().models.get(model_id).cloned()
529}
530
531pub fn qc_default_model(provider: &str) -> Option<String> {
532 std::env::var("BURIN_QC_MODEL")
533 .ok()
534 .filter(|value| !value.trim().is_empty())
535 .or_else(|| {
536 effective_config()
537 .qc_defaults
538 .get(&provider.to_lowercase())
539 .cloned()
540 })
541}
542
543pub fn qc_defaults() -> BTreeMap<String, String> {
544 effective_config().qc_defaults
545}
546
547pub fn model_pricing_per_mtok(model_id: &str) -> Option<ModelPricing> {
548 effective_config()
549 .models
550 .get(model_id)
551 .and_then(|model| model.pricing.clone())
552}
553
554pub fn pricing_per_1k_for(provider: &str, model_id: &str) -> Option<(f64, f64)> {
555 model_pricing_per_mtok(model_id)
556 .map(|pricing| {
557 (
558 pricing.input_per_mtok / 1000.0,
559 pricing.output_per_mtok / 1000.0,
560 )
561 })
562 .or_else(|| {
563 let (input, output, _) = provider_economics(provider);
564 match (input, output) {
565 (Some(input), Some(output)) => Some((input, output)),
566 _ => None,
567 }
568 })
569}
570
571pub fn auth_env_names(auth_env: &AuthEnv) -> Vec<String> {
572 match auth_env {
573 AuthEnv::None => Vec::new(),
574 AuthEnv::Single(name) => vec![name.clone()],
575 AuthEnv::Multiple(names) => names.clone(),
576 }
577}
578
579pub fn provider_key_available(provider: &str) -> bool {
580 let Some(pdef) = provider_config(provider) else {
581 return provider == "ollama";
582 };
583 if pdef.auth_style == "none" || matches!(pdef.auth_env, AuthEnv::None) {
584 return true;
585 }
586 auth_env_names(&pdef.auth_env).into_iter().any(|env_name| {
587 std::env::var(env_name)
588 .ok()
589 .is_some_and(|value| !value.trim().is_empty())
590 })
591}
592
593pub fn available_provider_names() -> Vec<String> {
594 provider_names()
595 .into_iter()
596 .filter(|provider| provider_key_available(provider))
597 .collect()
598}
599
600pub fn provider_has_feature(provider: &str, feature: &str) -> bool {
602 provider_config(provider)
603 .map(|p| p.features.iter().any(|f| f == feature))
604 .unwrap_or(false)
605}
606
607pub fn provider_economics(provider: &str) -> (Option<f64>, Option<f64>, Option<u64>) {
611 provider_config(provider)
612 .map(|p| (p.cost_per_1k_in, p.cost_per_1k_out, p.latency_p50_ms))
613 .unwrap_or((None, None, None))
614}
615
616pub fn default_tool_format(model: &str, provider: &str) -> String {
619 let config = effective_config();
620 default_tool_format_with_config(&config, model, provider)
621}
622
623fn default_tool_format_with_config(
624 config: &ProvidersConfig,
625 model: &str,
626 provider: &str,
627) -> String {
628 for (name, alias) in &config.aliases {
630 let matches = (alias.id == model && alias.provider == provider) || name == model;
631 if matches {
632 if let Some(ref fmt) = alias.tool_format {
633 return fmt.clone();
634 }
635 }
636 }
637 if config
638 .providers
639 .get(provider)
640 .map(|p| p.features.iter().any(|f| f == "native_tools"))
641 .unwrap_or(false)
642 {
643 "native".to_string()
644 } else {
645 "text".to_string()
646 }
647}
648
649pub fn resolve_tier_model(
651 target: &str,
652 preferred_provider: Option<&str>,
653) -> Option<(String, String)> {
654 let config = effective_config();
655
656 if let Some(alias) = config.aliases.get(target) {
657 return Some((alias.id.clone(), alias.provider.clone()));
658 }
659
660 let candidate_aliases = if let Some(provider) = preferred_provider {
661 vec![
662 format!("{provider}/{target}"),
663 format!("{provider}:{target}"),
664 format!("tier/{target}"),
665 target.to_string(),
666 ]
667 } else {
668 vec![format!("tier/{target}"), target.to_string()]
669 };
670
671 for alias_name in candidate_aliases {
672 if let Some(alias) = config.aliases.get(&alias_name) {
673 return Some((alias.id.clone(), alias.provider.clone()));
674 }
675 }
676
677 None
678}
679
680pub fn tier_candidates(target: &str) -> Vec<(String, String)> {
684 let config = effective_config();
685 let mut seen = std::collections::BTreeSet::new();
686 let mut candidates = Vec::new();
687
688 for alias in config.aliases.values() {
689 let pair = (alias.id.clone(), alias.provider.clone());
690 if seen.contains(&pair) {
691 continue;
692 }
693 if model_tier(&alias.id) == target {
694 seen.insert(pair.clone());
695 candidates.push(pair);
696 }
697 }
698
699 candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
700 provider_a
701 .cmp(provider_b)
702 .then_with(|| model_a.cmp(model_b))
703 });
704 candidates
705}
706
707pub fn all_model_candidates() -> Vec<(String, String)> {
710 let config = effective_config();
711 let mut seen = std::collections::BTreeSet::new();
712 let mut candidates = Vec::new();
713
714 for alias in config.aliases.values() {
715 let pair = (alias.id.clone(), alias.provider.clone());
716 if seen.insert(pair.clone()) {
717 candidates.push(pair);
718 }
719 }
720
721 candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
722 provider_a
723 .cmp(provider_b)
724 .then_with(|| model_a.cmp(model_b))
725 });
726 candidates
727}
728
729fn glob_match(pattern: &str, input: &str) -> bool {
731 if let Some(prefix) = pattern.strip_suffix('*') {
732 input.starts_with(prefix)
733 } else if let Some(suffix) = pattern.strip_prefix('*') {
734 input.ends_with(suffix)
735 } else if pattern.contains('*') {
736 let parts: Vec<&str> = pattern.split('*').collect();
737 if parts.len() == 2 {
738 input.starts_with(parts[0]) && input.ends_with(parts[1])
739 } else {
740 input == pattern
741 }
742 } else {
743 input == pattern
744 }
745}
746
747fn dirs_or_home() -> Option<String> {
748 std::env::var("HOME").ok()
749}
750
751pub fn resolve_base_url(pdef: &ProviderDef) -> String {
754 if let Some(env_name) = &pdef.base_url_env {
755 if let Ok(val) = std::env::var(env_name) {
756 let trimmed = val.trim().trim_matches('"').trim_matches('\'');
758 if !trimmed.is_empty() {
759 return trimmed.to_string();
760 }
761 }
762 }
763 pdef.base_url.clone()
764}
765
766fn default_config() -> ProvidersConfig {
767 let mut config = ProvidersConfig::default();
768
769 config.providers.insert(
770 "anthropic".to_string(),
771 ProviderDef {
772 base_url: "https://api.anthropic.com/v1".to_string(),
773 auth_style: "header".to_string(),
774 auth_header: Some("x-api-key".to_string()),
775 auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
776 extra_headers: BTreeMap::from([(
777 "anthropic-version".to_string(),
778 "2023-06-01".to_string(),
779 )]),
780 chat_endpoint: "/messages".to_string(),
781 completion_endpoint: None,
782 healthcheck: Some(HealthcheckDef {
783 method: "POST".to_string(),
784 path: Some("/messages/count_tokens".to_string()),
785 url: None,
786 body: Some(
787 r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
788 .to_string(),
789 ),
790 }),
791 features: vec!["prompt_caching".to_string(), "thinking".to_string()],
792 cost_per_1k_in: Some(0.003),
793 cost_per_1k_out: Some(0.015),
794 latency_p50_ms: Some(2500),
795 ..Default::default()
796 },
797 );
798
799 config.providers.insert(
801 "openai".to_string(),
802 ProviderDef {
803 base_url: "https://api.openai.com/v1".to_string(),
804 auth_style: "bearer".to_string(),
805 auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
806 chat_endpoint: "/chat/completions".to_string(),
807 completion_endpoint: Some("/completions".to_string()),
808 healthcheck: Some(HealthcheckDef {
809 method: "GET".to_string(),
810 path: Some("/models".to_string()),
811 url: None,
812 body: None,
813 }),
814 cost_per_1k_in: Some(0.0025),
815 cost_per_1k_out: Some(0.010),
816 latency_p50_ms: Some(1800),
817 ..Default::default()
818 },
819 );
820
821 config.providers.insert(
823 "openrouter".to_string(),
824 ProviderDef {
825 base_url: "https://openrouter.ai/api/v1".to_string(),
826 auth_style: "bearer".to_string(),
827 auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
828 chat_endpoint: "/chat/completions".to_string(),
829 completion_endpoint: Some("/completions".to_string()),
830 healthcheck: Some(HealthcheckDef {
831 method: "GET".to_string(),
832 path: Some("/auth/key".to_string()),
833 url: None,
834 body: None,
835 }),
836 cost_per_1k_in: Some(0.003),
837 cost_per_1k_out: Some(0.015),
838 latency_p50_ms: Some(2200),
839 ..Default::default()
840 },
841 );
842
843 config.providers.insert(
845 "huggingface".to_string(),
846 ProviderDef {
847 base_url: "https://router.huggingface.co/v1".to_string(),
848 auth_style: "bearer".to_string(),
849 auth_env: AuthEnv::Multiple(vec![
850 "HF_TOKEN".to_string(),
851 "HUGGINGFACE_API_KEY".to_string(),
852 ]),
853 chat_endpoint: "/chat/completions".to_string(),
854 completion_endpoint: Some("/completions".to_string()),
855 healthcheck: Some(HealthcheckDef {
856 method: "GET".to_string(),
857 url: Some("https://huggingface.co/api/whoami-v2".to_string()),
858 path: None,
859 body: None,
860 }),
861 cost_per_1k_in: Some(0.0002),
862 cost_per_1k_out: Some(0.0006),
863 latency_p50_ms: Some(2400),
864 ..Default::default()
865 },
866 );
867
868 config.providers.insert(
877 "ollama".to_string(),
878 ProviderDef {
879 base_url: "http://localhost:11434".to_string(),
880 base_url_env: Some("OLLAMA_HOST".to_string()),
881 auth_style: "none".to_string(),
882 chat_endpoint: "/api/chat".to_string(),
883 completion_endpoint: Some("/api/generate".to_string()),
884 healthcheck: Some(HealthcheckDef {
885 method: "GET".to_string(),
886 path: Some("/api/tags".to_string()),
887 url: None,
888 body: None,
889 }),
890 cost_per_1k_in: Some(0.0),
891 cost_per_1k_out: Some(0.0),
892 latency_p50_ms: Some(1200),
893 ..Default::default()
894 },
895 );
896
897 config.providers.insert(
899 "together".to_string(),
900 ProviderDef {
901 base_url: "https://api.together.xyz/v1".to_string(),
902 base_url_env: Some("TOGETHER_AI_BASE_URL".to_string()),
903 auth_style: "bearer".to_string(),
904 auth_env: AuthEnv::Single("TOGETHER_AI_API_KEY".to_string()),
905 chat_endpoint: "/chat/completions".to_string(),
906 completion_endpoint: Some("/completions".to_string()),
907 healthcheck: Some(HealthcheckDef {
908 method: "GET".to_string(),
909 path: Some("/models".to_string()),
910 url: None,
911 body: None,
912 }),
913 cost_per_1k_in: Some(0.0002),
914 cost_per_1k_out: Some(0.0006),
915 latency_p50_ms: Some(1600),
916 ..Default::default()
917 },
918 );
919
920 config.providers.insert(
922 "groq".to_string(),
923 ProviderDef {
924 base_url: "https://api.groq.com/openai/v1".to_string(),
925 base_url_env: Some("GROQ_BASE_URL".to_string()),
926 auth_style: "bearer".to_string(),
927 auth_env: AuthEnv::Single("GROQ_API_KEY".to_string()),
928 chat_endpoint: "/chat/completions".to_string(),
929 completion_endpoint: Some("/completions".to_string()),
930 healthcheck: Some(HealthcheckDef {
931 method: "GET".to_string(),
932 path: Some("/models".to_string()),
933 url: None,
934 body: None,
935 }),
936 cost_per_1k_in: Some(0.0001),
937 cost_per_1k_out: Some(0.0003),
938 latency_p50_ms: Some(450),
939 ..Default::default()
940 },
941 );
942
943 config.providers.insert(
945 "deepseek".to_string(),
946 ProviderDef {
947 base_url: "https://api.deepseek.com/v1".to_string(),
948 base_url_env: Some("DEEPSEEK_BASE_URL".to_string()),
949 auth_style: "bearer".to_string(),
950 auth_env: AuthEnv::Single("DEEPSEEK_API_KEY".to_string()),
951 chat_endpoint: "/chat/completions".to_string(),
952 completion_endpoint: Some("/completions".to_string()),
953 healthcheck: Some(HealthcheckDef {
954 method: "GET".to_string(),
955 path: Some("/models".to_string()),
956 url: None,
957 body: None,
958 }),
959 cost_per_1k_in: Some(0.00014),
960 cost_per_1k_out: Some(0.00028),
961 latency_p50_ms: Some(1800),
962 ..Default::default()
963 },
964 );
965
966 config.providers.insert(
968 "fireworks".to_string(),
969 ProviderDef {
970 base_url: "https://api.fireworks.ai/inference/v1".to_string(),
971 base_url_env: Some("FIREWORKS_BASE_URL".to_string()),
972 auth_style: "bearer".to_string(),
973 auth_env: AuthEnv::Single("FIREWORKS_API_KEY".to_string()),
974 chat_endpoint: "/chat/completions".to_string(),
975 completion_endpoint: Some("/completions".to_string()),
976 healthcheck: Some(HealthcheckDef {
977 method: "GET".to_string(),
978 path: Some("/models".to_string()),
979 url: None,
980 body: None,
981 }),
982 cost_per_1k_in: Some(0.0002),
983 cost_per_1k_out: Some(0.0006),
984 latency_p50_ms: Some(1400),
985 ..Default::default()
986 },
987 );
988
989 config.providers.insert(
991 "dashscope".to_string(),
992 ProviderDef {
993 base_url: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1".to_string(),
994 base_url_env: Some("DASHSCOPE_BASE_URL".to_string()),
995 auth_style: "bearer".to_string(),
996 auth_env: AuthEnv::Single("DASHSCOPE_API_KEY".to_string()),
997 chat_endpoint: "/chat/completions".to_string(),
998 completion_endpoint: Some("/completions".to_string()),
999 healthcheck: Some(HealthcheckDef {
1000 method: "GET".to_string(),
1001 path: Some("/models".to_string()),
1002 url: None,
1003 body: None,
1004 }),
1005 cost_per_1k_in: Some(0.0003),
1006 cost_per_1k_out: Some(0.0012),
1007 latency_p50_ms: Some(1600),
1008 ..Default::default()
1009 },
1010 );
1011
1012 config.providers.insert(
1014 "local".to_string(),
1015 ProviderDef {
1016 base_url: "http://localhost:8000".to_string(),
1017 base_url_env: Some("LOCAL_LLM_BASE_URL".to_string()),
1018 auth_style: "none".to_string(),
1019 chat_endpoint: "/v1/chat/completions".to_string(),
1020 completion_endpoint: Some("/v1/completions".to_string()),
1021 healthcheck: Some(HealthcheckDef {
1022 method: "GET".to_string(),
1023 path: Some("/v1/models".to_string()),
1024 url: None,
1025 body: None,
1026 }),
1027 cost_per_1k_in: Some(0.0),
1028 cost_per_1k_out: Some(0.0),
1029 latency_p50_ms: Some(900),
1030 ..Default::default()
1031 },
1032 );
1033
1034 config.providers.insert(
1038 "mlx".to_string(),
1039 ProviderDef {
1040 base_url: "http://127.0.0.1:8002".to_string(),
1041 base_url_env: Some("MLX_BASE_URL".to_string()),
1042 auth_style: "none".to_string(),
1043 chat_endpoint: "/v1/chat/completions".to_string(),
1044 completion_endpoint: Some("/v1/completions".to_string()),
1045 healthcheck: Some(HealthcheckDef {
1046 method: "GET".to_string(),
1047 path: Some("/v1/models".to_string()),
1048 url: None,
1049 body: None,
1050 }),
1051 cost_per_1k_in: Some(0.0),
1052 cost_per_1k_out: Some(0.0),
1053 latency_p50_ms: Some(900),
1054 ..Default::default()
1055 },
1056 );
1057
1058 config.providers.insert(
1060 "vllm".to_string(),
1061 ProviderDef {
1062 base_url: "http://localhost:8000".to_string(),
1063 base_url_env: Some("VLLM_BASE_URL".to_string()),
1064 auth_style: "none".to_string(),
1065 chat_endpoint: "/v1/chat/completions".to_string(),
1066 completion_endpoint: Some("/v1/completions".to_string()),
1067 healthcheck: Some(HealthcheckDef {
1068 method: "GET".to_string(),
1069 path: Some("/v1/models".to_string()),
1070 url: None,
1071 body: None,
1072 }),
1073 cost_per_1k_in: Some(0.0),
1074 cost_per_1k_out: Some(0.0),
1075 latency_p50_ms: Some(800),
1076 ..Default::default()
1077 },
1078 );
1079
1080 config.providers.insert(
1082 "tgi".to_string(),
1083 ProviderDef {
1084 base_url: "http://localhost:8080".to_string(),
1085 base_url_env: Some("TGI_BASE_URL".to_string()),
1086 auth_style: "none".to_string(),
1087 chat_endpoint: "/v1/chat/completions".to_string(),
1088 completion_endpoint: Some("/v1/completions".to_string()),
1089 healthcheck: Some(HealthcheckDef {
1090 method: "GET".to_string(),
1091 path: Some("/health".to_string()),
1092 url: None,
1093 body: None,
1094 }),
1095 cost_per_1k_in: Some(0.0),
1096 cost_per_1k_out: Some(0.0),
1097 latency_p50_ms: Some(950),
1098 ..Default::default()
1099 },
1100 );
1101
1102 config.inference_rules = vec![
1104 InferenceRule {
1105 pattern: Some("claude-*".to_string()),
1106 contains: None,
1107 exact: None,
1108 provider: "anthropic".to_string(),
1109 },
1110 InferenceRule {
1111 pattern: Some("gpt-*".to_string()),
1112 contains: None,
1113 exact: None,
1114 provider: "openai".to_string(),
1115 },
1116 InferenceRule {
1117 pattern: Some("o1*".to_string()),
1118 contains: None,
1119 exact: None,
1120 provider: "openai".to_string(),
1121 },
1122 InferenceRule {
1123 pattern: Some("o3*".to_string()),
1124 contains: None,
1125 exact: None,
1126 provider: "openai".to_string(),
1127 },
1128 InferenceRule {
1129 pattern: Some("local:*".to_string()),
1130 contains: None,
1131 exact: None,
1132 provider: "local".to_string(),
1133 },
1134 InferenceRule {
1135 pattern: None,
1136 contains: Some("/".to_string()),
1137 exact: None,
1138 provider: "openrouter".to_string(),
1139 },
1140 InferenceRule {
1141 pattern: None,
1142 contains: Some(":".to_string()),
1143 exact: None,
1144 provider: "ollama".to_string(),
1145 },
1146 ];
1147
1148 config.tier_rules = vec![
1150 TierRule {
1151 contains: Some("9b".to_string()),
1152 pattern: None,
1153 exact: None,
1154 tier: "small".to_string(),
1155 },
1156 TierRule {
1157 contains: Some("a3b".to_string()),
1158 pattern: None,
1159 exact: None,
1160 tier: "small".to_string(),
1161 },
1162 TierRule {
1163 contains: Some("gemma-4-e2b".to_string()),
1164 pattern: None,
1165 exact: None,
1166 tier: "small".to_string(),
1167 },
1168 TierRule {
1169 contains: Some("gemma-4-e4b".to_string()),
1170 pattern: None,
1171 exact: None,
1172 tier: "small".to_string(),
1173 },
1174 TierRule {
1175 contains: Some("gemma-4-26b".to_string()),
1176 pattern: None,
1177 exact: None,
1178 tier: "mid".to_string(),
1179 },
1180 TierRule {
1181 contains: Some("gemma-4-31b".to_string()),
1182 pattern: None,
1183 exact: None,
1184 tier: "frontier".to_string(),
1185 },
1186 TierRule {
1187 contains: Some("gemma4:26b".to_string()),
1188 pattern: None,
1189 exact: None,
1190 tier: "mid".to_string(),
1191 },
1192 TierRule {
1193 contains: Some("gemma4:31b".to_string()),
1194 pattern: None,
1195 exact: None,
1196 tier: "frontier".to_string(),
1197 },
1198 TierRule {
1199 pattern: Some("claude-*".to_string()),
1200 contains: None,
1201 exact: None,
1202 tier: "frontier".to_string(),
1203 },
1204 TierRule {
1205 exact: Some("gpt-4o".to_string()),
1206 contains: None,
1207 pattern: None,
1208 tier: "frontier".to_string(),
1209 },
1210 ];
1211
1212 config.tier_defaults = TierDefaults {
1213 default: "mid".to_string(),
1214 };
1215
1216 config.aliases.insert(
1217 "frontier".to_string(),
1218 AliasDef {
1219 id: "claude-sonnet-4-20250514".to_string(),
1220 provider: "anthropic".to_string(),
1221 tool_format: None,
1222 },
1223 );
1224 config.aliases.insert(
1225 "tier/frontier".to_string(),
1226 AliasDef {
1227 id: "claude-sonnet-4-20250514".to_string(),
1228 provider: "anthropic".to_string(),
1229 tool_format: None,
1230 },
1231 );
1232 config.aliases.insert(
1233 "mid".to_string(),
1234 AliasDef {
1235 id: "gpt-4o-mini".to_string(),
1236 provider: "openai".to_string(),
1237 tool_format: None,
1238 },
1239 );
1240 config.aliases.insert(
1241 "tier/mid".to_string(),
1242 AliasDef {
1243 id: "gpt-4o-mini".to_string(),
1244 provider: "openai".to_string(),
1245 tool_format: None,
1246 },
1247 );
1248 config.aliases.insert(
1249 "small".to_string(),
1250 AliasDef {
1251 id: "Qwen/Qwen3.5-9B".to_string(),
1252 provider: "openrouter".to_string(),
1253 tool_format: None,
1254 },
1255 );
1256 config.aliases.insert(
1257 "tier/small".to_string(),
1258 AliasDef {
1259 id: "Qwen/Qwen3.5-9B".to_string(),
1260 provider: "openrouter".to_string(),
1261 tool_format: None,
1262 },
1263 );
1264 config.aliases.insert(
1265 "local-gemma4".to_string(),
1266 AliasDef {
1267 id: "gemma-4-26b-a4b-it".to_string(),
1268 provider: "local".to_string(),
1269 tool_format: None,
1270 },
1271 );
1272 config.aliases.insert(
1273 "local-gemma4-26b".to_string(),
1274 AliasDef {
1275 id: "gemma-4-26b-a4b-it".to_string(),
1276 provider: "local".to_string(),
1277 tool_format: None,
1278 },
1279 );
1280 config.aliases.insert(
1281 "local-gemma4-31b".to_string(),
1282 AliasDef {
1283 id: "gemma-4-31b-it".to_string(),
1284 provider: "local".to_string(),
1285 tool_format: None,
1286 },
1287 );
1288 config.aliases.insert(
1289 "local-gemma4-e4b".to_string(),
1290 AliasDef {
1291 id: "gemma-4-e4b-it".to_string(),
1292 provider: "local".to_string(),
1293 tool_format: None,
1294 },
1295 );
1296 config.aliases.insert(
1297 "local-gemma4-e2b".to_string(),
1298 AliasDef {
1299 id: "gemma-4-e2b-it".to_string(),
1300 provider: "local".to_string(),
1301 tool_format: None,
1302 },
1303 );
1304 config.aliases.insert(
1305 "mlx-qwen36-27b".to_string(),
1306 AliasDef {
1307 id: "unsloth/Qwen3.6-27B-UD-MLX-4bit".to_string(),
1308 provider: "mlx".to_string(),
1309 tool_format: None,
1310 },
1311 );
1312
1313 config.qc_defaults.extend(BTreeMap::from([
1314 (
1315 "anthropic".to_string(),
1316 "claude-3-5-haiku-20241022".to_string(),
1317 ),
1318 ("openai".to_string(), "gpt-4o-mini".to_string()),
1319 (
1320 "openrouter".to_string(),
1321 "google/gemini-2.5-flash".to_string(),
1322 ),
1323 ("ollama".to_string(), "llama3.2".to_string()),
1324 ("local".to_string(), "gpt-4o".to_string()),
1325 ]));
1326
1327 config.models.extend(BTreeMap::from([
1328 (
1329 "claude-sonnet-4-20250514".to_string(),
1330 ModelDef {
1331 name: "Claude Sonnet 4".to_string(),
1332 provider: "anthropic".to_string(),
1333 context_window: 200_000,
1334 stream_timeout: None,
1335 capabilities: vec![
1336 "tools".to_string(),
1337 "streaming".to_string(),
1338 "prompt_caching".to_string(),
1339 "thinking".to_string(),
1340 ],
1341 pricing: Some(ModelPricing {
1342 input_per_mtok: 3.0,
1343 output_per_mtok: 15.0,
1344 cache_read_per_mtok: Some(0.3),
1345 cache_write_per_mtok: Some(3.75),
1346 }),
1347 },
1348 ),
1349 (
1350 "gpt-4o-mini".to_string(),
1351 ModelDef {
1352 name: "GPT-4o Mini".to_string(),
1353 provider: "openai".to_string(),
1354 context_window: 128_000,
1355 stream_timeout: None,
1356 capabilities: vec!["tools".to_string(), "streaming".to_string()],
1357 pricing: Some(ModelPricing {
1358 input_per_mtok: 0.15,
1359 output_per_mtok: 0.60,
1360 cache_read_per_mtok: None,
1361 cache_write_per_mtok: None,
1362 }),
1363 },
1364 ),
1365 (
1366 "Qwen/Qwen3.5-9B".to_string(),
1367 ModelDef {
1368 name: "Qwen3.5 9B".to_string(),
1369 provider: "openrouter".to_string(),
1370 context_window: 131_072,
1371 stream_timeout: None,
1372 capabilities: vec!["tools".to_string(), "streaming".to_string()],
1373 pricing: None,
1374 },
1375 ),
1376 (
1377 "llama3.2".to_string(),
1378 ModelDef {
1379 name: "Llama 3.2".to_string(),
1380 provider: "ollama".to_string(),
1381 context_window: 32_000,
1382 stream_timeout: Some(300.0),
1383 capabilities: vec!["tools".to_string(), "streaming".to_string()],
1384 pricing: None,
1385 },
1386 ),
1387 ]));
1388
1389 config
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394 use super::*;
1395
1396 fn reset_overrides() {
1397 clear_user_overrides();
1398 }
1399
1400 #[test]
1401 fn test_glob_match_prefix() {
1402 assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
1403 assert!(glob_match("gpt-*", "gpt-4o"));
1404 assert!(!glob_match("claude-*", "gpt-4o"));
1405 }
1406
1407 #[test]
1408 fn test_glob_match_suffix() {
1409 assert!(glob_match("*-latest", "llama3.2-latest"));
1410 assert!(!glob_match("*-latest", "llama3.2"));
1411 }
1412
1413 #[test]
1414 fn test_glob_match_middle() {
1415 assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
1416 assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
1417 }
1418
1419 #[test]
1420 fn test_glob_match_exact() {
1421 assert!(glob_match("gpt-4o", "gpt-4o"));
1422 assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
1423 }
1424
1425 #[test]
1426 fn test_infer_provider_from_defaults() {
1427 assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
1428 assert_eq!(infer_provider("gpt-4o"), "openai");
1429 assert_eq!(infer_provider("o1-preview"), "openai");
1430 assert_eq!(infer_provider("o3-mini"), "openai");
1431 assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
1432 assert_eq!(infer_provider("llama3.2:latest"), "ollama");
1433 assert_eq!(infer_provider("unknown-model"), "anthropic");
1434 }
1435
1436 #[test]
1437 fn test_infer_provider_local_prefix() {
1438 assert_eq!(infer_provider("local:gemma-4-e4b-it"), "local");
1441 assert_eq!(infer_provider("local:qwen2.5"), "local");
1442 assert_eq!(infer_provider("local:owner/model"), "local");
1444 }
1445
1446 #[test]
1447 fn test_resolve_model_info_normalizes_provider_prefixes() {
1448 let local = resolve_model_info("local:gemma-4-e4b-it");
1449 assert_eq!(local.id, "gemma-4-e4b-it");
1450 assert_eq!(local.provider, "local");
1451
1452 let ollama = resolve_model_info("ollama:qwen3:30b-a3b");
1453 assert_eq!(ollama.id, "qwen3:30b-a3b");
1454 assert_eq!(ollama.provider, "ollama");
1455
1456 let hf = resolve_model_info("hf:Qwen/Qwen3.6-35B-A3B");
1457 assert_eq!(hf.id, "Qwen/Qwen3.6-35B-A3B");
1458 assert_eq!(hf.provider, "huggingface");
1459 }
1460
1461 #[test]
1462 fn test_model_tier_from_defaults() {
1463 assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
1464 assert_eq!(model_tier("gpt-4o"), "frontier");
1465 assert_eq!(model_tier("Qwen3.5-9B"), "small");
1466 assert_eq!(model_tier("deepseek-v3"), "mid");
1467 }
1468
1469 #[test]
1470 fn test_resolve_model_unknown_alias() {
1471 let (id, provider) = resolve_model("gpt-4o");
1472 assert_eq!(id, "gpt-4o");
1473 assert!(provider.is_none());
1474 }
1475
1476 #[test]
1477 fn test_provider_names() {
1478 let names = provider_names();
1479 assert!(names.len() >= 7);
1480 assert!(names.contains(&"anthropic".to_string()));
1481 assert!(names.contains(&"together".to_string()));
1482 assert!(names.contains(&"local".to_string()));
1483 assert!(names.contains(&"mlx".to_string()));
1484 assert!(names.contains(&"openai".to_string()));
1485 assert!(names.contains(&"ollama".to_string()));
1486 }
1487
1488 #[test]
1489 fn test_resolve_tier_model_default_aliases() {
1490 let (model, provider) = resolve_tier_model("frontier", None).unwrap();
1491 assert_eq!(model, "claude-sonnet-4-20250514");
1492 assert_eq!(provider, "anthropic");
1493
1494 let (model, provider) = resolve_tier_model("small", None).unwrap();
1495 assert_eq!(model, "Qwen/Qwen3.5-9B");
1496 assert_eq!(provider, "openrouter");
1497 }
1498
1499 #[test]
1500 fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
1501 let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
1502 assert_eq!(model, "gpt-4o-mini");
1503 assert_eq!(provider, "openai");
1504 }
1505
1506 #[test]
1507 fn test_provider_config_anthropic() {
1508 let pdef = provider_config("anthropic").unwrap();
1509 assert_eq!(pdef.auth_style, "header");
1510 assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
1511 }
1512
1513 #[test]
1514 fn test_provider_config_mlx() {
1515 let pdef = provider_config("mlx").unwrap();
1516 assert_eq!(pdef.base_url, "http://127.0.0.1:8002");
1517 assert_eq!(pdef.base_url_env.as_deref(), Some("MLX_BASE_URL"));
1518 assert_eq!(
1519 pdef.healthcheck.unwrap().path.as_deref(),
1520 Some("/v1/models")
1521 );
1522
1523 let (model, provider) = resolve_model("mlx-qwen36-27b");
1524 assert_eq!(model, "unsloth/Qwen3.6-27B-UD-MLX-4bit");
1525 assert_eq!(provider.as_deref(), Some("mlx"));
1526 }
1527
1528 #[test]
1529 fn test_resolve_base_url_no_env() {
1530 let pdef = ProviderDef {
1531 base_url: "https://example.com".to_string(),
1532 ..Default::default()
1533 };
1534 assert_eq!(resolve_base_url(&pdef), "https://example.com");
1535 }
1536
1537 #[test]
1538 fn test_default_config_roundtrip() {
1539 let config = default_config();
1540 assert!(!config.providers.is_empty());
1541 assert!(!config.inference_rules.is_empty());
1542 assert!(!config.tier_rules.is_empty());
1543 assert_eq!(config.tier_defaults.default, "mid");
1544 }
1545
1546 #[test]
1547 fn test_model_params_empty() {
1548 let params = model_params("claude-sonnet-4-20250514");
1549 assert!(params.is_empty());
1550 }
1551
1552 #[test]
1553 fn test_user_overrides_add_provider_and_alias() {
1554 reset_overrides();
1555 let mut overlay = ProvidersConfig::default();
1556 overlay.providers.insert(
1557 "acme".to_string(),
1558 ProviderDef {
1559 base_url: "https://llm.acme.test/v1".to_string(),
1560 chat_endpoint: "/chat/completions".to_string(),
1561 ..Default::default()
1562 },
1563 );
1564 overlay.aliases.insert(
1565 "acme-fast".to_string(),
1566 AliasDef {
1567 id: "acme/model-fast".to_string(),
1568 provider: "acme".to_string(),
1569 tool_format: Some("native".to_string()),
1570 },
1571 );
1572 set_user_overrides(Some(overlay));
1573
1574 let (model, provider) = resolve_model("acme-fast");
1575 assert_eq!(model, "acme/model-fast");
1576 assert_eq!(provider.as_deref(), Some("acme"));
1577 assert!(provider_names().contains(&"acme".to_string()));
1578 assert_eq!(
1579 provider_config("acme").map(|provider| provider.base_url),
1580 Some("https://llm.acme.test/v1".to_string())
1581 );
1582
1583 reset_overrides();
1584 }
1585
1586 #[test]
1587 fn test_user_overrides_add_model_catalog_pricing_and_qc_defaults() {
1588 reset_overrides();
1589 let mut overlay = ProvidersConfig::default();
1590 overlay.models.insert(
1591 "acme/model-fast".to_string(),
1592 ModelDef {
1593 name: "Acme Fast".to_string(),
1594 provider: "acme".to_string(),
1595 context_window: 65_536,
1596 stream_timeout: Some(42.0),
1597 capabilities: vec!["tools".to_string(), "streaming".to_string()],
1598 pricing: Some(ModelPricing {
1599 input_per_mtok: 1.25,
1600 output_per_mtok: 2.5,
1601 cache_read_per_mtok: Some(0.25),
1602 cache_write_per_mtok: None,
1603 }),
1604 },
1605 );
1606 overlay
1607 .qc_defaults
1608 .insert("acme".to_string(), "acme/model-cheap".to_string());
1609 set_user_overrides(Some(overlay));
1610
1611 let entry = model_catalog_entry("acme/model-fast").expect("catalog entry");
1612 assert_eq!(entry.context_window, 65_536);
1613 assert_eq!(
1614 entry.pricing.as_ref().map(|pricing| pricing.input_per_mtok),
1615 Some(1.25)
1616 );
1617 assert_eq!(
1618 pricing_per_1k_for("acme", "acme/model-fast"),
1619 Some((0.00125, 0.0025))
1620 );
1621 assert_eq!(
1622 qc_default_model("acme").as_deref(),
1623 Some("acme/model-cheap")
1624 );
1625
1626 reset_overrides();
1627 }
1628
1629 #[test]
1630 fn test_user_overrides_prepend_inference_rules() {
1631 reset_overrides();
1632 let mut overlay = ProvidersConfig::default();
1633 overlay.inference_rules.push(InferenceRule {
1634 pattern: Some("internal-*".to_string()),
1635 contains: None,
1636 exact: None,
1637 provider: "openai".to_string(),
1638 });
1639 set_user_overrides(Some(overlay));
1640
1641 assert_eq!(infer_provider("internal-foo"), "openai");
1642
1643 reset_overrides();
1644 }
1645}