1use serde::Deserialize;
2use std::cell::RefCell;
3use std::collections::BTreeMap;
4use std::sync::OnceLock;
5
6static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
7static CONFIG_PATH: OnceLock<String> = OnceLock::new();
8
9thread_local! {
10 static USER_OVERRIDES: RefCell<Option<ProvidersConfig>> = const { RefCell::new(None) };
15}
16
17#[derive(Debug, Clone, Deserialize, Default)]
18pub struct ProvidersConfig {
19 #[serde(default)]
20 pub providers: BTreeMap<String, ProviderDef>,
21 #[serde(default)]
22 pub aliases: BTreeMap<String, AliasDef>,
23 #[serde(default)]
24 pub inference_rules: Vec<InferenceRule>,
25 #[serde(default)]
26 pub tier_rules: Vec<TierRule>,
27 #[serde(default)]
28 pub tier_defaults: TierDefaults,
29 #[serde(default)]
30 pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
31}
32
33impl ProvidersConfig {
34 pub fn is_empty(&self) -> bool {
35 self.providers.is_empty()
36 && self.aliases.is_empty()
37 && self.inference_rules.is_empty()
38 && self.tier_rules.is_empty()
39 && self.model_defaults.is_empty()
40 && self.tier_defaults.default == default_mid()
41 }
42
43 pub fn merge_from(&mut self, overlay: &ProvidersConfig) {
44 self.providers.extend(overlay.providers.clone());
45 self.aliases.extend(overlay.aliases.clone());
46
47 if !overlay.inference_rules.is_empty() {
48 let mut merged = overlay.inference_rules.clone();
49 merged.extend(self.inference_rules.clone());
50 self.inference_rules = merged;
51 }
52
53 if !overlay.tier_rules.is_empty() {
54 let mut merged = overlay.tier_rules.clone();
55 merged.extend(self.tier_rules.clone());
56 self.tier_rules = merged;
57 }
58
59 if overlay.tier_defaults.default != default_mid() {
60 self.tier_defaults = overlay.tier_defaults.clone();
61 }
62
63 for (pattern, defaults) in &overlay.model_defaults {
64 self.model_defaults
65 .entry(pattern.clone())
66 .or_default()
67 .extend(defaults.clone());
68 }
69 }
70}
71
72#[derive(Debug, Clone, Deserialize)]
73pub struct ProviderDef {
74 pub base_url: String,
75 #[serde(default)]
76 pub base_url_env: Option<String>,
77 #[serde(default = "default_bearer")]
78 pub auth_style: String,
79 #[serde(default)]
80 pub auth_header: Option<String>,
81 #[serde(default)]
82 pub auth_env: AuthEnv,
83 #[serde(default)]
84 pub extra_headers: BTreeMap<String, String>,
85 #[serde(default)]
86 pub chat_endpoint: String,
87 #[serde(default)]
88 pub completion_endpoint: Option<String>,
89 #[serde(default)]
90 pub healthcheck: Option<HealthcheckDef>,
91 #[serde(default)]
92 pub features: Vec<String>,
93 #[serde(default)]
95 pub fallback: Option<String>,
96 #[serde(default)]
98 pub retry_count: Option<u32>,
99 #[serde(default)]
101 pub retry_delay_ms: Option<u64>,
102 #[serde(default)]
104 pub rpm: Option<u32>,
105 #[serde(default)]
107 pub cost_per_1k_in: Option<f64>,
108 #[serde(default)]
110 pub cost_per_1k_out: Option<f64>,
111 #[serde(default)]
113 pub latency_p50_ms: Option<u64>,
114}
115
116impl Default for ProviderDef {
117 fn default() -> Self {
118 Self {
119 base_url: String::new(),
120 base_url_env: None,
121 auth_style: default_bearer(),
122 auth_header: None,
123 auth_env: AuthEnv::None,
124 extra_headers: BTreeMap::new(),
125 chat_endpoint: String::new(),
126 completion_endpoint: None,
127 healthcheck: None,
128 features: Vec::new(),
129 fallback: None,
130 retry_count: None,
131 retry_delay_ms: None,
132 rpm: None,
133 cost_per_1k_in: None,
134 cost_per_1k_out: None,
135 latency_p50_ms: None,
136 }
137 }
138}
139
140fn default_bearer() -> String {
141 "bearer".to_string()
142}
143
144#[derive(Debug, Clone, Deserialize, Default)]
147#[serde(untagged)]
148pub enum AuthEnv {
149 #[default]
150 None,
151 Single(String),
152 Multiple(Vec<String>),
153}
154
155#[derive(Debug, Clone, Deserialize)]
156pub struct HealthcheckDef {
157 pub method: String,
158 #[serde(default)]
159 pub path: Option<String>,
160 #[serde(default)]
161 pub url: Option<String>,
162 #[serde(default)]
163 pub body: Option<String>,
164}
165
166#[derive(Debug, Clone, Deserialize)]
167pub struct AliasDef {
168 pub id: String,
169 pub provider: String,
170 #[serde(default)]
175 pub tool_format: Option<String>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179pub struct InferenceRule {
180 #[serde(default)]
181 pub pattern: Option<String>,
182 #[serde(default)]
183 pub contains: Option<String>,
184 #[serde(default)]
185 pub exact: Option<String>,
186 pub provider: String,
187}
188
189#[derive(Debug, Clone, Deserialize)]
190pub struct TierRule {
191 #[serde(default)]
192 pub pattern: Option<String>,
193 #[serde(default)]
194 pub contains: Option<String>,
195 #[serde(default)]
196 pub exact: Option<String>,
197 pub tier: String,
198}
199
200#[derive(Debug, Clone, Deserialize)]
201pub struct TierDefaults {
202 #[serde(default = "default_mid")]
203 pub default: String,
204}
205
206impl Default for TierDefaults {
207 fn default() -> Self {
208 Self {
209 default: default_mid(),
210 }
211 }
212}
213
214fn default_mid() -> String {
215 "mid".to_string()
216}
217
218pub fn load_config() -> &'static ProvidersConfig {
220 CONFIG.get_or_init(|| {
221 let verbose_config_logging = matches!(
222 std::env::var("HARN_VERBOSE_CONFIG").ok().as_deref(),
223 Some("1" | "true" | "TRUE" | "yes" | "YES")
224 ) || matches!(
225 std::env::var("HARN_ACP_VERBOSE").ok().as_deref(),
226 Some("1" | "true" | "TRUE" | "yes" | "YES")
227 );
228 if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
229 match std::fs::read_to_string(&path) {
230 Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
231 Ok(config) => {
232 if verbose_config_logging {
233 eprintln!(
234 "[llm_config] Loaded {} providers, {} aliases from {}",
235 config.providers.len(),
236 config.aliases.len(),
237 path
238 );
239 }
240 let _ = CONFIG_PATH.set(path);
241 return config;
242 }
243 Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
244 },
245 Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
246 }
247 }
248 if let Some(home) = dirs_or_home() {
249 let path = format!("{home}/.config/harn/providers.toml");
250 if let Ok(content) = std::fs::read_to_string(&path) {
251 if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
252 let _ = CONFIG_PATH.set(path);
253 return config;
254 }
255 }
256 }
257 default_config()
258 })
259}
260
261pub fn loaded_config_path() -> Option<std::path::PathBuf> {
264 let _ = load_config();
266 CONFIG_PATH.get().map(std::path::PathBuf::from)
267}
268
269pub fn set_user_overrides(config: Option<ProvidersConfig>) {
273 USER_OVERRIDES.with(|cell| *cell.borrow_mut() = config);
274}
275
276pub fn clear_user_overrides() {
278 set_user_overrides(None);
279}
280
281fn effective_config() -> ProvidersConfig {
282 let mut merged = load_config().clone();
283 USER_OVERRIDES.with(|cell| {
284 if let Some(overlay) = cell.borrow().as_ref() {
285 merged.merge_from(overlay);
286 }
287 });
288 merged
289}
290
291pub fn resolve_model(alias: &str) -> (String, Option<String>) {
293 let config = effective_config();
294 if let Some(a) = config.aliases.get(alias) {
295 return (a.id.clone(), Some(a.provider.clone()));
296 }
297 (alias.to_string(), None)
298}
299
300pub fn infer_provider(model_id: &str) -> String {
302 let config = effective_config();
303 for rule in &config.inference_rules {
304 if let Some(exact) = &rule.exact {
305 if model_id == exact {
306 return rule.provider.clone();
307 }
308 }
309 if let Some(pattern) = &rule.pattern {
310 if glob_match(pattern, model_id) {
311 return rule.provider.clone();
312 }
313 }
314 if let Some(substr) = &rule.contains {
315 if model_id.contains(substr.as_str()) {
316 return rule.provider.clone();
317 }
318 }
319 }
320 if model_id.starts_with("local:") {
325 return "local".to_string();
326 }
327 if model_id.starts_with("claude-") {
328 return "anthropic".to_string();
329 }
330 if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
331 return "openai".to_string();
332 }
333 if model_id.contains('/') {
334 return "openrouter".to_string();
335 }
336 if model_id.contains(':') {
337 return "ollama".to_string();
338 }
339 "anthropic".to_string()
340}
341
342pub fn model_tier(model_id: &str) -> String {
344 let config = effective_config();
345 for rule in &config.tier_rules {
346 if let Some(exact) = &rule.exact {
347 if model_id == exact {
348 return rule.tier.clone();
349 }
350 }
351 if let Some(pattern) = &rule.pattern {
352 if glob_match(pattern, model_id) {
353 return rule.tier.clone();
354 }
355 }
356 if let Some(substr) = &rule.contains {
357 if model_id.contains(substr.as_str()) {
358 return rule.tier.clone();
359 }
360 }
361 }
362 let lower = model_id.to_lowercase();
363 if lower.contains("9b") || lower.contains("a3b") {
364 return "small".to_string();
365 }
366 if lower.starts_with("claude-") || lower == "gpt-4o" {
367 return "frontier".to_string();
368 }
369 config.tier_defaults.default.clone()
370}
371
372pub fn provider_config(name: &str) -> Option<ProviderDef> {
374 effective_config().providers.get(name).cloned()
375}
376
377pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
380 let config = effective_config();
381 let mut params = BTreeMap::new();
382 for (pattern, defaults) in &config.model_defaults {
383 if glob_match(pattern, model_id) {
384 for (k, v) in defaults {
385 params.insert(k.clone(), v.clone());
386 }
387 }
388 }
389 params
390}
391
392pub fn provider_names() -> Vec<String> {
394 effective_config().providers.keys().cloned().collect()
395}
396
397pub fn provider_has_feature(provider: &str, feature: &str) -> bool {
399 provider_config(provider)
400 .map(|p| p.features.iter().any(|f| f == feature))
401 .unwrap_or(false)
402}
403
404pub fn provider_economics(provider: &str) -> (Option<f64>, Option<f64>, Option<u64>) {
408 provider_config(provider)
409 .map(|p| (p.cost_per_1k_in, p.cost_per_1k_out, p.latency_p50_ms))
410 .unwrap_or((None, None, None))
411}
412
413pub fn default_tool_format(model: &str, provider: &str) -> String {
416 let config = effective_config();
417 for (name, alias) in &config.aliases {
419 let matches = (alias.id == model && alias.provider == provider) || name == model;
420 if matches {
421 if let Some(ref fmt) = alias.tool_format {
422 return fmt.clone();
423 }
424 }
425 }
426 if provider_has_feature(provider, "native_tools") {
427 "native".to_string()
428 } else {
429 "text".to_string()
430 }
431}
432
433pub fn resolve_tier_model(
435 target: &str,
436 preferred_provider: Option<&str>,
437) -> Option<(String, String)> {
438 let config = effective_config();
439
440 if let Some(alias) = config.aliases.get(target) {
441 return Some((alias.id.clone(), alias.provider.clone()));
442 }
443
444 let candidate_aliases = if let Some(provider) = preferred_provider {
445 vec![
446 format!("{provider}/{target}"),
447 format!("{provider}:{target}"),
448 format!("tier/{target}"),
449 target.to_string(),
450 ]
451 } else {
452 vec![format!("tier/{target}"), target.to_string()]
453 };
454
455 for alias_name in candidate_aliases {
456 if let Some(alias) = config.aliases.get(&alias_name) {
457 return Some((alias.id.clone(), alias.provider.clone()));
458 }
459 }
460
461 None
462}
463
464pub fn tier_candidates(target: &str) -> Vec<(String, String)> {
468 let config = effective_config();
469 let mut seen = std::collections::BTreeSet::new();
470 let mut candidates = Vec::new();
471
472 for alias in config.aliases.values() {
473 let pair = (alias.id.clone(), alias.provider.clone());
474 if seen.contains(&pair) {
475 continue;
476 }
477 if model_tier(&alias.id) == target {
478 seen.insert(pair.clone());
479 candidates.push(pair);
480 }
481 }
482
483 candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
484 provider_a
485 .cmp(provider_b)
486 .then_with(|| model_a.cmp(model_b))
487 });
488 candidates
489}
490
491pub fn all_model_candidates() -> Vec<(String, String)> {
494 let config = effective_config();
495 let mut seen = std::collections::BTreeSet::new();
496 let mut candidates = Vec::new();
497
498 for alias in config.aliases.values() {
499 let pair = (alias.id.clone(), alias.provider.clone());
500 if seen.insert(pair.clone()) {
501 candidates.push(pair);
502 }
503 }
504
505 candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
506 provider_a
507 .cmp(provider_b)
508 .then_with(|| model_a.cmp(model_b))
509 });
510 candidates
511}
512
513fn glob_match(pattern: &str, input: &str) -> bool {
515 if let Some(prefix) = pattern.strip_suffix('*') {
516 input.starts_with(prefix)
517 } else if let Some(suffix) = pattern.strip_prefix('*') {
518 input.ends_with(suffix)
519 } else if pattern.contains('*') {
520 let parts: Vec<&str> = pattern.split('*').collect();
521 if parts.len() == 2 {
522 input.starts_with(parts[0]) && input.ends_with(parts[1])
523 } else {
524 input == pattern
525 }
526 } else {
527 input == pattern
528 }
529}
530
531fn dirs_or_home() -> Option<String> {
532 std::env::var("HOME").ok()
533}
534
535pub fn resolve_base_url(pdef: &ProviderDef) -> String {
538 if let Some(env_name) = &pdef.base_url_env {
539 if let Ok(val) = std::env::var(env_name) {
540 let trimmed = val.trim().trim_matches('"').trim_matches('\'');
542 if !trimmed.is_empty() {
543 return trimmed.to_string();
544 }
545 }
546 }
547 pdef.base_url.clone()
548}
549
550fn default_config() -> ProvidersConfig {
551 let mut config = ProvidersConfig::default();
552
553 config.providers.insert(
554 "anthropic".to_string(),
555 ProviderDef {
556 base_url: "https://api.anthropic.com/v1".to_string(),
557 auth_style: "header".to_string(),
558 auth_header: Some("x-api-key".to_string()),
559 auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
560 extra_headers: BTreeMap::from([(
561 "anthropic-version".to_string(),
562 "2023-06-01".to_string(),
563 )]),
564 chat_endpoint: "/messages".to_string(),
565 completion_endpoint: None,
566 healthcheck: Some(HealthcheckDef {
567 method: "POST".to_string(),
568 path: Some("/messages/count_tokens".to_string()),
569 url: None,
570 body: Some(
571 r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
572 .to_string(),
573 ),
574 }),
575 features: vec!["prompt_caching".to_string(), "thinking".to_string()],
576 cost_per_1k_in: Some(0.003),
577 cost_per_1k_out: Some(0.015),
578 latency_p50_ms: Some(2500),
579 ..Default::default()
580 },
581 );
582
583 config.providers.insert(
585 "openai".to_string(),
586 ProviderDef {
587 base_url: "https://api.openai.com/v1".to_string(),
588 auth_style: "bearer".to_string(),
589 auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
590 chat_endpoint: "/chat/completions".to_string(),
591 completion_endpoint: Some("/completions".to_string()),
592 healthcheck: Some(HealthcheckDef {
593 method: "GET".to_string(),
594 path: Some("/models".to_string()),
595 url: None,
596 body: None,
597 }),
598 cost_per_1k_in: Some(0.0025),
599 cost_per_1k_out: Some(0.010),
600 latency_p50_ms: Some(1800),
601 ..Default::default()
602 },
603 );
604
605 config.providers.insert(
607 "openrouter".to_string(),
608 ProviderDef {
609 base_url: "https://openrouter.ai/api/v1".to_string(),
610 auth_style: "bearer".to_string(),
611 auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
612 chat_endpoint: "/chat/completions".to_string(),
613 completion_endpoint: Some("/completions".to_string()),
614 healthcheck: Some(HealthcheckDef {
615 method: "GET".to_string(),
616 path: Some("/auth/key".to_string()),
617 url: None,
618 body: None,
619 }),
620 cost_per_1k_in: Some(0.003),
621 cost_per_1k_out: Some(0.015),
622 latency_p50_ms: Some(2200),
623 ..Default::default()
624 },
625 );
626
627 config.providers.insert(
629 "huggingface".to_string(),
630 ProviderDef {
631 base_url: "https://router.huggingface.co/v1".to_string(),
632 auth_style: "bearer".to_string(),
633 auth_env: AuthEnv::Multiple(vec![
634 "HF_TOKEN".to_string(),
635 "HUGGINGFACE_API_KEY".to_string(),
636 ]),
637 chat_endpoint: "/chat/completions".to_string(),
638 completion_endpoint: Some("/completions".to_string()),
639 healthcheck: Some(HealthcheckDef {
640 method: "GET".to_string(),
641 url: Some("https://huggingface.co/api/whoami-v2".to_string()),
642 path: None,
643 body: None,
644 }),
645 cost_per_1k_in: Some(0.0002),
646 cost_per_1k_out: Some(0.0006),
647 latency_p50_ms: Some(2400),
648 ..Default::default()
649 },
650 );
651
652 config.providers.insert(
661 "ollama".to_string(),
662 ProviderDef {
663 base_url: "http://localhost:11434".to_string(),
664 base_url_env: Some("OLLAMA_HOST".to_string()),
665 auth_style: "none".to_string(),
666 chat_endpoint: "/api/chat".to_string(),
667 completion_endpoint: Some("/api/generate".to_string()),
668 healthcheck: Some(HealthcheckDef {
669 method: "GET".to_string(),
670 path: Some("/api/tags".to_string()),
671 url: None,
672 body: None,
673 }),
674 cost_per_1k_in: Some(0.0),
675 cost_per_1k_out: Some(0.0),
676 latency_p50_ms: Some(1200),
677 ..Default::default()
678 },
679 );
680
681 config.providers.insert(
683 "together".to_string(),
684 ProviderDef {
685 base_url: "https://api.together.xyz/v1".to_string(),
686 base_url_env: Some("TOGETHER_AI_BASE_URL".to_string()),
687 auth_style: "bearer".to_string(),
688 auth_env: AuthEnv::Single("TOGETHER_AI_API_KEY".to_string()),
689 chat_endpoint: "/chat/completions".to_string(),
690 completion_endpoint: Some("/completions".to_string()),
691 healthcheck: Some(HealthcheckDef {
692 method: "GET".to_string(),
693 path: Some("/models".to_string()),
694 url: None,
695 body: None,
696 }),
697 cost_per_1k_in: Some(0.0002),
698 cost_per_1k_out: Some(0.0006),
699 latency_p50_ms: Some(1600),
700 ..Default::default()
701 },
702 );
703
704 config.providers.insert(
706 "groq".to_string(),
707 ProviderDef {
708 base_url: "https://api.groq.com/openai/v1".to_string(),
709 base_url_env: Some("GROQ_BASE_URL".to_string()),
710 auth_style: "bearer".to_string(),
711 auth_env: AuthEnv::Single("GROQ_API_KEY".to_string()),
712 chat_endpoint: "/chat/completions".to_string(),
713 completion_endpoint: Some("/completions".to_string()),
714 healthcheck: Some(HealthcheckDef {
715 method: "GET".to_string(),
716 path: Some("/models".to_string()),
717 url: None,
718 body: None,
719 }),
720 cost_per_1k_in: Some(0.0001),
721 cost_per_1k_out: Some(0.0003),
722 latency_p50_ms: Some(450),
723 ..Default::default()
724 },
725 );
726
727 config.providers.insert(
729 "deepseek".to_string(),
730 ProviderDef {
731 base_url: "https://api.deepseek.com/v1".to_string(),
732 base_url_env: Some("DEEPSEEK_BASE_URL".to_string()),
733 auth_style: "bearer".to_string(),
734 auth_env: AuthEnv::Single("DEEPSEEK_API_KEY".to_string()),
735 chat_endpoint: "/chat/completions".to_string(),
736 completion_endpoint: Some("/completions".to_string()),
737 healthcheck: Some(HealthcheckDef {
738 method: "GET".to_string(),
739 path: Some("/models".to_string()),
740 url: None,
741 body: None,
742 }),
743 cost_per_1k_in: Some(0.00014),
744 cost_per_1k_out: Some(0.00028),
745 latency_p50_ms: Some(1800),
746 ..Default::default()
747 },
748 );
749
750 config.providers.insert(
752 "fireworks".to_string(),
753 ProviderDef {
754 base_url: "https://api.fireworks.ai/inference/v1".to_string(),
755 base_url_env: Some("FIREWORKS_BASE_URL".to_string()),
756 auth_style: "bearer".to_string(),
757 auth_env: AuthEnv::Single("FIREWORKS_API_KEY".to_string()),
758 chat_endpoint: "/chat/completions".to_string(),
759 completion_endpoint: Some("/completions".to_string()),
760 healthcheck: Some(HealthcheckDef {
761 method: "GET".to_string(),
762 path: Some("/models".to_string()),
763 url: None,
764 body: None,
765 }),
766 cost_per_1k_in: Some(0.0002),
767 cost_per_1k_out: Some(0.0006),
768 latency_p50_ms: Some(1400),
769 ..Default::default()
770 },
771 );
772
773 config.providers.insert(
775 "dashscope".to_string(),
776 ProviderDef {
777 base_url: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1".to_string(),
778 base_url_env: Some("DASHSCOPE_BASE_URL".to_string()),
779 auth_style: "bearer".to_string(),
780 auth_env: AuthEnv::Single("DASHSCOPE_API_KEY".to_string()),
781 chat_endpoint: "/chat/completions".to_string(),
782 completion_endpoint: Some("/completions".to_string()),
783 healthcheck: Some(HealthcheckDef {
784 method: "GET".to_string(),
785 path: Some("/models".to_string()),
786 url: None,
787 body: None,
788 }),
789 cost_per_1k_in: Some(0.0003),
790 cost_per_1k_out: Some(0.0012),
791 latency_p50_ms: Some(1600),
792 ..Default::default()
793 },
794 );
795
796 config.providers.insert(
798 "local".to_string(),
799 ProviderDef {
800 base_url: "http://localhost:8000".to_string(),
801 base_url_env: Some("LOCAL_LLM_BASE_URL".to_string()),
802 auth_style: "none".to_string(),
803 chat_endpoint: "/v1/chat/completions".to_string(),
804 completion_endpoint: Some("/v1/completions".to_string()),
805 healthcheck: Some(HealthcheckDef {
806 method: "GET".to_string(),
807 path: Some("/v1/models".to_string()),
808 url: None,
809 body: None,
810 }),
811 cost_per_1k_in: Some(0.0),
812 cost_per_1k_out: Some(0.0),
813 latency_p50_ms: Some(900),
814 ..Default::default()
815 },
816 );
817
818 config.providers.insert(
820 "vllm".to_string(),
821 ProviderDef {
822 base_url: "http://localhost:8000".to_string(),
823 base_url_env: Some("VLLM_BASE_URL".to_string()),
824 auth_style: "none".to_string(),
825 chat_endpoint: "/v1/chat/completions".to_string(),
826 completion_endpoint: Some("/v1/completions".to_string()),
827 healthcheck: Some(HealthcheckDef {
828 method: "GET".to_string(),
829 path: Some("/v1/models".to_string()),
830 url: None,
831 body: None,
832 }),
833 cost_per_1k_in: Some(0.0),
834 cost_per_1k_out: Some(0.0),
835 latency_p50_ms: Some(800),
836 ..Default::default()
837 },
838 );
839
840 config.providers.insert(
842 "tgi".to_string(),
843 ProviderDef {
844 base_url: "http://localhost:8080".to_string(),
845 base_url_env: Some("TGI_BASE_URL".to_string()),
846 auth_style: "none".to_string(),
847 chat_endpoint: "/v1/chat/completions".to_string(),
848 completion_endpoint: Some("/v1/completions".to_string()),
849 healthcheck: Some(HealthcheckDef {
850 method: "GET".to_string(),
851 path: Some("/health".to_string()),
852 url: None,
853 body: None,
854 }),
855 cost_per_1k_in: Some(0.0),
856 cost_per_1k_out: Some(0.0),
857 latency_p50_ms: Some(950),
858 ..Default::default()
859 },
860 );
861
862 config.inference_rules = vec![
864 InferenceRule {
865 pattern: Some("claude-*".to_string()),
866 contains: None,
867 exact: None,
868 provider: "anthropic".to_string(),
869 },
870 InferenceRule {
871 pattern: Some("gpt-*".to_string()),
872 contains: None,
873 exact: None,
874 provider: "openai".to_string(),
875 },
876 InferenceRule {
877 pattern: Some("o1*".to_string()),
878 contains: None,
879 exact: None,
880 provider: "openai".to_string(),
881 },
882 InferenceRule {
883 pattern: Some("o3*".to_string()),
884 contains: None,
885 exact: None,
886 provider: "openai".to_string(),
887 },
888 InferenceRule {
889 pattern: Some("local:*".to_string()),
890 contains: None,
891 exact: None,
892 provider: "local".to_string(),
893 },
894 InferenceRule {
895 pattern: None,
896 contains: Some("/".to_string()),
897 exact: None,
898 provider: "openrouter".to_string(),
899 },
900 InferenceRule {
901 pattern: None,
902 contains: Some(":".to_string()),
903 exact: None,
904 provider: "ollama".to_string(),
905 },
906 ];
907
908 config.tier_rules = vec![
910 TierRule {
911 contains: Some("9b".to_string()),
912 pattern: None,
913 exact: None,
914 tier: "small".to_string(),
915 },
916 TierRule {
917 contains: Some("a3b".to_string()),
918 pattern: None,
919 exact: None,
920 tier: "small".to_string(),
921 },
922 TierRule {
923 contains: Some("gemma-4-e2b".to_string()),
924 pattern: None,
925 exact: None,
926 tier: "small".to_string(),
927 },
928 TierRule {
929 contains: Some("gemma-4-e4b".to_string()),
930 pattern: None,
931 exact: None,
932 tier: "small".to_string(),
933 },
934 TierRule {
935 contains: Some("gemma-4-26b".to_string()),
936 pattern: None,
937 exact: None,
938 tier: "mid".to_string(),
939 },
940 TierRule {
941 contains: Some("gemma-4-31b".to_string()),
942 pattern: None,
943 exact: None,
944 tier: "frontier".to_string(),
945 },
946 TierRule {
947 contains: Some("gemma4:26b".to_string()),
948 pattern: None,
949 exact: None,
950 tier: "mid".to_string(),
951 },
952 TierRule {
953 contains: Some("gemma4:31b".to_string()),
954 pattern: None,
955 exact: None,
956 tier: "frontier".to_string(),
957 },
958 TierRule {
959 pattern: Some("claude-*".to_string()),
960 contains: None,
961 exact: None,
962 tier: "frontier".to_string(),
963 },
964 TierRule {
965 exact: Some("gpt-4o".to_string()),
966 contains: None,
967 pattern: None,
968 tier: "frontier".to_string(),
969 },
970 ];
971
972 config.tier_defaults = TierDefaults {
973 default: "mid".to_string(),
974 };
975
976 config.aliases.insert(
977 "frontier".to_string(),
978 AliasDef {
979 id: "claude-sonnet-4-20250514".to_string(),
980 provider: "anthropic".to_string(),
981 tool_format: None,
982 },
983 );
984 config.aliases.insert(
985 "tier/frontier".to_string(),
986 AliasDef {
987 id: "claude-sonnet-4-20250514".to_string(),
988 provider: "anthropic".to_string(),
989 tool_format: None,
990 },
991 );
992 config.aliases.insert(
993 "mid".to_string(),
994 AliasDef {
995 id: "gpt-4o-mini".to_string(),
996 provider: "openai".to_string(),
997 tool_format: None,
998 },
999 );
1000 config.aliases.insert(
1001 "tier/mid".to_string(),
1002 AliasDef {
1003 id: "gpt-4o-mini".to_string(),
1004 provider: "openai".to_string(),
1005 tool_format: None,
1006 },
1007 );
1008 config.aliases.insert(
1009 "small".to_string(),
1010 AliasDef {
1011 id: "Qwen/Qwen3.5-9B".to_string(),
1012 provider: "openrouter".to_string(),
1013 tool_format: None,
1014 },
1015 );
1016 config.aliases.insert(
1017 "tier/small".to_string(),
1018 AliasDef {
1019 id: "Qwen/Qwen3.5-9B".to_string(),
1020 provider: "openrouter".to_string(),
1021 tool_format: None,
1022 },
1023 );
1024 config.aliases.insert(
1025 "local-gemma4".to_string(),
1026 AliasDef {
1027 id: "gemma-4-26b-a4b-it".to_string(),
1028 provider: "local".to_string(),
1029 tool_format: None,
1030 },
1031 );
1032 config.aliases.insert(
1033 "local-gemma4-26b".to_string(),
1034 AliasDef {
1035 id: "gemma-4-26b-a4b-it".to_string(),
1036 provider: "local".to_string(),
1037 tool_format: None,
1038 },
1039 );
1040 config.aliases.insert(
1041 "local-gemma4-31b".to_string(),
1042 AliasDef {
1043 id: "gemma-4-31b-it".to_string(),
1044 provider: "local".to_string(),
1045 tool_format: None,
1046 },
1047 );
1048 config.aliases.insert(
1049 "local-gemma4-e4b".to_string(),
1050 AliasDef {
1051 id: "gemma-4-e4b-it".to_string(),
1052 provider: "local".to_string(),
1053 tool_format: None,
1054 },
1055 );
1056 config.aliases.insert(
1057 "local-gemma4-e2b".to_string(),
1058 AliasDef {
1059 id: "gemma-4-e2b-it".to_string(),
1060 provider: "local".to_string(),
1061 tool_format: None,
1062 },
1063 );
1064
1065 config
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070 use super::*;
1071
1072 fn reset_overrides() {
1073 clear_user_overrides();
1074 }
1075
1076 #[test]
1077 fn test_glob_match_prefix() {
1078 assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
1079 assert!(glob_match("gpt-*", "gpt-4o"));
1080 assert!(!glob_match("claude-*", "gpt-4o"));
1081 }
1082
1083 #[test]
1084 fn test_glob_match_suffix() {
1085 assert!(glob_match("*-latest", "llama3.2-latest"));
1086 assert!(!glob_match("*-latest", "llama3.2"));
1087 }
1088
1089 #[test]
1090 fn test_glob_match_middle() {
1091 assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
1092 assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
1093 }
1094
1095 #[test]
1096 fn test_glob_match_exact() {
1097 assert!(glob_match("gpt-4o", "gpt-4o"));
1098 assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
1099 }
1100
1101 #[test]
1102 fn test_infer_provider_from_defaults() {
1103 assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
1104 assert_eq!(infer_provider("gpt-4o"), "openai");
1105 assert_eq!(infer_provider("o1-preview"), "openai");
1106 assert_eq!(infer_provider("o3-mini"), "openai");
1107 assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
1108 assert_eq!(infer_provider("llama3.2:latest"), "ollama");
1109 assert_eq!(infer_provider("unknown-model"), "anthropic");
1110 }
1111
1112 #[test]
1113 fn test_infer_provider_local_prefix() {
1114 assert_eq!(infer_provider("local:gemma-4-e4b-it"), "local");
1117 assert_eq!(infer_provider("local:qwen2.5"), "local");
1118 assert_eq!(infer_provider("local:owner/model"), "local");
1120 }
1121
1122 #[test]
1123 fn test_model_tier_from_defaults() {
1124 assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
1125 assert_eq!(model_tier("gpt-4o"), "frontier");
1126 assert_eq!(model_tier("Qwen3.5-9B"), "small");
1127 assert_eq!(model_tier("deepseek-v3"), "mid");
1128 }
1129
1130 #[test]
1131 fn test_resolve_model_unknown_alias() {
1132 let (id, provider) = resolve_model("gpt-4o");
1133 assert_eq!(id, "gpt-4o");
1134 assert!(provider.is_none());
1135 }
1136
1137 #[test]
1138 fn test_provider_names() {
1139 let names = provider_names();
1140 assert!(names.len() >= 7);
1141 assert!(names.contains(&"anthropic".to_string()));
1142 assert!(names.contains(&"together".to_string()));
1143 assert!(names.contains(&"local".to_string()));
1144 assert!(names.contains(&"openai".to_string()));
1145 assert!(names.contains(&"ollama".to_string()));
1146 }
1147
1148 #[test]
1149 fn test_resolve_tier_model_default_aliases() {
1150 let (model, provider) = resolve_tier_model("frontier", None).unwrap();
1151 assert_eq!(model, "claude-sonnet-4-20250514");
1152 assert_eq!(provider, "anthropic");
1153
1154 let (model, provider) = resolve_tier_model("small", None).unwrap();
1155 assert_eq!(model, "Qwen/Qwen3.5-9B");
1156 assert_eq!(provider, "openrouter");
1157 }
1158
1159 #[test]
1160 fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
1161 let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
1162 assert_eq!(model, "gpt-4o-mini");
1163 assert_eq!(provider, "openai");
1164 }
1165
1166 #[test]
1167 fn test_provider_config_anthropic() {
1168 let pdef = provider_config("anthropic").unwrap();
1169 assert_eq!(pdef.auth_style, "header");
1170 assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
1171 }
1172
1173 #[test]
1174 fn test_resolve_base_url_no_env() {
1175 let pdef = ProviderDef {
1176 base_url: "https://example.com".to_string(),
1177 ..Default::default()
1178 };
1179 assert_eq!(resolve_base_url(&pdef), "https://example.com");
1180 }
1181
1182 #[test]
1183 fn test_default_config_roundtrip() {
1184 let config = default_config();
1185 assert!(!config.providers.is_empty());
1186 assert!(!config.inference_rules.is_empty());
1187 assert!(!config.tier_rules.is_empty());
1188 assert_eq!(config.tier_defaults.default, "mid");
1189 }
1190
1191 #[test]
1192 fn test_model_params_empty() {
1193 let params = model_params("claude-sonnet-4-20250514");
1194 assert!(params.is_empty());
1195 }
1196
1197 #[test]
1198 fn test_user_overrides_add_provider_and_alias() {
1199 reset_overrides();
1200 let mut overlay = ProvidersConfig::default();
1201 overlay.providers.insert(
1202 "acme".to_string(),
1203 ProviderDef {
1204 base_url: "https://llm.acme.test/v1".to_string(),
1205 chat_endpoint: "/chat/completions".to_string(),
1206 ..Default::default()
1207 },
1208 );
1209 overlay.aliases.insert(
1210 "acme-fast".to_string(),
1211 AliasDef {
1212 id: "acme/model-fast".to_string(),
1213 provider: "acme".to_string(),
1214 tool_format: Some("native".to_string()),
1215 },
1216 );
1217 set_user_overrides(Some(overlay));
1218
1219 let (model, provider) = resolve_model("acme-fast");
1220 assert_eq!(model, "acme/model-fast");
1221 assert_eq!(provider.as_deref(), Some("acme"));
1222 assert!(provider_names().contains(&"acme".to_string()));
1223 assert_eq!(
1224 provider_config("acme").map(|provider| provider.base_url),
1225 Some("https://llm.acme.test/v1".to_string())
1226 );
1227
1228 reset_overrides();
1229 }
1230
1231 #[test]
1232 fn test_user_overrides_prepend_inference_rules() {
1233 reset_overrides();
1234 let mut overlay = ProvidersConfig::default();
1235 overlay.inference_rules.push(InferenceRule {
1236 pattern: Some("internal-*".to_string()),
1237 contains: None,
1238 exact: None,
1239 provider: "openai".to_string(),
1240 });
1241 set_user_overrides(Some(overlay));
1242
1243 assert_eq!(infer_provider("internal-foo"), "openai");
1244
1245 reset_overrides();
1246 }
1247}