1use serde::Deserialize;
2use std::collections::BTreeMap;
3use std::sync::OnceLock;
4
5static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
6
7#[derive(Debug, Clone, Deserialize, Default)]
12pub struct ProvidersConfig {
13 #[serde(default)]
14 pub providers: BTreeMap<String, ProviderDef>,
15 #[serde(default)]
16 pub aliases: BTreeMap<String, AliasDef>,
17 #[serde(default)]
18 pub inference_rules: Vec<InferenceRule>,
19 #[serde(default)]
20 pub tier_rules: Vec<TierRule>,
21 #[serde(default)]
22 pub tier_defaults: TierDefaults,
23 #[serde(default)]
24 pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
25}
26
27#[derive(Debug, Clone, Deserialize)]
28pub struct ProviderDef {
29 pub base_url: String,
30 #[serde(default)]
31 pub base_url_env: Option<String>,
32 #[serde(default = "default_bearer")]
33 pub auth_style: String,
34 #[serde(default)]
35 pub auth_header: Option<String>,
36 #[serde(default)]
37 pub auth_env: AuthEnv,
38 #[serde(default)]
39 pub extra_headers: BTreeMap<String, String>,
40 #[serde(default)]
41 pub chat_endpoint: String,
42 #[serde(default)]
43 pub completion_endpoint: Option<String>,
44 #[serde(default)]
45 pub healthcheck: Option<HealthcheckDef>,
46 #[serde(default)]
47 pub features: Vec<String>,
48 #[serde(default)]
50 pub fallback: Option<String>,
51 #[serde(default)]
53 pub retry_count: Option<u32>,
54 #[serde(default)]
56 pub retry_delay_ms: Option<u64>,
57}
58
59impl Default for ProviderDef {
60 fn default() -> Self {
61 Self {
62 base_url: String::new(),
63 base_url_env: None,
64 auth_style: default_bearer(),
65 auth_header: None,
66 auth_env: AuthEnv::None,
67 extra_headers: BTreeMap::new(),
68 chat_endpoint: String::new(),
69 completion_endpoint: None,
70 healthcheck: None,
71 features: Vec::new(),
72 fallback: None,
73 retry_count: None,
74 retry_delay_ms: None,
75 }
76 }
77}
78
79fn default_bearer() -> String {
80 "bearer".to_string()
81}
82
83#[derive(Debug, Clone, Deserialize, Default)]
86#[serde(untagged)]
87pub enum AuthEnv {
88 #[default]
89 None,
90 Single(String),
91 Multiple(Vec<String>),
92}
93
94#[derive(Debug, Clone, Deserialize)]
95pub struct HealthcheckDef {
96 pub method: String,
97 #[serde(default)]
98 pub path: Option<String>,
99 #[serde(default)]
100 pub url: Option<String>,
101 #[serde(default)]
102 pub body: Option<String>,
103}
104
105#[derive(Debug, Clone, Deserialize)]
106pub struct AliasDef {
107 pub id: String,
108 pub provider: String,
109}
110
111#[derive(Debug, Clone, Deserialize)]
112pub struct InferenceRule {
113 #[serde(default)]
114 pub pattern: Option<String>,
115 #[serde(default)]
116 pub contains: Option<String>,
117 #[serde(default)]
118 pub exact: Option<String>,
119 pub provider: String,
120}
121
122#[derive(Debug, Clone, Deserialize)]
123pub struct TierRule {
124 #[serde(default)]
125 pub pattern: Option<String>,
126 #[serde(default)]
127 pub contains: Option<String>,
128 #[serde(default)]
129 pub exact: Option<String>,
130 pub tier: String,
131}
132
133#[derive(Debug, Clone, Deserialize)]
134pub struct TierDefaults {
135 #[serde(default = "default_mid")]
136 pub default: String,
137}
138
139impl Default for TierDefaults {
140 fn default() -> Self {
141 Self {
142 default: default_mid(),
143 }
144 }
145}
146
147fn default_mid() -> String {
148 "mid".to_string()
149}
150
151pub fn load_config() -> &'static ProvidersConfig {
157 CONFIG.get_or_init(|| {
158 if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
160 match std::fs::read_to_string(&path) {
161 Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
162 Ok(config) => {
163 eprintln!(
164 "[llm_config] Loaded {} providers, {} aliases from {}",
165 config.providers.len(),
166 config.aliases.len(),
167 path
168 );
169 return config;
170 }
171 Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
172 },
173 Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
174 }
175 }
176 if let Some(home) = dirs_or_home() {
178 let path = format!("{home}/.config/harn/providers.toml");
179 if let Ok(content) = std::fs::read_to_string(&path) {
180 if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
181 return config;
182 }
183 }
184 }
185 default_config()
187 })
188}
189
190pub fn resolve_model(alias: &str) -> (String, Option<String>) {
192 let config = load_config();
193 if let Some(a) = config.aliases.get(alias) {
194 return (a.id.clone(), Some(a.provider.clone()));
195 }
196 (alias.to_string(), None)
197}
198
199pub fn infer_provider(model_id: &str) -> String {
201 let config = load_config();
202 for rule in &config.inference_rules {
203 if let Some(exact) = &rule.exact {
204 if model_id == exact {
205 return rule.provider.clone();
206 }
207 }
208 if let Some(pattern) = &rule.pattern {
209 if glob_match(pattern, model_id) {
210 return rule.provider.clone();
211 }
212 }
213 if let Some(substr) = &rule.contains {
214 if model_id.contains(substr.as_str()) {
215 return rule.provider.clone();
216 }
217 }
218 }
219 if model_id.starts_with("claude-") {
221 return "anthropic".to_string();
222 }
223 if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
224 return "openai".to_string();
225 }
226 if model_id.contains('/') {
227 return "openrouter".to_string();
228 }
229 if model_id.contains(':') {
230 return "ollama".to_string();
231 }
232 "anthropic".to_string()
233}
234
235pub fn model_tier(model_id: &str) -> String {
237 let config = load_config();
238 for rule in &config.tier_rules {
239 if let Some(exact) = &rule.exact {
240 if model_id == exact {
241 return rule.tier.clone();
242 }
243 }
244 if let Some(pattern) = &rule.pattern {
245 if glob_match(pattern, model_id) {
246 return rule.tier.clone();
247 }
248 }
249 if let Some(substr) = &rule.contains {
250 if model_id.contains(substr.as_str()) {
251 return rule.tier.clone();
252 }
253 }
254 }
255 let lower = model_id.to_lowercase();
257 if lower.contains("9b") || lower.contains("a3b") {
258 return "small".to_string();
259 }
260 if lower.starts_with("claude-") || lower == "gpt-4o" {
261 return "frontier".to_string();
262 }
263 config.tier_defaults.default.clone()
264}
265
266pub fn provider_config(name: &str) -> Option<&'static ProviderDef> {
268 load_config().providers.get(name)
269}
270
271pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
274 let config = load_config();
275 let mut params = BTreeMap::new();
276 for (pattern, defaults) in &config.model_defaults {
277 if glob_match(pattern, model_id) {
278 for (k, v) in defaults {
279 params.insert(k.clone(), v.clone());
280 }
281 }
282 }
283 params
284}
285
286pub fn provider_names() -> Vec<String> {
288 load_config().providers.keys().cloned().collect()
289}
290
291pub fn resolve_tier_model(
293 target: &str,
294 preferred_provider: Option<&str>,
295) -> Option<(String, String)> {
296 let config = load_config();
297
298 if let Some(alias) = config.aliases.get(target) {
299 return Some((alias.id.clone(), alias.provider.clone()));
300 }
301
302 let candidate_aliases = if let Some(provider) = preferred_provider {
303 vec![
304 format!("{provider}/{target}"),
305 format!("{provider}:{target}"),
306 format!("tier/{target}"),
307 target.to_string(),
308 ]
309 } else {
310 vec![format!("tier/{target}"), target.to_string()]
311 };
312
313 for alias_name in candidate_aliases {
314 if let Some(alias) = config.aliases.get(&alias_name) {
315 return Some((alias.id.clone(), alias.provider.clone()));
316 }
317 }
318
319 None
320}
321
322fn glob_match(pattern: &str, input: &str) -> bool {
328 if let Some(prefix) = pattern.strip_suffix('*') {
329 input.starts_with(prefix)
330 } else if let Some(suffix) = pattern.strip_prefix('*') {
331 input.ends_with(suffix)
332 } else if pattern.contains('*') {
333 let parts: Vec<&str> = pattern.split('*').collect();
334 if parts.len() == 2 {
335 input.starts_with(parts[0]) && input.ends_with(parts[1])
336 } else {
337 input == pattern
338 }
339 } else {
340 input == pattern
341 }
342}
343
344fn dirs_or_home() -> Option<String> {
345 std::env::var("HOME").ok()
346}
347
348pub fn resolve_base_url(pdef: &ProviderDef) -> String {
351 if let Some(env_name) = &pdef.base_url_env {
352 if let Ok(val) = std::env::var(env_name) {
353 if !val.is_empty() {
354 return val;
355 }
356 }
357 }
358 pdef.base_url.clone()
359}
360
361fn default_config() -> ProvidersConfig {
366 let mut config = ProvidersConfig::default();
367
368 config.providers.insert(
370 "anthropic".to_string(),
371 ProviderDef {
372 base_url: "https://api.anthropic.com/v1".to_string(),
373 auth_style: "header".to_string(),
374 auth_header: Some("x-api-key".to_string()),
375 auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
376 extra_headers: BTreeMap::from([(
377 "anthropic-version".to_string(),
378 "2023-06-01".to_string(),
379 )]),
380 chat_endpoint: "/messages".to_string(),
381 completion_endpoint: None,
382 healthcheck: Some(HealthcheckDef {
383 method: "POST".to_string(),
384 path: Some("/messages/count_tokens".to_string()),
385 url: None,
386 body: Some(
387 r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
388 .to_string(),
389 ),
390 }),
391 features: vec!["prompt_caching".to_string(), "thinking".to_string()],
392 ..Default::default()
393 },
394 );
395
396 config.providers.insert(
398 "openai".to_string(),
399 ProviderDef {
400 base_url: "https://api.openai.com/v1".to_string(),
401 auth_style: "bearer".to_string(),
402 auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
403 chat_endpoint: "/chat/completions".to_string(),
404 completion_endpoint: Some("/completions".to_string()),
405 healthcheck: Some(HealthcheckDef {
406 method: "GET".to_string(),
407 path: Some("/models".to_string()),
408 url: None,
409 body: None,
410 }),
411 ..Default::default()
412 },
413 );
414
415 config.providers.insert(
417 "openrouter".to_string(),
418 ProviderDef {
419 base_url: "https://openrouter.ai/api/v1".to_string(),
420 auth_style: "bearer".to_string(),
421 auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
422 chat_endpoint: "/chat/completions".to_string(),
423 completion_endpoint: Some("/completions".to_string()),
424 healthcheck: Some(HealthcheckDef {
425 method: "GET".to_string(),
426 path: Some("/auth/key".to_string()),
427 url: None,
428 body: None,
429 }),
430 ..Default::default()
431 },
432 );
433
434 config.providers.insert(
436 "huggingface".to_string(),
437 ProviderDef {
438 base_url: "https://router.huggingface.co/v1".to_string(),
439 auth_style: "bearer".to_string(),
440 auth_env: AuthEnv::Multiple(vec![
441 "HF_TOKEN".to_string(),
442 "HUGGINGFACE_API_KEY".to_string(),
443 ]),
444 chat_endpoint: "/chat/completions".to_string(),
445 completion_endpoint: Some("/completions".to_string()),
446 healthcheck: Some(HealthcheckDef {
447 method: "GET".to_string(),
448 url: Some("https://huggingface.co/api/whoami-v2".to_string()),
449 path: None,
450 body: None,
451 }),
452 ..Default::default()
453 },
454 );
455
456 config.providers.insert(
458 "ollama".to_string(),
459 ProviderDef {
460 base_url: "http://localhost:11434".to_string(),
461 base_url_env: Some("OLLAMA_HOST".to_string()),
462 auth_style: "none".to_string(),
463 chat_endpoint: "/api/chat".to_string(),
464 completion_endpoint: Some("/api/generate".to_string()),
465 healthcheck: Some(HealthcheckDef {
466 method: "GET".to_string(),
467 path: Some("/api/tags".to_string()),
468 url: None,
469 body: None,
470 }),
471 ..Default::default()
472 },
473 );
474
475 config.inference_rules = vec![
477 InferenceRule {
478 pattern: Some("claude-*".to_string()),
479 contains: None,
480 exact: None,
481 provider: "anthropic".to_string(),
482 },
483 InferenceRule {
484 pattern: Some("gpt-*".to_string()),
485 contains: None,
486 exact: None,
487 provider: "openai".to_string(),
488 },
489 InferenceRule {
490 pattern: Some("o1*".to_string()),
491 contains: None,
492 exact: None,
493 provider: "openai".to_string(),
494 },
495 InferenceRule {
496 pattern: Some("o3*".to_string()),
497 contains: None,
498 exact: None,
499 provider: "openai".to_string(),
500 },
501 InferenceRule {
502 pattern: None,
503 contains: Some("/".to_string()),
504 exact: None,
505 provider: "openrouter".to_string(),
506 },
507 InferenceRule {
508 pattern: None,
509 contains: Some(":".to_string()),
510 exact: None,
511 provider: "ollama".to_string(),
512 },
513 ];
514
515 config.tier_rules = vec![
517 TierRule {
518 contains: Some("9b".to_string()),
519 pattern: None,
520 exact: None,
521 tier: "small".to_string(),
522 },
523 TierRule {
524 contains: Some("a3b".to_string()),
525 pattern: None,
526 exact: None,
527 tier: "small".to_string(),
528 },
529 TierRule {
530 pattern: Some("claude-*".to_string()),
531 contains: None,
532 exact: None,
533 tier: "frontier".to_string(),
534 },
535 TierRule {
536 exact: Some("gpt-4o".to_string()),
537 contains: None,
538 pattern: None,
539 tier: "frontier".to_string(),
540 },
541 ];
542
543 config.tier_defaults = TierDefaults {
544 default: "mid".to_string(),
545 };
546
547 config.aliases.insert(
548 "frontier".to_string(),
549 AliasDef {
550 id: "claude-sonnet-4-20250514".to_string(),
551 provider: "anthropic".to_string(),
552 },
553 );
554 config.aliases.insert(
555 "tier/frontier".to_string(),
556 AliasDef {
557 id: "claude-sonnet-4-20250514".to_string(),
558 provider: "anthropic".to_string(),
559 },
560 );
561 config.aliases.insert(
562 "mid".to_string(),
563 AliasDef {
564 id: "gpt-4o-mini".to_string(),
565 provider: "openai".to_string(),
566 },
567 );
568 config.aliases.insert(
569 "tier/mid".to_string(),
570 AliasDef {
571 id: "gpt-4o-mini".to_string(),
572 provider: "openai".to_string(),
573 },
574 );
575 config.aliases.insert(
576 "small".to_string(),
577 AliasDef {
578 id: "Qwen/Qwen3.5-9B".to_string(),
579 provider: "openrouter".to_string(),
580 },
581 );
582 config.aliases.insert(
583 "tier/small".to_string(),
584 AliasDef {
585 id: "Qwen/Qwen3.5-9B".to_string(),
586 provider: "openrouter".to_string(),
587 },
588 );
589
590 config
591}
592
593#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[test]
602 fn test_glob_match_prefix() {
603 assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
604 assert!(glob_match("gpt-*", "gpt-4o"));
605 assert!(!glob_match("claude-*", "gpt-4o"));
606 }
607
608 #[test]
609 fn test_glob_match_suffix() {
610 assert!(glob_match("*-latest", "llama3.2-latest"));
611 assert!(!glob_match("*-latest", "llama3.2"));
612 }
613
614 #[test]
615 fn test_glob_match_middle() {
616 assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
617 assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
618 }
619
620 #[test]
621 fn test_glob_match_exact() {
622 assert!(glob_match("gpt-4o", "gpt-4o"));
623 assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
624 }
625
626 #[test]
627 fn test_infer_provider_from_defaults() {
628 assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
630 assert_eq!(infer_provider("gpt-4o"), "openai");
631 assert_eq!(infer_provider("o1-preview"), "openai");
632 assert_eq!(infer_provider("o3-mini"), "openai");
633 assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
634 assert_eq!(infer_provider("llama3.2:latest"), "ollama");
635 assert_eq!(infer_provider("unknown-model"), "anthropic");
636 }
637
638 #[test]
639 fn test_model_tier_from_defaults() {
640 assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
641 assert_eq!(model_tier("gpt-4o"), "frontier");
642 assert_eq!(model_tier("Qwen3.5-9B"), "small");
643 assert_eq!(model_tier("deepseek-v3"), "mid");
644 }
645
646 #[test]
647 fn test_resolve_model_unknown_alias() {
648 let (id, provider) = resolve_model("gpt-4o");
649 assert_eq!(id, "gpt-4o");
650 assert!(provider.is_none());
651 }
652
653 #[test]
654 fn test_provider_names() {
655 let names = provider_names();
656 assert!(names.len() >= 5);
657 assert!(names.contains(&"anthropic".to_string()));
658 assert!(names.contains(&"openai".to_string()));
659 assert!(names.contains(&"ollama".to_string()));
660 }
661
662 #[test]
663 fn test_resolve_tier_model_default_aliases() {
664 let (model, provider) = resolve_tier_model("frontier", None).unwrap();
665 assert_eq!(model, "claude-sonnet-4-20250514");
666 assert_eq!(provider, "anthropic");
667
668 let (model, provider) = resolve_tier_model("small", None).unwrap();
669 assert_eq!(model, "Qwen/Qwen3.5-9B");
670 assert_eq!(provider, "openrouter");
671 }
672
673 #[test]
674 fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
675 let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
676 assert_eq!(model, "gpt-4o-mini");
677 assert_eq!(provider, "openai");
678 }
679
680 #[test]
681 fn test_provider_config_anthropic() {
682 let pdef = provider_config("anthropic").unwrap();
683 assert_eq!(pdef.auth_style, "header");
684 assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
685 }
686
687 #[test]
688 fn test_resolve_base_url_no_env() {
689 let pdef = ProviderDef {
690 base_url: "https://example.com".to_string(),
691 ..Default::default()
692 };
693 assert_eq!(resolve_base_url(&pdef), "https://example.com");
694 }
695
696 #[test]
697 fn test_default_config_roundtrip() {
698 let config = default_config();
699 assert!(!config.providers.is_empty());
700 assert!(!config.inference_rules.is_empty());
701 assert!(!config.tier_rules.is_empty());
702 assert_eq!(config.tier_defaults.default, "mid");
703 }
704
705 #[test]
706 fn test_model_params_empty() {
707 let params = model_params("claude-sonnet-4-20250514");
708 assert!(params.is_empty());
710 }
711}