1use serde::Deserialize;
2use std::cell::RefCell;
3use std::collections::BTreeMap;
4use std::sync::OnceLock;
5
6static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
7static CONFIG_PATH: OnceLock<String> = OnceLock::new();
8
9thread_local! {
10 static USER_OVERRIDES: RefCell<Option<ProvidersConfig>> = const { RefCell::new(None) };
15}
16
17#[derive(Debug, Clone, Deserialize, Default)]
18pub struct ProvidersConfig {
19 #[serde(default)]
20 pub providers: BTreeMap<String, ProviderDef>,
21 #[serde(default)]
22 pub aliases: BTreeMap<String, AliasDef>,
23 #[serde(default)]
24 pub inference_rules: Vec<InferenceRule>,
25 #[serde(default)]
26 pub tier_rules: Vec<TierRule>,
27 #[serde(default)]
28 pub tier_defaults: TierDefaults,
29 #[serde(default)]
30 pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
31}
32
33impl ProvidersConfig {
34 pub fn is_empty(&self) -> bool {
35 self.providers.is_empty()
36 && self.aliases.is_empty()
37 && self.inference_rules.is_empty()
38 && self.tier_rules.is_empty()
39 && self.model_defaults.is_empty()
40 && self.tier_defaults.default == default_mid()
41 }
42
43 pub fn merge_from(&mut self, overlay: &ProvidersConfig) {
44 self.providers.extend(overlay.providers.clone());
45 self.aliases.extend(overlay.aliases.clone());
46
47 if !overlay.inference_rules.is_empty() {
48 let mut merged = overlay.inference_rules.clone();
49 merged.extend(self.inference_rules.clone());
50 self.inference_rules = merged;
51 }
52
53 if !overlay.tier_rules.is_empty() {
54 let mut merged = overlay.tier_rules.clone();
55 merged.extend(self.tier_rules.clone());
56 self.tier_rules = merged;
57 }
58
59 if overlay.tier_defaults.default != default_mid() {
60 self.tier_defaults = overlay.tier_defaults.clone();
61 }
62
63 for (pattern, defaults) in &overlay.model_defaults {
64 self.model_defaults
65 .entry(pattern.clone())
66 .or_default()
67 .extend(defaults.clone());
68 }
69 }
70}
71
72#[derive(Debug, Clone, Deserialize)]
73pub struct ProviderDef {
74 pub base_url: String,
75 #[serde(default)]
76 pub base_url_env: Option<String>,
77 #[serde(default = "default_bearer")]
78 pub auth_style: String,
79 #[serde(default)]
80 pub auth_header: Option<String>,
81 #[serde(default)]
82 pub auth_env: AuthEnv,
83 #[serde(default)]
84 pub extra_headers: BTreeMap<String, String>,
85 #[serde(default)]
86 pub chat_endpoint: String,
87 #[serde(default)]
88 pub completion_endpoint: Option<String>,
89 #[serde(default)]
90 pub healthcheck: Option<HealthcheckDef>,
91 #[serde(default)]
92 pub features: Vec<String>,
93 #[serde(default)]
95 pub fallback: Option<String>,
96 #[serde(default)]
98 pub retry_count: Option<u32>,
99 #[serde(default)]
101 pub retry_delay_ms: Option<u64>,
102 #[serde(default)]
104 pub rpm: Option<u32>,
105}
106
107impl Default for ProviderDef {
108 fn default() -> Self {
109 Self {
110 base_url: String::new(),
111 base_url_env: None,
112 auth_style: default_bearer(),
113 auth_header: None,
114 auth_env: AuthEnv::None,
115 extra_headers: BTreeMap::new(),
116 chat_endpoint: String::new(),
117 completion_endpoint: None,
118 healthcheck: None,
119 features: Vec::new(),
120 fallback: None,
121 retry_count: None,
122 retry_delay_ms: None,
123 rpm: None,
124 }
125 }
126}
127
128fn default_bearer() -> String {
129 "bearer".to_string()
130}
131
132#[derive(Debug, Clone, Deserialize, Default)]
135#[serde(untagged)]
136pub enum AuthEnv {
137 #[default]
138 None,
139 Single(String),
140 Multiple(Vec<String>),
141}
142
143#[derive(Debug, Clone, Deserialize)]
144pub struct HealthcheckDef {
145 pub method: String,
146 #[serde(default)]
147 pub path: Option<String>,
148 #[serde(default)]
149 pub url: Option<String>,
150 #[serde(default)]
151 pub body: Option<String>,
152}
153
154#[derive(Debug, Clone, Deserialize)]
155pub struct AliasDef {
156 pub id: String,
157 pub provider: String,
158 #[serde(default)]
163 pub tool_format: Option<String>,
164}
165
166#[derive(Debug, Clone, Deserialize)]
167pub struct InferenceRule {
168 #[serde(default)]
169 pub pattern: Option<String>,
170 #[serde(default)]
171 pub contains: Option<String>,
172 #[serde(default)]
173 pub exact: Option<String>,
174 pub provider: String,
175}
176
177#[derive(Debug, Clone, Deserialize)]
178pub struct TierRule {
179 #[serde(default)]
180 pub pattern: Option<String>,
181 #[serde(default)]
182 pub contains: Option<String>,
183 #[serde(default)]
184 pub exact: Option<String>,
185 pub tier: String,
186}
187
188#[derive(Debug, Clone, Deserialize)]
189pub struct TierDefaults {
190 #[serde(default = "default_mid")]
191 pub default: String,
192}
193
194impl Default for TierDefaults {
195 fn default() -> Self {
196 Self {
197 default: default_mid(),
198 }
199 }
200}
201
202fn default_mid() -> String {
203 "mid".to_string()
204}
205
206pub fn load_config() -> &'static ProvidersConfig {
208 CONFIG.get_or_init(|| {
209 let verbose_config_logging = matches!(
210 std::env::var("HARN_VERBOSE_CONFIG").ok().as_deref(),
211 Some("1" | "true" | "TRUE" | "yes" | "YES")
212 ) || matches!(
213 std::env::var("HARN_ACP_VERBOSE").ok().as_deref(),
214 Some("1" | "true" | "TRUE" | "yes" | "YES")
215 );
216 if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
217 match std::fs::read_to_string(&path) {
218 Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
219 Ok(config) => {
220 if verbose_config_logging {
221 eprintln!(
222 "[llm_config] Loaded {} providers, {} aliases from {}",
223 config.providers.len(),
224 config.aliases.len(),
225 path
226 );
227 }
228 let _ = CONFIG_PATH.set(path);
229 return config;
230 }
231 Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
232 },
233 Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
234 }
235 }
236 if let Some(home) = dirs_or_home() {
237 let path = format!("{home}/.config/harn/providers.toml");
238 if let Ok(content) = std::fs::read_to_string(&path) {
239 if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
240 let _ = CONFIG_PATH.set(path);
241 return config;
242 }
243 }
244 }
245 default_config()
246 })
247}
248
249pub fn loaded_config_path() -> Option<std::path::PathBuf> {
252 let _ = load_config();
254 CONFIG_PATH.get().map(std::path::PathBuf::from)
255}
256
257pub fn set_user_overrides(config: Option<ProvidersConfig>) {
261 USER_OVERRIDES.with(|cell| *cell.borrow_mut() = config);
262}
263
264pub fn clear_user_overrides() {
266 set_user_overrides(None);
267}
268
269fn effective_config() -> ProvidersConfig {
270 let mut merged = load_config().clone();
271 USER_OVERRIDES.with(|cell| {
272 if let Some(overlay) = cell.borrow().as_ref() {
273 merged.merge_from(overlay);
274 }
275 });
276 merged
277}
278
279pub fn resolve_model(alias: &str) -> (String, Option<String>) {
281 let config = effective_config();
282 if let Some(a) = config.aliases.get(alias) {
283 return (a.id.clone(), Some(a.provider.clone()));
284 }
285 (alias.to_string(), None)
286}
287
288pub fn infer_provider(model_id: &str) -> String {
290 let config = effective_config();
291 for rule in &config.inference_rules {
292 if let Some(exact) = &rule.exact {
293 if model_id == exact {
294 return rule.provider.clone();
295 }
296 }
297 if let Some(pattern) = &rule.pattern {
298 if glob_match(pattern, model_id) {
299 return rule.provider.clone();
300 }
301 }
302 if let Some(substr) = &rule.contains {
303 if model_id.contains(substr.as_str()) {
304 return rule.provider.clone();
305 }
306 }
307 }
308 if model_id.starts_with("local:") {
313 return "local".to_string();
314 }
315 if model_id.starts_with("claude-") {
316 return "anthropic".to_string();
317 }
318 if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
319 return "openai".to_string();
320 }
321 if model_id.contains('/') {
322 return "openrouter".to_string();
323 }
324 if model_id.contains(':') {
325 return "ollama".to_string();
326 }
327 "anthropic".to_string()
328}
329
330pub fn model_tier(model_id: &str) -> String {
332 let config = effective_config();
333 for rule in &config.tier_rules {
334 if let Some(exact) = &rule.exact {
335 if model_id == exact {
336 return rule.tier.clone();
337 }
338 }
339 if let Some(pattern) = &rule.pattern {
340 if glob_match(pattern, model_id) {
341 return rule.tier.clone();
342 }
343 }
344 if let Some(substr) = &rule.contains {
345 if model_id.contains(substr.as_str()) {
346 return rule.tier.clone();
347 }
348 }
349 }
350 let lower = model_id.to_lowercase();
351 if lower.contains("9b") || lower.contains("a3b") {
352 return "small".to_string();
353 }
354 if lower.starts_with("claude-") || lower == "gpt-4o" {
355 return "frontier".to_string();
356 }
357 config.tier_defaults.default.clone()
358}
359
360pub fn provider_config(name: &str) -> Option<ProviderDef> {
362 effective_config().providers.get(name).cloned()
363}
364
365pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
368 let config = effective_config();
369 let mut params = BTreeMap::new();
370 for (pattern, defaults) in &config.model_defaults {
371 if glob_match(pattern, model_id) {
372 for (k, v) in defaults {
373 params.insert(k.clone(), v.clone());
374 }
375 }
376 }
377 params
378}
379
380pub fn provider_names() -> Vec<String> {
382 effective_config().providers.keys().cloned().collect()
383}
384
385pub fn provider_has_feature(provider: &str, feature: &str) -> bool {
387 provider_config(provider)
388 .map(|p| p.features.iter().any(|f| f == feature))
389 .unwrap_or(false)
390}
391
392pub fn default_tool_format(model: &str, provider: &str) -> String {
395 let config = effective_config();
396 for (name, alias) in &config.aliases {
398 let matches = (alias.id == model && alias.provider == provider) || name == model;
399 if matches {
400 if let Some(ref fmt) = alias.tool_format {
401 return fmt.clone();
402 }
403 }
404 }
405 if provider_has_feature(provider, "native_tools") {
406 "native".to_string()
407 } else {
408 "text".to_string()
409 }
410}
411
412pub fn resolve_tier_model(
414 target: &str,
415 preferred_provider: Option<&str>,
416) -> Option<(String, String)> {
417 let config = effective_config();
418
419 if let Some(alias) = config.aliases.get(target) {
420 return Some((alias.id.clone(), alias.provider.clone()));
421 }
422
423 let candidate_aliases = if let Some(provider) = preferred_provider {
424 vec![
425 format!("{provider}/{target}"),
426 format!("{provider}:{target}"),
427 format!("tier/{target}"),
428 target.to_string(),
429 ]
430 } else {
431 vec![format!("tier/{target}"), target.to_string()]
432 };
433
434 for alias_name in candidate_aliases {
435 if let Some(alias) = config.aliases.get(&alias_name) {
436 return Some((alias.id.clone(), alias.provider.clone()));
437 }
438 }
439
440 None
441}
442
443pub fn tier_candidates(target: &str) -> Vec<(String, String)> {
447 let config = effective_config();
448 let mut seen = std::collections::BTreeSet::new();
449 let mut candidates = Vec::new();
450
451 for alias in config.aliases.values() {
452 let pair = (alias.id.clone(), alias.provider.clone());
453 if seen.contains(&pair) {
454 continue;
455 }
456 if model_tier(&alias.id) == target {
457 seen.insert(pair.clone());
458 candidates.push(pair);
459 }
460 }
461
462 candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
463 provider_a
464 .cmp(provider_b)
465 .then_with(|| model_a.cmp(model_b))
466 });
467 candidates
468}
469
470fn glob_match(pattern: &str, input: &str) -> bool {
472 if let Some(prefix) = pattern.strip_suffix('*') {
473 input.starts_with(prefix)
474 } else if let Some(suffix) = pattern.strip_prefix('*') {
475 input.ends_with(suffix)
476 } else if pattern.contains('*') {
477 let parts: Vec<&str> = pattern.split('*').collect();
478 if parts.len() == 2 {
479 input.starts_with(parts[0]) && input.ends_with(parts[1])
480 } else {
481 input == pattern
482 }
483 } else {
484 input == pattern
485 }
486}
487
488fn dirs_or_home() -> Option<String> {
489 std::env::var("HOME").ok()
490}
491
492pub fn resolve_base_url(pdef: &ProviderDef) -> String {
495 if let Some(env_name) = &pdef.base_url_env {
496 if let Ok(val) = std::env::var(env_name) {
497 let trimmed = val.trim().trim_matches('"').trim_matches('\'');
499 if !trimmed.is_empty() {
500 return trimmed.to_string();
501 }
502 }
503 }
504 pdef.base_url.clone()
505}
506
507fn default_config() -> ProvidersConfig {
508 let mut config = ProvidersConfig::default();
509
510 config.providers.insert(
511 "anthropic".to_string(),
512 ProviderDef {
513 base_url: "https://api.anthropic.com/v1".to_string(),
514 auth_style: "header".to_string(),
515 auth_header: Some("x-api-key".to_string()),
516 auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
517 extra_headers: BTreeMap::from([(
518 "anthropic-version".to_string(),
519 "2023-06-01".to_string(),
520 )]),
521 chat_endpoint: "/messages".to_string(),
522 completion_endpoint: None,
523 healthcheck: Some(HealthcheckDef {
524 method: "POST".to_string(),
525 path: Some("/messages/count_tokens".to_string()),
526 url: None,
527 body: Some(
528 r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
529 .to_string(),
530 ),
531 }),
532 features: vec!["prompt_caching".to_string(), "thinking".to_string()],
533 ..Default::default()
534 },
535 );
536
537 config.providers.insert(
539 "openai".to_string(),
540 ProviderDef {
541 base_url: "https://api.openai.com/v1".to_string(),
542 auth_style: "bearer".to_string(),
543 auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
544 chat_endpoint: "/chat/completions".to_string(),
545 completion_endpoint: Some("/completions".to_string()),
546 healthcheck: Some(HealthcheckDef {
547 method: "GET".to_string(),
548 path: Some("/models".to_string()),
549 url: None,
550 body: None,
551 }),
552 ..Default::default()
553 },
554 );
555
556 config.providers.insert(
558 "openrouter".to_string(),
559 ProviderDef {
560 base_url: "https://openrouter.ai/api/v1".to_string(),
561 auth_style: "bearer".to_string(),
562 auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
563 chat_endpoint: "/chat/completions".to_string(),
564 completion_endpoint: Some("/completions".to_string()),
565 healthcheck: Some(HealthcheckDef {
566 method: "GET".to_string(),
567 path: Some("/auth/key".to_string()),
568 url: None,
569 body: None,
570 }),
571 ..Default::default()
572 },
573 );
574
575 config.providers.insert(
577 "huggingface".to_string(),
578 ProviderDef {
579 base_url: "https://router.huggingface.co/v1".to_string(),
580 auth_style: "bearer".to_string(),
581 auth_env: AuthEnv::Multiple(vec![
582 "HF_TOKEN".to_string(),
583 "HUGGINGFACE_API_KEY".to_string(),
584 ]),
585 chat_endpoint: "/chat/completions".to_string(),
586 completion_endpoint: Some("/completions".to_string()),
587 healthcheck: Some(HealthcheckDef {
588 method: "GET".to_string(),
589 url: Some("https://huggingface.co/api/whoami-v2".to_string()),
590 path: None,
591 body: None,
592 }),
593 ..Default::default()
594 },
595 );
596
597 config.providers.insert(
606 "ollama".to_string(),
607 ProviderDef {
608 base_url: "http://localhost:11434".to_string(),
609 base_url_env: Some("OLLAMA_HOST".to_string()),
610 auth_style: "none".to_string(),
611 chat_endpoint: "/api/chat".to_string(),
612 completion_endpoint: Some("/api/generate".to_string()),
613 healthcheck: Some(HealthcheckDef {
614 method: "GET".to_string(),
615 path: Some("/api/tags".to_string()),
616 url: None,
617 body: None,
618 }),
619 ..Default::default()
620 },
621 );
622
623 config.providers.insert(
625 "together".to_string(),
626 ProviderDef {
627 base_url: "https://api.together.xyz/v1".to_string(),
628 base_url_env: Some("TOGETHER_AI_BASE_URL".to_string()),
629 auth_style: "bearer".to_string(),
630 auth_env: AuthEnv::Single("TOGETHER_AI_API_KEY".to_string()),
631 chat_endpoint: "/chat/completions".to_string(),
632 completion_endpoint: Some("/completions".to_string()),
633 healthcheck: Some(HealthcheckDef {
634 method: "GET".to_string(),
635 path: Some("/models".to_string()),
636 url: None,
637 body: None,
638 }),
639 ..Default::default()
640 },
641 );
642
643 config.providers.insert(
645 "local".to_string(),
646 ProviderDef {
647 base_url: "http://localhost:8000".to_string(),
648 base_url_env: Some("LOCAL_LLM_BASE_URL".to_string()),
649 auth_style: "none".to_string(),
650 chat_endpoint: "/v1/chat/completions".to_string(),
651 completion_endpoint: Some("/v1/completions".to_string()),
652 healthcheck: Some(HealthcheckDef {
653 method: "GET".to_string(),
654 path: Some("/v1/models".to_string()),
655 url: None,
656 body: None,
657 }),
658 ..Default::default()
659 },
660 );
661
662 config.inference_rules = vec![
664 InferenceRule {
665 pattern: Some("claude-*".to_string()),
666 contains: None,
667 exact: None,
668 provider: "anthropic".to_string(),
669 },
670 InferenceRule {
671 pattern: Some("gpt-*".to_string()),
672 contains: None,
673 exact: None,
674 provider: "openai".to_string(),
675 },
676 InferenceRule {
677 pattern: Some("o1*".to_string()),
678 contains: None,
679 exact: None,
680 provider: "openai".to_string(),
681 },
682 InferenceRule {
683 pattern: Some("o3*".to_string()),
684 contains: None,
685 exact: None,
686 provider: "openai".to_string(),
687 },
688 InferenceRule {
689 pattern: Some("local:*".to_string()),
690 contains: None,
691 exact: None,
692 provider: "local".to_string(),
693 },
694 InferenceRule {
695 pattern: None,
696 contains: Some("/".to_string()),
697 exact: None,
698 provider: "openrouter".to_string(),
699 },
700 InferenceRule {
701 pattern: None,
702 contains: Some(":".to_string()),
703 exact: None,
704 provider: "ollama".to_string(),
705 },
706 ];
707
708 config.tier_rules = vec![
710 TierRule {
711 contains: Some("9b".to_string()),
712 pattern: None,
713 exact: None,
714 tier: "small".to_string(),
715 },
716 TierRule {
717 contains: Some("a3b".to_string()),
718 pattern: None,
719 exact: None,
720 tier: "small".to_string(),
721 },
722 TierRule {
723 contains: Some("gemma-4-e2b".to_string()),
724 pattern: None,
725 exact: None,
726 tier: "small".to_string(),
727 },
728 TierRule {
729 contains: Some("gemma-4-e4b".to_string()),
730 pattern: None,
731 exact: None,
732 tier: "small".to_string(),
733 },
734 TierRule {
735 contains: Some("gemma-4-26b".to_string()),
736 pattern: None,
737 exact: None,
738 tier: "mid".to_string(),
739 },
740 TierRule {
741 contains: Some("gemma-4-31b".to_string()),
742 pattern: None,
743 exact: None,
744 tier: "frontier".to_string(),
745 },
746 TierRule {
747 contains: Some("gemma4:26b".to_string()),
748 pattern: None,
749 exact: None,
750 tier: "mid".to_string(),
751 },
752 TierRule {
753 contains: Some("gemma4:31b".to_string()),
754 pattern: None,
755 exact: None,
756 tier: "frontier".to_string(),
757 },
758 TierRule {
759 pattern: Some("claude-*".to_string()),
760 contains: None,
761 exact: None,
762 tier: "frontier".to_string(),
763 },
764 TierRule {
765 exact: Some("gpt-4o".to_string()),
766 contains: None,
767 pattern: None,
768 tier: "frontier".to_string(),
769 },
770 ];
771
772 config.tier_defaults = TierDefaults {
773 default: "mid".to_string(),
774 };
775
776 config.aliases.insert(
777 "frontier".to_string(),
778 AliasDef {
779 id: "claude-sonnet-4-20250514".to_string(),
780 provider: "anthropic".to_string(),
781 tool_format: None,
782 },
783 );
784 config.aliases.insert(
785 "tier/frontier".to_string(),
786 AliasDef {
787 id: "claude-sonnet-4-20250514".to_string(),
788 provider: "anthropic".to_string(),
789 tool_format: None,
790 },
791 );
792 config.aliases.insert(
793 "mid".to_string(),
794 AliasDef {
795 id: "gpt-4o-mini".to_string(),
796 provider: "openai".to_string(),
797 tool_format: None,
798 },
799 );
800 config.aliases.insert(
801 "tier/mid".to_string(),
802 AliasDef {
803 id: "gpt-4o-mini".to_string(),
804 provider: "openai".to_string(),
805 tool_format: None,
806 },
807 );
808 config.aliases.insert(
809 "small".to_string(),
810 AliasDef {
811 id: "Qwen/Qwen3.5-9B".to_string(),
812 provider: "openrouter".to_string(),
813 tool_format: None,
814 },
815 );
816 config.aliases.insert(
817 "tier/small".to_string(),
818 AliasDef {
819 id: "Qwen/Qwen3.5-9B".to_string(),
820 provider: "openrouter".to_string(),
821 tool_format: None,
822 },
823 );
824 config.aliases.insert(
825 "local-gemma4".to_string(),
826 AliasDef {
827 id: "gemma-4-26b-a4b-it".to_string(),
828 provider: "local".to_string(),
829 tool_format: None,
830 },
831 );
832 config.aliases.insert(
833 "local-gemma4-26b".to_string(),
834 AliasDef {
835 id: "gemma-4-26b-a4b-it".to_string(),
836 provider: "local".to_string(),
837 tool_format: None,
838 },
839 );
840 config.aliases.insert(
841 "local-gemma4-31b".to_string(),
842 AliasDef {
843 id: "gemma-4-31b-it".to_string(),
844 provider: "local".to_string(),
845 tool_format: None,
846 },
847 );
848 config.aliases.insert(
849 "local-gemma4-e4b".to_string(),
850 AliasDef {
851 id: "gemma-4-e4b-it".to_string(),
852 provider: "local".to_string(),
853 tool_format: None,
854 },
855 );
856 config.aliases.insert(
857 "local-gemma4-e2b".to_string(),
858 AliasDef {
859 id: "gemma-4-e2b-it".to_string(),
860 provider: "local".to_string(),
861 tool_format: None,
862 },
863 );
864
865 config
866}
867
868#[cfg(test)]
869mod tests {
870 use super::*;
871
872 fn reset_overrides() {
873 clear_user_overrides();
874 }
875
876 #[test]
877 fn test_glob_match_prefix() {
878 assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
879 assert!(glob_match("gpt-*", "gpt-4o"));
880 assert!(!glob_match("claude-*", "gpt-4o"));
881 }
882
883 #[test]
884 fn test_glob_match_suffix() {
885 assert!(glob_match("*-latest", "llama3.2-latest"));
886 assert!(!glob_match("*-latest", "llama3.2"));
887 }
888
889 #[test]
890 fn test_glob_match_middle() {
891 assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
892 assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
893 }
894
895 #[test]
896 fn test_glob_match_exact() {
897 assert!(glob_match("gpt-4o", "gpt-4o"));
898 assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
899 }
900
901 #[test]
902 fn test_infer_provider_from_defaults() {
903 assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
904 assert_eq!(infer_provider("gpt-4o"), "openai");
905 assert_eq!(infer_provider("o1-preview"), "openai");
906 assert_eq!(infer_provider("o3-mini"), "openai");
907 assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
908 assert_eq!(infer_provider("llama3.2:latest"), "ollama");
909 assert_eq!(infer_provider("unknown-model"), "anthropic");
910 }
911
912 #[test]
913 fn test_infer_provider_local_prefix() {
914 assert_eq!(infer_provider("local:gemma-4-e4b-it"), "local");
917 assert_eq!(infer_provider("local:qwen2.5"), "local");
918 assert_eq!(infer_provider("local:owner/model"), "local");
920 }
921
922 #[test]
923 fn test_model_tier_from_defaults() {
924 assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
925 assert_eq!(model_tier("gpt-4o"), "frontier");
926 assert_eq!(model_tier("Qwen3.5-9B"), "small");
927 assert_eq!(model_tier("deepseek-v3"), "mid");
928 }
929
930 #[test]
931 fn test_resolve_model_unknown_alias() {
932 let (id, provider) = resolve_model("gpt-4o");
933 assert_eq!(id, "gpt-4o");
934 assert!(provider.is_none());
935 }
936
937 #[test]
938 fn test_provider_names() {
939 let names = provider_names();
940 assert!(names.len() >= 7);
941 assert!(names.contains(&"anthropic".to_string()));
942 assert!(names.contains(&"together".to_string()));
943 assert!(names.contains(&"local".to_string()));
944 assert!(names.contains(&"openai".to_string()));
945 assert!(names.contains(&"ollama".to_string()));
946 }
947
948 #[test]
949 fn test_resolve_tier_model_default_aliases() {
950 let (model, provider) = resolve_tier_model("frontier", None).unwrap();
951 assert_eq!(model, "claude-sonnet-4-20250514");
952 assert_eq!(provider, "anthropic");
953
954 let (model, provider) = resolve_tier_model("small", None).unwrap();
955 assert_eq!(model, "Qwen/Qwen3.5-9B");
956 assert_eq!(provider, "openrouter");
957 }
958
959 #[test]
960 fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
961 let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
962 assert_eq!(model, "gpt-4o-mini");
963 assert_eq!(provider, "openai");
964 }
965
966 #[test]
967 fn test_provider_config_anthropic() {
968 let pdef = provider_config("anthropic").unwrap();
969 assert_eq!(pdef.auth_style, "header");
970 assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
971 }
972
973 #[test]
974 fn test_resolve_base_url_no_env() {
975 let pdef = ProviderDef {
976 base_url: "https://example.com".to_string(),
977 ..Default::default()
978 };
979 assert_eq!(resolve_base_url(&pdef), "https://example.com");
980 }
981
982 #[test]
983 fn test_default_config_roundtrip() {
984 let config = default_config();
985 assert!(!config.providers.is_empty());
986 assert!(!config.inference_rules.is_empty());
987 assert!(!config.tier_rules.is_empty());
988 assert_eq!(config.tier_defaults.default, "mid");
989 }
990
991 #[test]
992 fn test_model_params_empty() {
993 let params = model_params("claude-sonnet-4-20250514");
994 assert!(params.is_empty());
995 }
996
997 #[test]
998 fn test_user_overrides_add_provider_and_alias() {
999 reset_overrides();
1000 let mut overlay = ProvidersConfig::default();
1001 overlay.providers.insert(
1002 "acme".to_string(),
1003 ProviderDef {
1004 base_url: "https://llm.acme.test/v1".to_string(),
1005 chat_endpoint: "/chat/completions".to_string(),
1006 ..Default::default()
1007 },
1008 );
1009 overlay.aliases.insert(
1010 "acme-fast".to_string(),
1011 AliasDef {
1012 id: "acme/model-fast".to_string(),
1013 provider: "acme".to_string(),
1014 tool_format: Some("native".to_string()),
1015 },
1016 );
1017 set_user_overrides(Some(overlay));
1018
1019 let (model, provider) = resolve_model("acme-fast");
1020 assert_eq!(model, "acme/model-fast");
1021 assert_eq!(provider.as_deref(), Some("acme"));
1022 assert!(provider_names().contains(&"acme".to_string()));
1023 assert_eq!(
1024 provider_config("acme").map(|provider| provider.base_url),
1025 Some("https://llm.acme.test/v1".to_string())
1026 );
1027
1028 reset_overrides();
1029 }
1030
1031 #[test]
1032 fn test_user_overrides_prepend_inference_rules() {
1033 reset_overrides();
1034 let mut overlay = ProvidersConfig::default();
1035 overlay.inference_rules.push(InferenceRule {
1036 pattern: Some("internal-*".to_string()),
1037 contains: None,
1038 exact: None,
1039 provider: "openai".to_string(),
1040 });
1041 set_user_overrides(Some(overlay));
1042
1043 assert_eq!(infer_provider("internal-foo"), "openai");
1044
1045 reset_overrides();
1046 }
1047}