Skip to main content

harn_vm/
llm_config.rs

1use serde::Deserialize;
2use std::collections::BTreeMap;
3use std::sync::OnceLock;
4
5static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
6
7// =============================================================================
8// Config structs
9// =============================================================================
10
11#[derive(Debug, Clone, Deserialize, Default)]
12pub struct ProvidersConfig {
13    #[serde(default)]
14    pub providers: BTreeMap<String, ProviderDef>,
15    #[serde(default)]
16    pub aliases: BTreeMap<String, AliasDef>,
17    #[serde(default)]
18    pub inference_rules: Vec<InferenceRule>,
19    #[serde(default)]
20    pub tier_rules: Vec<TierRule>,
21    #[serde(default)]
22    pub tier_defaults: TierDefaults,
23    #[serde(default)]
24    pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
25}
26
27#[derive(Debug, Clone, Deserialize)]
28pub struct ProviderDef {
29    pub base_url: String,
30    #[serde(default)]
31    pub base_url_env: Option<String>,
32    #[serde(default = "default_bearer")]
33    pub auth_style: String,
34    #[serde(default)]
35    pub auth_header: Option<String>,
36    #[serde(default)]
37    pub auth_env: AuthEnv,
38    #[serde(default)]
39    pub extra_headers: BTreeMap<String, String>,
40    #[serde(default)]
41    pub chat_endpoint: String,
42    #[serde(default)]
43    pub healthcheck: Option<HealthcheckDef>,
44    #[serde(default)]
45    pub features: Vec<String>,
46    /// Fallback provider name to try if this provider fails.
47    #[serde(default)]
48    pub fallback: Option<String>,
49    /// Number of retries before falling back (default 0).
50    #[serde(default)]
51    pub retry_count: Option<u32>,
52    /// Delay between retries in milliseconds (default 1000).
53    #[serde(default)]
54    pub retry_delay_ms: Option<u64>,
55}
56
57impl Default for ProviderDef {
58    fn default() -> Self {
59        Self {
60            base_url: String::new(),
61            base_url_env: None,
62            auth_style: default_bearer(),
63            auth_header: None,
64            auth_env: AuthEnv::None,
65            extra_headers: BTreeMap::new(),
66            chat_endpoint: String::new(),
67            healthcheck: None,
68            features: Vec::new(),
69            fallback: None,
70            retry_count: None,
71            retry_delay_ms: None,
72        }
73    }
74}
75
76fn default_bearer() -> String {
77    "bearer".to_string()
78}
79
80/// Auth env var name(s) for the provider. Can be a single string or an array
81/// (tried in order until one is set).
82#[derive(Debug, Clone, Deserialize, Default)]
83#[serde(untagged)]
84pub enum AuthEnv {
85    #[default]
86    None,
87    Single(String),
88    Multiple(Vec<String>),
89}
90
91#[derive(Debug, Clone, Deserialize)]
92pub struct HealthcheckDef {
93    pub method: String,
94    #[serde(default)]
95    pub path: Option<String>,
96    #[serde(default)]
97    pub url: Option<String>,
98    #[serde(default)]
99    pub body: Option<String>,
100}
101
102#[derive(Debug, Clone, Deserialize)]
103pub struct AliasDef {
104    pub id: String,
105    pub provider: String,
106}
107
108#[derive(Debug, Clone, Deserialize)]
109pub struct InferenceRule {
110    #[serde(default)]
111    pub pattern: Option<String>,
112    #[serde(default)]
113    pub contains: Option<String>,
114    #[serde(default)]
115    pub exact: Option<String>,
116    pub provider: String,
117}
118
119#[derive(Debug, Clone, Deserialize)]
120pub struct TierRule {
121    #[serde(default)]
122    pub pattern: Option<String>,
123    #[serde(default)]
124    pub contains: Option<String>,
125    #[serde(default)]
126    pub exact: Option<String>,
127    pub tier: String,
128}
129
130#[derive(Debug, Clone, Deserialize)]
131pub struct TierDefaults {
132    #[serde(default = "default_mid")]
133    pub default: String,
134}
135
136impl Default for TierDefaults {
137    fn default() -> Self {
138        Self {
139            default: default_mid(),
140        }
141    }
142}
143
144fn default_mid() -> String {
145    "mid".to_string()
146}
147
148// =============================================================================
149// Config loading
150// =============================================================================
151
152/// Load and cache the providers config. Called once at VM startup.
153pub fn load_config() -> &'static ProvidersConfig {
154    CONFIG.get_or_init(|| {
155        // Try explicit env var path first
156        if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
157            match std::fs::read_to_string(&path) {
158                Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
159                    Ok(config) => {
160                        eprintln!(
161                            "[llm_config] Loaded {} providers, {} aliases from {}",
162                            config.providers.len(),
163                            config.aliases.len(),
164                            path
165                        );
166                        return config;
167                    }
168                    Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
169                },
170                Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
171            }
172        }
173        // Try ~/.config/harn/providers.toml
174        if let Some(home) = dirs_or_home() {
175            let path = format!("{home}/.config/harn/providers.toml");
176            if let Ok(content) = std::fs::read_to_string(&path) {
177                if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
178                    return config;
179                }
180            }
181        }
182        // Fallback: built-in defaults
183        default_config()
184    })
185}
186
187/// Resolve a model alias to (model_id, provider_name).
188pub fn resolve_model(alias: &str) -> (String, Option<String>) {
189    let config = load_config();
190    if let Some(a) = config.aliases.get(alias) {
191        return (a.id.clone(), Some(a.provider.clone()));
192    }
193    (alias.to_string(), None)
194}
195
196/// Infer provider from a model ID using inference rules.
197pub fn infer_provider(model_id: &str) -> String {
198    let config = load_config();
199    for rule in &config.inference_rules {
200        if let Some(exact) = &rule.exact {
201            if model_id == exact {
202                return rule.provider.clone();
203            }
204        }
205        if let Some(pattern) = &rule.pattern {
206            if glob_match(pattern, model_id) {
207                return rule.provider.clone();
208            }
209        }
210        if let Some(substr) = &rule.contains {
211            if model_id.contains(substr.as_str()) {
212                return rule.provider.clone();
213            }
214        }
215    }
216    // Fallback to hardcoded inference
217    if model_id.starts_with("claude-") {
218        return "anthropic".to_string();
219    }
220    if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
221        return "openai".to_string();
222    }
223    if model_id.contains('/') {
224        return "openrouter".to_string();
225    }
226    if model_id.contains(':') {
227        return "ollama".to_string();
228    }
229    "anthropic".to_string()
230}
231
232/// Get model tier ("small", "mid", "frontier").
233pub fn model_tier(model_id: &str) -> String {
234    let config = load_config();
235    for rule in &config.tier_rules {
236        if let Some(exact) = &rule.exact {
237            if model_id == exact {
238                return rule.tier.clone();
239            }
240        }
241        if let Some(pattern) = &rule.pattern {
242            if glob_match(pattern, model_id) {
243                return rule.tier.clone();
244            }
245        }
246        if let Some(substr) = &rule.contains {
247            if model_id.contains(substr.as_str()) {
248                return rule.tier.clone();
249            }
250        }
251    }
252    // Fallback
253    let lower = model_id.to_lowercase();
254    if lower.contains("9b") || lower.contains("a3b") {
255        return "small".to_string();
256    }
257    if lower.starts_with("claude-") || lower == "gpt-4o" {
258        return "frontier".to_string();
259    }
260    config.tier_defaults.default.clone()
261}
262
263/// Get provider config for resolving base_url, auth, etc.
264pub fn provider_config(name: &str) -> Option<&'static ProviderDef> {
265    load_config().providers.get(name)
266}
267
268/// Get model-specific default parameters (temperature, etc.).
269/// Matches glob patterns in model_defaults keys.
270pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
271    let config = load_config();
272    let mut params = BTreeMap::new();
273    for (pattern, defaults) in &config.model_defaults {
274        if glob_match(pattern, model_id) {
275            for (k, v) in defaults {
276                params.insert(k.clone(), v.clone());
277            }
278        }
279    }
280    params
281}
282
283/// Get list of configured provider names.
284pub fn provider_names() -> Vec<String> {
285    load_config().providers.keys().cloned().collect()
286}
287
288// =============================================================================
289// Helpers
290// =============================================================================
291
292/// Simple glob matching for patterns like "claude-*", "qwen/*", "ollama:*".
293fn glob_match(pattern: &str, input: &str) -> bool {
294    if let Some(prefix) = pattern.strip_suffix('*') {
295        input.starts_with(prefix)
296    } else if let Some(suffix) = pattern.strip_prefix('*') {
297        input.ends_with(suffix)
298    } else if pattern.contains('*') {
299        let parts: Vec<&str> = pattern.split('*').collect();
300        if parts.len() == 2 {
301            input.starts_with(parts[0]) && input.ends_with(parts[1])
302        } else {
303            input == pattern
304        }
305    } else {
306        input == pattern
307    }
308}
309
310fn dirs_or_home() -> Option<String> {
311    std::env::var("HOME").ok()
312}
313
314/// Resolve the effective base URL for a provider, checking the `base_url_env`
315/// override first, then falling back to the configured `base_url`.
316pub fn resolve_base_url(pdef: &ProviderDef) -> String {
317    if let Some(env_name) = &pdef.base_url_env {
318        if let Ok(val) = std::env::var(env_name) {
319            if !val.is_empty() {
320                return val;
321            }
322        }
323    }
324    pdef.base_url.clone()
325}
326
327// =============================================================================
328// Built-in default config (matches current hardcoded behavior)
329// =============================================================================
330
331fn default_config() -> ProvidersConfig {
332    let mut config = ProvidersConfig::default();
333
334    // Anthropic
335    config.providers.insert(
336        "anthropic".to_string(),
337        ProviderDef {
338            base_url: "https://api.anthropic.com/v1".to_string(),
339            auth_style: "header".to_string(),
340            auth_header: Some("x-api-key".to_string()),
341            auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
342            extra_headers: BTreeMap::from([(
343                "anthropic-version".to_string(),
344                "2023-06-01".to_string(),
345            )]),
346            chat_endpoint: "/messages".to_string(),
347            healthcheck: Some(HealthcheckDef {
348                method: "POST".to_string(),
349                path: Some("/messages/count_tokens".to_string()),
350                url: None,
351                body: Some(
352                    r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
353                        .to_string(),
354                ),
355            }),
356            features: vec!["prompt_caching".to_string(), "thinking".to_string()],
357            ..Default::default()
358        },
359    );
360
361    // OpenAI
362    config.providers.insert(
363        "openai".to_string(),
364        ProviderDef {
365            base_url: "https://api.openai.com/v1".to_string(),
366            auth_style: "bearer".to_string(),
367            auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
368            chat_endpoint: "/chat/completions".to_string(),
369            healthcheck: Some(HealthcheckDef {
370                method: "GET".to_string(),
371                path: Some("/models".to_string()),
372                url: None,
373                body: None,
374            }),
375            ..Default::default()
376        },
377    );
378
379    // OpenRouter
380    config.providers.insert(
381        "openrouter".to_string(),
382        ProviderDef {
383            base_url: "https://openrouter.ai/api/v1".to_string(),
384            auth_style: "bearer".to_string(),
385            auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
386            chat_endpoint: "/chat/completions".to_string(),
387            healthcheck: Some(HealthcheckDef {
388                method: "GET".to_string(),
389                path: Some("/auth/key".to_string()),
390                url: None,
391                body: None,
392            }),
393            ..Default::default()
394        },
395    );
396
397    // HuggingFace
398    config.providers.insert(
399        "huggingface".to_string(),
400        ProviderDef {
401            base_url: "https://router.huggingface.co/v1".to_string(),
402            auth_style: "bearer".to_string(),
403            auth_env: AuthEnv::Multiple(vec![
404                "HF_TOKEN".to_string(),
405                "HUGGINGFACE_API_KEY".to_string(),
406            ]),
407            chat_endpoint: "/chat/completions".to_string(),
408            healthcheck: Some(HealthcheckDef {
409                method: "GET".to_string(),
410                url: Some("https://huggingface.co/api/whoami-v2".to_string()),
411                path: None,
412                body: None,
413            }),
414            ..Default::default()
415        },
416    );
417
418    // Ollama
419    config.providers.insert(
420        "ollama".to_string(),
421        ProviderDef {
422            base_url: "http://localhost:11434".to_string(),
423            base_url_env: Some("OLLAMA_HOST".to_string()),
424            auth_style: "none".to_string(),
425            chat_endpoint: "/api/chat".to_string(),
426            healthcheck: Some(HealthcheckDef {
427                method: "GET".to_string(),
428                path: Some("/api/tags".to_string()),
429                url: None,
430                body: None,
431            }),
432            ..Default::default()
433        },
434    );
435
436    // Default inference rules
437    config.inference_rules = vec![
438        InferenceRule {
439            pattern: Some("claude-*".to_string()),
440            contains: None,
441            exact: None,
442            provider: "anthropic".to_string(),
443        },
444        InferenceRule {
445            pattern: Some("gpt-*".to_string()),
446            contains: None,
447            exact: None,
448            provider: "openai".to_string(),
449        },
450        InferenceRule {
451            pattern: Some("o1*".to_string()),
452            contains: None,
453            exact: None,
454            provider: "openai".to_string(),
455        },
456        InferenceRule {
457            pattern: Some("o3*".to_string()),
458            contains: None,
459            exact: None,
460            provider: "openai".to_string(),
461        },
462        InferenceRule {
463            pattern: None,
464            contains: Some("/".to_string()),
465            exact: None,
466            provider: "openrouter".to_string(),
467        },
468        InferenceRule {
469            pattern: None,
470            contains: Some(":".to_string()),
471            exact: None,
472            provider: "ollama".to_string(),
473        },
474    ];
475
476    // Default tier rules
477    config.tier_rules = vec![
478        TierRule {
479            contains: Some("9b".to_string()),
480            pattern: None,
481            exact: None,
482            tier: "small".to_string(),
483        },
484        TierRule {
485            contains: Some("a3b".to_string()),
486            pattern: None,
487            exact: None,
488            tier: "small".to_string(),
489        },
490        TierRule {
491            pattern: Some("claude-*".to_string()),
492            contains: None,
493            exact: None,
494            tier: "frontier".to_string(),
495        },
496        TierRule {
497            exact: Some("gpt-4o".to_string()),
498            contains: None,
499            pattern: None,
500            tier: "frontier".to_string(),
501        },
502    ];
503
504    config.tier_defaults = TierDefaults {
505        default: "mid".to_string(),
506    };
507
508    config
509}
510
511// =============================================================================
512// Unit tests
513// =============================================================================
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_glob_match_prefix() {
521        assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
522        assert!(glob_match("gpt-*", "gpt-4o"));
523        assert!(!glob_match("claude-*", "gpt-4o"));
524    }
525
526    #[test]
527    fn test_glob_match_suffix() {
528        assert!(glob_match("*-latest", "llama3.2-latest"));
529        assert!(!glob_match("*-latest", "llama3.2"));
530    }
531
532    #[test]
533    fn test_glob_match_middle() {
534        assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
535        assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
536    }
537
538    #[test]
539    fn test_glob_match_exact() {
540        assert!(glob_match("gpt-4o", "gpt-4o"));
541        assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
542    }
543
544    #[test]
545    fn test_infer_provider_from_defaults() {
546        // These test the fallback logic (after rules)
547        assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
548        assert_eq!(infer_provider("gpt-4o"), "openai");
549        assert_eq!(infer_provider("o1-preview"), "openai");
550        assert_eq!(infer_provider("o3-mini"), "openai");
551        assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
552        assert_eq!(infer_provider("llama3.2:latest"), "ollama");
553        assert_eq!(infer_provider("unknown-model"), "anthropic");
554    }
555
556    #[test]
557    fn test_model_tier_from_defaults() {
558        assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
559        assert_eq!(model_tier("gpt-4o"), "frontier");
560        assert_eq!(model_tier("Qwen3.5-9B"), "small");
561        assert_eq!(model_tier("deepseek-v3"), "mid");
562    }
563
564    #[test]
565    fn test_resolve_model_unknown_alias() {
566        let (id, provider) = resolve_model("gpt-4o");
567        assert_eq!(id, "gpt-4o");
568        assert!(provider.is_none());
569    }
570
571    #[test]
572    fn test_provider_names() {
573        let names = provider_names();
574        assert!(names.len() >= 5);
575        assert!(names.contains(&"anthropic".to_string()));
576        assert!(names.contains(&"openai".to_string()));
577        assert!(names.contains(&"ollama".to_string()));
578    }
579
580    #[test]
581    fn test_provider_config_anthropic() {
582        let pdef = provider_config("anthropic").unwrap();
583        assert_eq!(pdef.auth_style, "header");
584        assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
585    }
586
587    #[test]
588    fn test_resolve_base_url_no_env() {
589        let pdef = ProviderDef {
590            base_url: "https://example.com".to_string(),
591            ..Default::default()
592        };
593        assert_eq!(resolve_base_url(&pdef), "https://example.com");
594    }
595
596    #[test]
597    fn test_default_config_roundtrip() {
598        let config = default_config();
599        assert!(!config.providers.is_empty());
600        assert!(!config.inference_rules.is_empty());
601        assert!(!config.tier_rules.is_empty());
602        assert_eq!(config.tier_defaults.default, "mid");
603    }
604
605    #[test]
606    fn test_model_params_empty() {
607        let params = model_params("claude-sonnet-4-20250514");
608        // Default config has no model_defaults, so should be empty
609        assert!(params.is_empty());
610    }
611}