Skip to main content

harn_vm/
llm_config.rs

1use serde::Deserialize;
2use std::collections::BTreeMap;
3use std::sync::OnceLock;
4
5static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
6static CONFIG_PATH: OnceLock<String> = OnceLock::new();
7
8#[derive(Debug, Clone, Deserialize, Default)]
9pub struct ProvidersConfig {
10    #[serde(default)]
11    pub providers: BTreeMap<String, ProviderDef>,
12    #[serde(default)]
13    pub aliases: BTreeMap<String, AliasDef>,
14    #[serde(default)]
15    pub inference_rules: Vec<InferenceRule>,
16    #[serde(default)]
17    pub tier_rules: Vec<TierRule>,
18    #[serde(default)]
19    pub tier_defaults: TierDefaults,
20    #[serde(default)]
21    pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
22}
23
24#[derive(Debug, Clone, Deserialize)]
25pub struct ProviderDef {
26    pub base_url: String,
27    #[serde(default)]
28    pub base_url_env: Option<String>,
29    #[serde(default = "default_bearer")]
30    pub auth_style: String,
31    #[serde(default)]
32    pub auth_header: Option<String>,
33    #[serde(default)]
34    pub auth_env: AuthEnv,
35    #[serde(default)]
36    pub extra_headers: BTreeMap<String, String>,
37    #[serde(default)]
38    pub chat_endpoint: String,
39    #[serde(default)]
40    pub completion_endpoint: Option<String>,
41    #[serde(default)]
42    pub healthcheck: Option<HealthcheckDef>,
43    #[serde(default)]
44    pub features: Vec<String>,
45    /// Fallback provider name to try if this provider fails.
46    #[serde(default)]
47    pub fallback: Option<String>,
48    /// Number of retries before falling back (default 0).
49    #[serde(default)]
50    pub retry_count: Option<u32>,
51    /// Delay between retries in milliseconds (default 1000).
52    #[serde(default)]
53    pub retry_delay_ms: Option<u64>,
54    /// Maximum requests per minute. None = unlimited.
55    #[serde(default)]
56    pub rpm: Option<u32>,
57}
58
59impl Default for ProviderDef {
60    fn default() -> Self {
61        Self {
62            base_url: String::new(),
63            base_url_env: None,
64            auth_style: default_bearer(),
65            auth_header: None,
66            auth_env: AuthEnv::None,
67            extra_headers: BTreeMap::new(),
68            chat_endpoint: String::new(),
69            completion_endpoint: None,
70            healthcheck: None,
71            features: Vec::new(),
72            fallback: None,
73            retry_count: None,
74            retry_delay_ms: None,
75            rpm: None,
76        }
77    }
78}
79
80fn default_bearer() -> String {
81    "bearer".to_string()
82}
83
84/// Auth env var name(s) for the provider. Can be a single string or an array
85/// (tried in order until one is set).
86#[derive(Debug, Clone, Deserialize, Default)]
87#[serde(untagged)]
88pub enum AuthEnv {
89    #[default]
90    None,
91    Single(String),
92    Multiple(Vec<String>),
93}
94
95#[derive(Debug, Clone, Deserialize)]
96pub struct HealthcheckDef {
97    pub method: String,
98    #[serde(default)]
99    pub path: Option<String>,
100    #[serde(default)]
101    pub url: Option<String>,
102    #[serde(default)]
103    pub body: Option<String>,
104}
105
106#[derive(Debug, Clone, Deserialize)]
107pub struct AliasDef {
108    pub id: String,
109    pub provider: String,
110    /// Per-model tool format override: "native" or "text". When set, this
111    /// takes precedence over the provider-level default. Models with strong
112    /// tool-calling fine-tuning (Kimi-K2.5, GPT-4o) should use "native";
113    /// models better served by text-based tool calling use "text".
114    #[serde(default)]
115    pub tool_format: Option<String>,
116}
117
118#[derive(Debug, Clone, Deserialize)]
119pub struct InferenceRule {
120    #[serde(default)]
121    pub pattern: Option<String>,
122    #[serde(default)]
123    pub contains: Option<String>,
124    #[serde(default)]
125    pub exact: Option<String>,
126    pub provider: String,
127}
128
129#[derive(Debug, Clone, Deserialize)]
130pub struct TierRule {
131    #[serde(default)]
132    pub pattern: Option<String>,
133    #[serde(default)]
134    pub contains: Option<String>,
135    #[serde(default)]
136    pub exact: Option<String>,
137    pub tier: String,
138}
139
140#[derive(Debug, Clone, Deserialize)]
141pub struct TierDefaults {
142    #[serde(default = "default_mid")]
143    pub default: String,
144}
145
146impl Default for TierDefaults {
147    fn default() -> Self {
148        Self {
149            default: default_mid(),
150        }
151    }
152}
153
154fn default_mid() -> String {
155    "mid".to_string()
156}
157
158/// Load and cache the providers config. Called once at VM startup.
159pub fn load_config() -> &'static ProvidersConfig {
160    CONFIG.get_or_init(|| {
161        let verbose_config_logging = matches!(
162            std::env::var("HARN_VERBOSE_CONFIG").ok().as_deref(),
163            Some("1" | "true" | "TRUE" | "yes" | "YES")
164        ) || matches!(
165            std::env::var("HARN_ACP_VERBOSE").ok().as_deref(),
166            Some("1" | "true" | "TRUE" | "yes" | "YES")
167        );
168        if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
169            match std::fs::read_to_string(&path) {
170                Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
171                    Ok(config) => {
172                        if verbose_config_logging {
173                            eprintln!(
174                                "[llm_config] Loaded {} providers, {} aliases from {}",
175                                config.providers.len(),
176                                config.aliases.len(),
177                                path
178                            );
179                        }
180                        let _ = CONFIG_PATH.set(path);
181                        return config;
182                    }
183                    Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
184                },
185                Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
186            }
187        }
188        if let Some(home) = dirs_or_home() {
189            let path = format!("{home}/.config/harn/providers.toml");
190            if let Ok(content) = std::fs::read_to_string(&path) {
191                if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
192                    let _ = CONFIG_PATH.set(path);
193                    return config;
194                }
195            }
196        }
197        default_config()
198    })
199}
200
201/// Returns the filesystem path of the currently-loaded providers config, if
202/// any. Returns `None` when built-in defaults are active.
203pub fn loaded_config_path() -> Option<std::path::PathBuf> {
204    // Force lazy init so CONFIG_PATH is populated if a file was loaded.
205    let _ = load_config();
206    CONFIG_PATH.get().map(std::path::PathBuf::from)
207}
208
209/// Resolve a model alias to (model_id, provider_name).
210pub fn resolve_model(alias: &str) -> (String, Option<String>) {
211    let config = load_config();
212    if let Some(a) = config.aliases.get(alias) {
213        return (a.id.clone(), Some(a.provider.clone()));
214    }
215    (alias.to_string(), None)
216}
217
218/// Infer provider from a model ID using inference rules.
219pub fn infer_provider(model_id: &str) -> String {
220    let config = load_config();
221    for rule in &config.inference_rules {
222        if let Some(exact) = &rule.exact {
223            if model_id == exact {
224                return rule.provider.clone();
225            }
226        }
227        if let Some(pattern) = &rule.pattern {
228            if glob_match(pattern, model_id) {
229                return rule.provider.clone();
230            }
231        }
232        if let Some(substr) = &rule.contains {
233            if model_id.contains(substr.as_str()) {
234                return rule.provider.clone();
235            }
236        }
237    }
238    // Fallback to hardcoded inference.
239    // Order matters: `local:` must beat the generic `:` → ollama rule, and
240    // any prefix-based rule must beat the generic `/` → openrouter rule for
241    // ids like `local:owner/model`.
242    if model_id.starts_with("local:") {
243        return "local".to_string();
244    }
245    if model_id.starts_with("claude-") {
246        return "anthropic".to_string();
247    }
248    if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
249        return "openai".to_string();
250    }
251    if model_id.contains('/') {
252        return "openrouter".to_string();
253    }
254    if model_id.contains(':') {
255        return "ollama".to_string();
256    }
257    "anthropic".to_string()
258}
259
260/// Get model tier ("small", "mid", "frontier").
261pub fn model_tier(model_id: &str) -> String {
262    let config = load_config();
263    for rule in &config.tier_rules {
264        if let Some(exact) = &rule.exact {
265            if model_id == exact {
266                return rule.tier.clone();
267            }
268        }
269        if let Some(pattern) = &rule.pattern {
270            if glob_match(pattern, model_id) {
271                return rule.tier.clone();
272            }
273        }
274        if let Some(substr) = &rule.contains {
275            if model_id.contains(substr.as_str()) {
276                return rule.tier.clone();
277            }
278        }
279    }
280    let lower = model_id.to_lowercase();
281    if lower.contains("9b") || lower.contains("a3b") {
282        return "small".to_string();
283    }
284    if lower.starts_with("claude-") || lower == "gpt-4o" {
285        return "frontier".to_string();
286    }
287    config.tier_defaults.default.clone()
288}
289
290/// Get provider config for resolving base_url, auth, etc.
291pub fn provider_config(name: &str) -> Option<&'static ProviderDef> {
292    load_config().providers.get(name)
293}
294
295/// Get model-specific default parameters (temperature, etc.).
296/// Matches glob patterns in model_defaults keys.
297pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
298    let config = load_config();
299    let mut params = BTreeMap::new();
300    for (pattern, defaults) in &config.model_defaults {
301        if glob_match(pattern, model_id) {
302            for (k, v) in defaults {
303                params.insert(k.clone(), v.clone());
304            }
305        }
306    }
307    params
308}
309
310/// Get list of configured provider names.
311pub fn provider_names() -> Vec<String> {
312    load_config().providers.keys().cloned().collect()
313}
314
315/// Check if a provider advertises a feature (e.g., "native_tools").
316pub fn provider_has_feature(provider: &str, feature: &str) -> bool {
317    provider_config(provider)
318        .map(|p| p.features.iter().any(|f| f == feature))
319        .unwrap_or(false)
320}
321
322/// Resolve the default tool format for a model+provider combination.
323/// Priority: alias `tool_format` (matched by model ID) > provider feature > "text".
324pub fn default_tool_format(model: &str, provider: &str) -> String {
325    let config = load_config();
326    // Aliases match by model ID + provider, or by alias name.
327    for (name, alias) in &config.aliases {
328        let matches = (alias.id == model && alias.provider == provider) || name == model;
329        if matches {
330            if let Some(ref fmt) = alias.tool_format {
331                return fmt.clone();
332            }
333        }
334    }
335    if provider_has_feature(provider, "native_tools") {
336        "native".to_string()
337    } else {
338        "text".to_string()
339    }
340}
341
342/// Resolve a tier or alias into a concrete model/provider pair.
343pub fn resolve_tier_model(
344    target: &str,
345    preferred_provider: Option<&str>,
346) -> Option<(String, String)> {
347    let config = load_config();
348
349    if let Some(alias) = config.aliases.get(target) {
350        return Some((alias.id.clone(), alias.provider.clone()));
351    }
352
353    let candidate_aliases = if let Some(provider) = preferred_provider {
354        vec![
355            format!("{provider}/{target}"),
356            format!("{provider}:{target}"),
357            format!("tier/{target}"),
358            target.to_string(),
359        ]
360    } else {
361        vec![format!("tier/{target}"), target.to_string()]
362    };
363
364    for alias_name in candidate_aliases {
365        if let Some(alias) = config.aliases.get(&alias_name) {
366            return Some((alias.id.clone(), alias.provider.clone()));
367        }
368    }
369
370    None
371}
372
373/// Return all configured alias-backed model/provider pairs whose resolved
374/// model falls into the requested capability tier. The result is de-duplicated
375/// and sorted deterministically by provider then model id.
376pub fn tier_candidates(target: &str) -> Vec<(String, String)> {
377    let config = load_config();
378    let mut seen = std::collections::BTreeSet::new();
379    let mut candidates = Vec::new();
380
381    for alias in config.aliases.values() {
382        let pair = (alias.id.clone(), alias.provider.clone());
383        if seen.contains(&pair) {
384            continue;
385        }
386        if model_tier(&alias.id) == target {
387            seen.insert(pair.clone());
388            candidates.push(pair);
389        }
390    }
391
392    candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
393        provider_a
394            .cmp(provider_b)
395            .then_with(|| model_a.cmp(model_b))
396    });
397    candidates
398}
399
400/// Simple glob matching for patterns like "claude-*", "qwen/*", "ollama:*".
401fn glob_match(pattern: &str, input: &str) -> bool {
402    if let Some(prefix) = pattern.strip_suffix('*') {
403        input.starts_with(prefix)
404    } else if let Some(suffix) = pattern.strip_prefix('*') {
405        input.ends_with(suffix)
406    } else if pattern.contains('*') {
407        let parts: Vec<&str> = pattern.split('*').collect();
408        if parts.len() == 2 {
409            input.starts_with(parts[0]) && input.ends_with(parts[1])
410        } else {
411            input == pattern
412        }
413    } else {
414        input == pattern
415    }
416}
417
418fn dirs_or_home() -> Option<String> {
419    std::env::var("HOME").ok()
420}
421
422/// Resolve the effective base URL for a provider, checking the `base_url_env`
423/// override first, then falling back to the configured `base_url`.
424pub fn resolve_base_url(pdef: &ProviderDef) -> String {
425    if let Some(env_name) = &pdef.base_url_env {
426        if let Ok(val) = std::env::var(env_name) {
427            // Strip surrounding quotes that some .env parsers leave intact.
428            let trimmed = val.trim().trim_matches('"').trim_matches('\'');
429            if !trimmed.is_empty() {
430                return trimmed.to_string();
431            }
432        }
433    }
434    pdef.base_url.clone()
435}
436
437fn default_config() -> ProvidersConfig {
438    let mut config = ProvidersConfig::default();
439
440    config.providers.insert(
441        "anthropic".to_string(),
442        ProviderDef {
443            base_url: "https://api.anthropic.com/v1".to_string(),
444            auth_style: "header".to_string(),
445            auth_header: Some("x-api-key".to_string()),
446            auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
447            extra_headers: BTreeMap::from([(
448                "anthropic-version".to_string(),
449                "2023-06-01".to_string(),
450            )]),
451            chat_endpoint: "/messages".to_string(),
452            completion_endpoint: None,
453            healthcheck: Some(HealthcheckDef {
454                method: "POST".to_string(),
455                path: Some("/messages/count_tokens".to_string()),
456                url: None,
457                body: Some(
458                    r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
459                        .to_string(),
460                ),
461            }),
462            features: vec!["prompt_caching".to_string(), "thinking".to_string()],
463            ..Default::default()
464        },
465    );
466
467    // OpenAI
468    config.providers.insert(
469        "openai".to_string(),
470        ProviderDef {
471            base_url: "https://api.openai.com/v1".to_string(),
472            auth_style: "bearer".to_string(),
473            auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
474            chat_endpoint: "/chat/completions".to_string(),
475            completion_endpoint: Some("/completions".to_string()),
476            healthcheck: Some(HealthcheckDef {
477                method: "GET".to_string(),
478                path: Some("/models".to_string()),
479                url: None,
480                body: None,
481            }),
482            ..Default::default()
483        },
484    );
485
486    // OpenRouter
487    config.providers.insert(
488        "openrouter".to_string(),
489        ProviderDef {
490            base_url: "https://openrouter.ai/api/v1".to_string(),
491            auth_style: "bearer".to_string(),
492            auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
493            chat_endpoint: "/chat/completions".to_string(),
494            completion_endpoint: Some("/completions".to_string()),
495            healthcheck: Some(HealthcheckDef {
496                method: "GET".to_string(),
497                path: Some("/auth/key".to_string()),
498                url: None,
499                body: None,
500            }),
501            ..Default::default()
502        },
503    );
504
505    // HuggingFace
506    config.providers.insert(
507        "huggingface".to_string(),
508        ProviderDef {
509            base_url: "https://router.huggingface.co/v1".to_string(),
510            auth_style: "bearer".to_string(),
511            auth_env: AuthEnv::Multiple(vec![
512                "HF_TOKEN".to_string(),
513                "HUGGINGFACE_API_KEY".to_string(),
514            ]),
515            chat_endpoint: "/chat/completions".to_string(),
516            completion_endpoint: Some("/completions".to_string()),
517            healthcheck: Some(HealthcheckDef {
518                method: "GET".to_string(),
519                url: Some("https://huggingface.co/api/whoami-v2".to_string()),
520                path: None,
521                body: None,
522            }),
523            ..Default::default()
524        },
525    );
526
527    // Ollama default. Note: Burin overrides this to `/v1/chat/completions`
528    // via its bundled `providers.toml` (loaded by setting
529    // `HARN_PROVIDERS_CONFIG` in the host process). The OpenAI-compat
530    // path bypasses Ollama's per-model tool-call post-processors
531    // (qwen3coder.go, qwen35.go) which raise HTTP 500s on text-mode
532    // responses for the Qwen3.5 family. The default here stays on
533    // `/api/chat` so the harn-vm test stub keeps working with Ollama's
534    // native NDJSON wire format.
535    config.providers.insert(
536        "ollama".to_string(),
537        ProviderDef {
538            base_url: "http://localhost:11434".to_string(),
539            base_url_env: Some("OLLAMA_HOST".to_string()),
540            auth_style: "none".to_string(),
541            chat_endpoint: "/api/chat".to_string(),
542            completion_endpoint: Some("/api/generate".to_string()),
543            healthcheck: Some(HealthcheckDef {
544                method: "GET".to_string(),
545                path: Some("/api/tags".to_string()),
546                url: None,
547                body: None,
548            }),
549            ..Default::default()
550        },
551    );
552
553    // Together AI (OpenAI-compatible)
554    config.providers.insert(
555        "together".to_string(),
556        ProviderDef {
557            base_url: "https://api.together.xyz/v1".to_string(),
558            base_url_env: Some("TOGETHER_AI_BASE_URL".to_string()),
559            auth_style: "bearer".to_string(),
560            auth_env: AuthEnv::Single("TOGETHER_AI_API_KEY".to_string()),
561            chat_endpoint: "/chat/completions".to_string(),
562            completion_endpoint: Some("/completions".to_string()),
563            healthcheck: Some(HealthcheckDef {
564                method: "GET".to_string(),
565                path: Some("/models".to_string()),
566                url: None,
567                body: None,
568            }),
569            ..Default::default()
570        },
571    );
572
573    // Local OpenAI-compatible server
574    config.providers.insert(
575        "local".to_string(),
576        ProviderDef {
577            base_url: "http://localhost:8000".to_string(),
578            base_url_env: Some("LOCAL_LLM_BASE_URL".to_string()),
579            auth_style: "none".to_string(),
580            chat_endpoint: "/v1/chat/completions".to_string(),
581            completion_endpoint: Some("/v1/completions".to_string()),
582            healthcheck: Some(HealthcheckDef {
583                method: "GET".to_string(),
584                path: Some("/v1/models".to_string()),
585                url: None,
586                body: None,
587            }),
588            ..Default::default()
589        },
590    );
591
592    // Default inference rules
593    config.inference_rules = vec![
594        InferenceRule {
595            pattern: Some("claude-*".to_string()),
596            contains: None,
597            exact: None,
598            provider: "anthropic".to_string(),
599        },
600        InferenceRule {
601            pattern: Some("gpt-*".to_string()),
602            contains: None,
603            exact: None,
604            provider: "openai".to_string(),
605        },
606        InferenceRule {
607            pattern: Some("o1*".to_string()),
608            contains: None,
609            exact: None,
610            provider: "openai".to_string(),
611        },
612        InferenceRule {
613            pattern: Some("o3*".to_string()),
614            contains: None,
615            exact: None,
616            provider: "openai".to_string(),
617        },
618        InferenceRule {
619            pattern: Some("local:*".to_string()),
620            contains: None,
621            exact: None,
622            provider: "local".to_string(),
623        },
624        InferenceRule {
625            pattern: None,
626            contains: Some("/".to_string()),
627            exact: None,
628            provider: "openrouter".to_string(),
629        },
630        InferenceRule {
631            pattern: None,
632            contains: Some(":".to_string()),
633            exact: None,
634            provider: "ollama".to_string(),
635        },
636    ];
637
638    // Default tier rules
639    config.tier_rules = vec![
640        TierRule {
641            contains: Some("9b".to_string()),
642            pattern: None,
643            exact: None,
644            tier: "small".to_string(),
645        },
646        TierRule {
647            contains: Some("a3b".to_string()),
648            pattern: None,
649            exact: None,
650            tier: "small".to_string(),
651        },
652        TierRule {
653            contains: Some("gemma-4-e2b".to_string()),
654            pattern: None,
655            exact: None,
656            tier: "small".to_string(),
657        },
658        TierRule {
659            contains: Some("gemma-4-e4b".to_string()),
660            pattern: None,
661            exact: None,
662            tier: "small".to_string(),
663        },
664        TierRule {
665            contains: Some("gemma-4-26b".to_string()),
666            pattern: None,
667            exact: None,
668            tier: "mid".to_string(),
669        },
670        TierRule {
671            contains: Some("gemma-4-31b".to_string()),
672            pattern: None,
673            exact: None,
674            tier: "frontier".to_string(),
675        },
676        TierRule {
677            contains: Some("gemma4:26b".to_string()),
678            pattern: None,
679            exact: None,
680            tier: "mid".to_string(),
681        },
682        TierRule {
683            contains: Some("gemma4:31b".to_string()),
684            pattern: None,
685            exact: None,
686            tier: "frontier".to_string(),
687        },
688        TierRule {
689            pattern: Some("claude-*".to_string()),
690            contains: None,
691            exact: None,
692            tier: "frontier".to_string(),
693        },
694        TierRule {
695            exact: Some("gpt-4o".to_string()),
696            contains: None,
697            pattern: None,
698            tier: "frontier".to_string(),
699        },
700    ];
701
702    config.tier_defaults = TierDefaults {
703        default: "mid".to_string(),
704    };
705
706    config.aliases.insert(
707        "frontier".to_string(),
708        AliasDef {
709            id: "claude-sonnet-4-20250514".to_string(),
710            provider: "anthropic".to_string(),
711            tool_format: None,
712        },
713    );
714    config.aliases.insert(
715        "tier/frontier".to_string(),
716        AliasDef {
717            id: "claude-sonnet-4-20250514".to_string(),
718            provider: "anthropic".to_string(),
719            tool_format: None,
720        },
721    );
722    config.aliases.insert(
723        "mid".to_string(),
724        AliasDef {
725            id: "gpt-4o-mini".to_string(),
726            provider: "openai".to_string(),
727            tool_format: None,
728        },
729    );
730    config.aliases.insert(
731        "tier/mid".to_string(),
732        AliasDef {
733            id: "gpt-4o-mini".to_string(),
734            provider: "openai".to_string(),
735            tool_format: None,
736        },
737    );
738    config.aliases.insert(
739        "small".to_string(),
740        AliasDef {
741            id: "Qwen/Qwen3.5-9B".to_string(),
742            provider: "openrouter".to_string(),
743            tool_format: None,
744        },
745    );
746    config.aliases.insert(
747        "tier/small".to_string(),
748        AliasDef {
749            id: "Qwen/Qwen3.5-9B".to_string(),
750            provider: "openrouter".to_string(),
751            tool_format: None,
752        },
753    );
754    config.aliases.insert(
755        "local-gemma4".to_string(),
756        AliasDef {
757            id: "gemma-4-26b-a4b-it".to_string(),
758            provider: "local".to_string(),
759            tool_format: None,
760        },
761    );
762    config.aliases.insert(
763        "local-gemma4-26b".to_string(),
764        AliasDef {
765            id: "gemma-4-26b-a4b-it".to_string(),
766            provider: "local".to_string(),
767            tool_format: None,
768        },
769    );
770    config.aliases.insert(
771        "local-gemma4-31b".to_string(),
772        AliasDef {
773            id: "gemma-4-31b-it".to_string(),
774            provider: "local".to_string(),
775            tool_format: None,
776        },
777    );
778    config.aliases.insert(
779        "local-gemma4-e4b".to_string(),
780        AliasDef {
781            id: "gemma-4-e4b-it".to_string(),
782            provider: "local".to_string(),
783            tool_format: None,
784        },
785    );
786    config.aliases.insert(
787        "local-gemma4-e2b".to_string(),
788        AliasDef {
789            id: "gemma-4-e2b-it".to_string(),
790            provider: "local".to_string(),
791            tool_format: None,
792        },
793    );
794
795    config
796}
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801
802    #[test]
803    fn test_glob_match_prefix() {
804        assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
805        assert!(glob_match("gpt-*", "gpt-4o"));
806        assert!(!glob_match("claude-*", "gpt-4o"));
807    }
808
809    #[test]
810    fn test_glob_match_suffix() {
811        assert!(glob_match("*-latest", "llama3.2-latest"));
812        assert!(!glob_match("*-latest", "llama3.2"));
813    }
814
815    #[test]
816    fn test_glob_match_middle() {
817        assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
818        assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
819    }
820
821    #[test]
822    fn test_glob_match_exact() {
823        assert!(glob_match("gpt-4o", "gpt-4o"));
824        assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
825    }
826
827    #[test]
828    fn test_infer_provider_from_defaults() {
829        assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
830        assert_eq!(infer_provider("gpt-4o"), "openai");
831        assert_eq!(infer_provider("o1-preview"), "openai");
832        assert_eq!(infer_provider("o3-mini"), "openai");
833        assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
834        assert_eq!(infer_provider("llama3.2:latest"), "ollama");
835        assert_eq!(infer_provider("unknown-model"), "anthropic");
836    }
837
838    #[test]
839    fn test_infer_provider_local_prefix() {
840        // `local:` must route to the local OpenAI-compatible provider, not
841        // ollama (which would otherwise swallow everything containing `:`).
842        assert_eq!(infer_provider("local:gemma-4-e4b-it"), "local");
843        assert_eq!(infer_provider("local:qwen2.5"), "local");
844        // Even when the id also contains `/`, the `local:` prefix wins.
845        assert_eq!(infer_provider("local:owner/model"), "local");
846    }
847
848    #[test]
849    fn test_model_tier_from_defaults() {
850        assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
851        assert_eq!(model_tier("gpt-4o"), "frontier");
852        assert_eq!(model_tier("Qwen3.5-9B"), "small");
853        assert_eq!(model_tier("deepseek-v3"), "mid");
854    }
855
856    #[test]
857    fn test_resolve_model_unknown_alias() {
858        let (id, provider) = resolve_model("gpt-4o");
859        assert_eq!(id, "gpt-4o");
860        assert!(provider.is_none());
861    }
862
863    #[test]
864    fn test_provider_names() {
865        let names = provider_names();
866        assert!(names.len() >= 7);
867        assert!(names.contains(&"anthropic".to_string()));
868        assert!(names.contains(&"together".to_string()));
869        assert!(names.contains(&"local".to_string()));
870        assert!(names.contains(&"openai".to_string()));
871        assert!(names.contains(&"ollama".to_string()));
872    }
873
874    #[test]
875    fn test_resolve_tier_model_default_aliases() {
876        let (model, provider) = resolve_tier_model("frontier", None).unwrap();
877        assert_eq!(model, "claude-sonnet-4-20250514");
878        assert_eq!(provider, "anthropic");
879
880        let (model, provider) = resolve_tier_model("small", None).unwrap();
881        assert_eq!(model, "Qwen/Qwen3.5-9B");
882        assert_eq!(provider, "openrouter");
883    }
884
885    #[test]
886    fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
887        let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
888        assert_eq!(model, "gpt-4o-mini");
889        assert_eq!(provider, "openai");
890    }
891
892    #[test]
893    fn test_provider_config_anthropic() {
894        let pdef = provider_config("anthropic").unwrap();
895        assert_eq!(pdef.auth_style, "header");
896        assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
897    }
898
899    #[test]
900    fn test_resolve_base_url_no_env() {
901        let pdef = ProviderDef {
902            base_url: "https://example.com".to_string(),
903            ..Default::default()
904        };
905        assert_eq!(resolve_base_url(&pdef), "https://example.com");
906    }
907
908    #[test]
909    fn test_default_config_roundtrip() {
910        let config = default_config();
911        assert!(!config.providers.is_empty());
912        assert!(!config.inference_rules.is_empty());
913        assert!(!config.tier_rules.is_empty());
914        assert_eq!(config.tier_defaults.default, "mid");
915    }
916
917    #[test]
918    fn test_model_params_empty() {
919        let params = model_params("claude-sonnet-4-20250514");
920        assert!(params.is_empty());
921    }
922}