1use serde::Deserialize;
2use std::collections::BTreeMap;
3use std::sync::OnceLock;
4
5static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
6static CONFIG_PATH: OnceLock<String> = OnceLock::new();
7
8#[derive(Debug, Clone, Deserialize, Default)]
13pub struct ProvidersConfig {
14 #[serde(default)]
15 pub providers: BTreeMap<String, ProviderDef>,
16 #[serde(default)]
17 pub aliases: BTreeMap<String, AliasDef>,
18 #[serde(default)]
19 pub inference_rules: Vec<InferenceRule>,
20 #[serde(default)]
21 pub tier_rules: Vec<TierRule>,
22 #[serde(default)]
23 pub tier_defaults: TierDefaults,
24 #[serde(default)]
25 pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
26}
27
28#[derive(Debug, Clone, Deserialize)]
29pub struct ProviderDef {
30 pub base_url: String,
31 #[serde(default)]
32 pub base_url_env: Option<String>,
33 #[serde(default = "default_bearer")]
34 pub auth_style: String,
35 #[serde(default)]
36 pub auth_header: Option<String>,
37 #[serde(default)]
38 pub auth_env: AuthEnv,
39 #[serde(default)]
40 pub extra_headers: BTreeMap<String, String>,
41 #[serde(default)]
42 pub chat_endpoint: String,
43 #[serde(default)]
44 pub completion_endpoint: Option<String>,
45 #[serde(default)]
46 pub healthcheck: Option<HealthcheckDef>,
47 #[serde(default)]
48 pub features: Vec<String>,
49 #[serde(default)]
51 pub fallback: Option<String>,
52 #[serde(default)]
54 pub retry_count: Option<u32>,
55 #[serde(default)]
57 pub retry_delay_ms: Option<u64>,
58 #[serde(default)]
60 pub rpm: Option<u32>,
61}
62
63impl Default for ProviderDef {
64 fn default() -> Self {
65 Self {
66 base_url: String::new(),
67 base_url_env: None,
68 auth_style: default_bearer(),
69 auth_header: None,
70 auth_env: AuthEnv::None,
71 extra_headers: BTreeMap::new(),
72 chat_endpoint: String::new(),
73 completion_endpoint: None,
74 healthcheck: None,
75 features: Vec::new(),
76 fallback: None,
77 retry_count: None,
78 retry_delay_ms: None,
79 rpm: None,
80 }
81 }
82}
83
84fn default_bearer() -> String {
85 "bearer".to_string()
86}
87
88#[derive(Debug, Clone, Deserialize, Default)]
91#[serde(untagged)]
92pub enum AuthEnv {
93 #[default]
94 None,
95 Single(String),
96 Multiple(Vec<String>),
97}
98
99#[derive(Debug, Clone, Deserialize)]
100pub struct HealthcheckDef {
101 pub method: String,
102 #[serde(default)]
103 pub path: Option<String>,
104 #[serde(default)]
105 pub url: Option<String>,
106 #[serde(default)]
107 pub body: Option<String>,
108}
109
110#[derive(Debug, Clone, Deserialize)]
111pub struct AliasDef {
112 pub id: String,
113 pub provider: String,
114 #[serde(default)]
119 pub tool_format: Option<String>,
120}
121
122#[derive(Debug, Clone, Deserialize)]
123pub struct InferenceRule {
124 #[serde(default)]
125 pub pattern: Option<String>,
126 #[serde(default)]
127 pub contains: Option<String>,
128 #[serde(default)]
129 pub exact: Option<String>,
130 pub provider: String,
131}
132
133#[derive(Debug, Clone, Deserialize)]
134pub struct TierRule {
135 #[serde(default)]
136 pub pattern: Option<String>,
137 #[serde(default)]
138 pub contains: Option<String>,
139 #[serde(default)]
140 pub exact: Option<String>,
141 pub tier: String,
142}
143
144#[derive(Debug, Clone, Deserialize)]
145pub struct TierDefaults {
146 #[serde(default = "default_mid")]
147 pub default: String,
148}
149
150impl Default for TierDefaults {
151 fn default() -> Self {
152 Self {
153 default: default_mid(),
154 }
155 }
156}
157
158fn default_mid() -> String {
159 "mid".to_string()
160}
161
162pub fn load_config() -> &'static ProvidersConfig {
168 CONFIG.get_or_init(|| {
169 if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
171 match std::fs::read_to_string(&path) {
172 Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
173 Ok(config) => {
174 eprintln!(
175 "[llm_config] Loaded {} providers, {} aliases from {}",
176 config.providers.len(),
177 config.aliases.len(),
178 path
179 );
180 let _ = CONFIG_PATH.set(path);
181 return config;
182 }
183 Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
184 },
185 Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
186 }
187 }
188 if let Some(home) = dirs_or_home() {
190 let path = format!("{home}/.config/harn/providers.toml");
191 if let Ok(content) = std::fs::read_to_string(&path) {
192 if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
193 let _ = CONFIG_PATH.set(path);
194 return config;
195 }
196 }
197 }
198 default_config()
200 })
201}
202
203pub fn loaded_config_path() -> Option<std::path::PathBuf> {
206 let _ = load_config();
208 CONFIG_PATH.get().map(std::path::PathBuf::from)
209}
210
211pub fn resolve_model(alias: &str) -> (String, Option<String>) {
213 let config = load_config();
214 if let Some(a) = config.aliases.get(alias) {
215 return (a.id.clone(), Some(a.provider.clone()));
216 }
217 (alias.to_string(), None)
218}
219
220pub fn infer_provider(model_id: &str) -> String {
222 let config = load_config();
223 for rule in &config.inference_rules {
224 if let Some(exact) = &rule.exact {
225 if model_id == exact {
226 return rule.provider.clone();
227 }
228 }
229 if let Some(pattern) = &rule.pattern {
230 if glob_match(pattern, model_id) {
231 return rule.provider.clone();
232 }
233 }
234 if let Some(substr) = &rule.contains {
235 if model_id.contains(substr.as_str()) {
236 return rule.provider.clone();
237 }
238 }
239 }
240 if model_id.starts_with("claude-") {
242 return "anthropic".to_string();
243 }
244 if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
245 return "openai".to_string();
246 }
247 if model_id.contains('/') {
248 return "openrouter".to_string();
249 }
250 if model_id.contains(':') {
251 return "ollama".to_string();
252 }
253 "anthropic".to_string()
254}
255
256pub fn model_tier(model_id: &str) -> String {
258 let config = load_config();
259 for rule in &config.tier_rules {
260 if let Some(exact) = &rule.exact {
261 if model_id == exact {
262 return rule.tier.clone();
263 }
264 }
265 if let Some(pattern) = &rule.pattern {
266 if glob_match(pattern, model_id) {
267 return rule.tier.clone();
268 }
269 }
270 if let Some(substr) = &rule.contains {
271 if model_id.contains(substr.as_str()) {
272 return rule.tier.clone();
273 }
274 }
275 }
276 let lower = model_id.to_lowercase();
278 if lower.contains("9b") || lower.contains("a3b") {
279 return "small".to_string();
280 }
281 if lower.starts_with("claude-") || lower == "gpt-4o" {
282 return "frontier".to_string();
283 }
284 config.tier_defaults.default.clone()
285}
286
287pub fn provider_config(name: &str) -> Option<&'static ProviderDef> {
289 load_config().providers.get(name)
290}
291
292pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
295 let config = load_config();
296 let mut params = BTreeMap::new();
297 for (pattern, defaults) in &config.model_defaults {
298 if glob_match(pattern, model_id) {
299 for (k, v) in defaults {
300 params.insert(k.clone(), v.clone());
301 }
302 }
303 }
304 params
305}
306
307pub fn provider_names() -> Vec<String> {
309 load_config().providers.keys().cloned().collect()
310}
311
312pub fn provider_has_feature(provider: &str, feature: &str) -> bool {
314 provider_config(provider)
315 .map(|p| p.features.iter().any(|f| f == feature))
316 .unwrap_or(false)
317}
318
319pub fn default_tool_format(model: &str, provider: &str) -> String {
322 let config = load_config();
323 for (name, alias) in &config.aliases {
325 let matches = (alias.id == model && alias.provider == provider) || name == model;
326 if matches {
327 if let Some(ref fmt) = alias.tool_format {
328 return fmt.clone();
329 }
330 }
331 }
332 if provider_has_feature(provider, "native_tools") {
334 "native".to_string()
335 } else {
336 "text".to_string()
337 }
338}
339
340pub fn resolve_tier_model(
342 target: &str,
343 preferred_provider: Option<&str>,
344) -> Option<(String, String)> {
345 let config = load_config();
346
347 if let Some(alias) = config.aliases.get(target) {
348 return Some((alias.id.clone(), alias.provider.clone()));
349 }
350
351 let candidate_aliases = if let Some(provider) = preferred_provider {
352 vec![
353 format!("{provider}/{target}"),
354 format!("{provider}:{target}"),
355 format!("tier/{target}"),
356 target.to_string(),
357 ]
358 } else {
359 vec![format!("tier/{target}"), target.to_string()]
360 };
361
362 for alias_name in candidate_aliases {
363 if let Some(alias) = config.aliases.get(&alias_name) {
364 return Some((alias.id.clone(), alias.provider.clone()));
365 }
366 }
367
368 None
369}
370
371fn glob_match(pattern: &str, input: &str) -> bool {
377 if let Some(prefix) = pattern.strip_suffix('*') {
378 input.starts_with(prefix)
379 } else if let Some(suffix) = pattern.strip_prefix('*') {
380 input.ends_with(suffix)
381 } else if pattern.contains('*') {
382 let parts: Vec<&str> = pattern.split('*').collect();
383 if parts.len() == 2 {
384 input.starts_with(parts[0]) && input.ends_with(parts[1])
385 } else {
386 input == pattern
387 }
388 } else {
389 input == pattern
390 }
391}
392
393fn dirs_or_home() -> Option<String> {
394 std::env::var("HOME").ok()
395}
396
397pub fn resolve_base_url(pdef: &ProviderDef) -> String {
400 if let Some(env_name) = &pdef.base_url_env {
401 if let Ok(val) = std::env::var(env_name) {
402 let trimmed = val.trim().trim_matches('"').trim_matches('\'');
404 if !trimmed.is_empty() {
405 return trimmed.to_string();
406 }
407 }
408 }
409 pdef.base_url.clone()
410}
411
412fn default_config() -> ProvidersConfig {
417 let mut config = ProvidersConfig::default();
418
419 config.providers.insert(
421 "anthropic".to_string(),
422 ProviderDef {
423 base_url: "https://api.anthropic.com/v1".to_string(),
424 auth_style: "header".to_string(),
425 auth_header: Some("x-api-key".to_string()),
426 auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
427 extra_headers: BTreeMap::from([(
428 "anthropic-version".to_string(),
429 "2023-06-01".to_string(),
430 )]),
431 chat_endpoint: "/messages".to_string(),
432 completion_endpoint: None,
433 healthcheck: Some(HealthcheckDef {
434 method: "POST".to_string(),
435 path: Some("/messages/count_tokens".to_string()),
436 url: None,
437 body: Some(
438 r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
439 .to_string(),
440 ),
441 }),
442 features: vec!["prompt_caching".to_string(), "thinking".to_string()],
443 ..Default::default()
444 },
445 );
446
447 config.providers.insert(
449 "openai".to_string(),
450 ProviderDef {
451 base_url: "https://api.openai.com/v1".to_string(),
452 auth_style: "bearer".to_string(),
453 auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
454 chat_endpoint: "/chat/completions".to_string(),
455 completion_endpoint: Some("/completions".to_string()),
456 healthcheck: Some(HealthcheckDef {
457 method: "GET".to_string(),
458 path: Some("/models".to_string()),
459 url: None,
460 body: None,
461 }),
462 ..Default::default()
463 },
464 );
465
466 config.providers.insert(
468 "openrouter".to_string(),
469 ProviderDef {
470 base_url: "https://openrouter.ai/api/v1".to_string(),
471 auth_style: "bearer".to_string(),
472 auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
473 chat_endpoint: "/chat/completions".to_string(),
474 completion_endpoint: Some("/completions".to_string()),
475 healthcheck: Some(HealthcheckDef {
476 method: "GET".to_string(),
477 path: Some("/auth/key".to_string()),
478 url: None,
479 body: None,
480 }),
481 ..Default::default()
482 },
483 );
484
485 config.providers.insert(
487 "huggingface".to_string(),
488 ProviderDef {
489 base_url: "https://router.huggingface.co/v1".to_string(),
490 auth_style: "bearer".to_string(),
491 auth_env: AuthEnv::Multiple(vec![
492 "HF_TOKEN".to_string(),
493 "HUGGINGFACE_API_KEY".to_string(),
494 ]),
495 chat_endpoint: "/chat/completions".to_string(),
496 completion_endpoint: Some("/completions".to_string()),
497 healthcheck: Some(HealthcheckDef {
498 method: "GET".to_string(),
499 url: Some("https://huggingface.co/api/whoami-v2".to_string()),
500 path: None,
501 body: None,
502 }),
503 ..Default::default()
504 },
505 );
506
507 config.providers.insert(
509 "ollama".to_string(),
510 ProviderDef {
511 base_url: "http://localhost:11434".to_string(),
512 base_url_env: Some("OLLAMA_HOST".to_string()),
513 auth_style: "none".to_string(),
514 chat_endpoint: "/api/chat".to_string(),
515 completion_endpoint: Some("/api/generate".to_string()),
516 healthcheck: Some(HealthcheckDef {
517 method: "GET".to_string(),
518 path: Some("/api/tags".to_string()),
519 url: None,
520 body: None,
521 }),
522 ..Default::default()
523 },
524 );
525
526 config.providers.insert(
528 "together".to_string(),
529 ProviderDef {
530 base_url: "https://api.together.xyz/v1".to_string(),
531 base_url_env: Some("TOGETHER_AI_BASE_URL".to_string()),
532 auth_style: "bearer".to_string(),
533 auth_env: AuthEnv::Single("TOGETHER_AI_API_KEY".to_string()),
534 chat_endpoint: "/chat/completions".to_string(),
535 completion_endpoint: Some("/completions".to_string()),
536 healthcheck: Some(HealthcheckDef {
537 method: "GET".to_string(),
538 path: Some("/models".to_string()),
539 url: None,
540 body: None,
541 }),
542 ..Default::default()
543 },
544 );
545
546 config.providers.insert(
548 "local".to_string(),
549 ProviderDef {
550 base_url: "http://localhost:8000".to_string(),
551 base_url_env: Some("LOCAL_LLM_BASE_URL".to_string()),
552 auth_style: "none".to_string(),
553 chat_endpoint: "/v1/chat/completions".to_string(),
554 completion_endpoint: Some("/v1/completions".to_string()),
555 healthcheck: Some(HealthcheckDef {
556 method: "GET".to_string(),
557 path: Some("/v1/models".to_string()),
558 url: None,
559 body: None,
560 }),
561 ..Default::default()
562 },
563 );
564
565 config.inference_rules = vec![
567 InferenceRule {
568 pattern: Some("claude-*".to_string()),
569 contains: None,
570 exact: None,
571 provider: "anthropic".to_string(),
572 },
573 InferenceRule {
574 pattern: Some("gpt-*".to_string()),
575 contains: None,
576 exact: None,
577 provider: "openai".to_string(),
578 },
579 InferenceRule {
580 pattern: Some("o1*".to_string()),
581 contains: None,
582 exact: None,
583 provider: "openai".to_string(),
584 },
585 InferenceRule {
586 pattern: Some("o3*".to_string()),
587 contains: None,
588 exact: None,
589 provider: "openai".to_string(),
590 },
591 InferenceRule {
592 pattern: None,
593 contains: Some("/".to_string()),
594 exact: None,
595 provider: "openrouter".to_string(),
596 },
597 InferenceRule {
598 pattern: None,
599 contains: Some(":".to_string()),
600 exact: None,
601 provider: "ollama".to_string(),
602 },
603 ];
604
605 config.tier_rules = vec![
607 TierRule {
608 contains: Some("9b".to_string()),
609 pattern: None,
610 exact: None,
611 tier: "small".to_string(),
612 },
613 TierRule {
614 contains: Some("a3b".to_string()),
615 pattern: None,
616 exact: None,
617 tier: "small".to_string(),
618 },
619 TierRule {
620 pattern: Some("claude-*".to_string()),
621 contains: None,
622 exact: None,
623 tier: "frontier".to_string(),
624 },
625 TierRule {
626 exact: Some("gpt-4o".to_string()),
627 contains: None,
628 pattern: None,
629 tier: "frontier".to_string(),
630 },
631 ];
632
633 config.tier_defaults = TierDefaults {
634 default: "mid".to_string(),
635 };
636
637 config.aliases.insert(
638 "frontier".to_string(),
639 AliasDef {
640 id: "claude-sonnet-4-20250514".to_string(),
641 provider: "anthropic".to_string(),
642 tool_format: None,
643 },
644 );
645 config.aliases.insert(
646 "tier/frontier".to_string(),
647 AliasDef {
648 id: "claude-sonnet-4-20250514".to_string(),
649 provider: "anthropic".to_string(),
650 tool_format: None,
651 },
652 );
653 config.aliases.insert(
654 "mid".to_string(),
655 AliasDef {
656 id: "gpt-4o-mini".to_string(),
657 provider: "openai".to_string(),
658 tool_format: None,
659 },
660 );
661 config.aliases.insert(
662 "tier/mid".to_string(),
663 AliasDef {
664 id: "gpt-4o-mini".to_string(),
665 provider: "openai".to_string(),
666 tool_format: None,
667 },
668 );
669 config.aliases.insert(
670 "small".to_string(),
671 AliasDef {
672 id: "Qwen/Qwen3.5-9B".to_string(),
673 provider: "openrouter".to_string(),
674 tool_format: None,
675 },
676 );
677 config.aliases.insert(
678 "tier/small".to_string(),
679 AliasDef {
680 id: "Qwen/Qwen3.5-9B".to_string(),
681 provider: "openrouter".to_string(),
682 tool_format: None,
683 },
684 );
685
686 config
687}
688
689#[cfg(test)]
694mod tests {
695 use super::*;
696
697 #[test]
698 fn test_glob_match_prefix() {
699 assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
700 assert!(glob_match("gpt-*", "gpt-4o"));
701 assert!(!glob_match("claude-*", "gpt-4o"));
702 }
703
704 #[test]
705 fn test_glob_match_suffix() {
706 assert!(glob_match("*-latest", "llama3.2-latest"));
707 assert!(!glob_match("*-latest", "llama3.2"));
708 }
709
710 #[test]
711 fn test_glob_match_middle() {
712 assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
713 assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
714 }
715
716 #[test]
717 fn test_glob_match_exact() {
718 assert!(glob_match("gpt-4o", "gpt-4o"));
719 assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
720 }
721
722 #[test]
723 fn test_infer_provider_from_defaults() {
724 assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
726 assert_eq!(infer_provider("gpt-4o"), "openai");
727 assert_eq!(infer_provider("o1-preview"), "openai");
728 assert_eq!(infer_provider("o3-mini"), "openai");
729 assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
730 assert_eq!(infer_provider("llama3.2:latest"), "ollama");
731 assert_eq!(infer_provider("unknown-model"), "anthropic");
732 }
733
734 #[test]
735 fn test_model_tier_from_defaults() {
736 assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
737 assert_eq!(model_tier("gpt-4o"), "frontier");
738 assert_eq!(model_tier("Qwen3.5-9B"), "small");
739 assert_eq!(model_tier("deepseek-v3"), "mid");
740 }
741
742 #[test]
743 fn test_resolve_model_unknown_alias() {
744 let (id, provider) = resolve_model("gpt-4o");
745 assert_eq!(id, "gpt-4o");
746 assert!(provider.is_none());
747 }
748
749 #[test]
750 fn test_provider_names() {
751 let names = provider_names();
752 assert!(names.len() >= 7);
753 assert!(names.contains(&"anthropic".to_string()));
754 assert!(names.contains(&"together".to_string()));
755 assert!(names.contains(&"local".to_string()));
756 assert!(names.contains(&"openai".to_string()));
757 assert!(names.contains(&"ollama".to_string()));
758 }
759
760 #[test]
761 fn test_resolve_tier_model_default_aliases() {
762 let (model, provider) = resolve_tier_model("frontier", None).unwrap();
763 assert_eq!(model, "claude-sonnet-4-20250514");
764 assert_eq!(provider, "anthropic");
765
766 let (model, provider) = resolve_tier_model("small", None).unwrap();
767 assert_eq!(model, "Qwen/Qwen3.5-9B");
768 assert_eq!(provider, "openrouter");
769 }
770
771 #[test]
772 fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
773 let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
774 assert_eq!(model, "gpt-4o-mini");
775 assert_eq!(provider, "openai");
776 }
777
778 #[test]
779 fn test_provider_config_anthropic() {
780 let pdef = provider_config("anthropic").unwrap();
781 assert_eq!(pdef.auth_style, "header");
782 assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
783 }
784
785 #[test]
786 fn test_resolve_base_url_no_env() {
787 let pdef = ProviderDef {
788 base_url: "https://example.com".to_string(),
789 ..Default::default()
790 };
791 assert_eq!(resolve_base_url(&pdef), "https://example.com");
792 }
793
794 #[test]
795 fn test_default_config_roundtrip() {
796 let config = default_config();
797 assert!(!config.providers.is_empty());
798 assert!(!config.inference_rules.is_empty());
799 assert!(!config.tier_rules.is_empty());
800 assert_eq!(config.tier_defaults.default, "mid");
801 }
802
803 #[test]
804 fn test_model_params_empty() {
805 let params = model_params("claude-sonnet-4-20250514");
806 assert!(params.is_empty());
808 }
809}