1use serde::Deserialize;
2use std::collections::BTreeMap;
3use std::sync::OnceLock;
4
5static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
6static CONFIG_PATH: OnceLock<String> = OnceLock::new();
7
8#[derive(Debug, Clone, Deserialize, Default)]
13pub struct ProvidersConfig {
14 #[serde(default)]
15 pub providers: BTreeMap<String, ProviderDef>,
16 #[serde(default)]
17 pub aliases: BTreeMap<String, AliasDef>,
18 #[serde(default)]
19 pub inference_rules: Vec<InferenceRule>,
20 #[serde(default)]
21 pub tier_rules: Vec<TierRule>,
22 #[serde(default)]
23 pub tier_defaults: TierDefaults,
24 #[serde(default)]
25 pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
26}
27
28#[derive(Debug, Clone, Deserialize)]
29pub struct ProviderDef {
30 pub base_url: String,
31 #[serde(default)]
32 pub base_url_env: Option<String>,
33 #[serde(default = "default_bearer")]
34 pub auth_style: String,
35 #[serde(default)]
36 pub auth_header: Option<String>,
37 #[serde(default)]
38 pub auth_env: AuthEnv,
39 #[serde(default)]
40 pub extra_headers: BTreeMap<String, String>,
41 #[serde(default)]
42 pub chat_endpoint: String,
43 #[serde(default)]
44 pub completion_endpoint: Option<String>,
45 #[serde(default)]
46 pub healthcheck: Option<HealthcheckDef>,
47 #[serde(default)]
48 pub features: Vec<String>,
49 #[serde(default)]
51 pub fallback: Option<String>,
52 #[serde(default)]
54 pub retry_count: Option<u32>,
55 #[serde(default)]
57 pub retry_delay_ms: Option<u64>,
58 #[serde(default)]
60 pub rpm: Option<u32>,
61}
62
63impl Default for ProviderDef {
64 fn default() -> Self {
65 Self {
66 base_url: String::new(),
67 base_url_env: None,
68 auth_style: default_bearer(),
69 auth_header: None,
70 auth_env: AuthEnv::None,
71 extra_headers: BTreeMap::new(),
72 chat_endpoint: String::new(),
73 completion_endpoint: None,
74 healthcheck: None,
75 features: Vec::new(),
76 fallback: None,
77 retry_count: None,
78 retry_delay_ms: None,
79 rpm: None,
80 }
81 }
82}
83
84fn default_bearer() -> String {
85 "bearer".to_string()
86}
87
88#[derive(Debug, Clone, Deserialize, Default)]
91#[serde(untagged)]
92pub enum AuthEnv {
93 #[default]
94 None,
95 Single(String),
96 Multiple(Vec<String>),
97}
98
99#[derive(Debug, Clone, Deserialize)]
100pub struct HealthcheckDef {
101 pub method: String,
102 #[serde(default)]
103 pub path: Option<String>,
104 #[serde(default)]
105 pub url: Option<String>,
106 #[serde(default)]
107 pub body: Option<String>,
108}
109
110#[derive(Debug, Clone, Deserialize)]
111pub struct AliasDef {
112 pub id: String,
113 pub provider: String,
114 #[serde(default)]
119 pub tool_format: Option<String>,
120}
121
122#[derive(Debug, Clone, Deserialize)]
123pub struct InferenceRule {
124 #[serde(default)]
125 pub pattern: Option<String>,
126 #[serde(default)]
127 pub contains: Option<String>,
128 #[serde(default)]
129 pub exact: Option<String>,
130 pub provider: String,
131}
132
133#[derive(Debug, Clone, Deserialize)]
134pub struct TierRule {
135 #[serde(default)]
136 pub pattern: Option<String>,
137 #[serde(default)]
138 pub contains: Option<String>,
139 #[serde(default)]
140 pub exact: Option<String>,
141 pub tier: String,
142}
143
144#[derive(Debug, Clone, Deserialize)]
145pub struct TierDefaults {
146 #[serde(default = "default_mid")]
147 pub default: String,
148}
149
150impl Default for TierDefaults {
151 fn default() -> Self {
152 Self {
153 default: default_mid(),
154 }
155 }
156}
157
158fn default_mid() -> String {
159 "mid".to_string()
160}
161
162pub fn load_config() -> &'static ProvidersConfig {
168 CONFIG.get_or_init(|| {
169 let verbose_config_logging = matches!(
170 std::env::var("HARN_VERBOSE_CONFIG").ok().as_deref(),
171 Some("1" | "true" | "TRUE" | "yes" | "YES")
172 ) || matches!(
173 std::env::var("HARN_ACP_VERBOSE").ok().as_deref(),
174 Some("1" | "true" | "TRUE" | "yes" | "YES")
175 );
176 if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
178 match std::fs::read_to_string(&path) {
179 Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
180 Ok(config) => {
181 if verbose_config_logging {
182 eprintln!(
183 "[llm_config] Loaded {} providers, {} aliases from {}",
184 config.providers.len(),
185 config.aliases.len(),
186 path
187 );
188 }
189 let _ = CONFIG_PATH.set(path);
190 return config;
191 }
192 Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
193 },
194 Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
195 }
196 }
197 if let Some(home) = dirs_or_home() {
199 let path = format!("{home}/.config/harn/providers.toml");
200 if let Ok(content) = std::fs::read_to_string(&path) {
201 if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
202 let _ = CONFIG_PATH.set(path);
203 return config;
204 }
205 }
206 }
207 default_config()
209 })
210}
211
212pub fn loaded_config_path() -> Option<std::path::PathBuf> {
215 let _ = load_config();
217 CONFIG_PATH.get().map(std::path::PathBuf::from)
218}
219
220pub fn resolve_model(alias: &str) -> (String, Option<String>) {
222 let config = load_config();
223 if let Some(a) = config.aliases.get(alias) {
224 return (a.id.clone(), Some(a.provider.clone()));
225 }
226 (alias.to_string(), None)
227}
228
229pub fn infer_provider(model_id: &str) -> String {
231 let config = load_config();
232 for rule in &config.inference_rules {
233 if let Some(exact) = &rule.exact {
234 if model_id == exact {
235 return rule.provider.clone();
236 }
237 }
238 if let Some(pattern) = &rule.pattern {
239 if glob_match(pattern, model_id) {
240 return rule.provider.clone();
241 }
242 }
243 if let Some(substr) = &rule.contains {
244 if model_id.contains(substr.as_str()) {
245 return rule.provider.clone();
246 }
247 }
248 }
249 if model_id.starts_with("claude-") {
251 return "anthropic".to_string();
252 }
253 if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
254 return "openai".to_string();
255 }
256 if model_id.contains('/') {
257 return "openrouter".to_string();
258 }
259 if model_id.contains(':') {
260 return "ollama".to_string();
261 }
262 "anthropic".to_string()
263}
264
265pub fn model_tier(model_id: &str) -> String {
267 let config = load_config();
268 for rule in &config.tier_rules {
269 if let Some(exact) = &rule.exact {
270 if model_id == exact {
271 return rule.tier.clone();
272 }
273 }
274 if let Some(pattern) = &rule.pattern {
275 if glob_match(pattern, model_id) {
276 return rule.tier.clone();
277 }
278 }
279 if let Some(substr) = &rule.contains {
280 if model_id.contains(substr.as_str()) {
281 return rule.tier.clone();
282 }
283 }
284 }
285 let lower = model_id.to_lowercase();
287 if lower.contains("9b") || lower.contains("a3b") {
288 return "small".to_string();
289 }
290 if lower.starts_with("claude-") || lower == "gpt-4o" {
291 return "frontier".to_string();
292 }
293 config.tier_defaults.default.clone()
294}
295
296pub fn provider_config(name: &str) -> Option<&'static ProviderDef> {
298 load_config().providers.get(name)
299}
300
301pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
304 let config = load_config();
305 let mut params = BTreeMap::new();
306 for (pattern, defaults) in &config.model_defaults {
307 if glob_match(pattern, model_id) {
308 for (k, v) in defaults {
309 params.insert(k.clone(), v.clone());
310 }
311 }
312 }
313 params
314}
315
316pub fn provider_names() -> Vec<String> {
318 load_config().providers.keys().cloned().collect()
319}
320
321pub fn provider_has_feature(provider: &str, feature: &str) -> bool {
323 provider_config(provider)
324 .map(|p| p.features.iter().any(|f| f == feature))
325 .unwrap_or(false)
326}
327
328pub fn default_tool_format(model: &str, provider: &str) -> String {
331 let config = load_config();
332 for (name, alias) in &config.aliases {
334 let matches = (alias.id == model && alias.provider == provider) || name == model;
335 if matches {
336 if let Some(ref fmt) = alias.tool_format {
337 return fmt.clone();
338 }
339 }
340 }
341 if provider_has_feature(provider, "native_tools") {
343 "native".to_string()
344 } else {
345 "text".to_string()
346 }
347}
348
349pub fn resolve_tier_model(
351 target: &str,
352 preferred_provider: Option<&str>,
353) -> Option<(String, String)> {
354 let config = load_config();
355
356 if let Some(alias) = config.aliases.get(target) {
357 return Some((alias.id.clone(), alias.provider.clone()));
358 }
359
360 let candidate_aliases = if let Some(provider) = preferred_provider {
361 vec![
362 format!("{provider}/{target}"),
363 format!("{provider}:{target}"),
364 format!("tier/{target}"),
365 target.to_string(),
366 ]
367 } else {
368 vec![format!("tier/{target}"), target.to_string()]
369 };
370
371 for alias_name in candidate_aliases {
372 if let Some(alias) = config.aliases.get(&alias_name) {
373 return Some((alias.id.clone(), alias.provider.clone()));
374 }
375 }
376
377 None
378}
379
380fn glob_match(pattern: &str, input: &str) -> bool {
386 if let Some(prefix) = pattern.strip_suffix('*') {
387 input.starts_with(prefix)
388 } else if let Some(suffix) = pattern.strip_prefix('*') {
389 input.ends_with(suffix)
390 } else if pattern.contains('*') {
391 let parts: Vec<&str> = pattern.split('*').collect();
392 if parts.len() == 2 {
393 input.starts_with(parts[0]) && input.ends_with(parts[1])
394 } else {
395 input == pattern
396 }
397 } else {
398 input == pattern
399 }
400}
401
402fn dirs_or_home() -> Option<String> {
403 std::env::var("HOME").ok()
404}
405
406pub fn resolve_base_url(pdef: &ProviderDef) -> String {
409 if let Some(env_name) = &pdef.base_url_env {
410 if let Ok(val) = std::env::var(env_name) {
411 let trimmed = val.trim().trim_matches('"').trim_matches('\'');
413 if !trimmed.is_empty() {
414 return trimmed.to_string();
415 }
416 }
417 }
418 pdef.base_url.clone()
419}
420
421fn default_config() -> ProvidersConfig {
426 let mut config = ProvidersConfig::default();
427
428 config.providers.insert(
430 "anthropic".to_string(),
431 ProviderDef {
432 base_url: "https://api.anthropic.com/v1".to_string(),
433 auth_style: "header".to_string(),
434 auth_header: Some("x-api-key".to_string()),
435 auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
436 extra_headers: BTreeMap::from([(
437 "anthropic-version".to_string(),
438 "2023-06-01".to_string(),
439 )]),
440 chat_endpoint: "/messages".to_string(),
441 completion_endpoint: None,
442 healthcheck: Some(HealthcheckDef {
443 method: "POST".to_string(),
444 path: Some("/messages/count_tokens".to_string()),
445 url: None,
446 body: Some(
447 r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
448 .to_string(),
449 ),
450 }),
451 features: vec!["prompt_caching".to_string(), "thinking".to_string()],
452 ..Default::default()
453 },
454 );
455
456 config.providers.insert(
458 "openai".to_string(),
459 ProviderDef {
460 base_url: "https://api.openai.com/v1".to_string(),
461 auth_style: "bearer".to_string(),
462 auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
463 chat_endpoint: "/chat/completions".to_string(),
464 completion_endpoint: Some("/completions".to_string()),
465 healthcheck: Some(HealthcheckDef {
466 method: "GET".to_string(),
467 path: Some("/models".to_string()),
468 url: None,
469 body: None,
470 }),
471 ..Default::default()
472 },
473 );
474
475 config.providers.insert(
477 "openrouter".to_string(),
478 ProviderDef {
479 base_url: "https://openrouter.ai/api/v1".to_string(),
480 auth_style: "bearer".to_string(),
481 auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
482 chat_endpoint: "/chat/completions".to_string(),
483 completion_endpoint: Some("/completions".to_string()),
484 healthcheck: Some(HealthcheckDef {
485 method: "GET".to_string(),
486 path: Some("/auth/key".to_string()),
487 url: None,
488 body: None,
489 }),
490 ..Default::default()
491 },
492 );
493
494 config.providers.insert(
496 "huggingface".to_string(),
497 ProviderDef {
498 base_url: "https://router.huggingface.co/v1".to_string(),
499 auth_style: "bearer".to_string(),
500 auth_env: AuthEnv::Multiple(vec![
501 "HF_TOKEN".to_string(),
502 "HUGGINGFACE_API_KEY".to_string(),
503 ]),
504 chat_endpoint: "/chat/completions".to_string(),
505 completion_endpoint: Some("/completions".to_string()),
506 healthcheck: Some(HealthcheckDef {
507 method: "GET".to_string(),
508 url: Some("https://huggingface.co/api/whoami-v2".to_string()),
509 path: None,
510 body: None,
511 }),
512 ..Default::default()
513 },
514 );
515
516 config.providers.insert(
518 "ollama".to_string(),
519 ProviderDef {
520 base_url: "http://localhost:11434".to_string(),
521 base_url_env: Some("OLLAMA_HOST".to_string()),
522 auth_style: "none".to_string(),
523 chat_endpoint: "/api/chat".to_string(),
524 completion_endpoint: Some("/api/generate".to_string()),
525 healthcheck: Some(HealthcheckDef {
526 method: "GET".to_string(),
527 path: Some("/api/tags".to_string()),
528 url: None,
529 body: None,
530 }),
531 ..Default::default()
532 },
533 );
534
535 config.providers.insert(
537 "together".to_string(),
538 ProviderDef {
539 base_url: "https://api.together.xyz/v1".to_string(),
540 base_url_env: Some("TOGETHER_AI_BASE_URL".to_string()),
541 auth_style: "bearer".to_string(),
542 auth_env: AuthEnv::Single("TOGETHER_AI_API_KEY".to_string()),
543 chat_endpoint: "/chat/completions".to_string(),
544 completion_endpoint: Some("/completions".to_string()),
545 healthcheck: Some(HealthcheckDef {
546 method: "GET".to_string(),
547 path: Some("/models".to_string()),
548 url: None,
549 body: None,
550 }),
551 ..Default::default()
552 },
553 );
554
555 config.providers.insert(
557 "local".to_string(),
558 ProviderDef {
559 base_url: "http://localhost:8000".to_string(),
560 base_url_env: Some("LOCAL_LLM_BASE_URL".to_string()),
561 auth_style: "none".to_string(),
562 chat_endpoint: "/v1/chat/completions".to_string(),
563 completion_endpoint: Some("/v1/completions".to_string()),
564 healthcheck: Some(HealthcheckDef {
565 method: "GET".to_string(),
566 path: Some("/v1/models".to_string()),
567 url: None,
568 body: None,
569 }),
570 ..Default::default()
571 },
572 );
573
574 config.inference_rules = vec![
576 InferenceRule {
577 pattern: Some("claude-*".to_string()),
578 contains: None,
579 exact: None,
580 provider: "anthropic".to_string(),
581 },
582 InferenceRule {
583 pattern: Some("gpt-*".to_string()),
584 contains: None,
585 exact: None,
586 provider: "openai".to_string(),
587 },
588 InferenceRule {
589 pattern: Some("o1*".to_string()),
590 contains: None,
591 exact: None,
592 provider: "openai".to_string(),
593 },
594 InferenceRule {
595 pattern: Some("o3*".to_string()),
596 contains: None,
597 exact: None,
598 provider: "openai".to_string(),
599 },
600 InferenceRule {
601 pattern: None,
602 contains: Some("/".to_string()),
603 exact: None,
604 provider: "openrouter".to_string(),
605 },
606 InferenceRule {
607 pattern: None,
608 contains: Some(":".to_string()),
609 exact: None,
610 provider: "ollama".to_string(),
611 },
612 ];
613
614 config.tier_rules = vec![
616 TierRule {
617 contains: Some("9b".to_string()),
618 pattern: None,
619 exact: None,
620 tier: "small".to_string(),
621 },
622 TierRule {
623 contains: Some("a3b".to_string()),
624 pattern: None,
625 exact: None,
626 tier: "small".to_string(),
627 },
628 TierRule {
629 pattern: Some("claude-*".to_string()),
630 contains: None,
631 exact: None,
632 tier: "frontier".to_string(),
633 },
634 TierRule {
635 exact: Some("gpt-4o".to_string()),
636 contains: None,
637 pattern: None,
638 tier: "frontier".to_string(),
639 },
640 ];
641
642 config.tier_defaults = TierDefaults {
643 default: "mid".to_string(),
644 };
645
646 config.aliases.insert(
647 "frontier".to_string(),
648 AliasDef {
649 id: "claude-sonnet-4-20250514".to_string(),
650 provider: "anthropic".to_string(),
651 tool_format: None,
652 },
653 );
654 config.aliases.insert(
655 "tier/frontier".to_string(),
656 AliasDef {
657 id: "claude-sonnet-4-20250514".to_string(),
658 provider: "anthropic".to_string(),
659 tool_format: None,
660 },
661 );
662 config.aliases.insert(
663 "mid".to_string(),
664 AliasDef {
665 id: "gpt-4o-mini".to_string(),
666 provider: "openai".to_string(),
667 tool_format: None,
668 },
669 );
670 config.aliases.insert(
671 "tier/mid".to_string(),
672 AliasDef {
673 id: "gpt-4o-mini".to_string(),
674 provider: "openai".to_string(),
675 tool_format: None,
676 },
677 );
678 config.aliases.insert(
679 "small".to_string(),
680 AliasDef {
681 id: "Qwen/Qwen3.5-9B".to_string(),
682 provider: "openrouter".to_string(),
683 tool_format: None,
684 },
685 );
686 config.aliases.insert(
687 "tier/small".to_string(),
688 AliasDef {
689 id: "Qwen/Qwen3.5-9B".to_string(),
690 provider: "openrouter".to_string(),
691 tool_format: None,
692 },
693 );
694
695 config
696}
697
698#[cfg(test)]
703mod tests {
704 use super::*;
705
706 #[test]
707 fn test_glob_match_prefix() {
708 assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
709 assert!(glob_match("gpt-*", "gpt-4o"));
710 assert!(!glob_match("claude-*", "gpt-4o"));
711 }
712
713 #[test]
714 fn test_glob_match_suffix() {
715 assert!(glob_match("*-latest", "llama3.2-latest"));
716 assert!(!glob_match("*-latest", "llama3.2"));
717 }
718
719 #[test]
720 fn test_glob_match_middle() {
721 assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
722 assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
723 }
724
725 #[test]
726 fn test_glob_match_exact() {
727 assert!(glob_match("gpt-4o", "gpt-4o"));
728 assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
729 }
730
731 #[test]
732 fn test_infer_provider_from_defaults() {
733 assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
735 assert_eq!(infer_provider("gpt-4o"), "openai");
736 assert_eq!(infer_provider("o1-preview"), "openai");
737 assert_eq!(infer_provider("o3-mini"), "openai");
738 assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
739 assert_eq!(infer_provider("llama3.2:latest"), "ollama");
740 assert_eq!(infer_provider("unknown-model"), "anthropic");
741 }
742
743 #[test]
744 fn test_model_tier_from_defaults() {
745 assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
746 assert_eq!(model_tier("gpt-4o"), "frontier");
747 assert_eq!(model_tier("Qwen3.5-9B"), "small");
748 assert_eq!(model_tier("deepseek-v3"), "mid");
749 }
750
751 #[test]
752 fn test_resolve_model_unknown_alias() {
753 let (id, provider) = resolve_model("gpt-4o");
754 assert_eq!(id, "gpt-4o");
755 assert!(provider.is_none());
756 }
757
758 #[test]
759 fn test_provider_names() {
760 let names = provider_names();
761 assert!(names.len() >= 7);
762 assert!(names.contains(&"anthropic".to_string()));
763 assert!(names.contains(&"together".to_string()));
764 assert!(names.contains(&"local".to_string()));
765 assert!(names.contains(&"openai".to_string()));
766 assert!(names.contains(&"ollama".to_string()));
767 }
768
769 #[test]
770 fn test_resolve_tier_model_default_aliases() {
771 let (model, provider) = resolve_tier_model("frontier", None).unwrap();
772 assert_eq!(model, "claude-sonnet-4-20250514");
773 assert_eq!(provider, "anthropic");
774
775 let (model, provider) = resolve_tier_model("small", None).unwrap();
776 assert_eq!(model, "Qwen/Qwen3.5-9B");
777 assert_eq!(provider, "openrouter");
778 }
779
780 #[test]
781 fn test_resolve_tier_model_prefers_provider_scoped_aliases() {
782 let (model, provider) = resolve_tier_model("mid", Some("openai")).unwrap();
783 assert_eq!(model, "gpt-4o-mini");
784 assert_eq!(provider, "openai");
785 }
786
787 #[test]
788 fn test_provider_config_anthropic() {
789 let pdef = provider_config("anthropic").unwrap();
790 assert_eq!(pdef.auth_style, "header");
791 assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
792 }
793
794 #[test]
795 fn test_resolve_base_url_no_env() {
796 let pdef = ProviderDef {
797 base_url: "https://example.com".to_string(),
798 ..Default::default()
799 };
800 assert_eq!(resolve_base_url(&pdef), "https://example.com");
801 }
802
803 #[test]
804 fn test_default_config_roundtrip() {
805 let config = default_config();
806 assert!(!config.providers.is_empty());
807 assert!(!config.inference_rules.is_empty());
808 assert!(!config.tier_rules.is_empty());
809 assert_eq!(config.tier_defaults.default, "mid");
810 }
811
812 #[test]
813 fn test_model_params_empty() {
814 let params = model_params("claude-sonnet-4-20250514");
815 assert!(params.is_empty());
817 }
818}