hefa_core/
config.rs

1use std::collections::HashMap;
2
3use dotenvy::dotenv;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7/// Supported LLM provider backends.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum ProviderKind {
10    OpenAi,
11    Ollama,
12    LmStudio,
13}
14
15/// Fully-resolved configuration for calling an LLM provider.
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct ProviderConfig {
18    pub kind: ProviderKind,
19    pub model: String,
20    pub base_url: String,
21    pub api_key: Option<String>,
22}
23
24/// Errors raised while resolving provider configuration.
25#[derive(Debug, Error)]
26pub enum ConfigError {
27    #[error("missing environment configuration for provider {0:?}")]
28    MissingEnv(ProviderKind),
29    #[error("could not infer provider from model `{model}`")]
30    UnknownProvider { model: String },
31}
32
33/// Simple environment wrapper that loads `.env` and process variables.
34#[derive(Debug, Clone)]
35pub struct Environment {
36    vars: HashMap<String, String>,
37}
38
39impl Environment {
40    pub fn from_process() -> Self {
41        // Ignore errors: dotenv is optional during runtime.
42        dotenv().ok();
43        let vars = std::env::vars().collect();
44        Self { vars }
45    }
46
47    pub fn from_pairs<I, K, V>(pairs: I) -> Self
48    where
49        I: IntoIterator<Item = (K, V)>,
50        K: Into<String>,
51        V: Into<String>,
52    {
53        Self {
54            vars: pairs
55                .into_iter()
56                .map(|(k, v)| (k.into(), v.into()))
57                .collect(),
58        }
59    }
60
61    pub fn get(&self, key: &str) -> Option<&str> {
62        self.vars.get(key).map(|s| s.as_str())
63    }
64}
65
66/// Resolves provider configuration based on model identifiers and env vars.
67#[derive(Debug, Clone)]
68pub struct ProviderResolver {
69    env: Environment,
70}
71
72impl ProviderResolver {
73    pub fn new(env: Environment) -> Self {
74        Self { env }
75    }
76
77    pub fn from_process() -> Self {
78        Self::new(Environment::from_process())
79    }
80
81    pub fn resolve(&self, model: &str) -> Result<ProviderConfig, ConfigError> {
82        let (explicit_provider, normalized_model) = parse_model_identifier(model);
83        let provider = explicit_provider
84            .or_else(|| infer_from_env(&self.env))
85            .ok_or_else(|| ConfigError::UnknownProvider {
86                model: model.to_string(),
87            })?;
88
89        self.resolve_with_kind(provider, normalized_model)
90    }
91
92    pub fn resolve_with_kind(
93        &self,
94        kind: ProviderKind,
95        model: &str,
96    ) -> Result<ProviderConfig, ConfigError> {
97        match kind {
98            ProviderKind::OpenAi => self.resolve_openai(model),
99            ProviderKind::Ollama => self.resolve_ollama(model),
100            ProviderKind::LmStudio => self.resolve_lmstudio(model),
101        }
102    }
103
104    fn resolve_openai(&self, model: &str) -> Result<ProviderConfig, ConfigError> {
105        let api_key = self.env.get("OPENAI_API_KEY").map(|s| s.to_string());
106        if api_key.is_none() {
107            return Err(ConfigError::MissingEnv(ProviderKind::OpenAi));
108        }
109        let base = self
110            .env
111            .get("OPENAI_API_BASE")
112            .unwrap_or("https://api.openai.com/v1")
113            .to_string();
114        Ok(ProviderConfig {
115            kind: ProviderKind::OpenAi,
116            model: model.to_string(),
117            base_url: base,
118            api_key,
119        })
120    }
121
122    fn resolve_ollama(&self, model: &str) -> Result<ProviderConfig, ConfigError> {
123        let base = self
124            .env
125            .get("OLLAMA_API_BASE")
126            .unwrap_or("http://localhost:11434")
127            .to_string();
128        Ok(ProviderConfig {
129            kind: ProviderKind::Ollama,
130            model: model.to_string(),
131            base_url: base,
132            api_key: None,
133        })
134    }
135
136    fn resolve_lmstudio(&self, model: &str) -> Result<ProviderConfig, ConfigError> {
137        let base = self
138            .env
139            .get("LMSTUDIO_API_BASE")
140            .unwrap_or("http://localhost:1234/v1")
141            .to_string();
142        Ok(ProviderConfig {
143            kind: ProviderKind::LmStudio,
144            model: model.to_string(),
145            base_url: base,
146            api_key: None,
147        })
148    }
149}
150
151fn parse_model_identifier(model: &str) -> (Option<ProviderKind>, &str) {
152    if let Some(stripped) = model.strip_prefix("openai:") {
153        return (Some(ProviderKind::OpenAi), stripped);
154    }
155    if let Some(stripped) = model.strip_prefix("ollama:") {
156        return (Some(ProviderKind::Ollama), stripped);
157    }
158    if let Some(stripped) = model.strip_prefix("lmstudio:") {
159        return (Some(ProviderKind::LmStudio), stripped);
160    }
161    (None, model)
162}
163
164fn infer_from_env(env: &Environment) -> Option<ProviderKind> {
165    if env.get("OPENAI_API_KEY").is_some() {
166        return Some(ProviderKind::OpenAi);
167    }
168    if env.get("OLLAMA_API_BASE").is_some() {
169        return Some(ProviderKind::Ollama);
170    }
171    if env.get("LMSTUDIO_API_BASE").is_some() {
172        return Some(ProviderKind::LmStudio);
173    }
174    None
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    fn resolver_with(vars: &[(&str, &str)]) -> ProviderResolver {
182        let env = Environment::from_pairs(vars.iter().cloned());
183        ProviderResolver::new(env)
184    }
185
186    #[test]
187    fn resolves_openai_with_default_model() {
188        let resolver = resolver_with(&[
189            ("OPENAI_API_KEY", "test-key"),
190            ("OPENAI_API_BASE", "https://api.openai.com/v1"),
191        ]);
192        let cfg = resolver.resolve("gpt-4o").expect("config");
193        assert_eq!(cfg.kind, ProviderKind::OpenAi);
194        assert_eq!(cfg.model, "gpt-4o");
195        assert_eq!(cfg.base_url, "https://api.openai.com/v1");
196        assert_eq!(cfg.api_key.as_deref(), Some("test-key"));
197    }
198
199    #[test]
200    fn resolves_ollama_with_prefix() {
201        let resolver = resolver_with(&[
202            ("OPENAI_API_KEY", "x"),
203            ("OLLAMA_API_BASE", "http://localhost:11434"),
204        ]);
205        let cfg = resolver.resolve("ollama:llama3").expect("config");
206        assert_eq!(cfg.kind, ProviderKind::Ollama);
207        assert_eq!(cfg.model, "llama3");
208        assert_eq!(cfg.base_url, "http://localhost:11434");
209        assert!(cfg.api_key.is_none());
210    }
211
212    #[test]
213    fn resolves_lmstudio_with_prefix() {
214        let resolver = resolver_with(&[("LMSTUDIO_API_BASE", "http://127.0.0.1:1234/v1")]);
215        let cfg = resolver.resolve("lmstudio:phi-3").expect("config");
216        assert_eq!(cfg.kind, ProviderKind::LmStudio);
217        assert_eq!(cfg.model, "phi-3");
218        assert_eq!(cfg.base_url, "http://127.0.0.1:1234/v1");
219    }
220
221    #[test]
222    fn errors_when_openai_key_missing() {
223        let resolver = resolver_with(&[]);
224        let err = resolver.resolve("openai:gpt-4o").unwrap_err();
225        assert!(matches!(err, ConfigError::MissingEnv(ProviderKind::OpenAi)));
226    }
227
228    #[test]
229    fn errors_when_provider_cannot_be_inferred() {
230        let resolver = resolver_with(&[]);
231        let err = resolver.resolve("gpt-4o").unwrap_err();
232        assert!(matches!(
233            err,
234            ConfigError::UnknownProvider { model } if model == "gpt-4o"
235        ));
236    }
237}