Skip to main content

fastllm_core/
autoagents.rs

1use crate::{LlmGatewayError, LlmProvider};
2use autoagents_llm::builder::LLMBuilder;
3use autoagents_llm::chat::ReasoningEffort;
4use serde_json::Value;
5use std::sync::Arc;
6
7#[derive(Debug, Clone, Default, PartialEq)]
8pub struct AutoagentsProviderConfig {
9    pub provider: String,
10    pub model: Option<String>,
11    pub api_key: Option<String>,
12    pub base_url: Option<String>,
13    pub max_tokens: Option<u32>,
14    pub temperature: Option<f32>,
15    pub timeout_seconds: Option<u64>,
16    pub reasoning: Option<bool>,
17    pub reasoning_effort: Option<String>,
18    pub reasoning_budget_tokens: Option<u32>,
19    pub top_p: Option<f32>,
20    pub top_k: Option<u32>,
21    pub normalize_response: Option<bool>,
22    pub extra_body: Option<Value>,
23    pub api_version: Option<String>,
24    pub deployment_id: Option<String>,
25}
26
27impl AutoagentsProviderConfig {
28    pub fn new(provider: impl Into<String>) -> Self {
29        Self {
30            provider: provider.into(),
31            ..Self::default()
32        }
33    }
34
35    pub fn from_env(provider: impl Into<String>, model: impl Into<String>) -> Self {
36        let provider = provider.into();
37        let env_prefix = provider.to_ascii_uppercase().replace(['-', '.'], "_");
38        Self {
39            provider,
40            model: Some(model.into()),
41            api_key: std::env::var(format!("{env_prefix}_API_KEY")).ok(),
42            base_url: std::env::var(format!("{env_prefix}_BASE_URL")).ok(),
43            api_version: std::env::var(format!("{env_prefix}_API_VERSION")).ok(),
44            deployment_id: std::env::var(format!("{env_prefix}_DEPLOYMENT_ID")).ok(),
45            ..Self::default()
46        }
47    }
48
49    pub fn with_model_config(mut self, config: Option<&Value>) -> Self {
50        let Some(config) = config.and_then(Value::as_object) else {
51            return self;
52        };
53
54        self.max_tokens = config
55            .get("max_tokens")
56            .and_then(Value::as_u64)
57            .and_then(|value| u32::try_from(value).ok())
58            .or(self.max_tokens);
59        self.temperature = config
60            .get("temperature")
61            .and_then(Value::as_f64)
62            .map(|value| value as f32)
63            .or(self.temperature);
64        self.timeout_seconds = config
65            .get("timeout_seconds")
66            .and_then(Value::as_u64)
67            .or(self.timeout_seconds);
68        self.reasoning = config
69            .get("reasoning")
70            .and_then(Value::as_bool)
71            .or(self.reasoning);
72        self.reasoning_effort = config
73            .get("reasoning_effort")
74            .and_then(Value::as_str)
75            .map(str::to_string)
76            .or(self.reasoning_effort);
77        self.reasoning_budget_tokens = config
78            .get("reasoning_budget_tokens")
79            .and_then(Value::as_u64)
80            .and_then(|value| u32::try_from(value).ok())
81            .or(self.reasoning_budget_tokens);
82        self.top_p = config
83            .get("top_p")
84            .and_then(Value::as_f64)
85            .map(|value| value as f32)
86            .or(self.top_p);
87        self.top_k = config
88            .get("top_k")
89            .and_then(Value::as_u64)
90            .and_then(|value| u32::try_from(value).ok())
91            .or(self.top_k);
92        self.normalize_response = config
93            .get("normalize_response")
94            .and_then(Value::as_bool)
95            .or(self.normalize_response);
96        self.extra_body = config.get("extra_body").cloned().or(self.extra_body);
97        self.api_version = config
98            .get("api_version")
99            .and_then(Value::as_str)
100            .map(str::to_string)
101            .or(self.api_version);
102        self.deployment_id = config
103            .get("deployment_id")
104            .and_then(Value::as_str)
105            .map(str::to_string)
106            .or(self.deployment_id);
107        self
108    }
109}
110
111pub fn build_autoagents_provider(
112    config: AutoagentsProviderConfig,
113) -> Result<Arc<dyn LlmProvider>, LlmGatewayError> {
114    match config.provider.as_str() {
115        "openai" => build::<autoagents_llm::backends::openai::OpenAI>(config),
116        "anthropic" => build::<autoagents_llm::backends::anthropic::Anthropic>(config),
117        "ollama" => build::<autoagents_llm::backends::ollama::Ollama>(config),
118        "deepseek" => build::<autoagents_llm::backends::deepseek::DeepSeek>(config),
119        "xai" => build::<autoagents_llm::backends::xai::XAI>(config),
120        "phind" => build::<autoagents_llm::backends::phind::Phind>(config),
121        "google" => build::<autoagents_llm::backends::google::Google>(config),
122        "groq" => build::<autoagents_llm::backends::groq::Groq>(config),
123        "azure-openai" => build::<autoagents_llm::backends::azure_openai::AzureOpenAI>(config),
124        "openrouter" => build::<autoagents_llm::backends::openrouter::OpenRouter>(config),
125        "minimax" => build::<autoagents_llm::backends::minimax::MiniMax>(config),
126        other => Err(LlmGatewayError::UnknownProvider(other.to_string())),
127    }
128}
129
130fn build<T>(config: AutoagentsProviderConfig) -> Result<Arc<dyn LlmProvider>, LlmGatewayError>
131where
132    T: autoagents_llm::LLMProvider + autoagents_llm::HasConfig,
133    LLMBuilder<T>: BuildAutoagentsProvider<T>,
134{
135    let provider_name = config.provider.clone();
136    BuildAutoagentsProvider::build_provider(apply_common::<T>(config)).map_err(|source| {
137        LlmGatewayError::Provider {
138            provider: provider_name,
139            message: source.to_string(),
140        }
141    })
142}
143
144fn apply_common<T>(config: AutoagentsProviderConfig) -> LLMBuilder<T>
145where
146    T: autoagents_llm::LLMProvider + autoagents_llm::HasConfig,
147{
148    let mut builder = LLMBuilder::<T>::new();
149    if let Some(api_key) = config.api_key {
150        builder = builder.api_key(api_key);
151    }
152    if let Some(base_url) = config.base_url {
153        builder = builder.base_url(base_url);
154    }
155    if let Some(model) = config.model {
156        builder = builder.model(model);
157    }
158    if let Some(max_tokens) = config.max_tokens {
159        builder = builder.max_tokens(max_tokens);
160    }
161    if let Some(temperature) = config.temperature {
162        builder = builder.temperature(temperature);
163    }
164    if let Some(timeout_seconds) = config.timeout_seconds {
165        builder = builder.timeout_seconds(timeout_seconds);
166    }
167    if let Some(reasoning) = config.reasoning {
168        builder = builder.reasoning(reasoning);
169    }
170    if let Some(reasoning_effort) = config.reasoning_effort {
171        builder = match reasoning_effort.as_str() {
172            "low" => builder.reasoning_effort(ReasoningEffort::Low),
173            "medium" => builder.reasoning_effort(ReasoningEffort::Medium),
174            "high" => builder.reasoning_effort(ReasoningEffort::High),
175            _ => builder,
176        };
177    }
178    if let Some(reasoning_budget_tokens) = config.reasoning_budget_tokens {
179        builder = builder.reasoning_budget_tokens(reasoning_budget_tokens);
180    }
181    if let Some(top_p) = config.top_p {
182        builder = builder.top_p(top_p);
183    }
184    if let Some(top_k) = config.top_k {
185        builder = builder.top_k(top_k);
186    }
187    if let Some(normalize_response) = config.normalize_response {
188        builder = builder.normalize_response(normalize_response);
189    }
190    if let Some(extra_body) = config.extra_body {
191        builder = builder.extra_body(extra_body);
192    }
193    if let Some(api_version) = config.api_version {
194        builder = builder.api_version(api_version);
195    }
196    if let Some(deployment_id) = config.deployment_id {
197        builder = builder.deployment_id(deployment_id);
198    }
199    builder
200}
201
202pub trait BuildAutoagentsProvider<T> {
203    fn build_provider(self) -> Result<Arc<dyn LlmProvider>, autoagents_llm::error::LLMError>;
204}
205
206macro_rules! impl_build_provider {
207    ($($ty:path),+ $(,)?) => {
208        $(
209            impl BuildAutoagentsProvider<$ty> for LLMBuilder<$ty> {
210                fn build_provider(self) -> Result<Arc<dyn LlmProvider>, autoagents_llm::error::LLMError> {
211                    Ok(self.build()?)
212                }
213            }
214        )+
215    };
216}
217
218impl_build_provider!(
219    autoagents_llm::backends::openai::OpenAI,
220    autoagents_llm::backends::anthropic::Anthropic,
221    autoagents_llm::backends::ollama::Ollama,
222    autoagents_llm::backends::deepseek::DeepSeek,
223    autoagents_llm::backends::xai::XAI,
224    autoagents_llm::backends::phind::Phind,
225    autoagents_llm::backends::google::Google,
226    autoagents_llm::backends::groq::Groq,
227    autoagents_llm::backends::azure_openai::AzureOpenAI,
228    autoagents_llm::backends::openrouter::OpenRouter,
229    autoagents_llm::backends::minimax::MiniMax,
230);
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn unsupported_provider_is_error() {
238        let err = match build_autoagents_provider(AutoagentsProviderConfig::new("missing")) {
239            Ok(_) => panic!("unknown provider should fail"),
240            Err(err) => err,
241        };
242
243        assert!(matches!(err, LlmGatewayError::UnknownProvider(_)));
244    }
245
246    #[test]
247    fn provider_build_errors_are_preserved() {
248        let err = match build_autoagents_provider(AutoagentsProviderConfig::new("openai")) {
249            Ok(_) => panic!("missing key should fail"),
250            Err(err) => err,
251        };
252
253        assert!(err.to_string().contains("OpenAI"));
254    }
255
256    #[test]
257    fn configured_autoagents_providers_build_without_network() {
258        let cases = [
259            AutoagentsProviderConfig {
260                provider: "openai".to_string(),
261                api_key: Some("test".to_string()),
262                model: Some("gpt-test".to_string()),
263                ..AutoagentsProviderConfig::default()
264            },
265            AutoagentsProviderConfig {
266                provider: "anthropic".to_string(),
267                api_key: Some("test".to_string()),
268                model: Some("claude-test".to_string()),
269                ..AutoagentsProviderConfig::default()
270            },
271            AutoagentsProviderConfig {
272                provider: "ollama".to_string(),
273                base_url: Some("http://localhost:11434".to_string()),
274                model: Some("llama-test".to_string()),
275                ..AutoagentsProviderConfig::default()
276            },
277            AutoagentsProviderConfig {
278                provider: "deepseek".to_string(),
279                api_key: Some("test".to_string()),
280                model: Some("deepseek-test".to_string()),
281                ..AutoagentsProviderConfig::default()
282            },
283            AutoagentsProviderConfig {
284                provider: "xai".to_string(),
285                api_key: Some("test".to_string()),
286                model: Some("grok-test".to_string()),
287                ..AutoagentsProviderConfig::default()
288            },
289            AutoagentsProviderConfig {
290                provider: "phind".to_string(),
291                model: Some("phind-test".to_string()),
292                ..AutoagentsProviderConfig::default()
293            },
294            AutoagentsProviderConfig {
295                provider: "google".to_string(),
296                api_key: Some("test".to_string()),
297                model: Some("gemini-test".to_string()),
298                ..AutoagentsProviderConfig::default()
299            },
300            AutoagentsProviderConfig {
301                provider: "groq".to_string(),
302                api_key: Some("test".to_string()),
303                model: Some("llama-test".to_string()),
304                ..AutoagentsProviderConfig::default()
305            },
306            AutoagentsProviderConfig {
307                provider: "azure-openai".to_string(),
308                api_key: Some("test".to_string()),
309                base_url: Some("https://example.test".to_string()),
310                api_version: Some("2024-02-01".to_string()),
311                deployment_id: Some("deployment".to_string()),
312                model: Some("gpt-test".to_string()),
313                ..AutoagentsProviderConfig::default()
314            },
315            AutoagentsProviderConfig {
316                provider: "openrouter".to_string(),
317                api_key: Some("test".to_string()),
318                model: Some("openrouter-test".to_string()),
319                ..AutoagentsProviderConfig::default()
320            },
321            AutoagentsProviderConfig {
322                provider: "minimax".to_string(),
323                api_key: Some("test".to_string()),
324                model: Some("minimax-test".to_string()),
325                ..AutoagentsProviderConfig::default()
326            },
327        ];
328
329        for config in cases {
330            build_autoagents_provider(config).expect("provider builds");
331        }
332    }
333}