1use super::metadata::metadata_from_provider;
7use super::{AiClient, Provider, ProviderFactory};
8use crate::api::ChatProvider;
9use crate::config::{ConnectionOptions, ResilienceConfig};
10use crate::metrics::{Metrics, NoopMetrics};
11use crate::model::ModelResolver;
12use crate::provider::classification::ProviderClassification;
13use crate::provider::strategies::{FailoverProvider, RoundRobinProvider, RoutingStrategyBuilder};
14use crate::rate_limiter::BackpressureController;
15use crate::types::AiLibError;
16use std::sync::Arc;
17
18pub struct AiClientBuilder {
20 provider: Provider,
21 base_url: Option<String>,
22 proxy_url: Option<String>,
23 timeout: Option<std::time::Duration>,
24 pool_max_idle: Option<usize>,
25 pool_idle_timeout: Option<std::time::Duration>,
26 metrics: Option<Arc<dyn Metrics>>,
27 default_chat_model: Option<String>,
29 default_multimodal_model: Option<String>,
30 resilience_config: ResilienceConfig,
32 #[cfg(feature = "interceptors")]
33 interceptor_pipeline: Option<crate::interceptors::InterceptorPipeline>,
34 #[cfg(feature = "interceptors")]
35 interceptor_builder: Option<crate::interceptors::default::DefaultInterceptorsBuilder>,
36 strategy: Option<Box<dyn ChatProvider>>,
38 model_resolver: Option<Arc<ModelResolver>>,
39}
40
41impl AiClientBuilder {
42 pub fn new(provider: Provider) -> Self {
44 Self {
45 provider,
46 base_url: None,
47 proxy_url: None,
48 timeout: None,
49 pool_max_idle: None,
50 pool_idle_timeout: None,
51 metrics: None,
52 default_chat_model: None,
53 default_multimodal_model: None,
54 resilience_config: ResilienceConfig::default(),
55 #[cfg(feature = "interceptors")]
56 interceptor_pipeline: None,
57 #[cfg(feature = "interceptors")]
58 interceptor_builder: None,
59 strategy: None,
60 model_resolver: None,
61 }
62 }
63
64 pub fn with_base_url(mut self, base_url: &str) -> Self {
66 self.base_url = Some(base_url.to_string());
67 self
68 }
69
70 pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
72 self.proxy_url = proxy_url.map(|s| s.to_string());
73 self
74 }
75
76 pub fn without_proxy(mut self) -> Self {
78 self.proxy_url = Some("".to_string());
79 self
80 }
81
82 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
84 self.timeout = Some(timeout);
85 self
86 }
87
88 pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
90 self.pool_max_idle = Some(max_idle);
91 self.pool_idle_timeout = Some(idle_timeout);
92 self
93 }
94
95 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
97 self.metrics = Some(metrics);
98 self
99 }
100
101 #[cfg(feature = "interceptors")]
102 pub fn with_interceptor_pipeline(
103 mut self,
104 pipeline: crate::interceptors::InterceptorPipeline,
105 ) -> Self {
106 self.interceptor_pipeline = Some(pipeline);
107 self
108 }
109
110 #[cfg(feature = "interceptors")]
111 pub fn enable_default_interceptors(mut self) -> Self {
112 let p = crate::interceptors::create_default_interceptors();
113 self.interceptor_pipeline = Some(p);
114 self
115 }
116
117 #[cfg(feature = "interceptors")]
118 pub fn enable_minimal_interceptors(mut self) -> Self {
119 let p = crate::interceptors::default::DefaultInterceptorsBuilder::new()
120 .enable_circuit_breaker(false)
121 .enable_rate_limit(false)
122 .build();
123 self.interceptor_pipeline = Some(p);
124 self
125 }
126
127 #[cfg(feature = "interceptors")]
130 pub fn with_rate_limit(mut self, requests_per_minute: u32) -> Self {
131 let builder = self.interceptor_builder.unwrap_or_default();
132 self.interceptor_builder = Some(builder.with_rate_limit(requests_per_minute));
133 self
134 }
135
136 #[cfg(feature = "interceptors")]
137 pub fn with_circuit_breaker(mut self, threshold: u32, recovery: std::time::Duration) -> Self {
138 let builder = self.interceptor_builder.unwrap_or_default();
139 self.interceptor_builder = Some(builder.with_circuit_breaker(threshold, recovery));
140 self
141 }
142
143 #[cfg(feature = "interceptors")]
144 pub fn with_retry(
145 mut self,
146 max_attempts: u32,
147 base_delay: std::time::Duration,
148 max_delay: std::time::Duration,
149 ) -> Self {
150 let builder = self.interceptor_builder.unwrap_or_default();
151 self.interceptor_builder = Some(builder.with_retry(max_attempts, base_delay, max_delay));
152 self
153 }
154
155 #[cfg(feature = "interceptors")]
156 pub fn with_interceptor_timeout(mut self, duration: std::time::Duration) -> Self {
157 let builder = self.interceptor_builder.unwrap_or_default();
158 self.interceptor_builder = Some(builder.with_timeout(duration));
159 self
160 }
161
162 #[cfg(feature = "interceptors")]
163 pub fn enable_retry(mut self, enable: bool) -> Self {
164 let builder = self.interceptor_builder.unwrap_or_default();
165 self.interceptor_builder = Some(builder.enable_retry(enable));
166 self
167 }
168
169 #[cfg(feature = "interceptors")]
170 pub fn enable_circuit_breaker(mut self, enable: bool) -> Self {
171 let builder = self.interceptor_builder.unwrap_or_default();
172 self.interceptor_builder = Some(builder.enable_circuit_breaker(enable));
173 self
174 }
175
176 #[cfg(feature = "interceptors")]
177 pub fn enable_rate_limit(mut self, enable: bool) -> Self {
178 let builder = self.interceptor_builder.unwrap_or_default();
179 self.interceptor_builder = Some(builder.enable_rate_limit(enable));
180 self
181 }
182
183 #[cfg(feature = "interceptors")]
184 pub fn enable_interceptor_timeout(mut self, enable: bool) -> Self {
185 let builder = self.interceptor_builder.unwrap_or_default();
186 self.interceptor_builder = Some(builder.enable_timeout(enable));
187 self
188 }
189
190 pub fn with_default_chat_model(mut self, model: &str) -> Self {
192 self.default_chat_model = Some(model.to_string());
193 self
194 }
195
196 pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
198 self.default_multimodal_model = Some(model.to_string());
199 self
200 }
201
202 pub fn with_model_resolver(mut self, resolver: Arc<ModelResolver>) -> Self {
204 self.model_resolver = Some(resolver);
205 self
206 }
207
208 pub fn with_smart_defaults(mut self) -> Self {
210 self.resilience_config = ResilienceConfig::smart_defaults();
211 self
212 }
213
214 pub fn for_production(mut self) -> Self {
216 self.resilience_config = ResilienceConfig::production();
217 self
218 }
219
220 pub fn for_development(mut self) -> Self {
222 self.resilience_config = ResilienceConfig::development();
223 self
224 }
225
226 pub fn build_provider(mut self) -> Result<Box<dyn ChatProvider>, AiLibError> {
229 if let Some(p) = self.strategy.take() {
231 return Ok(p);
232 }
233
234 let base_url = self.determine_base_url()?;
236 let proxy_url = self.determine_proxy_url();
237 let timeout = self
238 .timeout
239 .unwrap_or_else(|| std::time::Duration::from_secs(30));
240 let transport = self.create_custom_transport(proxy_url, timeout)?;
241
242 ProviderFactory::create_adapter(self.provider, None, Some(base_url), transport)
243 }
244
245 pub fn with_max_concurrency(mut self, max_concurrent_requests: usize) -> Self {
247 let mut cfg = self.resilience_config.clone();
248 cfg.backpressure = Some(crate::config::BackpressureConfig {
249 max_concurrent_requests,
250 });
251 self.resilience_config = cfg;
252 self
253 }
254
255 pub fn with_resilience_config(mut self, config: ResilienceConfig) -> Self {
257 self.resilience_config = config;
258 self
259 }
260
261 pub fn with_strategy(mut self, strategy: Box<dyn ChatProvider>) -> Self {
282 self.strategy = Some(strategy);
283 self
284 }
285
286 pub fn with_round_robin_strategy<I>(mut self, providers: I) -> Result<Self, AiLibError>
306 where
307 I: IntoIterator<Item = Box<dyn ChatProvider>>,
308 {
309 let providers_vec: Vec<_> = providers.into_iter().collect();
310 let rr = RoundRobinProvider::new(providers_vec)?;
311 self.strategy = Some(Box::new(rr));
312 Ok(self)
313 }
314
315 pub fn with_failover_strategy<I>(mut self, providers: I) -> Result<Self, AiLibError>
317 where
318 I: IntoIterator<Item = Box<dyn ChatProvider>>,
319 {
320 let providers_vec: Vec<_> = providers.into_iter().collect();
321 let failover = FailoverProvider::new(providers_vec)?;
322 self.strategy = Some(Box::new(failover));
323 Ok(self)
324 }
325
326 pub fn with_round_robin_chain(mut self, providers: Vec<Provider>) -> Result<Self, AiLibError> {
328 let adapters = Self::build_strategy_chain(providers)?;
329 let rr = RoundRobinProvider::new(adapters)?;
330 self.strategy = Some(Box::new(rr));
331 Ok(self)
332 }
333
334 pub fn with_failover_chain(mut self, providers: Vec<Provider>) -> Result<Self, AiLibError> {
336 let adapters = Self::build_strategy_chain(providers)?;
337 let failover = FailoverProvider::new(adapters)?;
338 self.strategy = Some(Box::new(failover));
339 Ok(self)
340 }
341
342 pub fn with_round_robin_builder(
344 mut self,
345 builder: RoutingStrategyBuilder,
346 ) -> Result<Self, AiLibError> {
347 let rr = builder.build_round_robin()?;
348 self.strategy = Some(Box::new(rr));
349 Ok(self)
350 }
351
352 pub fn with_failover_builder(
354 mut self,
355 builder: RoutingStrategyBuilder,
356 ) -> Result<Self, AiLibError> {
357 let failover = builder.build_failover()?;
358 self.strategy = Some(Box::new(failover));
359 Ok(self)
360 }
361
362 pub fn build(self) -> Result<AiClient, AiLibError> {
364 let mut builder = self;
365 let base_url = builder.determine_base_url()?;
367
368 let proxy_url = builder.determine_proxy_url();
370
371 let timeout = builder
373 .timeout
374 .unwrap_or_else(|| std::time::Duration::from_secs(30));
375
376 let transport = builder.create_custom_transport(proxy_url, timeout)?;
378 let chat_provider = if let Some(strategy) = builder.strategy.take() {
379 strategy
380 } else {
381 ProviderFactory::create_adapter(
382 builder.provider,
383 None, Some(base_url.clone()),
385 transport,
386 )?
387 };
388
389 let bp_ctrl: Option<Arc<BackpressureController>> = builder
391 .resilience_config
392 .backpressure
393 .as_ref()
394 .map(|cfg| Arc::new(BackpressureController::new(cfg.max_concurrent_requests)));
395
396 let metadata = metadata_from_provider(
398 builder.provider,
399 chat_provider.name().to_string(),
400 Some(base_url.clone()),
401 None,
402 None,
403 );
404
405 let model_resolver = builder
406 .model_resolver
407 .unwrap_or_else(|| Arc::new(ModelResolver::new()));
408
409 let client = AiClient {
410 chat_provider,
411 metadata,
412 metrics: builder
413 .metrics
414 .unwrap_or_else(|| Arc::new(NoopMetrics::new())),
415 model_resolver,
416 connection_options: None,
417 custom_default_chat_model: builder.default_chat_model,
418 custom_default_multimodal_model: builder.default_multimodal_model,
419 backpressure: bp_ctrl,
420 #[cfg(feature = "interceptors")]
421 interceptor_pipeline: builder
422 .interceptor_pipeline
423 .or_else(|| builder.interceptor_builder.map(|b| b.build())),
424 };
425
426 Ok(client)
427 }
428
429 fn determine_base_url(&self) -> Result<String, AiLibError> {
431 resolve_base_url(self.provider, self.base_url.clone())
432 }
433
434 fn build_strategy_chain(
435 providers: Vec<Provider>,
436 ) -> Result<Vec<Box<dyn ChatProvider>>, AiLibError> {
437 if providers.is_empty() {
438 return Err(AiLibError::ConfigurationError(
439 "routing strategy requires at least one provider".to_string(),
440 ));
441 }
442 providers
443 .into_iter()
444 .map(|provider| Self::create_adapter_from_env(provider))
445 .collect()
446 }
447
448 fn create_adapter_from_env(provider: Provider) -> Result<Box<dyn ChatProvider>, AiLibError> {
449 let opts = ConnectionOptions::default().hydrate_with_env(provider.env_prefix());
450 let resolved_base_url = resolve_base_url(provider, opts.base_url.clone())?;
451 let transport = Self::transport_from_options(&opts)?;
452 ProviderFactory::create_adapter(
453 provider,
454 opts.api_key.clone(),
455 Some(resolved_base_url),
456 transport,
457 )
458 }
459
460 fn transport_from_options(
461 opts: &ConnectionOptions,
462 ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
463 let effective_proxy = if opts.disable_proxy {
464 None
465 } else {
466 opts.proxy.clone()
467 };
468
469 if effective_proxy.is_none() && opts.timeout.is_none() {
470 return Ok(None);
471 }
472
473 let transport_config = crate::transport::HttpTransportConfig {
474 timeout: opts
475 .timeout
476 .unwrap_or_else(|| std::time::Duration::from_secs(30)),
477 proxy: effective_proxy,
478 pool_max_idle_per_host: None,
479 pool_idle_timeout: None,
480 };
481 Ok(Some(
482 crate::transport::HttpTransport::new_with_config(transport_config)?.boxed(),
483 ))
484 }
485
486 fn determine_proxy_url(&self) -> Option<String> {
488 if let Some(ref proxy_url) = self.proxy_url {
490 if proxy_url.is_empty() {
492 return None;
493 }
494 return Some(proxy_url.clone());
495 }
496
497 std::env::var("AI_PROXY_URL").ok()
499 }
500
501 fn create_custom_transport(
503 &self,
504 proxy_url: Option<String>,
505 timeout: std::time::Duration,
506 ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
507 if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
509 return Ok(None);
510 }
511
512 let transport_config = crate::transport::HttpTransportConfig {
514 timeout,
515 proxy: proxy_url,
516 pool_max_idle_per_host: self.pool_max_idle,
517 pool_idle_timeout: self.pool_idle_timeout,
518 };
519
520 let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
522 Ok(Some(transport.boxed()))
523 }
524}
525
526pub(crate) fn resolve_base_url(
527 provider: Provider,
528 explicit: Option<String>,
529) -> Result<String, AiLibError> {
530 if let Some(url) = explicit {
531 return Ok(url);
532 }
533
534 if let Some(env_var) = base_url_env_var(provider) {
535 if let Ok(value) = std::env::var(env_var) {
536 return Ok(value);
537 }
538 }
539
540 if provider.is_config_driven() {
541 return provider.get_default_config().map(|config| config.base_url);
542 }
543
544 default_base_url(provider)
545}
546
547fn base_url_env_var(provider: Provider) -> Option<&'static str> {
548 match provider {
549 Provider::Groq => Some("GROQ_BASE_URL"),
550 Provider::XaiGrok => Some("GROK_BASE_URL"),
551 Provider::Ollama => Some("OLLAMA_BASE_URL"),
552 Provider::DeepSeek => Some("DEEPSEEK_BASE_URL"),
553 Provider::Qwen => Some("DASHSCOPE_BASE_URL"),
554 Provider::BaiduWenxin => Some("BAIDU_WENXIN_BASE_URL"),
555 Provider::TencentHunyuan => Some("TENCENT_HUNYUAN_BASE_URL"),
556 Provider::IflytekSpark => Some("IFLYTEK_BASE_URL"),
557 Provider::Moonshot => Some("MOONSHOT_BASE_URL"),
558 Provider::Anthropic => Some("ANTHROPIC_BASE_URL"),
559 Provider::AzureOpenAI => Some("AZURE_OPENAI_BASE_URL"),
560 Provider::HuggingFace => Some("HUGGINGFACE_BASE_URL"),
561 Provider::TogetherAI => Some("TOGETHER_BASE_URL"),
562 Provider::OpenRouter => Some("OPENROUTER_BASE_URL"),
563 Provider::Replicate => Some("REPLICATE_BASE_URL"),
564 Provider::ZhipuAI => Some("ZHIPU_BASE_URL"),
565 Provider::MiniMax => Some("MINIMAX_BASE_URL"),
566 Provider::Perplexity => Some("PERPLEXITY_BASE_URL"),
567 Provider::AI21 => Some("AI21_BASE_URL"),
568 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => None,
570 }
571}
572
573fn default_base_url(provider: Provider) -> Result<String, AiLibError> {
574 match provider {
575 Provider::OpenAI => Ok("https://api.openai.com".to_string()),
576 Provider::Gemini => Ok("https://generativelanguage.googleapis.com".to_string()),
577 Provider::Mistral => Ok("https://api.mistral.ai".to_string()),
578 Provider::Cohere => Ok("https://api.cohere.ai".to_string()),
579 other => Err(AiLibError::ConfigurationError(format!(
580 "Unknown provider for base URL determination: {other:?}"
581 ))),
582 }
583}