ai_lib/provider/
builders.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use crate::{
4    api::ChatProvider,
5    client::{AiClient, AiClientBuilder, Provider},
6    config::ResilienceConfig,
7    metrics::Metrics,
8    provider::{chat_provider::AdapterProvider, config::ProviderConfig, generic::GenericAdapter},
9    transport::DynHttpTransportRef,
10    types::AiLibError,
11};
12
13macro_rules! define_provider_builder {
14    ($name:ident, $provider_variant:expr) => {
15        pub struct $name {
16            inner: AiClientBuilder,
17        }
18
19        impl Default for $name {
20            fn default() -> Self {
21                Self::new()
22            }
23        }
24
25        impl $name {
26            pub fn new() -> Self {
27                Self {
28                    inner: AiClientBuilder::new($provider_variant),
29                }
30            }
31
32            pub fn with_base_url(mut self, base_url: &str) -> Self {
33                self.inner = self.inner.with_base_url(base_url);
34                self
35            }
36
37            pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
38                self.inner = self.inner.with_proxy(proxy_url);
39                self
40            }
41
42            pub fn without_proxy(mut self) -> Self {
43                self.inner = self.inner.without_proxy();
44                self
45            }
46
47            pub fn with_timeout(mut self, timeout: Duration) -> Self {
48                self.inner = self.inner.with_timeout(timeout);
49                self
50            }
51
52            pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: Duration) -> Self {
53                self.inner = self.inner.with_pool_config(max_idle, idle_timeout);
54                self
55            }
56
57            pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
58                self.inner = self.inner.with_metrics(metrics);
59                self
60            }
61
62            pub fn with_default_chat_model(mut self, model: &str) -> Self {
63                self.inner = self.inner.with_default_chat_model(model);
64                self
65            }
66
67            pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
68                self.inner = self.inner.with_default_multimodal_model(model);
69                self
70            }
71
72            pub fn with_smart_defaults(mut self) -> Self {
73                self.inner = self.inner.with_smart_defaults();
74                self
75            }
76
77            pub fn for_production(mut self) -> Self {
78                self.inner = self.inner.for_production();
79                self
80            }
81
82            pub fn for_development(mut self) -> Self {
83                self.inner = self.inner.for_development();
84                self
85            }
86
87            pub fn with_max_concurrency(mut self, max: usize) -> Self {
88                self.inner = self.inner.with_max_concurrency(max);
89                self
90            }
91
92            pub fn with_resilience_config(mut self, config: ResilienceConfig) -> Self {
93                self.inner = self.inner.with_resilience_config(config);
94                self
95            }
96
97            pub fn with_strategy(mut self, strategy: Box<dyn ChatProvider>) -> Self {
98                self.inner = self.inner.with_strategy(strategy);
99                self
100            }
101
102            pub fn build(self) -> Result<AiClient, AiLibError> {
103                self.inner.build()
104            }
105
106            pub fn build_provider(self) -> Result<Box<dyn ChatProvider>, AiLibError> {
107                self.inner.build_provider()
108            }
109        }
110    };
111}
112
113define_provider_builder!(GroqBuilder, Provider::Groq);
114define_provider_builder!(XaiGrokBuilder, Provider::XaiGrok);
115define_provider_builder!(OllamaBuilder, Provider::Ollama);
116define_provider_builder!(DeepSeekBuilder, Provider::DeepSeek);
117define_provider_builder!(AnthropicBuilder, Provider::Anthropic);
118define_provider_builder!(AzureOpenAiBuilder, Provider::AzureOpenAI);
119define_provider_builder!(HuggingFaceBuilder, Provider::HuggingFace);
120define_provider_builder!(TogetherAiBuilder, Provider::TogetherAI);
121define_provider_builder!(OpenRouterBuilder, Provider::OpenRouter);
122define_provider_builder!(ReplicateBuilder, Provider::Replicate);
123define_provider_builder!(BaiduWenxinBuilder, Provider::BaiduWenxin);
124define_provider_builder!(TencentHunyuanBuilder, Provider::TencentHunyuan);
125define_provider_builder!(IflytekSparkBuilder, Provider::IflytekSpark);
126define_provider_builder!(MoonshotBuilder, Provider::Moonshot);
127define_provider_builder!(QwenBuilder, Provider::Qwen);
128define_provider_builder!(ZhipuAiBuilder, Provider::ZhipuAI);
129define_provider_builder!(MiniMaxBuilder, Provider::MiniMax);
130define_provider_builder!(OpenAiBuilder, Provider::OpenAI);
131define_provider_builder!(GeminiBuilder, Provider::Gemini);
132define_provider_builder!(MistralBuilder, Provider::Mistral);
133define_provider_builder!(CohereBuilder, Provider::Cohere);
134define_provider_builder!(PerplexityBuilder, Provider::Perplexity);
135define_provider_builder!(Ai21Builder, Provider::AI21);
136
137/// Builder for OpenAI-compatible custom providers without editing the `Provider` enum.
138///
139/// This builder allows you to create a `ChatProvider` for any service that exposes
140/// an OpenAI-compatible API, even if it's not natively supported by the library.
141///
142/// # Example
143///
144/// ```rust
145/// # use ai_lib::provider::builders::CustomProviderBuilder;
146/// # use ai_lib::AiClientBuilder;
147/// # use ai_lib::Provider;
148/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
149/// // Create a provider for a hypothetical service "MyService"
150/// let my_provider = CustomProviderBuilder::new("MyService")
151///     .with_base_url("https://api.myservice.com/v1")
152///     .with_api_key_env("MY_SERVICE_API_KEY")
153///     .with_default_chat_model("my-model-v1")
154///     .build_provider()?;
155///
156/// // Inject it into the client
157/// let client = AiClientBuilder::new(Provider::OpenAI) // Enum ignored
158///     .with_strategy(my_provider)
159///     .build()?;
160/// # Ok(())
161/// # }
162/// ```
163pub struct CustomProviderBuilder {
164    name: String,
165    base_url: Option<String>,
166    api_key_env: Option<String>,
167    api_key_override: Option<String>,
168    chat_model: Option<String>,
169    multimodal_model: Option<String>,
170    chat_endpoint: String,
171    upload_endpoint: Option<String>,
172    models_endpoint: Option<String>,
173    headers: HashMap<String, String>,
174    transport: Option<DynHttpTransportRef>,
175}
176
177impl CustomProviderBuilder {
178    /// Create a new builder with the human-readable provider name.
179    pub fn new(name: impl Into<String>) -> Self {
180        Self {
181            name: name.into(),
182            base_url: None,
183            api_key_env: None,
184            api_key_override: None,
185            chat_model: None,
186            multimodal_model: None,
187            chat_endpoint: "/chat/completions".to_string(),
188            upload_endpoint: Some("/v1/files".to_string()),
189            models_endpoint: Some("/models".to_string()),
190            headers: HashMap::new(),
191            transport: None,
192        }
193    }
194
195    /// Set the base URL (required) for the custom provider.
196    pub fn with_base_url(mut self, base_url: &str) -> Self {
197        self.base_url = Some(base_url.to_string());
198        self
199    }
200
201    /// Override the environment variable used to fetch the API key at runtime.
202    pub fn with_api_key_env(mut self, env_var: &str) -> Self {
203        self.api_key_env = Some(env_var.to_string());
204        self
205    }
206
207    /// Inject a literal API key instead of relying on environment variables.
208    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
209        self.api_key_override = Some(api_key.into());
210        self
211    }
212
213    /// Override the default chat model used for simple helpers.
214    pub fn with_default_chat_model(mut self, model: &str) -> Self {
215        self.chat_model = Some(model.to_string());
216        self
217    }
218
219    /// Override the default multimodal model (optional).
220    pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
221        self.multimodal_model = Some(model.to_string());
222        self
223    }
224
225    /// Override the chat completion endpoint (default: `/chat/completions`).
226    pub fn with_chat_endpoint(mut self, endpoint: &str) -> Self {
227        self.chat_endpoint = endpoint.to_string();
228        self
229    }
230
231    /// Override the upload endpoint (default: `/v1/files`).
232    pub fn with_upload_endpoint(mut self, endpoint: Option<&str>) -> Self {
233        self.upload_endpoint = endpoint.map(|e| e.to_string());
234        self
235    }
236
237    /// Override the models endpoint (default: `/models`).
238    pub fn with_models_endpoint(mut self, endpoint: Option<&str>) -> Self {
239        self.models_endpoint = endpoint.map(|e| e.to_string());
240        self
241    }
242
243    /// Merge custom HTTP headers (e.g., vendor-specific auth scopes).
244    pub fn with_headers<I, K, V>(mut self, headers: I) -> Self
245    where
246        I: IntoIterator<Item = (K, V)>,
247        K: Into<String>,
248        V: Into<String>,
249    {
250        for (k, v) in headers {
251            self.headers.insert(k.into(), v.into());
252        }
253        self
254    }
255
256    /// Provide a pre-built transport (shared client, proxy, custom TLS, etc.).
257    pub fn with_transport(mut self, transport: DynHttpTransportRef) -> Self {
258        self.transport = Some(transport);
259        self
260    }
261
262    /// Build a boxed `ChatProvider` that can be passed to `AiClientBuilder::with_strategy`.
263    pub fn build_provider(self) -> Result<Box<dyn ChatProvider>, AiLibError> {
264        let base_url = self.base_url.ok_or_else(|| {
265            AiLibError::ConfigurationError(
266                "CustomProviderBuilder requires `with_base_url` to be set".to_string(),
267            )
268        })?;
269
270        let chat_model = self
271            .chat_model
272            .unwrap_or_else(|| "gpt-3.5-turbo".to_string());
273        let env_key = self.api_key_env.unwrap_or_else(|| {
274            let upper = self
275                .name
276                .chars()
277                .map(|c| {
278                    if c.is_ascii_alphanumeric() {
279                        c.to_ascii_uppercase()
280                    } else {
281                        '_'
282                    }
283                })
284                .collect::<String>();
285            format!("{upper}_API_KEY")
286        });
287
288        let mut config = ProviderConfig::openai_compatible(
289            &base_url,
290            &env_key,
291            &chat_model,
292            self.multimodal_model.as_deref(),
293        );
294        config.chat_endpoint = self.chat_endpoint;
295        config.upload_endpoint = self.upload_endpoint;
296        config.models_endpoint = self.models_endpoint;
297        config.headers.extend(self.headers);
298
299        let adapter = match (self.transport, self.api_key_override) {
300            (Some(transport), api_key) => {
301                GenericAdapter::with_transport_ref_api_key(config, transport, api_key)?
302            }
303            (None, api_key) => GenericAdapter::new_with_api_key(config, api_key)?,
304        };
305
306        Ok(AdapterProvider::new(self.name, Box::new(adapter)).boxed())
307    }
308}