Skip to main content

harn_vm/
llm_config.rs

1use serde::Deserialize;
2use std::cell::RefCell;
3use std::collections::BTreeMap;
4use std::sync::OnceLock;
5
6static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
7static CONFIG_PATH: OnceLock<String> = OnceLock::new();
8
9thread_local! {
10    /// Thread-local provider config overlays installed by the CLI after it
11    /// reads the nearest `harn.toml` plus any installed package manifests.
12    /// Kept thread-local so tests and multi-VM hosts can scope extensions to
13    /// the current run without mutating the process-wide default config.
14    static USER_OVERRIDES: RefCell<Option<ProvidersConfig>> = const { RefCell::new(None) };
15}
16
17#[derive(Debug, Clone, Deserialize, Default)]
18pub struct ProvidersConfig {
19    #[serde(default)]
20    pub providers: BTreeMap<String, ProviderDef>,
21    #[serde(default)]
22    pub aliases: BTreeMap<String, AliasDef>,
23    #[serde(default)]
24    pub inference_rules: Vec<InferenceRule>,
25    #[serde(default)]
26    pub tier_rules: Vec<TierRule>,
27    #[serde(default)]
28    pub tier_defaults: TierDefaults,
29    #[serde(default)]
30    pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
31}
32
33impl ProvidersConfig {
34    pub fn is_empty(&self) -> bool {
35        self.providers.is_empty()
36            && self.aliases.is_empty()
37            && self.inference_rules.is_empty()
38            && self.tier_rules.is_empty()
39            && self.model_defaults.is_empty()
40            && self.tier_defaults.default == default_mid()
41    }
42
43    pub fn merge_from(&mut self, overlay: &ProvidersConfig) {
44        self.providers.extend(overlay.providers.clone());
45        self.aliases.extend(overlay.aliases.clone());
46
47        if !overlay.inference_rules.is_empty() {
48            let mut merged = overlay.inference_rules.clone();
49            merged.extend(self.inference_rules.clone());
50            self.inference_rules = merged;
51        }
52
53        if !overlay.tier_rules.is_empty() {
54            let mut merged = overlay.tier_rules.clone();
55            merged.extend(self.tier_rules.clone());
56            self.tier_rules = merged;
57        }
58
59        if overlay.tier_defaults.default != default_mid() {
60            self.tier_defaults = overlay.tier_defaults.clone();
61        }
62
63        for (pattern, defaults) in &overlay.model_defaults {
64            self.model_defaults
65                .entry(pattern.clone())
66                .or_default()
67                .extend(defaults.clone());
68        }
69    }
70}
71
72#[derive(Debug, Clone, Deserialize)]
73pub struct ProviderDef {
74    pub base_url: String,
75    #[serde(default)]
76    pub base_url_env: Option<String>,
77    #[serde(default = "default_bearer")]
78    pub auth_style: String,
79    #[serde(default)]
80    pub auth_header: Option<String>,
81    #[serde(default)]
82    pub auth_env: AuthEnv,
83    #[serde(default)]
84    pub extra_headers: BTreeMap<String, String>,
85    #[serde(default)]
86    pub chat_endpoint: String,
87    #[serde(default)]
88    pub completion_endpoint: Option<String>,
89    #[serde(default)]
90    pub healthcheck: Option<HealthcheckDef>,
91    #[serde(default)]
92    pub features: Vec<String>,
93    /// Fallback provider name to try if this provider fails.
94    #[serde(default)]
95    pub fallback: Option<String>,
96    /// Number of retries before falling back (default 0).
97    #[serde(default)]
98    pub retry_count: Option<u32>,
99    /// Delay between retries in milliseconds (default 1000).
100    #[serde(default)]
101    pub retry_delay_ms: Option<u64>,
102    /// Maximum requests per minute. None = unlimited.
103    #[serde(default)]
104    pub rpm: Option<u32>,
105}
106
107impl Default for ProviderDef {
108    fn default() -> Self {
109        Self {
110            base_url: String::new(),
111            base_url_env: None,
112            auth_style: default_bearer(),
113            auth_header: None,
114            auth_env: AuthEnv::None,
115            extra_headers: BTreeMap::new(),
116            chat_endpoint: String::new(),
117            completion_endpoint: None,
118            healthcheck: None,
119            features: Vec::new(),
120            fallback: None,
121            retry_count: None,
122            retry_delay_ms: None,
123            rpm: None,
124        }
125    }
126}
127
128fn default_bearer() -> String {
129    "bearer".to_string()
130}
131
132/// Auth env var name(s) for the provider. Can be a single string or an array
133/// (tried in order until one is set).
134#[derive(Debug, Clone, Deserialize, Default)]
135#[serde(untagged)]
136pub enum AuthEnv {
137    #[default]
138    None,
139    Single(String),
140    Multiple(Vec<String>),
141}
142
143#[derive(Debug, Clone, Deserialize)]
144pub struct HealthcheckDef {
145    pub method: String,
146    #[serde(default)]
147    pub path: Option<String>,
148    #[serde(default)]
149    pub url: Option<String>,
150    #[serde(default)]
151    pub body: Option<String>,
152}
153
154#[derive(Debug, Clone, Deserialize)]
155pub struct AliasDef {
156    pub id: String,
157    pub provider: String,
158    /// Per-model tool format override: "native" or "text". When set, this
159    /// takes precedence over the provider-level default. Models with strong
160    /// tool-calling fine-tuning (Kimi-K2.5, GPT-4o) should use "native";
161    /// models better served by text-based tool calling use "text".
162    #[serde(default)]
163    pub tool_format: Option<String>,
164}
165
166#[derive(Debug, Clone, Deserialize)]
167pub struct InferenceRule {
168    #[serde(default)]
169    pub pattern: Option<String>,
170    #[serde(default)]
171    pub contains: Option<String>,
172    #[serde(default)]
173    pub exact: Option<String>,
174    pub provider: String,
175}
176
177#[derive(Debug, Clone, Deserialize)]
178pub struct TierRule {
179    #[serde(default)]
180    pub pattern: Option<String>,
181    #[serde(default)]
182    pub contains: Option<String>,
183    #[serde(default)]
184    pub exact: Option<String>,
185    pub tier: String,
186}
187
188#[derive(Debug, Clone, Deserialize)]
189pub struct TierDefaults {
190    #[serde(default = "default_mid")]
191    pub default: String,
192}
193
194impl Default for TierDefaults {
195    fn default() -> Self {
196        Self {
197            default: default_mid(),
198        }
199    }
200}
201
202fn default_mid() -> String {
203    "mid".to_string()
204}
205
206/// Load and cache the providers config. Called once at VM startup.
207pub fn load_config() -> &'static ProvidersConfig {
208    CONFIG.get_or_init(|| {
209        let verbose_config_logging = matches!(
210            std::env::var("HARN_VERBOSE_CONFIG").ok().as_deref(),
211            Some("1" | "true" | "TRUE" | "yes" | "YES")
212        ) || matches!(
213            std::env::var("HARN_ACP_VERBOSE").ok().as_deref(),
214            Some("1" | "true" | "TRUE" | "yes" | "YES")
215        );
216        if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
217            match std::fs::read_to_string(&path) {
218                Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
219                    Ok(config) => {
220                        if verbose_config_logging {
221                            eprintln!(
222                                "[llm_config] Loaded {} providers, {} aliases from {}",
223                                config.providers.len(),
224                                config.aliases.len(),
225                                path
226                            );
227                        }
228                        let _ = CONFIG_PATH.set(path);
229                        return config;
230                    }
231                    Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
232                },
233                Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
234            }
235        }
236        if let Some(home) = dirs_or_home() {
237            let path = format!("{home}/.config/harn/providers.toml");
238            if let Ok(content) = std::fs::read_to_string(&path) {
239                if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
240                    let _ = CONFIG_PATH.set(path);
241                    return config;
242                }
243            }
244        }
245        default_config()
246    })
247}
248
249/// Returns the filesystem path of the currently-loaded providers config, if
250/// any. Returns `None` when built-in defaults are active.
251pub fn loaded_config_path() -> Option<std::path::PathBuf> {
252    // Force lazy init so CONFIG_PATH is populated if a file was loaded.
253    let _ = load_config();
254    CONFIG_PATH.get().map(std::path::PathBuf::from)
255}
256
257/// Install per-run provider config overlays. The overlay uses the same shape as
258/// `providers.toml`, but lives under `[llm]` in `harn.toml` and package
259/// manifests. Passing `None` clears the overlay.
260pub fn set_user_overrides(config: Option<ProvidersConfig>) {
261    USER_OVERRIDES.with(|cell| *cell.borrow_mut() = config);
262}
263
264/// Clear per-run provider config overlays.
265pub fn clear_user_overrides() {
266    set_user_overrides(None);
267}
268
269fn effective_config() -> ProvidersConfig {
270    let mut merged = load_config().clone();
271    USER_OVERRIDES.with(|cell| {
272        if let Some(overlay) = cell.borrow().as_ref() {
273            merged.merge_from(overlay);
274        }
275    });
276    merged
277}
278
279/// Resolve a model alias to (model_id, provider_name).
280pub fn resolve_model(alias: &str) -> (String, Option<String>) {
281    let config = effective_config();
282    if let Some(a) = config.aliases.get(alias) {
283        return (a.id.clone(), Some(a.provider.clone()));
284    }
285    (alias.to_string(), None)
286}
287
288/// Infer provider from a model ID using inference rules.
289pub fn infer_provider(model_id: &str) -> String {
290    let config = effective_config();
291    for rule in &config.inference_rules {
292        if let Some(exact) = &rule.exact {
293            if model_id == exact {
294                return rule.provider.clone();
295            }
296        }
297        if let Some(pattern) = &rule.pattern {
298            if glob_match(pattern, model_id) {
299                return rule.provider.clone();
300            }
301        }
302        if let Some(substr) = &rule.contains {
303            if model_id.contains(substr.as_str()) {
304                return rule.provider.clone();
305            }
306        }
307    }
308    // Fallback to hardcoded inference.
309    // Order matters: `local:` must beat the generic `:` → ollama rule, and
310    // any prefix-based rule must beat the generic `/` → openrouter rule for
311    // ids like `local:owner/model`.
312    if model_id.starts_with("local:") {
313        return "local".to_string();
314    }
315    if model_id.starts_with("claude-") {
316        return "anthropic".to_string();
317    }
318    if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
319        return "openai".to_string();
320    }
321    if model_id.contains('/') {
322        return "openrouter".to_string();
323    }
324    if model_id.contains(':') {
325        return "ollama".to_string();
326    }
327    "anthropic".to_string()
328}
329
330/// Get model tier ("small", "mid", "frontier").
331pub fn model_tier(model_id: &str) -> String {
332    let config = effective_config();
333    for rule in &config.tier_rules {
334        if let Some(exact) = &rule.exact {
335            if model_id == exact {
336                return rule.tier.clone();
337            }
338        }
339        if let Some(pattern) = &rule.pattern {
340            if glob_match(pattern, model_id) {
341                return rule.tier.clone();
342            }
343        }
344        if let Some(substr) = &rule.contains {
345            if model_id.contains(substr.as_str()) {
346                return rule.tier.clone();
347            }
348        }
349    }
350    let lower = model_id.to_lowercase();
351    if lower.contains("9b") || lower.contains("a3b") {
352        return "small".to_string();
353    }
354    if lower.starts_with("claude-") || lower == "gpt-4o" {
355        return "frontier".to_string();
356    }
357    config.tier_defaults.default.clone()
358}
359
360/// Get provider config for resolving base_url, auth, etc.
361pub fn provider_config(name: &str) -> Option<ProviderDef> {
362    effective_config().providers.get(name).cloned()
363}
364
365/// Get model-specific default parameters (temperature, etc.).
366/// Matches glob patterns in model_defaults keys.
367pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
368    let config = effective_config();
369    let mut params = BTreeMap::new();
370    for (pattern, defaults) in &config.model_defaults {
371        if glob_match(pattern, model_id) {
372            for (k, v) in defaults {
373                params.insert(k.clone(), v.clone());
374            }
375        }
376    }
377    params
378}
379
380/// Get list of configured provider names.
381pub fn provider_names() -> Vec<String> {
382    effective_config().providers.keys().cloned().collect()
383}
384
385/// Check if a provider advertises a feature (e.g., "native_tools").
386pub fn provider_has_feature(provider: &str, feature: &str) -> bool {
387    provider_config(provider)
388        .map(|p| p.features.iter().any(|f| f == feature))
389        .unwrap_or(false)
390}
391
392/// Resolve the default tool format for a model+provider combination.
393/// Priority: alias `tool_format` (matched by model ID) > provider feature > "text".
394pub fn default_tool_format(model: &str, provider: &str) -> String {
395    let config = effective_config();
396    // Aliases match by model ID + provider, or by alias name.
397    for (name, alias) in &config.aliases {
398        let matches = (alias.id == model && alias.provider == provider) || name == model;
399        if matches {
400            if let Some(ref fmt) = alias.tool_format {
401                return fmt.clone();
402            }
403        }
404    }
405    if provider_has_feature(provider, "native_tools") {
406        "native".to_string()
407    } else {
408        "text".to_string()
409    }
410}
411
412/// Resolve a tier or alias into a concrete model/provider pair.
413pub fn resolve_tier_model(
414    target: &str,
415    preferred_provider: Option<&str>,
416) -> Option<(String, String)> {
417    let config = effective_config();
418
419    if let Some(alias) = config.aliases.get(target) {
420        return Some((alias.id.clone(), alias.provider.clone()));
421    }
422
423    let candidate_aliases = if let Some(provider) = preferred_provider {
424        vec![
425            format!("{provider}/{target}"),
426            format!("{provider}:{target}"),
427            format!("tier/{target}"),
428            target.to_string(),
429        ]
430    } else {
431        vec![format!("tier/{target}"), target.to_string()]
432    };
433
434    for alias_name in candidate_aliases {
435        if let Some(alias) = config.aliases.get(&alias_name) {
436            return Some((alias.id.clone(), alias.provider.clone()));
437        }
438    }
439
440    None
441}
442
443/// Return all configured alias-backed model/provider pairs whose resolved
444/// model falls into the requested capability tier. The result is de-duplicated
445/// and sorted deterministically by provider then model id.
446pub fn tier_candidates(target: &str) -> Vec<(String, String)> {
447    let config = effective_config();
448    let mut seen = std::collections::BTreeSet::new();
449    let mut candidates = Vec::new();
450
451    for alias in config.aliases.values() {
452        let pair = (alias.id.clone(), alias.provider.clone());
453        if seen.contains(&pair) {
454            continue;
455        }
456        if model_tier(&alias.id) == target {
457            seen.insert(pair.clone());
458            candidates.push(pair);
459        }
460    }
461
462    candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
463        provider_a
464            .cmp(provider_b)
465            .then_with(|| model_a.cmp(model_b))
466    });
467    candidates
468}
469
470/// Simple glob matching for patterns like "claude-*", "qwen/*", "ollama:*".
471fn glob_match(pattern: &str, input: &str) -> bool {
472    if let Some(prefix) = pattern.strip_suffix('*') {
473        input.starts_with(prefix)
474    } else if let Some(suffix) = pattern.strip_prefix('*') {
475        input.ends_with(suffix)
476    } else if pattern.contains('*') {
477        let parts: Vec<&str> = pattern.split('*').collect();
478        if parts.len() == 2 {
479            input.starts_with(parts[0]) && input.ends_with(parts[1])
480        } else {
481            input == pattern
482        }
483    } else {
484        input == pattern
485    }
486}
487
488fn dirs_or_home() -> Option<String> {
489    std::env::var("HOME").ok()
490}
491
492/// Resolve the effective base URL for a provider, checking the `base_url_env`
493/// override first, then falling back to the configured `base_url`.
494pub fn resolve_base_url(pdef: &ProviderDef) -> String {
495    if let Some(env_name) = &pdef.base_url_env {
496        if let Ok(val) = std::env::var(env_name) {
497            // Strip surrounding quotes that some .env parsers leave intact.
498            let trimmed = val.trim().trim_matches('"').trim_matches('\'');
499            if !trimmed.is_empty() {
500                return trimmed.to_string();
501            }
502        }
503    }
504    pdef.base_url.clone()
505}
506
507fn default_config() -> ProvidersConfig {
508    let mut config = ProvidersConfig::default();
509
510    config.providers.insert(
511        "anthropic".to_string(),
512        ProviderDef {
513            base_url: "https://api.anthropic.com/v1".to_string(),
514            auth_style: "header".to_string(),
515            auth_header: Some("x-api-key".to_string()),
516            auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
517            extra_headers: BTreeMap::from([(
518                "anthropic-version".to_string(),
519                "2023-06-01".to_string(),
520            )]),
521            chat_endpoint: "/messages".to_string(),
522            completion_endpoint: None,
523            healthcheck: Some(HealthcheckDef {
524                method: "POST".to_string(),
525                path: Some("/messages/count_tokens".to_string()),
526                url: None,
527                body: Some(
528                    r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
529                        .to_string(),
530                ),
531            }),
532            features: vec!["prompt_caching".to_string(), "thinking".to_string()],
533            ..Default::default()
534        },
535    );
536
537    // OpenAI
538    config.providers.insert(
539        "openai".to_string(),
540        ProviderDef {
541            base_url: "https://api.openai.com/v1".to_string(),
542            auth_style: "bearer".to_string(),
543            auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
544            chat_endpoint: "/chat/completions".to_string(),
545            completion_endpoint: Some("/completions".to_string()),
546            healthcheck: Some(HealthcheckDef {
547                method: "GET".to_string(),
548                path: Some("/models".to_string()),
549                url: None,
550                body: None,
551            }),
552            ..Default::default()
553        },
554    );
555
556    // OpenRouter
557    config.providers.insert(
558        "openrouter".to_string(),
559        ProviderDef {
560            base_url: "https://openrouter.ai/api/v1".to_string(),
561            auth_style: "bearer".to_string(),
562            auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
563            chat_endpoint: "/chat/completions".to_string(),
564            completion_endpoint: Some("/completions".to_string()),
565            healthcheck: Some(HealthcheckDef {
566                method: "GET".to_string(),
567                path: Some("/auth/key".to_string()),
568                url: None,
569                body: None,
570            }),
571            ..Default::default()
572        },
573    );
574
575    // HuggingFace
576    config.providers.insert(
577        "huggingface".to_string(),
578        ProviderDef {
579            base_url: "https://router.huggingface.co/v1".to_string(),
580            auth_style: "bearer".to_string(),
581            auth_env: AuthEnv::Multiple(vec![
582                "HF_TOKEN".to_string(),
583                "HUGGINGFACE_API_KEY".to_string(),
584            ]),
585            chat_endpoint: "/chat/completions".to_string(),
586            completion_endpoint: Some("/completions".to_string()),
587            healthcheck: Some(HealthcheckDef {
588                method: "GET".to_string(),
589                url: Some("https://huggingface.co/api/whoami-v2".to_string()),
590                path: None,
591                body: None,
592            }),
593            ..Default::default()
594        },
595    );
596
597    // Ollama default. Note: Burin overrides this to `/v1/chat/completions`
598    // via its bundled `providers.toml` (loaded by setting
599    // `HARN_PROVIDERS_CONFIG` in the host process). The OpenAI-compat
600    // path bypasses Ollama's per-model tool-call post-processors
601    // (qwen3coder.go, qwen35.go) which raise HTTP 500s on text-mode
602    // responses for the Qwen3.5 family. The default here stays on
603    // `/api/chat` so the harn-vm test stub keeps working with Ollama's
604    // native NDJSON wire format.
605    config.providers.insert(
606        "ollama".to_string(),
607        ProviderDef {
608            base_url: "http://localhost:11434".to_string(),
609            base_url_env: Some("OLLAMA_HOST".to_string()),
610            auth_style: "none".to_string(),
611            chat_endpoint: "/api/chat".to_string(),
612            completion_endpoint: Some("/api/generate".to_string()),
613            healthcheck: Some(HealthcheckDef {
614                method: "GET".to_string(),
615                path: Some("/api/tags".to_string()),
616                url: None,
617                body: None,
618            }),
619            ..Default::default()
620        },
621    );
622
623    // Together AI (OpenAI-compatible)
624    config.providers.insert(
625        "together".to_string(),
626        ProviderDef {
627            base_url: "https://api.together.xyz/v1".to_string(),
628            base_url_env: Some("TOGETHER_AI_BASE_URL".to_string()),
629            auth_style: "bearer".to_string(),
630            auth_env: AuthEnv::Single("TOGETHER_AI_API_KEY".to_string()),
631            chat_endpoint: "/chat/completions".to_string(),
632            completion_endpoint: Some("/completions".to_string()),
633            healthcheck: Some(HealthcheckDef {
634                method: "GET".to_string(),
635                path: Some("/models".to_string()),
636                url: None,
637                body: None,
638            }),
639            ..Default::default()
640        },
641    );
642
643    // Local OpenAI-compatible server
644    config.providers.insert(
645        "local".to_string(),
646        ProviderDef {
647            base_url: "http://localhost:8000".to_string(),
648            base_url_env: Some("LOCAL_LLM_BASE_URL".to_string()),
649            auth_style: "none".to_string(),
650            chat_endpoint: "/v1/chat/completions".to_string(),
651            completion_endpoint: Some("/v1/completions".to_string()),
652            healthcheck: Some(HealthcheckDef {
653                method: "GET".to_string(),
654                path: Some("/v1/models".to_string()),
655                url: None,
656                body: None,
657            }),
658            ..Default::default()
659        },
660    );
661
662    // Default inference rules
663    config.inference_rules = vec![
664        InferenceRule {
665            pattern: Some("claude-*".to_string()),
666            contains: None,
667            exact: None,
668            provider: "anthropic".to_string(),
669        },
670        InferenceRule {
671            pattern: Some("gpt-*".to_string()),
672            contains: None,
673            exact: None,
674            provider: "openai".to_string(),
675        },
676        InferenceRule {
677            pattern: Some("o1*".to_string()),
678            contains: None,
679            exact: None,
680            provider: "openai".to_string(),
681        },
682        InferenceRule {
683            pattern: Some("o3*".to_string()),
684            contains: None,
685            exact: None,
686            provider: "openai".to_string(),
687        },
688        InferenceRule {
689            pattern: Some("local:*".to_string()),
690            contains: None,
691            exact: None,
692            provider: "local".to_string(),
693        },
694        InferenceRule {
695            pattern: None,
696            contains: Some("/".to_string()),
697            exact: None,
698            provider: "openrouter".to_string(),
699        },
700        InferenceRule {
701            pattern: None,
702            contains: Some(":".to_string()),
703            exact: None,
704            provider: "ollama".to_string(),
705        },
706    ];
707
708    // Default tier rules
709    config.tier_rules = vec![
710        TierRule {
711            contains: Some("9b".to_string()),
712            pattern: None,
713            exact: None,
714            tier: "small".to_string(),
715        },
716        TierRule {
717            contains: Some("a3b".to_string()),
718            pattern: None,
719            exact: None,
720            tier: "small".to_string(),
721        },
722        TierRule {
723            contains: Some("gemma-4-e2b".to_string()),
724            pattern: None,
725            exact: None,
726            tier: "small".to_string(),
727        },
728        TierRule {
729            contains: Some("gemma-4-e4b".to_string()),
730            pattern: None,
731            exact: None,
732            tier: "small".to_string(),
733        },
734        TierRule {
735            contains: Some("gemma-4-26b".to_string()),
736            pattern: None,
737            exact: None,
738            tier: "mid".to_string(),
739        },
740        TierRule {
741            contains: Some("gemma-4-31b".to_string()),
742            pattern: None,
743            exact: None,
744            tier: "frontier".to_string(),
745        },
746        TierRule {
747            contains: Some("gemma4:26b".to_string()),
748            pattern: None,
749            exact: None,
750            tier: "mid".to_string(),
751        },
752        TierRule {
753            contains: Some("gemma4:31b".to_string()),
754            pattern: None,
755            exact: None,
756            tier: "frontier".to_string(),
757        },
758        TierRule {
759            pattern: Some("claude-*".to_string()),
760            contains: None,
761            exact: None,
762            tier: "frontier".to_string(),
763        },
764        TierRule {
765            exact: Some("gpt-4o".to_string()),
766            contains: None,
767            pattern: None,
768            tier: "frontier".to_string(),
769        },
770    ];
771
772    config.tier_defaults = TierDefaults {
773        default: "mid".to_string(),
774    };
775
776    config.aliases.insert(
777        "frontier".to_string(),
778        AliasDef {
779            id: "claude-sonnet-4-20250514".to_string(),
780            provider: "anthropic".to_string(),
781            tool_format: None,
782        },
783    );
784    config.aliases.insert(
785        "tier/frontier".to_string(),
786        AliasDef {
787            id: "claude-sonnet-4-20250514".to_string(),
788            provider: "anthropic".to_string(),
789            tool_format: None,
790        },
791    );
792    config.aliases.insert(
793        "mid".to_string(),
794        AliasDef {
795            id: "gpt-4o-mini".to_string(),
796            provider: "openai".to_string(),
797            tool_format: None,
798        },
799    );
800    config.aliases.insert(
801        "tier/mid".to_string(),
802        AliasDef {
803            id: "gpt-4o-mini".to_string(),
804            provider: "openai".to_string(),
805            tool_format: None,
806        },
807    );
808    config.aliases.insert(
809        "small".to_string(),
810        AliasDef {
811            id: "Qwen/Qwen3.5-9B".to_string(),
812            provider: "openrouter".to_string(),
813            tool_format: None,
814        },
815    );
816    config.aliases.insert(
817        "tier/small".to_string(),
818        AliasDef {
819            id: "Qwen/Qwen3.5-9B".to_string(),
820            provider: "openrouter".to_string(),
821            tool_format: None,
822        },
823    );
824    config.aliases.insert(
825        "local-gemma4".to_string(),
826        AliasDef {
827            id: "gemma-4-26b-a4b-it".to_string(),
828            provider: "local".to_string(),
829            tool_format: None,
830        },
831    );
832    config.aliases.insert(
833        "local-gemma4-26b".to_string(),
834        AliasDef {
835            id: "gemma-4-26b-a4b-it".to_string(),
836            provider: "local".to_string(),
837            tool_format: None,
838        },
839    );
840    config.aliases.insert(
841        "local-gemma4-31b".to_string(),
842        AliasDef {
843            id: "gemma-4-31b-it".to_string(),
844            provider: "local".to_string(),
845            tool_format: None,
846        },
847    );
848    config.aliases.insert(
849        "local-gemma4-e4b".to_string(),
850        AliasDef {
851            id: "gemma-4-e4b-it".to_string(),
852            provider: "local".to_string(),
853            tool_format: None,
854        },
855    );
856    config.aliases.insert(
857        "local-gemma4-e2b".to_string(),
858        AliasDef {
859            id: "gemma-4-e2b-it".to_string(),
860            provider: "local".to_string(),
861            tool_format: None,
862        },
863    );
864
865    config
866}
867
868#[cfg(test)]
869mod tests {
870    use super::*;
871
872    fn reset_overrides() {
873        clear_user_overrides();
874    }
875
876    #[test]
877    fn test_glob_match_prefix() {
878        assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
879        assert!(glob_match("gpt-*", "gpt-4o"));
880        assert!(!glob_match("claude-*", "gpt-4o"));
881    }
882
883    #[test]
884    fn test_glob_match_suffix() {
885        assert!(glob_match("*-latest", "llama3.2-latest"));
886        assert!(!glob_match("*-latest", "llama3.2"));
887    }
888
889    #[test]
890    fn test_glob_match_middle() {
891        assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
892        assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
893    }
894
895    #[test]
896    fn test_glob_match_exact() {
897        assert!(glob_match("gpt-4o", "gpt-4o"));
898        assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
899    }
900
901    #[test]
902    fn test_infer_provider_from_defaults() {
903        assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
904        assert_eq!(infer_provider("gpt-4o"), "openai");
905        assert_eq!(infer_provider("o1-preview"), "openai");
906        assert_eq!(infer_provider("o3-mini"), "openai");
907        assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
908        assert_eq!(infer_provider("llama3.2:latest"), "ollama");
909        assert_eq!(infer_provider("unknown-model"), "anthropic");
910    }
911
912    #[test]
913    fn test_infer_provider_local_prefix() {
914        // `local:` must route to the local OpenAI-compatible provider, not
915        // ollama (which would otherwise swallow everything containing `:`).
916        assert_eq!(infer_provider("local:gemma-4-e4b-it"), "local");
917        assert_eq!(infer_provider("local:qwen2.5"), "local");
918        // Even when the id also contains `/`, the `local:` prefix wins.
919        assert_eq!(infer_provider("local:owner/model"), "local");
920    }
921
922    #[test]
923    fn test_model_tier_from_defaults() {
924        assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
925        assert_eq!(model_tier("gpt-4o"), "frontier");
926        assert_eq!(model_tier("Qwen3.5-9B"), "small");
927        assert_eq!(model_tier("deepseek-v3"), "mid");
928    }
929
930    #[test]
931    fn test_resolve_model_unknown_alias() {
932        let (id, provider) = resolve_model("gpt-4o");
933        assert_eq!(id, "gpt-4o");
934        assert!(provider.is_none());
935    }
936
937    #[test]
938    fn test_provider_names() {
939        let names = provider_names();
940        assert!(names.len() >= 7);
941        assert!(names.contains(&"anthropic".to_string()));
942        assert!(names.contains(&"together".to_string()));
943        assert!(names.contains(&"local".to_string()));
944        assert!(names.contains(&"openai".to_string()));
945        assert!(names.contains(&"ollama".to_string()));
946    }
947
948    #[test]
949    fn test_resolve_tier_model_default_aliases() {
950        let (model, provider) = resolve_tier_model("frontier", None).unwrap();
951        assert_eq!(model, "claude-sonnet-4-20250514");
952        assert_eq!(provider, "anthropic");
953
954        let (model, provider) = resolve_tier_model("small", None).unwrap();
955        assert_eq!(model, "Qwen/Qwen3.5-9B");
956        assert_eq!(provider, "openrouter");
957    }
958
959    #[test]
960    fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
961        let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
962        assert_eq!(model, "gpt-4o-mini");
963        assert_eq!(provider, "openai");
964    }
965
966    #[test]
967    fn test_provider_config_anthropic() {
968        let pdef = provider_config("anthropic").unwrap();
969        assert_eq!(pdef.auth_style, "header");
970        assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
971    }
972
973    #[test]
974    fn test_resolve_base_url_no_env() {
975        let pdef = ProviderDef {
976            base_url: "https://example.com".to_string(),
977            ..Default::default()
978        };
979        assert_eq!(resolve_base_url(&pdef), "https://example.com");
980    }
981
982    #[test]
983    fn test_default_config_roundtrip() {
984        let config = default_config();
985        assert!(!config.providers.is_empty());
986        assert!(!config.inference_rules.is_empty());
987        assert!(!config.tier_rules.is_empty());
988        assert_eq!(config.tier_defaults.default, "mid");
989    }
990
991    #[test]
992    fn test_model_params_empty() {
993        let params = model_params("claude-sonnet-4-20250514");
994        assert!(params.is_empty());
995    }
996
997    #[test]
998    fn test_user_overrides_add_provider_and_alias() {
999        reset_overrides();
1000        let mut overlay = ProvidersConfig::default();
1001        overlay.providers.insert(
1002            "acme".to_string(),
1003            ProviderDef {
1004                base_url: "https://llm.acme.test/v1".to_string(),
1005                chat_endpoint: "/chat/completions".to_string(),
1006                ..Default::default()
1007            },
1008        );
1009        overlay.aliases.insert(
1010            "acme-fast".to_string(),
1011            AliasDef {
1012                id: "acme/model-fast".to_string(),
1013                provider: "acme".to_string(),
1014                tool_format: Some("native".to_string()),
1015            },
1016        );
1017        set_user_overrides(Some(overlay));
1018
1019        let (model, provider) = resolve_model("acme-fast");
1020        assert_eq!(model, "acme/model-fast");
1021        assert_eq!(provider.as_deref(), Some("acme"));
1022        assert!(provider_names().contains(&"acme".to_string()));
1023        assert_eq!(
1024            provider_config("acme").map(|provider| provider.base_url),
1025            Some("https://llm.acme.test/v1".to_string())
1026        );
1027
1028        reset_overrides();
1029    }
1030
1031    #[test]
1032    fn test_user_overrides_prepend_inference_rules() {
1033        reset_overrides();
1034        let mut overlay = ProvidersConfig::default();
1035        overlay.inference_rules.push(InferenceRule {
1036            pattern: Some("internal-*".to_string()),
1037            contains: None,
1038            exact: None,
1039            provider: "openai".to_string(),
1040        });
1041        set_user_overrides(Some(overlay));
1042
1043        assert_eq!(infer_provider("internal-foo"), "openai");
1044
1045        reset_overrides();
1046    }
1047}