1use std::collections::HashMap;
2
3use dotenvy::dotenv;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum ProviderKind {
10 OpenAi,
11 Ollama,
12 LmStudio,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct ProviderConfig {
18 pub kind: ProviderKind,
19 pub model: String,
20 pub base_url: String,
21 pub api_key: Option<String>,
22}
23
24#[derive(Debug, Error)]
26pub enum ConfigError {
27 #[error("missing environment configuration for provider {0:?}")]
28 MissingEnv(ProviderKind),
29 #[error("could not infer provider from model `{model}`")]
30 UnknownProvider { model: String },
31}
32
33#[derive(Debug, Clone)]
35pub struct Environment {
36 vars: HashMap<String, String>,
37}
38
39impl Environment {
40 pub fn from_process() -> Self {
41 dotenv().ok();
43 let vars = std::env::vars().collect();
44 Self { vars }
45 }
46
47 pub fn from_pairs<I, K, V>(pairs: I) -> Self
48 where
49 I: IntoIterator<Item = (K, V)>,
50 K: Into<String>,
51 V: Into<String>,
52 {
53 Self {
54 vars: pairs
55 .into_iter()
56 .map(|(k, v)| (k.into(), v.into()))
57 .collect(),
58 }
59 }
60
61 pub fn get(&self, key: &str) -> Option<&str> {
62 self.vars.get(key).map(|s| s.as_str())
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct ProviderResolver {
69 env: Environment,
70}
71
72impl ProviderResolver {
73 pub fn new(env: Environment) -> Self {
74 Self { env }
75 }
76
77 pub fn from_process() -> Self {
78 Self::new(Environment::from_process())
79 }
80
81 pub fn resolve(&self, model: &str) -> Result<ProviderConfig, ConfigError> {
82 let (explicit_provider, normalized_model) = parse_model_identifier(model);
83 let provider = explicit_provider
84 .or_else(|| infer_from_env(&self.env))
85 .ok_or_else(|| ConfigError::UnknownProvider {
86 model: model.to_string(),
87 })?;
88
89 self.resolve_with_kind(provider, normalized_model)
90 }
91
92 pub fn resolve_with_kind(
93 &self,
94 kind: ProviderKind,
95 model: &str,
96 ) -> Result<ProviderConfig, ConfigError> {
97 match kind {
98 ProviderKind::OpenAi => self.resolve_openai(model),
99 ProviderKind::Ollama => self.resolve_ollama(model),
100 ProviderKind::LmStudio => self.resolve_lmstudio(model),
101 }
102 }
103
104 fn resolve_openai(&self, model: &str) -> Result<ProviderConfig, ConfigError> {
105 let api_key = self.env.get("OPENAI_API_KEY").map(|s| s.to_string());
106 if api_key.is_none() {
107 return Err(ConfigError::MissingEnv(ProviderKind::OpenAi));
108 }
109 let base = self
110 .env
111 .get("OPENAI_API_BASE")
112 .unwrap_or("https://api.openai.com/v1")
113 .to_string();
114 Ok(ProviderConfig {
115 kind: ProviderKind::OpenAi,
116 model: model.to_string(),
117 base_url: base,
118 api_key,
119 })
120 }
121
122 fn resolve_ollama(&self, model: &str) -> Result<ProviderConfig, ConfigError> {
123 let base = self
124 .env
125 .get("OLLAMA_API_BASE")
126 .unwrap_or("http://localhost:11434")
127 .to_string();
128 Ok(ProviderConfig {
129 kind: ProviderKind::Ollama,
130 model: model.to_string(),
131 base_url: base,
132 api_key: None,
133 })
134 }
135
136 fn resolve_lmstudio(&self, model: &str) -> Result<ProviderConfig, ConfigError> {
137 let base = self
138 .env
139 .get("LMSTUDIO_API_BASE")
140 .unwrap_or("http://localhost:1234/v1")
141 .to_string();
142 Ok(ProviderConfig {
143 kind: ProviderKind::LmStudio,
144 model: model.to_string(),
145 base_url: base,
146 api_key: None,
147 })
148 }
149}
150
151fn parse_model_identifier(model: &str) -> (Option<ProviderKind>, &str) {
152 if let Some(stripped) = model.strip_prefix("openai:") {
153 return (Some(ProviderKind::OpenAi), stripped);
154 }
155 if let Some(stripped) = model.strip_prefix("ollama:") {
156 return (Some(ProviderKind::Ollama), stripped);
157 }
158 if let Some(stripped) = model.strip_prefix("lmstudio:") {
159 return (Some(ProviderKind::LmStudio), stripped);
160 }
161 (None, model)
162}
163
164fn infer_from_env(env: &Environment) -> Option<ProviderKind> {
165 if env.get("OPENAI_API_KEY").is_some() {
166 return Some(ProviderKind::OpenAi);
167 }
168 if env.get("OLLAMA_API_BASE").is_some() {
169 return Some(ProviderKind::Ollama);
170 }
171 if env.get("LMSTUDIO_API_BASE").is_some() {
172 return Some(ProviderKind::LmStudio);
173 }
174 None
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 fn resolver_with(vars: &[(&str, &str)]) -> ProviderResolver {
182 let env = Environment::from_pairs(vars.iter().cloned());
183 ProviderResolver::new(env)
184 }
185
186 #[test]
187 fn resolves_openai_with_default_model() {
188 let resolver = resolver_with(&[
189 ("OPENAI_API_KEY", "test-key"),
190 ("OPENAI_API_BASE", "https://api.openai.com/v1"),
191 ]);
192 let cfg = resolver.resolve("gpt-4o").expect("config");
193 assert_eq!(cfg.kind, ProviderKind::OpenAi);
194 assert_eq!(cfg.model, "gpt-4o");
195 assert_eq!(cfg.base_url, "https://api.openai.com/v1");
196 assert_eq!(cfg.api_key.as_deref(), Some("test-key"));
197 }
198
199 #[test]
200 fn resolves_ollama_with_prefix() {
201 let resolver = resolver_with(&[
202 ("OPENAI_API_KEY", "x"),
203 ("OLLAMA_API_BASE", "http://localhost:11434"),
204 ]);
205 let cfg = resolver.resolve("ollama:llama3").expect("config");
206 assert_eq!(cfg.kind, ProviderKind::Ollama);
207 assert_eq!(cfg.model, "llama3");
208 assert_eq!(cfg.base_url, "http://localhost:11434");
209 assert!(cfg.api_key.is_none());
210 }
211
212 #[test]
213 fn resolves_lmstudio_with_prefix() {
214 let resolver = resolver_with(&[("LMSTUDIO_API_BASE", "http://127.0.0.1:1234/v1")]);
215 let cfg = resolver.resolve("lmstudio:phi-3").expect("config");
216 assert_eq!(cfg.kind, ProviderKind::LmStudio);
217 assert_eq!(cfg.model, "phi-3");
218 assert_eq!(cfg.base_url, "http://127.0.0.1:1234/v1");
219 }
220
221 #[test]
222 fn errors_when_openai_key_missing() {
223 let resolver = resolver_with(&[]);
224 let err = resolver.resolve("openai:gpt-4o").unwrap_err();
225 assert!(matches!(err, ConfigError::MissingEnv(ProviderKind::OpenAi)));
226 }
227
228 #[test]
229 fn errors_when_provider_cannot_be_inferred() {
230 let resolver = resolver_with(&[]);
231 let err = resolver.resolve("gpt-4o").unwrap_err();
232 assert!(matches!(
233 err,
234 ConfigError::UnknownProvider { model } if model == "gpt-4o"
235 ));
236 }
237}