1use crate::{LlmGatewayError, LlmProvider};
2use autoagents_llm::builder::LLMBuilder;
3use autoagents_llm::chat::ReasoningEffort;
4use serde_json::Value;
5use std::sync::Arc;
6
7#[derive(Debug, Clone, Default, PartialEq)]
8pub struct AutoagentsProviderConfig {
9 pub provider: String,
10 pub model: Option<String>,
11 pub api_key: Option<String>,
12 pub base_url: Option<String>,
13 pub max_tokens: Option<u32>,
14 pub temperature: Option<f32>,
15 pub timeout_seconds: Option<u64>,
16 pub reasoning: Option<bool>,
17 pub reasoning_effort: Option<String>,
18 pub reasoning_budget_tokens: Option<u32>,
19 pub top_p: Option<f32>,
20 pub top_k: Option<u32>,
21 pub normalize_response: Option<bool>,
22 pub extra_body: Option<Value>,
23 pub api_version: Option<String>,
24 pub deployment_id: Option<String>,
25}
26
27impl AutoagentsProviderConfig {
28 pub fn new(provider: impl Into<String>) -> Self {
29 Self {
30 provider: provider.into(),
31 ..Self::default()
32 }
33 }
34
35 pub fn from_env(provider: impl Into<String>, model: impl Into<String>) -> Self {
36 let provider = provider.into();
37 let env_prefix = provider.to_ascii_uppercase().replace(['-', '.'], "_");
38 Self {
39 provider,
40 model: Some(model.into()),
41 api_key: std::env::var(format!("{env_prefix}_API_KEY")).ok(),
42 base_url: std::env::var(format!("{env_prefix}_BASE_URL")).ok(),
43 api_version: std::env::var(format!("{env_prefix}_API_VERSION")).ok(),
44 deployment_id: std::env::var(format!("{env_prefix}_DEPLOYMENT_ID")).ok(),
45 ..Self::default()
46 }
47 }
48
49 pub fn with_model_config(mut self, config: Option<&Value>) -> Self {
50 let Some(config) = config.and_then(Value::as_object) else {
51 return self;
52 };
53
54 self.max_tokens = config
55 .get("max_tokens")
56 .and_then(Value::as_u64)
57 .and_then(|value| u32::try_from(value).ok())
58 .or(self.max_tokens);
59 self.temperature = config
60 .get("temperature")
61 .and_then(Value::as_f64)
62 .map(|value| value as f32)
63 .or(self.temperature);
64 self.timeout_seconds = config
65 .get("timeout_seconds")
66 .and_then(Value::as_u64)
67 .or(self.timeout_seconds);
68 self.reasoning = config
69 .get("reasoning")
70 .and_then(Value::as_bool)
71 .or(self.reasoning);
72 self.reasoning_effort = config
73 .get("reasoning_effort")
74 .and_then(Value::as_str)
75 .map(str::to_string)
76 .or(self.reasoning_effort);
77 self.reasoning_budget_tokens = config
78 .get("reasoning_budget_tokens")
79 .and_then(Value::as_u64)
80 .and_then(|value| u32::try_from(value).ok())
81 .or(self.reasoning_budget_tokens);
82 self.top_p = config
83 .get("top_p")
84 .and_then(Value::as_f64)
85 .map(|value| value as f32)
86 .or(self.top_p);
87 self.top_k = config
88 .get("top_k")
89 .and_then(Value::as_u64)
90 .and_then(|value| u32::try_from(value).ok())
91 .or(self.top_k);
92 self.normalize_response = config
93 .get("normalize_response")
94 .and_then(Value::as_bool)
95 .or(self.normalize_response);
96 self.extra_body = config.get("extra_body").cloned().or(self.extra_body);
97 self.api_version = config
98 .get("api_version")
99 .and_then(Value::as_str)
100 .map(str::to_string)
101 .or(self.api_version);
102 self.deployment_id = config
103 .get("deployment_id")
104 .and_then(Value::as_str)
105 .map(str::to_string)
106 .or(self.deployment_id);
107 self
108 }
109}
110
111pub fn build_autoagents_provider(
112 config: AutoagentsProviderConfig,
113) -> Result<Arc<dyn LlmProvider>, LlmGatewayError> {
114 match config.provider.as_str() {
115 "openai" => build::<autoagents_llm::backends::openai::OpenAI>(config),
116 "anthropic" => build::<autoagents_llm::backends::anthropic::Anthropic>(config),
117 "ollama" => build::<autoagents_llm::backends::ollama::Ollama>(config),
118 "deepseek" => build::<autoagents_llm::backends::deepseek::DeepSeek>(config),
119 "xai" => build::<autoagents_llm::backends::xai::XAI>(config),
120 "phind" => build::<autoagents_llm::backends::phind::Phind>(config),
121 "google" => build::<autoagents_llm::backends::google::Google>(config),
122 "groq" => build::<autoagents_llm::backends::groq::Groq>(config),
123 "azure-openai" => build::<autoagents_llm::backends::azure_openai::AzureOpenAI>(config),
124 "openrouter" => build::<autoagents_llm::backends::openrouter::OpenRouter>(config),
125 "minimax" => build::<autoagents_llm::backends::minimax::MiniMax>(config),
126 other => Err(LlmGatewayError::UnknownProvider(other.to_string())),
127 }
128}
129
130fn build<T>(config: AutoagentsProviderConfig) -> Result<Arc<dyn LlmProvider>, LlmGatewayError>
131where
132 T: autoagents_llm::LLMProvider + autoagents_llm::HasConfig,
133 LLMBuilder<T>: BuildAutoagentsProvider<T>,
134{
135 let provider_name = config.provider.clone();
136 BuildAutoagentsProvider::build_provider(apply_common::<T>(config)).map_err(|source| {
137 LlmGatewayError::Provider {
138 provider: provider_name,
139 message: source.to_string(),
140 }
141 })
142}
143
144fn apply_common<T>(config: AutoagentsProviderConfig) -> LLMBuilder<T>
145where
146 T: autoagents_llm::LLMProvider + autoagents_llm::HasConfig,
147{
148 let mut builder = LLMBuilder::<T>::new();
149 if let Some(api_key) = config.api_key {
150 builder = builder.api_key(api_key);
151 }
152 if let Some(base_url) = config.base_url {
153 builder = builder.base_url(base_url);
154 }
155 if let Some(model) = config.model {
156 builder = builder.model(model);
157 }
158 if let Some(max_tokens) = config.max_tokens {
159 builder = builder.max_tokens(max_tokens);
160 }
161 if let Some(temperature) = config.temperature {
162 builder = builder.temperature(temperature);
163 }
164 if let Some(timeout_seconds) = config.timeout_seconds {
165 builder = builder.timeout_seconds(timeout_seconds);
166 }
167 if let Some(reasoning) = config.reasoning {
168 builder = builder.reasoning(reasoning);
169 }
170 if let Some(reasoning_effort) = config.reasoning_effort {
171 builder = match reasoning_effort.as_str() {
172 "low" => builder.reasoning_effort(ReasoningEffort::Low),
173 "medium" => builder.reasoning_effort(ReasoningEffort::Medium),
174 "high" => builder.reasoning_effort(ReasoningEffort::High),
175 _ => builder,
176 };
177 }
178 if let Some(reasoning_budget_tokens) = config.reasoning_budget_tokens {
179 builder = builder.reasoning_budget_tokens(reasoning_budget_tokens);
180 }
181 if let Some(top_p) = config.top_p {
182 builder = builder.top_p(top_p);
183 }
184 if let Some(top_k) = config.top_k {
185 builder = builder.top_k(top_k);
186 }
187 if let Some(normalize_response) = config.normalize_response {
188 builder = builder.normalize_response(normalize_response);
189 }
190 if let Some(extra_body) = config.extra_body {
191 builder = builder.extra_body(extra_body);
192 }
193 if let Some(api_version) = config.api_version {
194 builder = builder.api_version(api_version);
195 }
196 if let Some(deployment_id) = config.deployment_id {
197 builder = builder.deployment_id(deployment_id);
198 }
199 builder
200}
201
202pub trait BuildAutoagentsProvider<T> {
203 fn build_provider(self) -> Result<Arc<dyn LlmProvider>, autoagents_llm::error::LLMError>;
204}
205
206macro_rules! impl_build_provider {
207 ($($ty:path),+ $(,)?) => {
208 $(
209 impl BuildAutoagentsProvider<$ty> for LLMBuilder<$ty> {
210 fn build_provider(self) -> Result<Arc<dyn LlmProvider>, autoagents_llm::error::LLMError> {
211 Ok(self.build()?)
212 }
213 }
214 )+
215 };
216}
217
218impl_build_provider!(
219 autoagents_llm::backends::openai::OpenAI,
220 autoagents_llm::backends::anthropic::Anthropic,
221 autoagents_llm::backends::ollama::Ollama,
222 autoagents_llm::backends::deepseek::DeepSeek,
223 autoagents_llm::backends::xai::XAI,
224 autoagents_llm::backends::phind::Phind,
225 autoagents_llm::backends::google::Google,
226 autoagents_llm::backends::groq::Groq,
227 autoagents_llm::backends::azure_openai::AzureOpenAI,
228 autoagents_llm::backends::openrouter::OpenRouter,
229 autoagents_llm::backends::minimax::MiniMax,
230);
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn unsupported_provider_is_error() {
238 let err = match build_autoagents_provider(AutoagentsProviderConfig::new("missing")) {
239 Ok(_) => panic!("unknown provider should fail"),
240 Err(err) => err,
241 };
242
243 assert!(matches!(err, LlmGatewayError::UnknownProvider(_)));
244 }
245
246 #[test]
247 fn provider_build_errors_are_preserved() {
248 let err = match build_autoagents_provider(AutoagentsProviderConfig::new("openai")) {
249 Ok(_) => panic!("missing key should fail"),
250 Err(err) => err,
251 };
252
253 assert!(err.to_string().contains("OpenAI"));
254 }
255
256 #[test]
257 fn configured_autoagents_providers_build_without_network() {
258 let cases = [
259 AutoagentsProviderConfig {
260 provider: "openai".to_string(),
261 api_key: Some("test".to_string()),
262 model: Some("gpt-test".to_string()),
263 ..AutoagentsProviderConfig::default()
264 },
265 AutoagentsProviderConfig {
266 provider: "anthropic".to_string(),
267 api_key: Some("test".to_string()),
268 model: Some("claude-test".to_string()),
269 ..AutoagentsProviderConfig::default()
270 },
271 AutoagentsProviderConfig {
272 provider: "ollama".to_string(),
273 base_url: Some("http://localhost:11434".to_string()),
274 model: Some("llama-test".to_string()),
275 ..AutoagentsProviderConfig::default()
276 },
277 AutoagentsProviderConfig {
278 provider: "deepseek".to_string(),
279 api_key: Some("test".to_string()),
280 model: Some("deepseek-test".to_string()),
281 ..AutoagentsProviderConfig::default()
282 },
283 AutoagentsProviderConfig {
284 provider: "xai".to_string(),
285 api_key: Some("test".to_string()),
286 model: Some("grok-test".to_string()),
287 ..AutoagentsProviderConfig::default()
288 },
289 AutoagentsProviderConfig {
290 provider: "phind".to_string(),
291 model: Some("phind-test".to_string()),
292 ..AutoagentsProviderConfig::default()
293 },
294 AutoagentsProviderConfig {
295 provider: "google".to_string(),
296 api_key: Some("test".to_string()),
297 model: Some("gemini-test".to_string()),
298 ..AutoagentsProviderConfig::default()
299 },
300 AutoagentsProviderConfig {
301 provider: "groq".to_string(),
302 api_key: Some("test".to_string()),
303 model: Some("llama-test".to_string()),
304 ..AutoagentsProviderConfig::default()
305 },
306 AutoagentsProviderConfig {
307 provider: "azure-openai".to_string(),
308 api_key: Some("test".to_string()),
309 base_url: Some("https://example.test".to_string()),
310 api_version: Some("2024-02-01".to_string()),
311 deployment_id: Some("deployment".to_string()),
312 model: Some("gpt-test".to_string()),
313 ..AutoagentsProviderConfig::default()
314 },
315 AutoagentsProviderConfig {
316 provider: "openrouter".to_string(),
317 api_key: Some("test".to_string()),
318 model: Some("openrouter-test".to_string()),
319 ..AutoagentsProviderConfig::default()
320 },
321 AutoagentsProviderConfig {
322 provider: "minimax".to_string(),
323 api_key: Some("test".to_string()),
324 model: Some("minimax-test".to_string()),
325 ..AutoagentsProviderConfig::default()
326 },
327 ];
328
329 for config in cases {
330 build_autoagents_provider(config).expect("provider builds");
331 }
332 }
333}