1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use crate::{
4 api::ChatProvider,
5 client::{AiClient, AiClientBuilder, Provider},
6 config::ResilienceConfig,
7 metrics::Metrics,
8 provider::{chat_provider::AdapterProvider, config::ProviderConfig, generic::GenericAdapter},
9 transport::DynHttpTransportRef,
10 types::AiLibError,
11};
12
13macro_rules! define_provider_builder {
14 ($name:ident, $provider_variant:expr) => {
15 pub struct $name {
16 inner: AiClientBuilder,
17 }
18
19 impl Default for $name {
20 fn default() -> Self {
21 Self::new()
22 }
23 }
24
25 impl $name {
26 pub fn new() -> Self {
27 Self {
28 inner: AiClientBuilder::new($provider_variant),
29 }
30 }
31
32 pub fn with_base_url(mut self, base_url: &str) -> Self {
33 self.inner = self.inner.with_base_url(base_url);
34 self
35 }
36
37 pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
38 self.inner = self.inner.with_proxy(proxy_url);
39 self
40 }
41
42 pub fn without_proxy(mut self) -> Self {
43 self.inner = self.inner.without_proxy();
44 self
45 }
46
47 pub fn with_timeout(mut self, timeout: Duration) -> Self {
48 self.inner = self.inner.with_timeout(timeout);
49 self
50 }
51
52 pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: Duration) -> Self {
53 self.inner = self.inner.with_pool_config(max_idle, idle_timeout);
54 self
55 }
56
57 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
58 self.inner = self.inner.with_metrics(metrics);
59 self
60 }
61
62 pub fn with_default_chat_model(mut self, model: &str) -> Self {
63 self.inner = self.inner.with_default_chat_model(model);
64 self
65 }
66
67 pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
68 self.inner = self.inner.with_default_multimodal_model(model);
69 self
70 }
71
72 pub fn with_smart_defaults(mut self) -> Self {
73 self.inner = self.inner.with_smart_defaults();
74 self
75 }
76
77 pub fn for_production(mut self) -> Self {
78 self.inner = self.inner.for_production();
79 self
80 }
81
82 pub fn for_development(mut self) -> Self {
83 self.inner = self.inner.for_development();
84 self
85 }
86
87 pub fn with_max_concurrency(mut self, max: usize) -> Self {
88 self.inner = self.inner.with_max_concurrency(max);
89 self
90 }
91
92 pub fn with_resilience_config(mut self, config: ResilienceConfig) -> Self {
93 self.inner = self.inner.with_resilience_config(config);
94 self
95 }
96
97 pub fn with_strategy(mut self, strategy: Box<dyn ChatProvider>) -> Self {
98 self.inner = self.inner.with_strategy(strategy);
99 self
100 }
101
102 pub fn build(self) -> Result<AiClient, AiLibError> {
103 self.inner.build()
104 }
105
106 pub fn build_provider(self) -> Result<Box<dyn ChatProvider>, AiLibError> {
107 self.inner.build_provider()
108 }
109 }
110 };
111}
112
113define_provider_builder!(GroqBuilder, Provider::Groq);
114define_provider_builder!(XaiGrokBuilder, Provider::XaiGrok);
115define_provider_builder!(OllamaBuilder, Provider::Ollama);
116define_provider_builder!(DeepSeekBuilder, Provider::DeepSeek);
117define_provider_builder!(AnthropicBuilder, Provider::Anthropic);
118define_provider_builder!(AzureOpenAiBuilder, Provider::AzureOpenAI);
119define_provider_builder!(HuggingFaceBuilder, Provider::HuggingFace);
120define_provider_builder!(TogetherAiBuilder, Provider::TogetherAI);
121define_provider_builder!(OpenRouterBuilder, Provider::OpenRouter);
122define_provider_builder!(ReplicateBuilder, Provider::Replicate);
123define_provider_builder!(BaiduWenxinBuilder, Provider::BaiduWenxin);
124define_provider_builder!(TencentHunyuanBuilder, Provider::TencentHunyuan);
125define_provider_builder!(IflytekSparkBuilder, Provider::IflytekSpark);
126define_provider_builder!(MoonshotBuilder, Provider::Moonshot);
127define_provider_builder!(QwenBuilder, Provider::Qwen);
128define_provider_builder!(ZhipuAiBuilder, Provider::ZhipuAI);
129define_provider_builder!(MiniMaxBuilder, Provider::MiniMax);
130define_provider_builder!(OpenAiBuilder, Provider::OpenAI);
131define_provider_builder!(GeminiBuilder, Provider::Gemini);
132define_provider_builder!(MistralBuilder, Provider::Mistral);
133define_provider_builder!(CohereBuilder, Provider::Cohere);
134define_provider_builder!(PerplexityBuilder, Provider::Perplexity);
135define_provider_builder!(Ai21Builder, Provider::AI21);
136
137pub struct CustomProviderBuilder {
164 name: String,
165 base_url: Option<String>,
166 api_key_env: Option<String>,
167 api_key_override: Option<String>,
168 chat_model: Option<String>,
169 multimodal_model: Option<String>,
170 chat_endpoint: String,
171 upload_endpoint: Option<String>,
172 models_endpoint: Option<String>,
173 headers: HashMap<String, String>,
174 transport: Option<DynHttpTransportRef>,
175}
176
177impl CustomProviderBuilder {
178 pub fn new(name: impl Into<String>) -> Self {
180 Self {
181 name: name.into(),
182 base_url: None,
183 api_key_env: None,
184 api_key_override: None,
185 chat_model: None,
186 multimodal_model: None,
187 chat_endpoint: "/chat/completions".to_string(),
188 upload_endpoint: Some("/v1/files".to_string()),
189 models_endpoint: Some("/models".to_string()),
190 headers: HashMap::new(),
191 transport: None,
192 }
193 }
194
195 pub fn with_base_url(mut self, base_url: &str) -> Self {
197 self.base_url = Some(base_url.to_string());
198 self
199 }
200
201 pub fn with_api_key_env(mut self, env_var: &str) -> Self {
203 self.api_key_env = Some(env_var.to_string());
204 self
205 }
206
207 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
209 self.api_key_override = Some(api_key.into());
210 self
211 }
212
213 pub fn with_default_chat_model(mut self, model: &str) -> Self {
215 self.chat_model = Some(model.to_string());
216 self
217 }
218
219 pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
221 self.multimodal_model = Some(model.to_string());
222 self
223 }
224
225 pub fn with_chat_endpoint(mut self, endpoint: &str) -> Self {
227 self.chat_endpoint = endpoint.to_string();
228 self
229 }
230
231 pub fn with_upload_endpoint(mut self, endpoint: Option<&str>) -> Self {
233 self.upload_endpoint = endpoint.map(|e| e.to_string());
234 self
235 }
236
237 pub fn with_models_endpoint(mut self, endpoint: Option<&str>) -> Self {
239 self.models_endpoint = endpoint.map(|e| e.to_string());
240 self
241 }
242
243 pub fn with_headers<I, K, V>(mut self, headers: I) -> Self
245 where
246 I: IntoIterator<Item = (K, V)>,
247 K: Into<String>,
248 V: Into<String>,
249 {
250 for (k, v) in headers {
251 self.headers.insert(k.into(), v.into());
252 }
253 self
254 }
255
256 pub fn with_transport(mut self, transport: DynHttpTransportRef) -> Self {
258 self.transport = Some(transport);
259 self
260 }
261
262 pub fn build_provider(self) -> Result<Box<dyn ChatProvider>, AiLibError> {
264 let base_url = self.base_url.ok_or_else(|| {
265 AiLibError::ConfigurationError(
266 "CustomProviderBuilder requires `with_base_url` to be set".to_string(),
267 )
268 })?;
269
270 let chat_model = self
271 .chat_model
272 .unwrap_or_else(|| "gpt-3.5-turbo".to_string());
273 let env_key = self.api_key_env.unwrap_or_else(|| {
274 let upper = self
275 .name
276 .chars()
277 .map(|c| {
278 if c.is_ascii_alphanumeric() {
279 c.to_ascii_uppercase()
280 } else {
281 '_'
282 }
283 })
284 .collect::<String>();
285 format!("{upper}_API_KEY")
286 });
287
288 let mut config = ProviderConfig::openai_compatible(
289 &base_url,
290 &env_key,
291 &chat_model,
292 self.multimodal_model.as_deref(),
293 );
294 config.chat_endpoint = self.chat_endpoint;
295 config.upload_endpoint = self.upload_endpoint;
296 config.models_endpoint = self.models_endpoint;
297 config.headers.extend(self.headers);
298
299 let adapter = match (self.transport, self.api_key_override) {
300 (Some(transport), api_key) => {
301 GenericAdapter::with_transport_ref_api_key(config, transport, api_key)?
302 }
303 (None, api_key) => GenericAdapter::new_with_api_key(config, api_key)?,
304 };
305
306 Ok(AdapterProvider::new(self.name, Box::new(adapter)).boxed())
307 }
308}