1use serde::Deserialize;
2use std::collections::BTreeMap;
3use std::sync::OnceLock;
4
5static CONFIG: OnceLock<ProvidersConfig> = OnceLock::new();
6
7#[derive(Debug, Clone, Deserialize, Default)]
12pub struct ProvidersConfig {
13 #[serde(default)]
14 pub providers: BTreeMap<String, ProviderDef>,
15 #[serde(default)]
16 pub aliases: BTreeMap<String, AliasDef>,
17 #[serde(default)]
18 pub inference_rules: Vec<InferenceRule>,
19 #[serde(default)]
20 pub tier_rules: Vec<TierRule>,
21 #[serde(default)]
22 pub tier_defaults: TierDefaults,
23 #[serde(default)]
24 pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
25}
26
27#[derive(Debug, Clone, Deserialize)]
28pub struct ProviderDef {
29 pub base_url: String,
30 #[serde(default)]
31 pub base_url_env: Option<String>,
32 #[serde(default = "default_bearer")]
33 pub auth_style: String,
34 #[serde(default)]
35 pub auth_header: Option<String>,
36 #[serde(default)]
37 pub auth_env: AuthEnv,
38 #[serde(default)]
39 pub extra_headers: BTreeMap<String, String>,
40 #[serde(default)]
41 pub chat_endpoint: String,
42 #[serde(default)]
43 pub healthcheck: Option<HealthcheckDef>,
44 #[serde(default)]
45 pub features: Vec<String>,
46 #[serde(default)]
48 pub fallback: Option<String>,
49 #[serde(default)]
51 pub retry_count: Option<u32>,
52 #[serde(default)]
54 pub retry_delay_ms: Option<u64>,
55}
56
57impl Default for ProviderDef {
58 fn default() -> Self {
59 Self {
60 base_url: String::new(),
61 base_url_env: None,
62 auth_style: default_bearer(),
63 auth_header: None,
64 auth_env: AuthEnv::None,
65 extra_headers: BTreeMap::new(),
66 chat_endpoint: String::new(),
67 healthcheck: None,
68 features: Vec::new(),
69 fallback: None,
70 retry_count: None,
71 retry_delay_ms: None,
72 }
73 }
74}
75
76fn default_bearer() -> String {
77 "bearer".to_string()
78}
79
80#[derive(Debug, Clone, Deserialize, Default)]
83#[serde(untagged)]
84pub enum AuthEnv {
85 #[default]
86 None,
87 Single(String),
88 Multiple(Vec<String>),
89}
90
91#[derive(Debug, Clone, Deserialize)]
92pub struct HealthcheckDef {
93 pub method: String,
94 #[serde(default)]
95 pub path: Option<String>,
96 #[serde(default)]
97 pub url: Option<String>,
98 #[serde(default)]
99 pub body: Option<String>,
100}
101
102#[derive(Debug, Clone, Deserialize)]
103pub struct AliasDef {
104 pub id: String,
105 pub provider: String,
106}
107
108#[derive(Debug, Clone, Deserialize)]
109pub struct InferenceRule {
110 #[serde(default)]
111 pub pattern: Option<String>,
112 #[serde(default)]
113 pub contains: Option<String>,
114 #[serde(default)]
115 pub exact: Option<String>,
116 pub provider: String,
117}
118
119#[derive(Debug, Clone, Deserialize)]
120pub struct TierRule {
121 #[serde(default)]
122 pub pattern: Option<String>,
123 #[serde(default)]
124 pub contains: Option<String>,
125 #[serde(default)]
126 pub exact: Option<String>,
127 pub tier: String,
128}
129
130#[derive(Debug, Clone, Deserialize)]
131pub struct TierDefaults {
132 #[serde(default = "default_mid")]
133 pub default: String,
134}
135
136impl Default for TierDefaults {
137 fn default() -> Self {
138 Self {
139 default: default_mid(),
140 }
141 }
142}
143
144fn default_mid() -> String {
145 "mid".to_string()
146}
147
148pub fn load_config() -> &'static ProvidersConfig {
154 CONFIG.get_or_init(|| {
155 if let Ok(path) = std::env::var("HARN_PROVIDERS_CONFIG") {
157 match std::fs::read_to_string(&path) {
158 Ok(content) => match toml::from_str::<ProvidersConfig>(&content) {
159 Ok(config) => {
160 eprintln!(
161 "[llm_config] Loaded {} providers, {} aliases from {}",
162 config.providers.len(),
163 config.aliases.len(),
164 path
165 );
166 return config;
167 }
168 Err(e) => eprintln!("[llm_config] TOML parse error in {}: {}", path, e),
169 },
170 Err(e) => eprintln!("[llm_config] Cannot read {}: {}", path, e),
171 }
172 }
173 if let Some(home) = dirs_or_home() {
175 let path = format!("{home}/.config/harn/providers.toml");
176 if let Ok(content) = std::fs::read_to_string(&path) {
177 if let Ok(config) = toml::from_str::<ProvidersConfig>(&content) {
178 return config;
179 }
180 }
181 }
182 default_config()
184 })
185}
186
187pub fn resolve_model(alias: &str) -> (String, Option<String>) {
189 let config = load_config();
190 if let Some(a) = config.aliases.get(alias) {
191 return (a.id.clone(), Some(a.provider.clone()));
192 }
193 (alias.to_string(), None)
194}
195
196pub fn infer_provider(model_id: &str) -> String {
198 let config = load_config();
199 for rule in &config.inference_rules {
200 if let Some(exact) = &rule.exact {
201 if model_id == exact {
202 return rule.provider.clone();
203 }
204 }
205 if let Some(pattern) = &rule.pattern {
206 if glob_match(pattern, model_id) {
207 return rule.provider.clone();
208 }
209 }
210 if let Some(substr) = &rule.contains {
211 if model_id.contains(substr.as_str()) {
212 return rule.provider.clone();
213 }
214 }
215 }
216 if model_id.starts_with("claude-") {
218 return "anthropic".to_string();
219 }
220 if model_id.starts_with("gpt-") || model_id.starts_with("o1") || model_id.starts_with("o3") {
221 return "openai".to_string();
222 }
223 if model_id.contains('/') {
224 return "openrouter".to_string();
225 }
226 if model_id.contains(':') {
227 return "ollama".to_string();
228 }
229 "anthropic".to_string()
230}
231
232pub fn model_tier(model_id: &str) -> String {
234 let config = load_config();
235 for rule in &config.tier_rules {
236 if let Some(exact) = &rule.exact {
237 if model_id == exact {
238 return rule.tier.clone();
239 }
240 }
241 if let Some(pattern) = &rule.pattern {
242 if glob_match(pattern, model_id) {
243 return rule.tier.clone();
244 }
245 }
246 if let Some(substr) = &rule.contains {
247 if model_id.contains(substr.as_str()) {
248 return rule.tier.clone();
249 }
250 }
251 }
252 let lower = model_id.to_lowercase();
254 if lower.contains("9b") || lower.contains("a3b") {
255 return "small".to_string();
256 }
257 if lower.starts_with("claude-") || lower == "gpt-4o" {
258 return "frontier".to_string();
259 }
260 config.tier_defaults.default.clone()
261}
262
263pub fn provider_config(name: &str) -> Option<&'static ProviderDef> {
265 load_config().providers.get(name)
266}
267
268pub fn model_params(model_id: &str) -> BTreeMap<String, toml::Value> {
271 let config = load_config();
272 let mut params = BTreeMap::new();
273 for (pattern, defaults) in &config.model_defaults {
274 if glob_match(pattern, model_id) {
275 for (k, v) in defaults {
276 params.insert(k.clone(), v.clone());
277 }
278 }
279 }
280 params
281}
282
283pub fn provider_names() -> Vec<String> {
285 load_config().providers.keys().cloned().collect()
286}
287
288fn glob_match(pattern: &str, input: &str) -> bool {
294 if let Some(prefix) = pattern.strip_suffix('*') {
295 input.starts_with(prefix)
296 } else if let Some(suffix) = pattern.strip_prefix('*') {
297 input.ends_with(suffix)
298 } else if pattern.contains('*') {
299 let parts: Vec<&str> = pattern.split('*').collect();
300 if parts.len() == 2 {
301 input.starts_with(parts[0]) && input.ends_with(parts[1])
302 } else {
303 input == pattern
304 }
305 } else {
306 input == pattern
307 }
308}
309
310fn dirs_or_home() -> Option<String> {
311 std::env::var("HOME").ok()
312}
313
314pub fn resolve_base_url(pdef: &ProviderDef) -> String {
317 if let Some(env_name) = &pdef.base_url_env {
318 if let Ok(val) = std::env::var(env_name) {
319 if !val.is_empty() {
320 return val;
321 }
322 }
323 }
324 pdef.base_url.clone()
325}
326
327fn default_config() -> ProvidersConfig {
332 let mut config = ProvidersConfig::default();
333
334 config.providers.insert(
336 "anthropic".to_string(),
337 ProviderDef {
338 base_url: "https://api.anthropic.com/v1".to_string(),
339 auth_style: "header".to_string(),
340 auth_header: Some("x-api-key".to_string()),
341 auth_env: AuthEnv::Single("ANTHROPIC_API_KEY".to_string()),
342 extra_headers: BTreeMap::from([(
343 "anthropic-version".to_string(),
344 "2023-06-01".to_string(),
345 )]),
346 chat_endpoint: "/messages".to_string(),
347 healthcheck: Some(HealthcheckDef {
348 method: "POST".to_string(),
349 path: Some("/messages/count_tokens".to_string()),
350 url: None,
351 body: Some(
352 r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"x"}]}"#
353 .to_string(),
354 ),
355 }),
356 features: vec!["prompt_caching".to_string(), "thinking".to_string()],
357 ..Default::default()
358 },
359 );
360
361 config.providers.insert(
363 "openai".to_string(),
364 ProviderDef {
365 base_url: "https://api.openai.com/v1".to_string(),
366 auth_style: "bearer".to_string(),
367 auth_env: AuthEnv::Single("OPENAI_API_KEY".to_string()),
368 chat_endpoint: "/chat/completions".to_string(),
369 healthcheck: Some(HealthcheckDef {
370 method: "GET".to_string(),
371 path: Some("/models".to_string()),
372 url: None,
373 body: None,
374 }),
375 ..Default::default()
376 },
377 );
378
379 config.providers.insert(
381 "openrouter".to_string(),
382 ProviderDef {
383 base_url: "https://openrouter.ai/api/v1".to_string(),
384 auth_style: "bearer".to_string(),
385 auth_env: AuthEnv::Single("OPENROUTER_API_KEY".to_string()),
386 chat_endpoint: "/chat/completions".to_string(),
387 healthcheck: Some(HealthcheckDef {
388 method: "GET".to_string(),
389 path: Some("/auth/key".to_string()),
390 url: None,
391 body: None,
392 }),
393 ..Default::default()
394 },
395 );
396
397 config.providers.insert(
399 "huggingface".to_string(),
400 ProviderDef {
401 base_url: "https://router.huggingface.co/v1".to_string(),
402 auth_style: "bearer".to_string(),
403 auth_env: AuthEnv::Multiple(vec![
404 "HF_TOKEN".to_string(),
405 "HUGGINGFACE_API_KEY".to_string(),
406 ]),
407 chat_endpoint: "/chat/completions".to_string(),
408 healthcheck: Some(HealthcheckDef {
409 method: "GET".to_string(),
410 url: Some("https://huggingface.co/api/whoami-v2".to_string()),
411 path: None,
412 body: None,
413 }),
414 ..Default::default()
415 },
416 );
417
418 config.providers.insert(
420 "ollama".to_string(),
421 ProviderDef {
422 base_url: "http://localhost:11434".to_string(),
423 base_url_env: Some("OLLAMA_HOST".to_string()),
424 auth_style: "none".to_string(),
425 chat_endpoint: "/api/chat".to_string(),
426 healthcheck: Some(HealthcheckDef {
427 method: "GET".to_string(),
428 path: Some("/api/tags".to_string()),
429 url: None,
430 body: None,
431 }),
432 ..Default::default()
433 },
434 );
435
436 config.inference_rules = vec![
438 InferenceRule {
439 pattern: Some("claude-*".to_string()),
440 contains: None,
441 exact: None,
442 provider: "anthropic".to_string(),
443 },
444 InferenceRule {
445 pattern: Some("gpt-*".to_string()),
446 contains: None,
447 exact: None,
448 provider: "openai".to_string(),
449 },
450 InferenceRule {
451 pattern: Some("o1*".to_string()),
452 contains: None,
453 exact: None,
454 provider: "openai".to_string(),
455 },
456 InferenceRule {
457 pattern: Some("o3*".to_string()),
458 contains: None,
459 exact: None,
460 provider: "openai".to_string(),
461 },
462 InferenceRule {
463 pattern: None,
464 contains: Some("/".to_string()),
465 exact: None,
466 provider: "openrouter".to_string(),
467 },
468 InferenceRule {
469 pattern: None,
470 contains: Some(":".to_string()),
471 exact: None,
472 provider: "ollama".to_string(),
473 },
474 ];
475
476 config.tier_rules = vec![
478 TierRule {
479 contains: Some("9b".to_string()),
480 pattern: None,
481 exact: None,
482 tier: "small".to_string(),
483 },
484 TierRule {
485 contains: Some("a3b".to_string()),
486 pattern: None,
487 exact: None,
488 tier: "small".to_string(),
489 },
490 TierRule {
491 pattern: Some("claude-*".to_string()),
492 contains: None,
493 exact: None,
494 tier: "frontier".to_string(),
495 },
496 TierRule {
497 exact: Some("gpt-4o".to_string()),
498 contains: None,
499 pattern: None,
500 tier: "frontier".to_string(),
501 },
502 ];
503
504 config.tier_defaults = TierDefaults {
505 default: "mid".to_string(),
506 };
507
508 config
509}
510
511#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_glob_match_prefix() {
521 assert!(glob_match("claude-*", "claude-sonnet-4-20250514"));
522 assert!(glob_match("gpt-*", "gpt-4o"));
523 assert!(!glob_match("claude-*", "gpt-4o"));
524 }
525
526 #[test]
527 fn test_glob_match_suffix() {
528 assert!(glob_match("*-latest", "llama3.2-latest"));
529 assert!(!glob_match("*-latest", "llama3.2"));
530 }
531
532 #[test]
533 fn test_glob_match_middle() {
534 assert!(glob_match("claude-*-latest", "claude-sonnet-latest"));
535 assert!(!glob_match("claude-*-latest", "claude-sonnet-beta"));
536 }
537
538 #[test]
539 fn test_glob_match_exact() {
540 assert!(glob_match("gpt-4o", "gpt-4o"));
541 assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
542 }
543
544 #[test]
545 fn test_infer_provider_from_defaults() {
546 assert_eq!(infer_provider("claude-sonnet-4-20250514"), "anthropic");
548 assert_eq!(infer_provider("gpt-4o"), "openai");
549 assert_eq!(infer_provider("o1-preview"), "openai");
550 assert_eq!(infer_provider("o3-mini"), "openai");
551 assert_eq!(infer_provider("qwen/qwen3-coder"), "openrouter");
552 assert_eq!(infer_provider("llama3.2:latest"), "ollama");
553 assert_eq!(infer_provider("unknown-model"), "anthropic");
554 }
555
556 #[test]
557 fn test_model_tier_from_defaults() {
558 assert_eq!(model_tier("claude-sonnet-4-20250514"), "frontier");
559 assert_eq!(model_tier("gpt-4o"), "frontier");
560 assert_eq!(model_tier("Qwen3.5-9B"), "small");
561 assert_eq!(model_tier("deepseek-v3"), "mid");
562 }
563
564 #[test]
565 fn test_resolve_model_unknown_alias() {
566 let (id, provider) = resolve_model("gpt-4o");
567 assert_eq!(id, "gpt-4o");
568 assert!(provider.is_none());
569 }
570
571 #[test]
572 fn test_provider_names() {
573 let names = provider_names();
574 assert!(names.len() >= 5);
575 assert!(names.contains(&"anthropic".to_string()));
576 assert!(names.contains(&"openai".to_string()));
577 assert!(names.contains(&"ollama".to_string()));
578 }
579
580 #[test]
581 fn test_provider_config_anthropic() {
582 let pdef = provider_config("anthropic").unwrap();
583 assert_eq!(pdef.auth_style, "header");
584 assert_eq!(pdef.auth_header.as_deref(), Some("x-api-key"));
585 }
586
587 #[test]
588 fn test_resolve_base_url_no_env() {
589 let pdef = ProviderDef {
590 base_url: "https://example.com".to_string(),
591 ..Default::default()
592 };
593 assert_eq!(resolve_base_url(&pdef), "https://example.com");
594 }
595
596 #[test]
597 fn test_default_config_roundtrip() {
598 let config = default_config();
599 assert!(!config.providers.is_empty());
600 assert!(!config.inference_rules.is_empty());
601 assert!(!config.tier_rules.is_empty());
602 assert_eq!(config.tier_defaults.default, "mid");
603 }
604
605 #[test]
606 fn test_model_params_empty() {
607 let params = model_params("claude-sonnet-4-20250514");
608 assert!(params.is_empty());
610 }
611}