ai_lib/client/
builder.rs

1// AI Client Builder Module
2//
3// This module contains the AiClientBuilder implementation with progressive
4// custom configuration options for creating AiClient instances.
5
6use super::metadata::metadata_from_provider;
7use super::{AiClient, Provider, ProviderFactory};
8use crate::api::ChatProvider;
9use crate::config::{ConnectionOptions, ResilienceConfig};
10use crate::metrics::{Metrics, NoopMetrics};
11use crate::model::ModelResolver;
12use crate::provider::classification::ProviderClassification;
13use crate::provider::strategies::{FailoverProvider, RoundRobinProvider, RoutingStrategyBuilder};
14use crate::rate_limiter::BackpressureController;
15use crate::types::AiLibError;
16use std::sync::Arc;
17
18/// AI client builder with progressive custom configuration
19pub struct AiClientBuilder {
20    provider: Provider,
21    base_url: Option<String>,
22    proxy_url: Option<String>,
23    timeout: Option<std::time::Duration>,
24    pool_max_idle: Option<usize>,
25    pool_idle_timeout: Option<std::time::Duration>,
26    metrics: Option<Arc<dyn Metrics>>,
27    // Model configuration options
28    default_chat_model: Option<String>,
29    default_multimodal_model: Option<String>,
30    // Resilience configuration
31    resilience_config: ResilienceConfig,
32    #[cfg(feature = "interceptors")]
33    interceptor_pipeline: Option<crate::interceptors::InterceptorPipeline>,
34    #[cfg(feature = "interceptors")]
35    interceptor_builder: Option<crate::interceptors::default::DefaultInterceptorsBuilder>,
36    // Optional pre-built provider strategy (allows building provider directly)
37    strategy: Option<Box<dyn ChatProvider>>,
38    model_resolver: Option<Arc<ModelResolver>>,
39}
40
41impl AiClientBuilder {
42    /// Create a new builder instance
43    pub fn new(provider: Provider) -> Self {
44        Self {
45            provider,
46            base_url: None,
47            proxy_url: None,
48            timeout: None,
49            pool_max_idle: None,
50            pool_idle_timeout: None,
51            metrics: None,
52            default_chat_model: None,
53            default_multimodal_model: None,
54            resilience_config: ResilienceConfig::default(),
55            #[cfg(feature = "interceptors")]
56            interceptor_pipeline: None,
57            #[cfg(feature = "interceptors")]
58            interceptor_builder: None,
59            strategy: None,
60            model_resolver: None,
61        }
62    }
63
64    /// Set custom base URL
65    pub fn with_base_url(mut self, base_url: &str) -> Self {
66        self.base_url = Some(base_url.to_string());
67        self
68    }
69
70    /// Set custom proxy URL
71    pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
72        self.proxy_url = proxy_url.map(|s| s.to_string());
73        self
74    }
75
76    /// Explicitly disable proxy usage
77    pub fn without_proxy(mut self) -> Self {
78        self.proxy_url = Some("".to_string());
79        self
80    }
81
82    /// Set custom timeout duration
83    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
84        self.timeout = Some(timeout);
85        self
86    }
87
88    /// Set connection pool configuration
89    pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
90        self.pool_max_idle = Some(max_idle);
91        self.pool_idle_timeout = Some(idle_timeout);
92        self
93    }
94
95    /// Set custom metrics implementation
96    pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
97        self.metrics = Some(metrics);
98        self
99    }
100
101    #[cfg(feature = "interceptors")]
102    pub fn with_interceptor_pipeline(
103        mut self,
104        pipeline: crate::interceptors::InterceptorPipeline,
105    ) -> Self {
106        self.interceptor_pipeline = Some(pipeline);
107        self
108    }
109
110    #[cfg(feature = "interceptors")]
111    pub fn enable_default_interceptors(mut self) -> Self {
112        let p = crate::interceptors::create_default_interceptors();
113        self.interceptor_pipeline = Some(p);
114        self
115    }
116
117    #[cfg(feature = "interceptors")]
118    pub fn enable_minimal_interceptors(mut self) -> Self {
119        let p = crate::interceptors::default::DefaultInterceptorsBuilder::new()
120            .enable_circuit_breaker(false)
121            .enable_rate_limit(false)
122            .build();
123        self.interceptor_pipeline = Some(p);
124        self
125    }
126
127    // --- Interceptor Configuration Proxies ---
128
129    #[cfg(feature = "interceptors")]
130    pub fn with_rate_limit(mut self, requests_per_minute: u32) -> Self {
131        let builder = self.interceptor_builder.unwrap_or_default();
132        self.interceptor_builder = Some(builder.with_rate_limit(requests_per_minute));
133        self
134    }
135
136    #[cfg(feature = "interceptors")]
137    pub fn with_circuit_breaker(mut self, threshold: u32, recovery: std::time::Duration) -> Self {
138        let builder = self.interceptor_builder.unwrap_or_default();
139        self.interceptor_builder = Some(builder.with_circuit_breaker(threshold, recovery));
140        self
141    }
142
143    #[cfg(feature = "interceptors")]
144    pub fn with_retry(
145        mut self,
146        max_attempts: u32,
147        base_delay: std::time::Duration,
148        max_delay: std::time::Duration,
149    ) -> Self {
150        let builder = self.interceptor_builder.unwrap_or_default();
151        self.interceptor_builder = Some(builder.with_retry(max_attempts, base_delay, max_delay));
152        self
153    }
154
155    #[cfg(feature = "interceptors")]
156    pub fn with_interceptor_timeout(mut self, duration: std::time::Duration) -> Self {
157        let builder = self.interceptor_builder.unwrap_or_default();
158        self.interceptor_builder = Some(builder.with_timeout(duration));
159        self
160    }
161
162    #[cfg(feature = "interceptors")]
163    pub fn enable_retry(mut self, enable: bool) -> Self {
164        let builder = self.interceptor_builder.unwrap_or_default();
165        self.interceptor_builder = Some(builder.enable_retry(enable));
166        self
167    }
168
169    #[cfg(feature = "interceptors")]
170    pub fn enable_circuit_breaker(mut self, enable: bool) -> Self {
171        let builder = self.interceptor_builder.unwrap_or_default();
172        self.interceptor_builder = Some(builder.enable_circuit_breaker(enable));
173        self
174    }
175
176    #[cfg(feature = "interceptors")]
177    pub fn enable_rate_limit(mut self, enable: bool) -> Self {
178        let builder = self.interceptor_builder.unwrap_or_default();
179        self.interceptor_builder = Some(builder.enable_rate_limit(enable));
180        self
181    }
182
183    #[cfg(feature = "interceptors")]
184    pub fn enable_interceptor_timeout(mut self, enable: bool) -> Self {
185        let builder = self.interceptor_builder.unwrap_or_default();
186        self.interceptor_builder = Some(builder.enable_timeout(enable));
187        self
188    }
189
190    /// Set default chat model for the client
191    pub fn with_default_chat_model(mut self, model: &str) -> Self {
192        self.default_chat_model = Some(model.to_string());
193        self
194    }
195
196    /// Set default multimodal model for the client
197    pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
198        self.default_multimodal_model = Some(model.to_string());
199        self
200    }
201
202    /// Inject a custom model resolver (advanced usage).
203    pub fn with_model_resolver(mut self, resolver: Arc<ModelResolver>) -> Self {
204        self.model_resolver = Some(resolver);
205        self
206    }
207
208    /// Enable smart defaults for resilience features
209    pub fn with_smart_defaults(mut self) -> Self {
210        self.resilience_config = ResilienceConfig::smart_defaults();
211        self
212    }
213
214    /// Configure for production environment
215    pub fn for_production(mut self) -> Self {
216        self.resilience_config = ResilienceConfig::production();
217        self
218    }
219
220    /// Configure for development environment
221    pub fn for_development(mut self) -> Self {
222        self.resilience_config = ResilienceConfig::development();
223        self
224    }
225
226    /// Build and return a boxed `ChatProvider` according to the current builder configuration.
227    /// If a custom `strategy` was provided via `with_strategy`, it will be returned directly.
228    pub fn build_provider(mut self) -> Result<Box<dyn ChatProvider>, AiLibError> {
229        // If caller supplied a custom strategy, use it directly
230        if let Some(p) = self.strategy.take() {
231            return Ok(p);
232        }
233
234        // Otherwise, construct provider according to builder configuration
235        let base_url = self.determine_base_url()?;
236        let proxy_url = self.determine_proxy_url();
237        let timeout = self
238            .timeout
239            .unwrap_or_else(|| std::time::Duration::from_secs(30));
240        let transport = self.create_custom_transport(proxy_url, timeout)?;
241
242        ProviderFactory::create_adapter(self.provider, None, Some(base_url), transport)
243    }
244
245    /// Configure a simple max concurrent requests backpressure guard
246    pub fn with_max_concurrency(mut self, max_concurrent_requests: usize) -> Self {
247        let mut cfg = self.resilience_config.clone();
248        cfg.backpressure = Some(crate::config::BackpressureConfig {
249            max_concurrent_requests,
250        });
251        self.resilience_config = cfg;
252        self
253    }
254
255    /// Set custom resilience configuration
256    pub fn with_resilience_config(mut self, config: ResilienceConfig) -> Self {
257        self.resilience_config = config;
258        self
259    }
260
261    /// Provide a custom provider strategy (boxed ChatProvider)
262    ///
263    /// This allows injecting a fully custom implementation of `ChatProvider`,
264    /// bypassing the standard provider factory logic.
265    ///
266    /// # Example
267    ///
268    /// ```rust
269    /// # use ai_lib::{AiClientBuilder, Provider};
270    /// # use ai_lib::provider::strategies::RoundRobinProvider;
271    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
272    /// // Create a custom strategy (e.g., manually built RoundRobin)
273    /// let strategy = RoundRobinProvider::new(vec![])?;
274    ///
275    /// let client = AiClientBuilder::new(Provider::OpenAI) // Provider enum ignored here
276    ///     .with_strategy(Box::new(strategy))
277    ///     .build()?;
278    /// # Ok(())
279    /// # }
280    /// ```
281    pub fn with_strategy(mut self, strategy: Box<dyn ChatProvider>) -> Self {
282        self.strategy = Some(strategy);
283        self
284    }
285
286    /// Compose a round-robin strategy from the provided providers.
287    ///
288    /// This method takes a collection of boxed `ChatProvider` instances and wraps them
289    /// in a `RoundRobinProvider`.
290    ///
291    /// # Example
292    ///
293    /// ```rust
294    /// # use ai_lib::{AiClientBuilder, Provider, AiClient};
295    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
296    /// let p1 = AiClientBuilder::new(Provider::OpenAI).build_provider()?;
297    /// let p2 = AiClientBuilder::new(Provider::Anthropic).build_provider()?;
298    ///
299    /// let client = AiClientBuilder::new(Provider::OpenAI)
300    ///     .with_round_robin_strategy(vec![p1, p2])?
301    ///     .build()?;
302    /// # Ok(())
303    /// # }
304    /// ```
305    pub fn with_round_robin_strategy<I>(mut self, providers: I) -> Result<Self, AiLibError>
306    where
307        I: IntoIterator<Item = Box<dyn ChatProvider>>,
308    {
309        let providers_vec: Vec<_> = providers.into_iter().collect();
310        let rr = RoundRobinProvider::new(providers_vec)?;
311        self.strategy = Some(Box::new(rr));
312        Ok(self)
313    }
314
315    /// Compose a failover strategy from the provided providers.
316    pub fn with_failover_strategy<I>(mut self, providers: I) -> Result<Self, AiLibError>
317    where
318        I: IntoIterator<Item = Box<dyn ChatProvider>>,
319    {
320        let providers_vec: Vec<_> = providers.into_iter().collect();
321        let failover = FailoverProvider::new(providers_vec)?;
322        self.strategy = Some(Box::new(failover));
323        Ok(self)
324    }
325
326    /// Compose a round-robin strategy from built-in `Provider` variants.
327    pub fn with_round_robin_chain(mut self, providers: Vec<Provider>) -> Result<Self, AiLibError> {
328        let adapters = Self::build_strategy_chain(providers)?;
329        let rr = RoundRobinProvider::new(adapters)?;
330        self.strategy = Some(Box::new(rr));
331        Ok(self)
332    }
333
334    /// Compose a failover strategy from built-in `Provider` variants.
335    pub fn with_failover_chain(mut self, providers: Vec<Provider>) -> Result<Self, AiLibError> {
336        let adapters = Self::build_strategy_chain(providers)?;
337        let failover = FailoverProvider::new(adapters)?;
338        self.strategy = Some(Box::new(failover));
339        Ok(self)
340    }
341
342    /// Use `RoutingStrategyBuilder` to configure a round-robin strategy inline.
343    pub fn with_round_robin_builder(
344        mut self,
345        builder: RoutingStrategyBuilder,
346    ) -> Result<Self, AiLibError> {
347        let rr = builder.build_round_robin()?;
348        self.strategy = Some(Box::new(rr));
349        Ok(self)
350    }
351
352    /// Use `RoutingStrategyBuilder` to configure a failover strategy inline.
353    pub fn with_failover_builder(
354        mut self,
355        builder: RoutingStrategyBuilder,
356    ) -> Result<Self, AiLibError> {
357        let failover = builder.build_failover()?;
358        self.strategy = Some(Box::new(failover));
359        Ok(self)
360    }
361
362    /// Build AiClient instance
363    pub fn build(self) -> Result<AiClient, AiLibError> {
364        let mut builder = self;
365        // 1. Determine base_url: explicit setting > environment variable > default
366        let base_url = builder.determine_base_url()?;
367
368        // 2. Determine proxy_url: explicit setting > environment variable
369        let proxy_url = builder.determine_proxy_url();
370
371        // 3. Determine timeout: explicit setting > default
372        let timeout = builder
373            .timeout
374            .unwrap_or_else(|| std::time::Duration::from_secs(30));
375
376        // 4. Create provider strategy (custom or factory-built)
377        let transport = builder.create_custom_transport(proxy_url, timeout)?;
378        let chat_provider = if let Some(strategy) = builder.strategy.take() {
379            strategy
380        } else {
381            ProviderFactory::create_adapter(
382                builder.provider,
383                None, // API key handled by config or env
384                Some(base_url.clone()),
385                transport,
386            )?
387        };
388
389        // 5. Build backpressure controller if configured
390        let bp_ctrl: Option<Arc<BackpressureController>> = builder
391            .resilience_config
392            .backpressure
393            .as_ref()
394            .map(|cfg| Arc::new(BackpressureController::new(cfg.max_concurrent_requests)));
395
396        // 6. Create AiClient
397        let metadata = metadata_from_provider(
398            builder.provider,
399            chat_provider.name().to_string(),
400            Some(base_url.clone()),
401            None,
402            None,
403        );
404
405        let model_resolver = builder
406            .model_resolver
407            .unwrap_or_else(|| Arc::new(ModelResolver::new()));
408
409        let client = AiClient {
410            chat_provider,
411            metadata,
412            metrics: builder
413                .metrics
414                .unwrap_or_else(|| Arc::new(NoopMetrics::new())),
415            model_resolver,
416            connection_options: None,
417            custom_default_chat_model: builder.default_chat_model,
418            custom_default_multimodal_model: builder.default_multimodal_model,
419            backpressure: bp_ctrl,
420            #[cfg(feature = "interceptors")]
421            interceptor_pipeline: builder
422                .interceptor_pipeline
423                .or_else(|| builder.interceptor_builder.map(|b| b.build())),
424        };
425
426        Ok(client)
427    }
428
429    /// Determine base_url, priority: explicit setting > environment variable > default
430    fn determine_base_url(&self) -> Result<String, AiLibError> {
431        resolve_base_url(self.provider, self.base_url.clone())
432    }
433
434    fn build_strategy_chain(
435        providers: Vec<Provider>,
436    ) -> Result<Vec<Box<dyn ChatProvider>>, AiLibError> {
437        if providers.is_empty() {
438            return Err(AiLibError::ConfigurationError(
439                "routing strategy requires at least one provider".to_string(),
440            ));
441        }
442        providers
443            .into_iter()
444            .map(|provider| Self::create_adapter_from_env(provider))
445            .collect()
446    }
447
448    fn create_adapter_from_env(provider: Provider) -> Result<Box<dyn ChatProvider>, AiLibError> {
449        let opts = ConnectionOptions::default().hydrate_with_env(provider.env_prefix());
450        let resolved_base_url = resolve_base_url(provider, opts.base_url.clone())?;
451        let transport = Self::transport_from_options(&opts)?;
452        ProviderFactory::create_adapter(
453            provider,
454            opts.api_key.clone(),
455            Some(resolved_base_url),
456            transport,
457        )
458    }
459
460    fn transport_from_options(
461        opts: &ConnectionOptions,
462    ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
463        let effective_proxy = if opts.disable_proxy {
464            None
465        } else {
466            opts.proxy.clone()
467        };
468
469        if effective_proxy.is_none() && opts.timeout.is_none() {
470            return Ok(None);
471        }
472
473        let transport_config = crate::transport::HttpTransportConfig {
474            timeout: opts
475                .timeout
476                .unwrap_or_else(|| std::time::Duration::from_secs(30)),
477            proxy: effective_proxy,
478            pool_max_idle_per_host: None,
479            pool_idle_timeout: None,
480        };
481        Ok(Some(
482            crate::transport::HttpTransport::new_with_config(transport_config)?.boxed(),
483        ))
484    }
485
486    /// Determine proxy_url, priority: explicit setting > environment variable
487    fn determine_proxy_url(&self) -> Option<String> {
488        // 1. Explicitly set proxy_url
489        if let Some(ref proxy_url) = self.proxy_url {
490            // If proxy_url is empty string, it means explicitly no proxy
491            if proxy_url.is_empty() {
492                return None;
493            }
494            return Some(proxy_url.clone());
495        }
496
497        // 2. AI_PROXY_URL from environment variable
498        std::env::var("AI_PROXY_URL").ok()
499    }
500
501    /// Create custom HttpTransport
502    fn create_custom_transport(
503        &self,
504        proxy_url: Option<String>,
505        timeout: std::time::Duration,
506    ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
507        // If no custom configuration, return None (use default transport)
508        if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
509            return Ok(None);
510        }
511
512        // Create custom HttpTransportConfig
513        let transport_config = crate::transport::HttpTransportConfig {
514            timeout,
515            proxy: proxy_url,
516            pool_max_idle_per_host: self.pool_max_idle,
517            pool_idle_timeout: self.pool_idle_timeout,
518        };
519
520        // Create custom HttpTransport
521        let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
522        Ok(Some(transport.boxed()))
523    }
524}
525
526pub(crate) fn resolve_base_url(
527    provider: Provider,
528    explicit: Option<String>,
529) -> Result<String, AiLibError> {
530    if let Some(url) = explicit {
531        return Ok(url);
532    }
533
534    if let Some(env_var) = base_url_env_var(provider) {
535        if let Ok(value) = std::env::var(env_var) {
536            return Ok(value);
537        }
538    }
539
540    if provider.is_config_driven() {
541        return provider.get_default_config().map(|config| config.base_url);
542    }
543
544    default_base_url(provider)
545}
546
547fn base_url_env_var(provider: Provider) -> Option<&'static str> {
548    match provider {
549        Provider::Groq => Some("GROQ_BASE_URL"),
550        Provider::XaiGrok => Some("GROK_BASE_URL"),
551        Provider::Ollama => Some("OLLAMA_BASE_URL"),
552        Provider::DeepSeek => Some("DEEPSEEK_BASE_URL"),
553        Provider::Qwen => Some("DASHSCOPE_BASE_URL"),
554        Provider::BaiduWenxin => Some("BAIDU_WENXIN_BASE_URL"),
555        Provider::TencentHunyuan => Some("TENCENT_HUNYUAN_BASE_URL"),
556        Provider::IflytekSpark => Some("IFLYTEK_BASE_URL"),
557        Provider::Moonshot => Some("MOONSHOT_BASE_URL"),
558        Provider::Anthropic => Some("ANTHROPIC_BASE_URL"),
559        Provider::AzureOpenAI => Some("AZURE_OPENAI_BASE_URL"),
560        Provider::HuggingFace => Some("HUGGINGFACE_BASE_URL"),
561        Provider::TogetherAI => Some("TOGETHER_BASE_URL"),
562        Provider::OpenRouter => Some("OPENROUTER_BASE_URL"),
563        Provider::Replicate => Some("REPLICATE_BASE_URL"),
564        Provider::ZhipuAI => Some("ZHIPU_BASE_URL"),
565        Provider::MiniMax => Some("MINIMAX_BASE_URL"),
566        Provider::Perplexity => Some("PERPLEXITY_BASE_URL"),
567        Provider::AI21 => Some("AI21_BASE_URL"),
568        // Independent adapters use fixed endpoints
569        Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => None,
570    }
571}
572
573fn default_base_url(provider: Provider) -> Result<String, AiLibError> {
574    match provider {
575        Provider::OpenAI => Ok("https://api.openai.com".to_string()),
576        Provider::Gemini => Ok("https://generativelanguage.googleapis.com".to_string()),
577        Provider::Mistral => Ok("https://api.mistral.ai".to_string()),
578        Provider::Cohere => Ok("https://api.cohere.ai".to_string()),
579        other => Err(AiLibError::ConfigurationError(format!(
580            "Unknown provider for base URL determination: {other:?}"
581        ))),
582    }
583}