Skip to main content

harn_vm/
llm_config.rs

1use serde::Deserialize;
2use std::cell::RefCell;
3use std::collections::BTreeMap;
4use std::sync::OnceLock;
5
6static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
7static CONFIG_PATH: OnceLock<String> = OnceLock::new();
8
9thread_local! {
10    /// Thread-local provider config overlays installed by the CLI after it
11    /// reads the nearest `harn.toml` plus any installed package manifests.
12    /// Kept thread-local so tests and multi-VM hosts can scope extensions to
13    /// the current run without mutating the process-wide default config.
14    static USER_OVERRIDES: RefCell<Option<ProvidersConfig>> = const { RefCell::new(None) };
15}
16
17#[derive(Debug, Clone, Deserialize, Default)]
18pub struct ProvidersConfig {
19    #[serde(default)]
20    pub providers: BTreeMap<String, ProviderDef>,
21    #[serde(default)]
22    pub aliases: BTreeMap<String, AliasDef>,
23    #[serde(default)]
24    pub inference_rules: Vec<InferenceRule>,
25    #[serde(default)]
26    pub tier_rules: Vec<TierRule>,
27    #[serde(default)]
28    pub tier_defaults: TierDefaults,
29    #[serde(default)]
30    pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
31}
32
33impl ProvidersConfig {
34    pub fn is_empty(&self) -> bool {
35        self.providers.is_empty()
36            && self.aliases.is_empty()
37            && self.inference_rules.is_empty()
38            && self.tier_rules.is_empty()
39            && self.model_defaults.is_empty()
40            && self.tier_defaults.default == default_mid()
41    }
42
43    pub fn merge_from(&mut self, overlay: &ProvidersConfig) {
44        self.providers.extend(overlay.providers.clone());
45        self.aliases.extend(overlay.aliases.clone());
46
47        if !overlay.inference_rules.is_empty() {
48            let mut merged = overlay.inference_rules.clone();
49            merged.extend(self.inference_rules.clone());
50            self.inference_rules = merged;
51        }
52
53        if !overlay.tier_rules.is_empty() {
54            let mut merged = overlay.tier_rules.clone();
55            merged.extend(self.tier_rules.clone());
56            self.tier_rules = merged;
57        }
58
59        if overlay.tier_defaults.default != default_mid() {
60            self.tier_defaults = overlay.tier_defaults.clone();
61        }
62
63        for (pattern, defaults) in &overlay.model_defaults {
64            self.model_defaults
65                .entry(pattern.clone())
66                .or_default()
67                .extend(defaults.clone());
68        }
69    }
70}
71
72#[derive(Debug, Clone, Deserialize)]
73pub struct ProviderDef {
74    pub base_url: String,
75    #[serde(default)]
76    pub base_url_env: Option<String>,
77    #[serde(default = "default_bearer")]
78    pub auth_style: String,
79    #[serde(default)]
80    pub auth_header: Option<String>,
81    #[serde(default)]
82    pub auth_env: AuthEnv,
83    #[serde(default)]
84    pub extra_headers: BTreeMap<String, String>,
85    #[serde(default)]
86    pub chat_endpoint: String,
87    #[serde(default)]
88    pub completion_endpoint: Option<String>,
89    #[serde(default)]
90    pub healthcheck: Option<HealthcheckDef>,
91    #[serde(default)]
92    pub features: Vec<String>,
93    /// Fallback provider name to try if this provider fails.
94    #[serde(default)]
95    pub fallback: Option<String>,
96    /// Number of retries before falling back (default 0).
97    #[serde(default)]
98    pub retry_count: Option<u32>,
99    /// Delay between retries in milliseconds (default 1000).
100    #[serde(default)]
101    pub retry_delay_ms: Option<u64>,
102    /// Maximum requests per minute. None = unlimited.
103    #[serde(default)]
104    pub rpm: Option<u32>,
105    /// Provider/catalog pricing in USD per 1k input tokens.
106    #[serde(default)]
107    pub cost_per_1k_in: Option<f64>,
108    /// Provider/catalog pricing in USD per 1k output tokens.
109    #[serde(default)]
110    pub cost_per_1k_out: Option<f64>,
111    /// Observed or configured p50 latency in milliseconds.
112    #[serde(default)]
113    pub latency_p50_ms: Option<u64>,
114}
115
116impl Default for ProviderDef {
117    fn default() -> Self {
118        Self {
119            base_url: String::new(),
120            base_url_env: None,
121            auth_style: default_bearer(),
122            auth_header: None,
123            auth_env: AuthEnv::None,
124            extra_headers: BTreeMap::new(),
125            chat_endpoint: String::new(),
126            completion_endpoint: None,
127            healthcheck: None,
128            features: Vec::new(),
129            fallback: None,
130            retry_count: None,
131            retry_delay_ms: None,
132            rpm: None,
133            cost_per_1k_in: None,
134            cost_per_1k_out: None,
135            latency_p50_ms: None,
136        }
137    }
138}
139
140fn default_bearer() -> String {
141    "bearer".to_string()
142}
143
144/// Auth env var name(s) for the provider. Can be a single string or an array
145/// (tried in order until one is set).
146#[derive(Debug, Clone, Deserialize, Default)]
147#[serde(untagged)]
148pub enum AuthEnv {
149    #[default]
150    None,
151    Single(String),
152    Multiple(Vec<String>),
153}
154
155#[derive(Debug, Clone, Deserialize)]
156pub struct HealthcheckDef {
157    pub method: String,
158    #[serde(default)]
159    pub path: Option<String>,
160    #[serde(default)]
161    pub url: Option<String>,
162    #[serde(default)]
163    pub body: Option<String>,
164}
165
166#[derive(Debug, Clone, Deserialize)]
167pub struct AliasDef {
168    pub id: String,
169    pub provider: String,
170    /// Per-model tool format override: "native" or "text". When set, this
171    /// takes precedence over the provider-level default. Models with strong
172    /// tool-calling fine-tuning (Kimi-K2.5, GPT-4o) should use "native";
173    /// models better served by text-based tool calling use "text".
174    #[serde(default)]
175    pub tool_format: Option<String>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179pub struct InferenceRule {
180    #[serde(default)]
181    pub pattern: Option<String>,
182    #[serde(default)]
183    pub contains: Option<String>,
184    #[serde(default)]
185    pub exact: Option<String>,
186    pub provider: String,
187}
188
189#[derive(Debug, Clone, Deserialize)]
190pub struct TierRule {
191    #[serde(default)]
192    pub pattern: Option<String>,
193    #[serde(default)]
194    pub contains: Option<String>,
195    #[serde(default)]
196    pub exact: Option<String>,
197    pub tier: String,
198}
199
200#[derive(Debug, Clone, Deserialize)]
201pub struct TierDefaults {
202    #[serde(default = "default_mid")]
203    pub default: String,
204}
205
206impl Default for TierDefaults {
207    fn default() -> Self {
208        Self {
209            default: default_mid(),
210        }
211    }
212}
213
214fn default_mid() -> String {
215    "mid".to_string()
216}
217
218/// Load and cache the providers config. Called once at VM startup.
219pub fn load_config() -> &'static ProvidersConfig {
220    CONFIG.get_or_init(|| {
221        let verbose_config_logging = matches!(
222            std::env::var("HARN_VERBOSE_CONFIG").ok().as_deref(),
223            Some("1" | "true" | "TRUE" | "yes" | "YES")
224        ) || matches!(
225            std::env::var("HARN_ACP_VERBOSE").ok().as_deref(),
226            Some("1" | "true" | "TRUE" | "yes" | "YES")
227        );
228        if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
229            match std::fs::read_to_string(&path) {
230                Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
231                    Ok(config) => {
232                        if verbose_config_logging {
233                            eprintln!(
234                                "[llm_config] Loaded {} providers, {} aliases from {}",
235                                config.providers.len(),
236                                config.aliases.len(),
237                                path
238                            );
239                        }
240                        let _ = CONFIG_PATH.set(path);
241                        return config;
242                    }
243                    Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
244                },
245                Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
246            }
247        }
248        if let Some(home) = dirs_or_home() {
249            let path = format!("{home}/.config/harn/providers.toml");
250            if let Ok(content) = std::fs::read_to_string(&path) {
251                if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
252                    let _ = CONFIG_PATH.set(path);
253                    return config;
254                }
255            }
256        }
257        default_config()
258    })
259}
260
261/// Returns the filesystem path of the currently-loaded providers config, if
262/// any. Returns `None` when built-in defaults are active.
263pub fn loaded_config_path() -> Option<std::path::PathBuf> {
264    // Force lazy init so CONFIG_PATH is populated if a file was loaded.
265    let _ = load_config();
266    CONFIG_PATH.get().map(std::path::PathBuf::from)
267}
268
269/// Install per-run provider config overlays. The overlay uses the same shape as
270/// `providers.toml`, but lives under `[llm]` in `harn.toml` and package
271/// manifests. Passing `None` clears the overlay.
272pub fn set_user_overrides(config: Option<ProvidersConfig>) {
273    USER_OVERRIDES.with(|cell| *cell.borrow_mut() = config);
274}
275
276/// Clear per-run provider config overlays.
277pub fn clear_user_overrides() {
278    set_user_overrides(None);
279}
280
281fn effective_config() -> ProvidersConfig {
282    let mut merged = load_config().clone();
283    USER_OVERRIDES.with(|cell| {
284        if let Some(overlay) = cell.borrow().as_ref() {
285            merged.merge_from(overlay);
286        }
287    });
288    merged
289}
290
291/// Resolve a model alias to (model_id, provider_name).
292pub fn resolve_model(alias: &str) -> (String, Option<String>) {
293    let config = effective_config();
294    if let Some(a) = config.aliases.get(alias) {
295        return (a.id.clone(), Some(a.provider.clone()));
296    }
297    (alias.to_string(), None)
298}
299
300/// Infer provider from a model ID using inference rules.
301pub fn infer_provider(model_id: &str) -> String {
302    let config = effective_config();
303    for rule in &config.inference_rules {
304        if let Some(exact) = &rule.exact {
305            if model_id == exact {
306                return rule.provider.clone();
307            }
308        }
309        if let Some(pattern) = &rule.pattern {
310            if glob_match(pattern, model_id) {
311                return rule.provider.clone();
312            }
313        }
314        if let Some(substr) = &rule.contains {
315            if model_id.contains(substr.as_str()) {
316                return rule.provider.clone();
317            }
318        }
319    }
320    // Fallback to hardcoded inference.
321    // Order matters: `local:` must beat the generic `:` → ollama rule, and
322    // any prefix-based rule must beat the generic `/` → openrouter rule for
323    // ids like `local:owner/model`.
324    if model_id.starts_with("local:") {
325        return "local".to_string();
326    }
327    if model_id.starts_with("claude-") {
328        return "anthropic".to_string();
329    }
330    if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
331        return "openai".to_string();
332    }
333    if model_id.contains('/') {
334        return "openrouter".to_string();
335    }
336    if model_id.contains(':') {
337        return "ollama".to_string();
338    }
339    "anthropic".to_string()
340}
341
342/// Get model tier ("small", "mid", "frontier").
343pub fn model_tier(model_id: &str) -> String {
344    let config = effective_config();
345    for rule in &config.tier_rules {
346        if let Some(exact) = &rule.exact {
347            if model_id == exact {
348                return rule.tier.clone();
349            }
350        }
351        if let Some(pattern) = &rule.pattern {
352            if glob_match(pattern, model_id) {
353                return rule.tier.clone();
354            }
355        }
356        if let Some(substr) = &rule.contains {
357            if model_id.contains(substr.as_str()) {
358                return rule.tier.clone();
359            }
360        }
361    }
362    let lower = model_id.to_lowercase();
363    if lower.contains("9b") || lower.contains("a3b") {
364        return "small".to_string();
365    }
366    if lower.starts_with("claude-") || lower == "gpt-4o" {
367        return "frontier".to_string();
368    }
369    config.tier_defaults.default.clone()
370}
371
372/// Get provider config for resolving base_url, auth, etc.
373pub fn provider_config(name: &str) -> Option<ProviderDef> {
374    effective_config().providers.get(name).cloned()
375}
376
377/// Get model-specific default parameters (temperature, etc.).
378/// Matches glob patterns in model_defaults keys.
379pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
380    let config = effective_config();
381    let mut params = BTreeMap::new();
382    for (pattern, defaults) in &config.model_defaults {
383        if glob_match(pattern, model_id) {
384            for (k, v) in defaults {
385                params.insert(k.clone(), v.clone());
386            }
387        }
388    }
389    params
390}
391
392/// Get list of configured provider names.
393pub fn provider_names() -> Vec<String> {
394    effective_config().providers.keys().cloned().collect()
395}
396
397/// Check if a provider advertises a feature (e.g., "native_tools").
398pub fn provider_has_feature(provider: &str, feature: &str) -> bool {
399    provider_config(provider)
400        .map(|p| p.features.iter().any(|f| f == feature))
401        .unwrap_or(false)
402}
403
404/// Provider-level catalog pricing/latency. Model-specific static pricing in
405/// `llm::cost` still wins when available; this is the adapter-level fallback
406/// used by routing and portal summaries.
407pub fn provider_economics(provider: &str) -> (Option<f64>, Option<f64>, Option<u64>) {
408    provider_config(provider)
409        .map(|p| (p.cost_per_1k_in, p.cost_per_1k_out, p.latency_p50_ms))
410        .unwrap_or((None, None, None))
411}
412
413/// Resolve the default tool format for a model+provider combination.
414/// Priority: alias `tool_format` (matched by model ID) > provider feature > "text".
415pub fn default_tool_format(model: &str, provider: &str) -> String {
416    let config = effective_config();
417    // Aliases match by model ID + provider, or by alias name.
418    for (name, alias) in &config.aliases {
419        let matches = (alias.id == model && alias.provider == provider) || name == model;
420        if matches {
421            if let Some(ref fmt) = alias.tool_format {
422                return fmt.clone();
423            }
424        }
425    }
426    if provider_has_feature(provider, "native_tools") {
427        "native".to_string()
428    } else {
429        "text".to_string()
430    }
431}
432
433/// Resolve a tier or alias into a concrete model/provider pair.
434pub fn resolve_tier_model(
435    target: &str,
436    preferred_provider: Option<&str>,
437) -> Option<(String, String)> {
438    let config = effective_config();
439
440    if let Some(alias) = config.aliases.get(target) {
441        return Some((alias.id.clone(), alias.provider.clone()));
442    }
443
444    let candidate_aliases = if let Some(provider) = preferred_provider {
445        vec![
446            format!("{provider}/{target}"),
447            format!("{provider}:{target}"),
448            format!("tier/{target}"),
449            target.to_string(),
450        ]
451    } else {
452        vec![format!("tier/{target}"), target.to_string()]
453    };
454
455    for alias_name in candidate_aliases {
456        if let Some(alias) = config.aliases.get(&alias_name) {
457            return Some((alias.id.clone(), alias.provider.clone()));
458        }
459    }
460
461    None
462}
463
464/// Return all configured alias-backed model/provider pairs whose resolved
465/// model falls into the requested capability tier. The result is de-duplicated
466/// and sorted deterministically by provider then model id.
467pub fn tier_candidates(target: &str) -> Vec<(String, String)> {
468    let config = effective_config();
469    let mut seen = std::collections::BTreeSet::new();
470    let mut candidates = Vec::new();
471
472    for alias in config.aliases.values() {
473        let pair = (alias.id.clone(), alias.provider.clone());
474        if seen.contains(&pair) {
475            continue;
476        }
477        if model_tier(&alias.id) == target {
478            seen.insert(pair.clone());
479            candidates.push(pair);
480        }
481    }
482
483    candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
484        provider_a
485            .cmp(provider_b)
486            .then_with(|| model_a.cmp(model_b))
487    });
488    candidates
489}
490
491/// Return all configured alias-backed model/provider pairs. Used by routing
492/// policies that need to compare alternatives across tiers.
493pub fn all_model_candidates() -> Vec<(String, String)> {
494    let config = effective_config();
495    let mut seen = std::collections::BTreeSet::new();
496    let mut candidates = Vec::new();
497
498    for alias in config.aliases.values() {
499        let pair = (alias.id.clone(), alias.provider.clone());
500        if seen.insert(pair.clone()) {
501            candidates.push(pair);
502        }
503    }
504
505    candidates.sort_by(|(model_a, provider_a), (model_b, provider_b)| {
506        provider_a
507            .cmp(provider_b)
508            .then_with(|| model_a.cmp(model_b))
509    });
510    candidates
511}
512
513/// Simple glob matching for patterns like "claude-*", "qwen/*", "ollama:*".
514fn glob_match(pattern: &str, input: &str) -> bool {
515    if let Some(prefix) = pattern.strip_suffix('*') {
516        input.starts_with(prefix)
517    } else if let Some(suffix) = pattern.strip_prefix('*') {
518        input.ends_with(suffix)
519    } else if pattern.contains('*') {
520        let parts: Vec<&str> = pattern.split('*').collect();
521        if parts.len() == 2 {
522            input.starts_with(parts[0]) && input.ends_with(parts[1])
523        } else {
524            input == pattern
525        }
526    } else {
527        input == pattern
528    }
529}
530
531fn dirs_or_home() -> Option<String> {
532    std::env::var("HOME").ok()
533}
534
535/// Resolve the effective base URL for a provider, checking the `base_url_env`
536/// override first, then falling back to the configured `base_url`.
537pub fn resolve_base_url(pdef: &ProviderDef) -> String {
538    if let Some(env_name) = &pdef.base_url_env {
539        if let Ok(val) = std::env::var(env_name) {
540            // Strip surrounding quotes that some .env parsers leave intact.
541            let trimmed = val.trim().trim_matches('"').trim_matches('\'');
542            if !trimmed.is_empty() {
543                return trimmed.to_string();
544            }
545        }
546    }
547    pdef.base_url.clone()
548}
549
550fn default_config() -> ProvidersConfig {
551    let mut config = ProvidersConfig::default();
552
553    config.providers.insert(
554        "anthropic".to_string(),
555        ProviderDef {
556            base_url: "https://api.anthropic.com/v1".to_string(),
557            auth_style: "header".to_string(),
558            auth_header: Some("x-api-key".to_string()),
559            auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
560            extra_headers: BTreeMap::from([(
561                "anthropic-version".to_string(),
562                "2023-06-01".to_string(),
563            )]),
564            chat_endpoint: "/messages".to_string(),
565            completion_endpoint: None,
566            healthcheck: Some(HealthcheckDef {
567                method: "POST".to_string(),
568                path: Some("/messages/count_tokens".to_string()),
569                url: None,
570                body: Some(
571                    r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
572                        .to_string(),
573                ),
574            }),
575            features: vec!["prompt_caching".to_string(), "thinking".to_string()],
576            cost_per_1k_in: Some(0.003),
577            cost_per_1k_out: Some(0.015),
578            latency_p50_ms: Some(2500),
579            ..Default::default()
580        },
581    );
582
583    // OpenAI
584    config.providers.insert(
585        "openai".to_string(),
586        ProviderDef {
587            base_url: "https://api.openai.com/v1".to_string(),
588            auth_style: "bearer".to_string(),
589            auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
590            chat_endpoint: "/chat/completions".to_string(),
591            completion_endpoint: Some("/completions".to_string()),
592            healthcheck: Some(HealthcheckDef {
593                method: "GET".to_string(),
594                path: Some("/models".to_string()),
595                url: None,
596                body: None,
597            }),
598            cost_per_1k_in: Some(0.0025),
599            cost_per_1k_out: Some(0.010),
600            latency_p50_ms: Some(1800),
601            ..Default::default()
602        },
603    );
604
605    // OpenRouter
606    config.providers.insert(
607        "openrouter".to_string(),
608        ProviderDef {
609            base_url: "https://openrouter.ai/api/v1".to_string(),
610            auth_style: "bearer".to_string(),
611            auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
612            chat_endpoint: "/chat/completions".to_string(),
613            completion_endpoint: Some("/completions".to_string()),
614            healthcheck: Some(HealthcheckDef {
615                method: "GET".to_string(),
616                path: Some("/auth/key".to_string()),
617                url: None,
618                body: None,
619            }),
620            cost_per_1k_in: Some(0.003),
621            cost_per_1k_out: Some(0.015),
622            latency_p50_ms: Some(2200),
623            ..Default::default()
624        },
625    );
626
627    // HuggingFace
628    config.providers.insert(
629        "huggingface".to_string(),
630        ProviderDef {
631            base_url: "https://router.huggingface.co/v1".to_string(),
632            auth_style: "bearer".to_string(),
633            auth_env: AuthEnv::Multiple(vec![
634                "HF_TOKEN".to_string(),
635                "HUGGINGFACE_API_KEY".to_string(),
636            ]),
637            chat_endpoint: "/chat/completions".to_string(),
638            completion_endpoint: Some("/completions".to_string()),
639            healthcheck: Some(HealthcheckDef {
640                method: "GET".to_string(),
641                url: Some("https://huggingface.co/api/whoami-v2".to_string()),
642                path: None,
643                body: None,
644            }),
645            cost_per_1k_in: Some(0.0002),
646            cost_per_1k_out: Some(0.0006),
647            latency_p50_ms: Some(2400),
648            ..Default::default()
649        },
650    );
651
652    // Ollama default. Hosts can override this to `/v1/chat/completions`
653    // via a bundled `providers.toml` (loaded by setting
654    // `HARN_PROVIDERS_CONFIG` in the host process). The OpenAI-compat
655    // path bypasses Ollama's per-model tool-call post-processors
656    // (qwen3coder.go, qwen35.go) which raise HTTP 500s on text-mode
657    // responses for the Qwen3.5 family. The default here stays on
658    // `/api/chat` so the harn-vm test stub keeps working with Ollama's
659    // native NDJSON wire format.
660    config.providers.insert(
661        "ollama".to_string(),
662        ProviderDef {
663            base_url: "http://localhost:11434".to_string(),
664            base_url_env: Some("OLLAMA_HOST".to_string()),
665            auth_style: "none".to_string(),
666            chat_endpoint: "/api/chat".to_string(),
667            completion_endpoint: Some("/api/generate".to_string()),
668            healthcheck: Some(HealthcheckDef {
669                method: "GET".to_string(),
670                path: Some("/api/tags".to_string()),
671                url: None,
672                body: None,
673            }),
674            cost_per_1k_in: Some(0.0),
675            cost_per_1k_out: Some(0.0),
676            latency_p50_ms: Some(1200),
677            ..Default::default()
678        },
679    );
680
681    // Together AI (OpenAI-compatible)
682    config.providers.insert(
683        "together".to_string(),
684        ProviderDef {
685            base_url: "https://api.together.xyz/v1".to_string(),
686            base_url_env: Some("TOGETHER_AI_BASE_URL".to_string()),
687            auth_style: "bearer".to_string(),
688            auth_env: AuthEnv::Single("TOGETHER_AI_API_KEY".to_string()),
689            chat_endpoint: "/chat/completions".to_string(),
690            completion_endpoint: Some("/completions".to_string()),
691            healthcheck: Some(HealthcheckDef {
692                method: "GET".to_string(),
693                path: Some("/models".to_string()),
694                url: None,
695                body: None,
696            }),
697            cost_per_1k_in: Some(0.0002),
698            cost_per_1k_out: Some(0.0006),
699            latency_p50_ms: Some(1600),
700            ..Default::default()
701        },
702    );
703
704    // Groq (OpenAI-compatible)
705    config.providers.insert(
706        "groq".to_string(),
707        ProviderDef {
708            base_url: "https://api.groq.com/openai/v1".to_string(),
709            base_url_env: Some("GROQ_BASE_URL".to_string()),
710            auth_style: "bearer".to_string(),
711            auth_env: AuthEnv::Single("GROQ_API_KEY".to_string()),
712            chat_endpoint: "/chat/completions".to_string(),
713            completion_endpoint: Some("/completions".to_string()),
714            healthcheck: Some(HealthcheckDef {
715                method: "GET".to_string(),
716                path: Some("/models".to_string()),
717                url: None,
718                body: None,
719            }),
720            cost_per_1k_in: Some(0.0001),
721            cost_per_1k_out: Some(0.0003),
722            latency_p50_ms: Some(450),
723            ..Default::default()
724        },
725    );
726
727    // DeepSeek (OpenAI-compatible)
728    config.providers.insert(
729        "deepseek".to_string(),
730        ProviderDef {
731            base_url: "https://api.deepseek.com/v1".to_string(),
732            base_url_env: Some("DEEPSEEK_BASE_URL".to_string()),
733            auth_style: "bearer".to_string(),
734            auth_env: AuthEnv::Single("DEEPSEEK_API_KEY".to_string()),
735            chat_endpoint: "/chat/completions".to_string(),
736            completion_endpoint: Some("/completions".to_string()),
737            healthcheck: Some(HealthcheckDef {
738                method: "GET".to_string(),
739                path: Some("/models".to_string()),
740                url: None,
741                body: None,
742            }),
743            cost_per_1k_in: Some(0.00014),
744            cost_per_1k_out: Some(0.00028),
745            latency_p50_ms: Some(1800),
746            ..Default::default()
747        },
748    );
749
750    // Fireworks (OpenAI-compatible open-weight hosting)
751    config.providers.insert(
752        "fireworks".to_string(),
753        ProviderDef {
754            base_url: "https://api.fireworks.ai/inference/v1".to_string(),
755            base_url_env: Some("FIREWORKS_BASE_URL".to_string()),
756            auth_style: "bearer".to_string(),
757            auth_env: AuthEnv::Single("FIREWORKS_API_KEY".to_string()),
758            chat_endpoint: "/chat/completions".to_string(),
759            completion_endpoint: Some("/completions".to_string()),
760            healthcheck: Some(HealthcheckDef {
761                method: "GET".to_string(),
762                path: Some("/models".to_string()),
763                url: None,
764                body: None,
765            }),
766            cost_per_1k_in: Some(0.0002),
767            cost_per_1k_out: Some(0.0006),
768            latency_p50_ms: Some(1400),
769            ..Default::default()
770        },
771    );
772
773    // Alibaba DashScope (OpenAI-compatible Qwen host)
774    config.providers.insert(
775        "dashscope".to_string(),
776        ProviderDef {
777            base_url: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1".to_string(),
778            base_url_env: Some("DASHSCOPE_BASE_URL".to_string()),
779            auth_style: "bearer".to_string(),
780            auth_env: AuthEnv::Single("DASHSCOPE_API_KEY".to_string()),
781            chat_endpoint: "/chat/completions".to_string(),
782            completion_endpoint: Some("/completions".to_string()),
783            healthcheck: Some(HealthcheckDef {
784                method: "GET".to_string(),
785                path: Some("/models".to_string()),
786                url: None,
787                body: None,
788            }),
789            cost_per_1k_in: Some(0.0003),
790            cost_per_1k_out: Some(0.0012),
791            latency_p50_ms: Some(1600),
792            ..Default::default()
793        },
794    );
795
796    // Local OpenAI-compatible server
797    config.providers.insert(
798        "local".to_string(),
799        ProviderDef {
800            base_url: "http://localhost:8000".to_string(),
801            base_url_env: Some("LOCAL_LLM_BASE_URL".to_string()),
802            auth_style: "none".to_string(),
803            chat_endpoint: "/v1/chat/completions".to_string(),
804            completion_endpoint: Some("/v1/completions".to_string()),
805            healthcheck: Some(HealthcheckDef {
806                method: "GET".to_string(),
807                path: Some("/v1/models".to_string()),
808                url: None,
809                body: None,
810            }),
811            cost_per_1k_in: Some(0.0),
812            cost_per_1k_out: Some(0.0),
813            latency_p50_ms: Some(900),
814            ..Default::default()
815        },
816    );
817
818    // vLLM OpenAI-compatible server.
819    config.providers.insert(
820        "vllm".to_string(),
821        ProviderDef {
822            base_url: "http://localhost:8000".to_string(),
823            base_url_env: Some("VLLM_BASE_URL".to_string()),
824            auth_style: "none".to_string(),
825            chat_endpoint: "/v1/chat/completions".to_string(),
826            completion_endpoint: Some("/v1/completions".to_string()),
827            healthcheck: Some(HealthcheckDef {
828                method: "GET".to_string(),
829                path: Some("/v1/models".to_string()),
830                url: None,
831                body: None,
832            }),
833            cost_per_1k_in: Some(0.0),
834            cost_per_1k_out: Some(0.0),
835            latency_p50_ms: Some(800),
836            ..Default::default()
837        },
838    );
839
840    // HuggingFace Text Generation Inference OpenAI-compatible endpoint.
841    config.providers.insert(
842        "tgi".to_string(),
843        ProviderDef {
844            base_url: "http://localhost:8080".to_string(),
845            base_url_env: Some("TGI_BASE_URL".to_string()),
846            auth_style: "none".to_string(),
847            chat_endpoint: "/v1/chat/completions".to_string(),
848            completion_endpoint: Some("/v1/completions".to_string()),
849            healthcheck: Some(HealthcheckDef {
850                method: "GET".to_string(),
851                path: Some("/health".to_string()),
852                url: None,
853                body: None,
854            }),
855            cost_per_1k_in: Some(0.0),
856            cost_per_1k_out: Some(0.0),
857            latency_p50_ms: Some(950),
858            ..Default::default()
859        },
860    );
861
862    // Default inference rules
863    config.inference_rules = vec![
864        InferenceRule {
865            pattern: Some("claude-*".to_string()),
866            contains: None,
867            exact: None,
868            provider: "anthropic".to_string(),
869        },
870        InferenceRule {
871            pattern: Some("gpt-*".to_string()),
872            contains: None,
873            exact: None,
874            provider: "openai".to_string(),
875        },
876        InferenceRule {
877            pattern: Some("o1*".to_string()),
878            contains: None,
879            exact: None,
880            provider: "openai".to_string(),
881        },
882        InferenceRule {
883            pattern: Some("o3*".to_string()),
884            contains: None,
885            exact: None,
886            provider: "openai".to_string(),
887        },
888        InferenceRule {
889            pattern: Some("local:*".to_string()),
890            contains: None,
891            exact: None,
892            provider: "local".to_string(),
893        },
894        InferenceRule {
895            pattern: None,
896            contains: Some("/".to_string()),
897            exact: None,
898            provider: "openrouter".to_string(),
899        },
900        InferenceRule {
901            pattern: None,
902            contains: Some(":".to_string()),
903            exact: None,
904            provider: "ollama".to_string(),
905        },
906    ];
907
908    // Default tier rules
909    config.tier_rules = vec![
910        TierRule {
911            contains: Some("9b".to_string()),
912            pattern: None,
913            exact: None,
914            tier: "small".to_string(),
915        },
916        TierRule {
917            contains: Some("a3b".to_string()),
918            pattern: None,
919            exact: None,
920            tier: "small".to_string(),
921        },
922        TierRule {
923            contains: Some("gemma-4-e2b".to_string()),
924            pattern: None,
925            exact: None,
926            tier: "small".to_string(),
927        },
928        TierRule {
929            contains: Some("gemma-4-e4b".to_string()),
930            pattern: None,
931            exact: None,
932            tier: "small".to_string(),
933        },
934        TierRule {
935            contains: Some("gemma-4-26b".to_string()),
936            pattern: None,
937            exact: None,
938            tier: "mid".to_string(),
939        },
940        TierRule {
941            contains: Some("gemma-4-31b".to_string()),
942            pattern: None,
943            exact: None,
944            tier: "frontier".to_string(),
945        },
946        TierRule {
947            contains: Some("gemma4:26b".to_string()),
948            pattern: None,
949            exact: None,
950            tier: "mid".to_string(),
951        },
952        TierRule {
953            contains: Some("gemma4:31b".to_string()),
954            pattern: None,
955            exact: None,
956            tier: "frontier".to_string(),
957        },
958        TierRule {
959            pattern: Some("claude-*".to_string()),
960            contains: None,
961            exact: None,
962            tier: "frontier".to_string(),
963        },
964        TierRule {
965            exact: Some("gpt-4o".to_string()),
966            contains: None,
967            pattern: None,
968            tier: "frontier".to_string(),
969        },
970    ];
971
972    config.tier_defaults = TierDefaults {
973        default: "mid".to_string(),
974    };
975
976    config.aliases.insert(
977        "frontier".to_string(),
978        AliasDef {
979            id: "claude-sonnet-4-20250514".to_string(),
980            provider: "anthropic".to_string(),
981            tool_format: None,
982        },
983    );
984    config.aliases.insert(
985        "tier/frontier".to_string(),
986        AliasDef {
987            id: "claude-sonnet-4-20250514".to_string(),
988            provider: "anthropic".to_string(),
989            tool_format: None,
990        },
991    );
992    config.aliases.insert(
993        "mid".to_string(),
994        AliasDef {
995            id: "gpt-4o-mini".to_string(),
996            provider: "openai".to_string(),
997            tool_format: None,
998        },
999    );
1000    config.aliases.insert(
1001        "tier/mid".to_string(),
1002        AliasDef {
1003            id: "gpt-4o-mini".to_string(),
1004            provider: "openai".to_string(),
1005            tool_format: None,
1006        },
1007    );
1008    config.aliases.insert(
1009        "small".to_string(),
1010        AliasDef {
1011            id: "Qwen/Qwen3.5-9B".to_string(),
1012            provider: "openrouter".to_string(),
1013            tool_format: None,
1014        },
1015    );
1016    config.aliases.insert(
1017        "tier/small".to_string(),
1018        AliasDef {
1019            id: "Qwen/Qwen3.5-9B".to_string(),
1020            provider: "openrouter".to_string(),
1021            tool_format: None,
1022        },
1023    );
1024    config.aliases.insert(
1025        "local-gemma4".to_string(),
1026        AliasDef {
1027            id: "gemma-4-26b-a4b-it".to_string(),
1028            provider: "local".to_string(),
1029            tool_format: None,
1030        },
1031    );
1032    config.aliases.insert(
1033        "local-gemma4-26b".to_string(),
1034        AliasDef {
1035            id: "gemma-4-26b-a4b-it".to_string(),
1036            provider: "local".to_string(),
1037            tool_format: None,
1038        },
1039    );
1040    config.aliases.insert(
1041        "local-gemma4-31b".to_string(),
1042        AliasDef {
1043            id: "gemma-4-31b-it".to_string(),
1044            provider: "local".to_string(),
1045            tool_format: None,
1046        },
1047    );
1048    config.aliases.insert(
1049        "local-gemma4-e4b".to_string(),
1050        AliasDef {
1051            id: "gemma-4-e4b-it".to_string(),
1052            provider: "local".to_string(),
1053            tool_format: None,
1054        },
1055    );
1056    config.aliases.insert(
1057        "local-gemma4-e2b".to_string(),
1058        AliasDef {
1059            id: "gemma-4-e2b-it".to_string(),
1060            provider: "local".to_string(),
1061            tool_format: None,
1062        },
1063    );
1064
1065    config
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070    use super::*;
1071
1072    fn reset_overrides() {
1073        clear_user_overrides();
1074    }
1075
1076    #[test]
1077    fn test_glob_match_prefix() {
1078        assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
1079        assert!(glob_match("gpt-*", "gpt-4o"));
1080        assert!(!glob_match("claude-*", "gpt-4o"));
1081    }
1082
1083    #[test]
1084    fn test_glob_match_suffix() {
1085        assert!(glob_match("*-latest", "llama3.2-latest"));
1086        assert!(!glob_match("*-latest", "llama3.2"));
1087    }
1088
1089    #[test]
1090    fn test_glob_match_middle() {
1091        assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
1092        assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
1093    }
1094
1095    #[test]
1096    fn test_glob_match_exact() {
1097        assert!(glob_match("gpt-4o", "gpt-4o"));
1098        assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
1099    }
1100
1101    #[test]
1102    fn test_infer_provider_from_defaults() {
1103        assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
1104        assert_eq!(infer_provider("gpt-4o"), "openai");
1105        assert_eq!(infer_provider("o1-preview"), "openai");
1106        assert_eq!(infer_provider("o3-mini"), "openai");
1107        assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
1108        assert_eq!(infer_provider("llama3.2:latest"), "ollama");
1109        assert_eq!(infer_provider("unknown-model"), "anthropic");
1110    }
1111
1112    #[test]
1113    fn test_infer_provider_local_prefix() {
1114        // `local:` must route to the local OpenAI-compatible provider, not
1115        // ollama (which would otherwise swallow everything containing `:`).
1116        assert_eq!(infer_provider("local:gemma-4-e4b-it"), "local");
1117        assert_eq!(infer_provider("local:qwen2.5"), "local");
1118        // Even when the id also contains `/`, the `local:` prefix wins.
1119        assert_eq!(infer_provider("local:owner/model"), "local");
1120    }
1121
1122    #[test]
1123    fn test_model_tier_from_defaults() {
1124        assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
1125        assert_eq!(model_tier("gpt-4o"), "frontier");
1126        assert_eq!(model_tier("Qwen3.5-9B"), "small");
1127        assert_eq!(model_tier("deepseek-v3"), "mid");
1128    }
1129
1130    #[test]
1131    fn test_resolve_model_unknown_alias() {
1132        let (id, provider) = resolve_model("gpt-4o");
1133        assert_eq!(id, "gpt-4o");
1134        assert!(provider.is_none());
1135    }
1136
1137    #[test]
1138    fn test_provider_names() {
1139        let names = provider_names();
1140        assert!(names.len() >= 7);
1141        assert!(names.contains(&"anthropic".to_string()));
1142        assert!(names.contains(&"together".to_string()));
1143        assert!(names.contains(&"local".to_string()));
1144        assert!(names.contains(&"openai".to_string()));
1145        assert!(names.contains(&"ollama".to_string()));
1146    }
1147
1148    #[test]
1149    fn test_resolve_tier_model_default_aliases() {
1150        let (model, provider) = resolve_tier_model("frontier", None).unwrap();
1151        assert_eq!(model, "claude-sonnet-4-20250514");
1152        assert_eq!(provider, "anthropic");
1153
1154        let (model, provider) = resolve_tier_model("small", None).unwrap();
1155        assert_eq!(model, "Qwen/Qwen3.5-9B");
1156        assert_eq!(provider, "openrouter");
1157    }
1158
1159    #[test]
1160    fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
1161        let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
1162        assert_eq!(model, "gpt-4o-mini");
1163        assert_eq!(provider, "openai");
1164    }
1165
1166    #[test]
1167    fn test_provider_config_anthropic() {
1168        let pdef = provider_config("anthropic").unwrap();
1169        assert_eq!(pdef.auth_style, "header");
1170        assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
1171    }
1172
1173    #[test]
1174    fn test_resolve_base_url_no_env() {
1175        let pdef = ProviderDef {
1176            base_url: "https://example.com".to_string(),
1177            ..Default::default()
1178        };
1179        assert_eq!(resolve_base_url(&pdef), "https://example.com");
1180    }
1181
1182    #[test]
1183    fn test_default_config_roundtrip() {
1184        let config = default_config();
1185        assert!(!config.providers.is_empty());
1186        assert!(!config.inference_rules.is_empty());
1187        assert!(!config.tier_rules.is_empty());
1188        assert_eq!(config.tier_defaults.default, "mid");
1189    }
1190
1191    #[test]
1192    fn test_model_params_empty() {
1193        let params = model_params("claude-sonnet-4-20250514");
1194        assert!(params.is_empty());
1195    }
1196
1197    #[test]
1198    fn test_user_overrides_add_provider_and_alias() {
1199        reset_overrides();
1200        let mut overlay = ProvidersConfig::default();
1201        overlay.providers.insert(
1202            "acme".to_string(),
1203            ProviderDef {
1204                base_url: "https://llm.acme.test/v1".to_string(),
1205                chat_endpoint: "/chat/completions".to_string(),
1206                ..Default::default()
1207            },
1208        );
1209        overlay.aliases.insert(
1210            "acme-fast".to_string(),
1211            AliasDef {
1212                id: "acme/model-fast".to_string(),
1213                provider: "acme".to_string(),
1214                tool_format: Some("native".to_string()),
1215            },
1216        );
1217        set_user_overrides(Some(overlay));
1218
1219        let (model, provider) = resolve_model("acme-fast");
1220        assert_eq!(model, "acme/model-fast");
1221        assert_eq!(provider.as_deref(), Some("acme"));
1222        assert!(provider_names().contains(&"acme".to_string()));
1223        assert_eq!(
1224            provider_config("acme").map(|provider| provider.base_url),
1225            Some("https://llm.acme.test/v1".to_string())
1226        );
1227
1228        reset_overrides();
1229    }
1230
1231    #[test]
1232    fn test_user_overrides_prepend_inference_rules() {
1233        reset_overrides();
1234        let mut overlay = ProvidersConfig::default();
1235        overlay.inference_rules.push(InferenceRule {
1236            pattern: Some("internal-*".to_string()),
1237            contains: None,
1238            exact: None,
1239            provider: "openai".to_string(),
1240        });
1241        set_user_overrides(Some(overlay));
1242
1243        assert_eq!(infer_provider("internal-foo"), "openai");
1244
1245        reset_overrides();
1246    }
1247}