Skip to main content

harn_vm/
llm_config.rs

1use serde::Deserialize;
2use std::collections::BTreeMap;
3use std::sync::OnceLock;
4
5static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
6
7// =============================================================================
8// Config structs
9// =============================================================================
10
11#[derive(Debug, Clone, Deserialize, Default)]
12pub struct ProvidersConfig {
13    #[serde(default)]
14    pub providers: BTreeMap<String, ProviderDef>,
15    #[serde(default)]
16    pub aliases: BTreeMap<String, AliasDef>,
17    #[serde(default)]
18    pub inference_rules: Vec<InferenceRule>,
19    #[serde(default)]
20    pub tier_rules: Vec<TierRule>,
21    #[serde(default)]
22    pub tier_defaults: TierDefaults,
23    #[serde(default)]
24    pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
25}
26
27#[derive(Debug, Clone, Deserialize)]
28pub struct ProviderDef {
29    pub base_url: String,
30    #[serde(default)]
31    pub base_url_env: Option<String>,
32    #[serde(default = "default_bearer")]
33    pub auth_style: String,
34    #[serde(default)]
35    pub auth_header: Option<String>,
36    #[serde(default)]
37    pub auth_env: AuthEnv,
38    #[serde(default)]
39    pub extra_headers: BTreeMap<String, String>,
40    #[serde(default)]
41    pub chat_endpoint: String,
42    #[serde(default)]
43    pub completion_endpoint: Option<String>,
44    #[serde(default)]
45    pub healthcheck: Option<HealthcheckDef>,
46    #[serde(default)]
47    pub features: Vec<String>,
48    /// Fallback provider name to try if this provider fails.
49    #[serde(default)]
50    pub fallback: Option<String>,
51    /// Number of retries before falling back (default 0).
52    #[serde(default)]
53    pub retry_count: Option<u32>,
54    /// Delay between retries in milliseconds (default 1000).
55    #[serde(default)]
56    pub retry_delay_ms: Option<u64>,
57}
58
59impl Default for ProviderDef {
60    fn default() -> Self {
61        Self {
62            base_url: String::new(),
63            base_url_env: None,
64            auth_style: default_bearer(),
65            auth_header: None,
66            auth_env: AuthEnv::None,
67            extra_headers: BTreeMap::new(),
68            chat_endpoint: String::new(),
69            completion_endpoint: None,
70            healthcheck: None,
71            features: Vec::new(),
72            fallback: None,
73            retry_count: None,
74            retry_delay_ms: None,
75        }
76    }
77}
78
79fn default_bearer() -> String {
80    "bearer".to_string()
81}
82
83/// Auth env var name(s) for the provider. Can be a single string or an array
84/// (tried in order until one is set).
85#[derive(Debug, Clone, Deserialize, Default)]
86#[serde(untagged)]
87pub enum AuthEnv {
88    #[default]
89    None,
90    Single(String),
91    Multiple(Vec<String>),
92}
93
94#[derive(Debug, Clone, Deserialize)]
95pub struct HealthcheckDef {
96    pub method: String,
97    #[serde(default)]
98    pub path: Option<String>,
99    #[serde(default)]
100    pub url: Option<String>,
101    #[serde(default)]
102    pub body: Option<String>,
103}
104
105#[derive(Debug, Clone, Deserialize)]
106pub struct AliasDef {
107    pub id: String,
108    pub provider: String,
109}
110
111#[derive(Debug, Clone, Deserialize)]
112pub struct InferenceRule {
113    #[serde(default)]
114    pub pattern: Option<String>,
115    #[serde(default)]
116    pub contains: Option<String>,
117    #[serde(default)]
118    pub exact: Option<String>,
119    pub provider: String,
120}
121
122#[derive(Debug, Clone, Deserialize)]
123pub struct TierRule {
124    #[serde(default)]
125    pub pattern: Option<String>,
126    #[serde(default)]
127    pub contains: Option<String>,
128    #[serde(default)]
129    pub exact: Option<String>,
130    pub tier: String,
131}
132
133#[derive(Debug, Clone, Deserialize)]
134pub struct TierDefaults {
135    #[serde(default = "default_mid")]
136    pub default: String,
137}
138
139impl Default for TierDefaults {
140    fn default() -> Self {
141        Self {
142            default: default_mid(),
143        }
144    }
145}
146
147fn default_mid() -> String {
148    "mid".to_string()
149}
150
151// =============================================================================
152// Config loading
153// =============================================================================
154
155/// Load and cache the providers config. Called once at VM startup.
156pub fn load_config() -> &'static ProvidersConfig {
157    CONFIG.get_or_init(|| {
158        // Try explicit env var path first
159        if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
160            match std::fs::read_to_string(&path) {
161                Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
162                    Ok(config) => {
163                        eprintln!(
164                            "[llm_config] Loaded {} providers, {} aliases from {}",
165                            config.providers.len(),
166                            config.aliases.len(),
167                            path
168                        );
169                        return config;
170                    }
171                    Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
172                },
173                Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
174            }
175        }
176        // Try ~/.config/harn/providers.toml
177        if let Some(home) = dirs_or_home() {
178            let path = format!("{home}/.config/harn/providers.toml");
179            if let Ok(content) = std::fs::read_to_string(&path) {
180                if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
181                    return config;
182                }
183            }
184        }
185        // Fallback: built-in defaults
186        default_config()
187    })
188}
189
190/// Resolve a model alias to (model_id, provider_name).
191pub fn resolve_model(alias: &str) -> (String, Option<String>) {
192    let config = load_config();
193    if let Some(a) = config.aliases.get(alias) {
194        return (a.id.clone(), Some(a.provider.clone()));
195    }
196    (alias.to_string(), None)
197}
198
199/// Infer provider from a model ID using inference rules.
200pub fn infer_provider(model_id: &str) -> String {
201    let config = load_config();
202    for rule in &config.inference_rules {
203        if let Some(exact) = &rule.exact {
204            if model_id == exact {
205                return rule.provider.clone();
206            }
207        }
208        if let Some(pattern) = &rule.pattern {
209            if glob_match(pattern, model_id) {
210                return rule.provider.clone();
211            }
212        }
213        if let Some(substr) = &rule.contains {
214            if model_id.contains(substr.as_str()) {
215                return rule.provider.clone();
216            }
217        }
218    }
219    // Fallback to hardcoded inference
220    if model_id.starts_with("claude-") {
221        return "anthropic".to_string();
222    }
223    if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
224        return "openai".to_string();
225    }
226    if model_id.contains('/') {
227        return "openrouter".to_string();
228    }
229    if model_id.contains(':') {
230        return "ollama".to_string();
231    }
232    "anthropic".to_string()
233}
234
235/// Get model tier ("small", "mid", "frontier").
236pub fn model_tier(model_id: &str) -> String {
237    let config = load_config();
238    for rule in &config.tier_rules {
239        if let Some(exact) = &rule.exact {
240            if model_id == exact {
241                return rule.tier.clone();
242            }
243        }
244        if let Some(pattern) = &rule.pattern {
245            if glob_match(pattern, model_id) {
246                return rule.tier.clone();
247            }
248        }
249        if let Some(substr) = &rule.contains {
250            if model_id.contains(substr.as_str()) {
251                return rule.tier.clone();
252            }
253        }
254    }
255    // Fallback
256    let lower = model_id.to_lowercase();
257    if lower.contains("9b") || lower.contains("a3b") {
258        return "small".to_string();
259    }
260    if lower.starts_with("claude-") || lower == "gpt-4o" {
261        return "frontier".to_string();
262    }
263    config.tier_defaults.default.clone()
264}
265
266/// Get provider config for resolving base_url, auth, etc.
267pub fn provider_config(name: &str) -> Option<&'static ProviderDef> {
268    load_config().providers.get(name)
269}
270
271/// Get model-specific default parameters (temperature, etc.).
272/// Matches glob patterns in model_defaults keys.
273pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
274    let config = load_config();
275    let mut params = BTreeMap::new();
276    for (pattern, defaults) in &config.model_defaults {
277        if glob_match(pattern, model_id) {
278            for (k, v) in defaults {
279                params.insert(k.clone(), v.clone());
280            }
281        }
282    }
283    params
284}
285
286/// Get list of configured provider names.
287pub fn provider_names() -> Vec<String> {
288    load_config().providers.keys().cloned().collect()
289}
290
291/// Resolve a tier or alias into a concrete model/provider pair.
292pub fn resolve_tier_model(
293    target: &str,
294    preferred_provider: Option<&str>,
295) -> Option<(String, String)> {
296    let config = load_config();
297
298    if let Some(alias) = config.aliases.get(target) {
299        return Some((alias.id.clone(), alias.provider.clone()));
300    }
301
302    let candidate_aliases = if let Some(provider) = preferred_provider {
303        vec![
304            format!("{provider}/{target}"),
305            format!("{provider}:{target}"),
306            format!("tier/{target}"),
307            target.to_string(),
308        ]
309    } else {
310        vec![format!("tier/{target}"), target.to_string()]
311    };
312
313    for alias_name in candidate_aliases {
314        if let Some(alias) = config.aliases.get(&alias_name) {
315            return Some((alias.id.clone(), alias.provider.clone()));
316        }
317    }
318
319    None
320}
321
322// =============================================================================
323// Helpers
324// =============================================================================
325
326/// Simple glob matching for patterns like "claude-*", "qwen/*", "ollama:*".
327fn glob_match(pattern: &str, input: &str) -> bool {
328    if let Some(prefix) = pattern.strip_suffix('*') {
329        input.starts_with(prefix)
330    } else if let Some(suffix) = pattern.strip_prefix('*') {
331        input.ends_with(suffix)
332    } else if pattern.contains('*') {
333        let parts: Vec<&str> = pattern.split('*').collect();
334        if parts.len() == 2 {
335            input.starts_with(parts[0]) && input.ends_with(parts[1])
336        } else {
337            input == pattern
338        }
339    } else {
340        input == pattern
341    }
342}
343
344fn dirs_or_home() -> Option<String> {
345    std::env::var("HOME").ok()
346}
347
348/// Resolve the effective base URL for a provider, checking the `base_url_env`
349/// override first, then falling back to the configured `base_url`.
350pub fn resolve_base_url(pdef: &ProviderDef) -> String {
351    if let Some(env_name) = &pdef.base_url_env {
352        if let Ok(val) = std::env::var(env_name) {
353            if !val.is_empty() {
354                return val;
355            }
356        }
357    }
358    pdef.base_url.clone()
359}
360
361// =============================================================================
362// Built-in default config (matches current hardcoded behavior)
363// =============================================================================
364
365fn default_config() -> ProvidersConfig {
366    let mut config = ProvidersConfig::default();
367
368    // Anthropic
369    config.providers.insert(
370        "anthropic".to_string(),
371        ProviderDef {
372            base_url: "https://api.anthropic.com/v1".to_string(),
373            auth_style: "header".to_string(),
374            auth_header: Some("x-api-key".to_string()),
375            auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
376            extra_headers: BTreeMap::from([(
377                "anthropic-version".to_string(),
378                "2023-06-01".to_string(),
379            )]),
380            chat_endpoint: "/messages".to_string(),
381            completion_endpoint: None,
382            healthcheck: Some(HealthcheckDef {
383                method: "POST".to_string(),
384                path: Some("/messages/count_tokens".to_string()),
385                url: None,
386                body: Some(
387                    r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
388                        .to_string(),
389                ),
390            }),
391            features: vec!["prompt_caching".to_string(), "thinking".to_string()],
392            ..Default::default()
393        },
394    );
395
396    // OpenAI
397    config.providers.insert(
398        "openai".to_string(),
399        ProviderDef {
400            base_url: "https://api.openai.com/v1".to_string(),
401            auth_style: "bearer".to_string(),
402            auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
403            chat_endpoint: "/chat/completions".to_string(),
404            completion_endpoint: Some("/completions".to_string()),
405            healthcheck: Some(HealthcheckDef {
406                method: "GET".to_string(),
407                path: Some("/models".to_string()),
408                url: None,
409                body: None,
410            }),
411            ..Default::default()
412        },
413    );
414
415    // OpenRouter
416    config.providers.insert(
417        "openrouter".to_string(),
418        ProviderDef {
419            base_url: "https://openrouter.ai/api/v1".to_string(),
420            auth_style: "bearer".to_string(),
421            auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
422            chat_endpoint: "/chat/completions".to_string(),
423            completion_endpoint: Some("/completions".to_string()),
424            healthcheck: Some(HealthcheckDef {
425                method: "GET".to_string(),
426                path: Some("/auth/key".to_string()),
427                url: None,
428                body: None,
429            }),
430            ..Default::default()
431        },
432    );
433
434    // HuggingFace
435    config.providers.insert(
436        "huggingface".to_string(),
437        ProviderDef {
438            base_url: "https://router.huggingface.co/v1".to_string(),
439            auth_style: "bearer".to_string(),
440            auth_env: AuthEnv::Multiple(vec![
441                "HF_TOKEN".to_string(),
442                "HUGGINGFACE_API_KEY".to_string(),
443            ]),
444            chat_endpoint: "/chat/completions".to_string(),
445            completion_endpoint: Some("/completions".to_string()),
446            healthcheck: Some(HealthcheckDef {
447                method: "GET".to_string(),
448                url: Some("https://huggingface.co/api/whoami-v2".to_string()),
449                path: None,
450                body: None,
451            }),
452            ..Default::default()
453        },
454    );
455
456    // Ollama
457    config.providers.insert(
458        "ollama".to_string(),
459        ProviderDef {
460            base_url: "http://localhost:11434".to_string(),
461            base_url_env: Some("OLLAMA_HOST".to_string()),
462            auth_style: "none".to_string(),
463            chat_endpoint: "/api/chat".to_string(),
464            completion_endpoint: Some("/api/generate".to_string()),
465            healthcheck: Some(HealthcheckDef {
466                method: "GET".to_string(),
467                path: Some("/api/tags".to_string()),
468                url: None,
469                body: None,
470            }),
471            ..Default::default()
472        },
473    );
474
475    // Default inference rules
476    config.inference_rules = vec![
477        InferenceRule {
478            pattern: Some("claude-*".to_string()),
479            contains: None,
480            exact: None,
481            provider: "anthropic".to_string(),
482        },
483        InferenceRule {
484            pattern: Some("gpt-*".to_string()),
485            contains: None,
486            exact: None,
487            provider: "openai".to_string(),
488        },
489        InferenceRule {
490            pattern: Some("o1*".to_string()),
491            contains: None,
492            exact: None,
493            provider: "openai".to_string(),
494        },
495        InferenceRule {
496            pattern: Some("o3*".to_string()),
497            contains: None,
498            exact: None,
499            provider: "openai".to_string(),
500        },
501        InferenceRule {
502            pattern: None,
503            contains: Some("/".to_string()),
504            exact: None,
505            provider: "openrouter".to_string(),
506        },
507        InferenceRule {
508            pattern: None,
509            contains: Some(":".to_string()),
510            exact: None,
511            provider: "ollama".to_string(),
512        },
513    ];
514
515    // Default tier rules
516    config.tier_rules = vec![
517        TierRule {
518            contains: Some("9b".to_string()),
519            pattern: None,
520            exact: None,
521            tier: "small".to_string(),
522        },
523        TierRule {
524            contains: Some("a3b".to_string()),
525            pattern: None,
526            exact: None,
527            tier: "small".to_string(),
528        },
529        TierRule {
530            pattern: Some("claude-*".to_string()),
531            contains: None,
532            exact: None,
533            tier: "frontier".to_string(),
534        },
535        TierRule {
536            exact: Some("gpt-4o".to_string()),
537            contains: None,
538            pattern: None,
539            tier: "frontier".to_string(),
540        },
541    ];
542
543    config.tier_defaults = TierDefaults {
544        default: "mid".to_string(),
545    };
546
547    config.aliases.insert(
548        "frontier".to_string(),
549        AliasDef {
550            id: "claude-sonnet-4-20250514".to_string(),
551            provider: "anthropic".to_string(),
552        },
553    );
554    config.aliases.insert(
555        "tier/frontier".to_string(),
556        AliasDef {
557            id: "claude-sonnet-4-20250514".to_string(),
558            provider: "anthropic".to_string(),
559        },
560    );
561    config.aliases.insert(
562        "mid".to_string(),
563        AliasDef {
564            id: "gpt-4o-mini".to_string(),
565            provider: "openai".to_string(),
566        },
567    );
568    config.aliases.insert(
569        "tier/mid".to_string(),
570        AliasDef {
571            id: "gpt-4o-mini".to_string(),
572            provider: "openai".to_string(),
573        },
574    );
575    config.aliases.insert(
576        "small".to_string(),
577        AliasDef {
578            id: "Qwen/Qwen3.5-9B".to_string(),
579            provider: "openrouter".to_string(),
580        },
581    );
582    config.aliases.insert(
583        "tier/small".to_string(),
584        AliasDef {
585            id: "Qwen/Qwen3.5-9B".to_string(),
586            provider: "openrouter".to_string(),
587        },
588    );
589
590    config
591}
592
593// =============================================================================
594// Unit tests
595// =============================================================================
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_glob_match_prefix() {
603        assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
604        assert!(glob_match("gpt-*", "gpt-4o"));
605        assert!(!glob_match("claude-*", "gpt-4o"));
606    }
607
608    #[test]
609    fn test_glob_match_suffix() {
610        assert!(glob_match("*-latest", "llama3.2-latest"));
611        assert!(!glob_match("*-latest", "llama3.2"));
612    }
613
614    #[test]
615    fn test_glob_match_middle() {
616        assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
617        assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
618    }
619
620    #[test]
621    fn test_glob_match_exact() {
622        assert!(glob_match("gpt-4o", "gpt-4o"));
623        assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
624    }
625
626    #[test]
627    fn test_infer_provider_from_defaults() {
628        // These test the fallback logic (after rules)
629        assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
630        assert_eq!(infer_provider("gpt-4o"), "openai");
631        assert_eq!(infer_provider("o1-preview"), "openai");
632        assert_eq!(infer_provider("o3-mini"), "openai");
633        assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
634        assert_eq!(infer_provider("llama3.2:latest"), "ollama");
635        assert_eq!(infer_provider("unknown-model"), "anthropic");
636    }
637
638    #[test]
639    fn test_model_tier_from_defaults() {
640        assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
641        assert_eq!(model_tier("gpt-4o"), "frontier");
642        assert_eq!(model_tier("Qwen3.5-9B"), "small");
643        assert_eq!(model_tier("deepseek-v3"), "mid");
644    }
645
646    #[test]
647    fn test_resolve_model_unknown_alias() {
648        let (id, provider) = resolve_model("gpt-4o");
649        assert_eq!(id, "gpt-4o");
650        assert!(provider.is_none());
651    }
652
653    #[test]
654    fn test_provider_names() {
655        let names = provider_names();
656        assert!(names.len() >= 5);
657        assert!(names.contains(&"anthropic".to_string()));
658        assert!(names.contains(&"openai".to_string()));
659        assert!(names.contains(&"ollama".to_string()));
660    }
661
662    #[test]
663    fn test_resolve_tier_model_default_aliases() {
664        let (model, provider) = resolve_tier_model("frontier", None).unwrap();
665        assert_eq!(model, "claude-sonnet-4-20250514");
666        assert_eq!(provider, "anthropic");
667
668        let (model, provider) = resolve_tier_model("small", None).unwrap();
669        assert_eq!(model, "Qwen/Qwen3.5-9B");
670        assert_eq!(provider, "openrouter");
671    }
672
673    #[test]
674    fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
675        let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
676        assert_eq!(model, "gpt-4o-mini");
677        assert_eq!(provider, "openai");
678    }
679
680    #[test]
681    fn test_provider_config_anthropic() {
682        let pdef = provider_config("anthropic").unwrap();
683        assert_eq!(pdef.auth_style, "header");
684        assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
685    }
686
687    #[test]
688    fn test_resolve_base_url_no_env() {
689        let pdef = ProviderDef {
690            base_url: "https://example.com".to_string(),
691            ..Default::default()
692        };
693        assert_eq!(resolve_base_url(&pdef), "https://example.com");
694    }
695
696    #[test]
697    fn test_default_config_roundtrip() {
698        let config = default_config();
699        assert!(!config.providers.is_empty());
700        assert!(!config.inference_rules.is_empty());
701        assert!(!config.tier_rules.is_empty());
702        assert_eq!(config.tier_defaults.default, "mid");
703    }
704
705    #[test]
706    fn test_model_params_empty() {
707        let params = model_params("claude-sonnet-4-20250514");
708        // Default config has no model_defaults, so should be empty
709        assert!(params.is_empty());
710    }
711}