ai_lib/client.rs
1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::config::{ConnectionOptions, ResilienceConfig};
3use crate::metrics::{Metrics, NoopMetrics};
4use crate::provider::{
5 classification::ProviderClassification,
6 CohereAdapter, GeminiAdapter, GenericAdapter, MistralAdapter, OpenAiAdapter, ProviderConfigs,
7};
8use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
9use futures::stream::Stream;
10use futures::Future;
11use std::sync::Arc;
12use tokio::sync::oneshot;
13
14/// Model configuration options for explicit model selection
15#[derive(Debug, Clone)]
16pub struct ModelOptions {
17 pub chat_model: Option<String>,
18 pub multimodal_model: Option<String>,
19 pub fallback_models: Vec<String>,
20 pub auto_discovery: bool,
21}
22
23impl Default for ModelOptions {
24 fn default() -> Self {
25 Self { chat_model: None, multimodal_model: None, fallback_models: Vec::new(), auto_discovery: true }
26 }
27}
28
29impl ModelOptions {
30 /// Create default model options
31 pub fn new() -> Self { Self::default() }
32
33 /// Set chat model
34 pub fn with_chat_model(mut self, model: &str) -> Self {
35 self.chat_model = Some(model.to_string());
36 self
37 }
38
39 /// Set multimodal model
40 pub fn with_multimodal_model(mut self, model: &str) -> Self {
41 self.multimodal_model = Some(model.to_string());
42 self
43 }
44
45 /// Set fallback models
46 pub fn with_fallback_models(mut self, models: Vec<&str>) -> Self {
47 self.fallback_models = models.into_iter().map(|s| s.to_string()).collect();
48 self
49 }
50
51 /// Enable or disable auto discovery
52 pub fn with_auto_discovery(mut self, enabled: bool) -> Self {
53 self.auto_discovery = enabled;
54 self
55 }
56}
57
58/// Helper function to create GenericAdapter with optional custom transport
59fn create_generic_adapter(
60 config: crate::provider::config::ProviderConfig,
61 transport: Option<crate::transport::DynHttpTransportRef>,
62) -> Result<Box<dyn ChatApi>, AiLibError> {
63 if let Some(custom_transport) = transport {
64 Ok(Box::new(GenericAdapter::with_transport_ref(config, custom_transport)?))
65 } else {
66 Ok(Box::new(GenericAdapter::new(config)?))
67 }
68}
69
70/// Unified AI client module
71///
72/// AI model provider enumeration
73#[derive(Debug, Clone, Copy, PartialEq)]
74pub enum Provider {
75 // Config-driven providers
76 Groq,
77 XaiGrok,
78 Ollama,
79 DeepSeek,
80 Anthropic,
81 AzureOpenAI,
82 HuggingFace,
83 TogetherAI,
84 // Chinese providers (OpenAI-compatible or config-driven)
85 BaiduWenxin,
86 TencentHunyuan,
87 IflytekSpark,
88 Moonshot,
89 // Independent adapters
90 OpenAI,
91 Qwen,
92 Gemini,
93 Mistral,
94 Cohere,
95 // Bedrock removed (deferred)
96}
97
98impl Provider {
99 /// Get the provider's preferred default chat model.
100 /// These should mirror the values used inside `ProviderConfigs`.
101 pub fn default_chat_model(&self) -> &'static str {
102 match self {
103 Provider::Groq => "llama-3.1-8b-instant",
104 Provider::XaiGrok => "grok-beta",
105 Provider::Ollama => "llama3-8b",
106 Provider::DeepSeek => "deepseek-chat",
107 Provider::Anthropic => "claude-3-5-sonnet-20241022",
108 Provider::AzureOpenAI => "gpt-35-turbo",
109 Provider::HuggingFace => "microsoft/DialoGPT-medium",
110 Provider::TogetherAI => "meta-llama/Llama-3-8b-chat-hf",
111 Provider::BaiduWenxin => "ernie-3.5",
112 Provider::TencentHunyuan => "hunyuan-standard",
113 Provider::IflytekSpark => "spark-v3.0",
114 Provider::Moonshot => "moonshot-v1-8k",
115 Provider::OpenAI => "gpt-3.5-turbo",
116 Provider::Qwen => "qwen-turbo",
117 Provider::Gemini => "gemini-pro", // widely used text/chat model
118 Provider::Mistral => "mistral-small", // generic default
119 Provider::Cohere => "command-r", // chat-capable model
120 }
121 }
122
123 /// Get the provider's preferred multimodal model (if any).
124 pub fn default_multimodal_model(&self) -> Option<&'static str> {
125 match self {
126 Provider::OpenAI => Some("gpt-4o"),
127 Provider::AzureOpenAI => Some("gpt-4o"),
128 Provider::Anthropic => Some("claude-3-5-sonnet-20241022"),
129 Provider::Groq => None, // No multimodal model currently available
130 Provider::Gemini => Some("gemini-1.5-flash"),
131 Provider::Cohere => Some("command-r-plus"),
132 // Others presently have no clearly documented multimodal endpoint or are not yet wired.
133 _ => None,
134 }
135 }
136}
137
138/// Unified AI client
139///
140/// Usage example:
141/// ```rust
142/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
143///
144/// #[tokio::main]
145/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
146/// // Switch model provider by changing Provider value
147/// let client = AiClient::new(Provider::Groq)?;
148///
149/// let request = ChatCompletionRequest::new(
150/// "test-model".to_string(),
151/// vec![Message {
152/// role: Role::User,
153/// content: ai_lib::types::common::Content::Text("Hello".to_string()),
154/// function_call: None,
155/// }],
156/// );
157///
158/// // Note: Set GROQ_API_KEY environment variable for actual API calls
159/// // Optional: Set AI_PROXY_URL environment variable to use proxy server
160/// // let response = client.chat_completion(request).await?;
161///
162/// println!("Client created successfully with provider: {:?}", client.current_provider());
163/// println!("Request prepared for model: {}", request.model);
164///
165/// Ok(())
166/// }
167/// ```
168///
169/// # Proxy Configuration
170///
171/// Configure proxy server by setting the `AI_PROXY_URL` environment variable:
172///
173/// ```bash
174/// export AI_PROXY_URL=http://proxy.example.com:8080
175/// ```
176///
177/// Supported proxy formats:
178/// - HTTP proxy: `http://proxy.example.com:8080`
179/// - HTTPS proxy: `https://proxy.example.com:8080`
180/// - With authentication: `http://user:pass@proxy.example.com:8080`
181pub struct AiClient {
182 provider: Provider,
183 adapter: Box<dyn ChatApi>,
184 metrics: Arc<dyn Metrics>,
185 connection_options: Option<ConnectionOptions>,
186 // Custom default models (override provider defaults)
187 custom_default_chat_model: Option<String>,
188 custom_default_multimodal_model: Option<String>,
189 #[cfg(feature = "routing_mvp")]
190 routing_array: Option<crate::provider::models::ModelArray>,
191}
192
193impl AiClient {
194 /// Get the effective default chat model for this client (honors custom override)
195 pub fn default_chat_model(&self) -> String {
196 self
197 .custom_default_chat_model
198 .clone()
199 .unwrap_or_else(|| self.provider.default_chat_model().to_string())
200 }
201 /// Create a new AI client
202 ///
203 /// # Arguments
204 /// * `provider` - The AI model provider to use
205 ///
206 /// # Returns
207 /// * `Result<Self, AiLibError>` - Client instance on success, error on failure
208 ///
209 /// # Example
210 /// ```rust
211 /// use ai_lib::{AiClient, Provider};
212 ///
213 /// let client = AiClient::new(Provider::Groq)?;
214 /// # Ok::<(), ai_lib::AiLibError>(())
215 /// ```
216 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
217 // Use the new builder to create client with automatic environment variable detection
218 let mut c = AiClientBuilder::new(provider).build()?;
219 c.connection_options = None;
220 Ok(c)
221 }
222
223 #[cfg(feature = "routing_mvp")]
224 /// Set a routing model array to enable basic endpoint selection before requests.
225 pub fn with_routing_array(mut self, array: crate::provider::models::ModelArray) -> Self {
226 self.routing_array = Some(array);
227 self
228 }
229
230 // #[cfg(feature = "routing_mvp")]
231 // fn select_routed_model(&mut self, fallback: &str) -> String {
232 // if let Some(arr) = self.routing_array.as_mut() {
233 // if let Some(ep) = arr.select_endpoint() {
234 // return ep.model_name.clone();
235 // }
236 // }
237 // fallback.to_string()
238 // }
239
240 /// Create client with minimal explicit options (base_url/proxy/timeout). Not all providers
241 /// support overrides; unsupported providers ignore unspecified fields gracefully.
242 pub fn with_options(provider: Provider, opts: ConnectionOptions) -> Result<Self, AiLibError> {
243 let config_driven = provider.is_config_driven();
244 let need_builder = config_driven
245 && (opts.base_url.is_some()
246 || opts.proxy.is_some()
247 || opts.timeout.is_some()
248 || opts.disable_proxy);
249 if need_builder {
250 let mut b = AiClient::builder(provider);
251 if let Some(ref base) = opts.base_url {
252 b = b.with_base_url(base);
253 }
254 if opts.disable_proxy {
255 b = b.without_proxy();
256 } else if let Some(ref proxy) = opts.proxy {
257 if proxy.is_empty() {
258 b = b.without_proxy();
259 } else {
260 b = b.with_proxy(Some(proxy));
261 }
262 }
263 if let Some(t) = opts.timeout {
264 b = b.with_timeout(t);
265 }
266 let mut client = b.build()?;
267 // If api_key override + generic provider path: re-wrap adapter using override
268 if opts.api_key.is_some() {
269 // Only applies to config-driven generic adapter providers
270 let new_adapter: Option<Box<dyn ChatApi>> = match provider {
271 Provider::Groq => Some(Box::new(GenericAdapter::new_with_api_key(
272 ProviderConfigs::groq(),
273 opts.api_key.clone(),
274 )?)),
275 Provider::XaiGrok => Some(Box::new(GenericAdapter::new_with_api_key(
276 ProviderConfigs::xai_grok(),
277 opts.api_key.clone(),
278 )?)),
279 Provider::Ollama => Some(Box::new(GenericAdapter::new_with_api_key(
280 ProviderConfigs::ollama(),
281 opts.api_key.clone(),
282 )?)),
283 Provider::DeepSeek => Some(Box::new(GenericAdapter::new_with_api_key(
284 ProviderConfigs::deepseek(),
285 opts.api_key.clone(),
286 )?)),
287 Provider::Qwen => Some(Box::new(GenericAdapter::new_with_api_key(
288 ProviderConfigs::qwen(),
289 opts.api_key.clone(),
290 )?)),
291 Provider::BaiduWenxin => Some(Box::new(GenericAdapter::new_with_api_key(
292 ProviderConfigs::baidu_wenxin(),
293 opts.api_key.clone(),
294 )?)),
295 Provider::TencentHunyuan => Some(Box::new(GenericAdapter::new_with_api_key(
296 ProviderConfigs::tencent_hunyuan(),
297 opts.api_key.clone(),
298 )?)),
299 Provider::IflytekSpark => Some(Box::new(GenericAdapter::new_with_api_key(
300 ProviderConfigs::iflytek_spark(),
301 opts.api_key.clone(),
302 )?)),
303 Provider::Moonshot => Some(Box::new(GenericAdapter::new_with_api_key(
304 ProviderConfigs::moonshot(),
305 opts.api_key.clone(),
306 )?)),
307 Provider::Anthropic => Some(Box::new(GenericAdapter::new_with_api_key(
308 ProviderConfigs::anthropic(),
309 opts.api_key.clone(),
310 )?)),
311 Provider::AzureOpenAI => Some(Box::new(GenericAdapter::new_with_api_key(
312 ProviderConfigs::azure_openai(),
313 opts.api_key.clone(),
314 )?)),
315 Provider::HuggingFace => Some(Box::new(GenericAdapter::new_with_api_key(
316 ProviderConfigs::huggingface(),
317 opts.api_key.clone(),
318 )?)),
319 Provider::TogetherAI => Some(Box::new(GenericAdapter::new_with_api_key(
320 ProviderConfigs::together_ai(),
321 opts.api_key.clone(),
322 )?)),
323 _ => None,
324 };
325 if let Some(a) = new_adapter {
326 client.adapter = a;
327 }
328 }
329 client.connection_options = Some(opts);
330 return Ok(client);
331 }
332
333 // Independent adapters: OpenAI / Gemini / Mistral / Cohere
334 if provider.is_independent() {
335 let adapter: Box<dyn ChatApi> = match provider {
336 Provider::OpenAI => {
337 if let Some(ref k) = opts.api_key {
338 let inner =
339 OpenAiAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
340 Box::new(inner)
341 } else {
342 let inner = OpenAiAdapter::new()?;
343 Box::new(inner)
344 }
345 }
346 Provider::Gemini => {
347 if let Some(ref k) = opts.api_key {
348 let inner =
349 GeminiAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
350 Box::new(inner)
351 } else {
352 let inner = GeminiAdapter::new()?;
353 Box::new(inner)
354 }
355 }
356 Provider::Mistral => {
357 if opts.api_key.is_some() || opts.base_url.is_some() {
358 let inner = MistralAdapter::new_with_overrides(
359 opts.api_key.clone(),
360 opts.base_url.clone(),
361 )?;
362 Box::new(inner)
363 } else {
364 let inner = MistralAdapter::new()?;
365 Box::new(inner)
366 }
367 }
368 Provider::Cohere => {
369 if let Some(ref k) = opts.api_key {
370 let inner =
371 CohereAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
372 Box::new(inner)
373 } else {
374 let inner = CohereAdapter::new()?;
375 Box::new(inner)
376 }
377 }
378 _ => unreachable!(),
379 };
380 return Ok(AiClient {
381 provider,
382 adapter,
383 metrics: Arc::new(NoopMetrics::new()),
384 connection_options: Some(opts),
385 custom_default_chat_model: None,
386 custom_default_multimodal_model: None,
387 #[cfg(feature = "routing_mvp")]
388 routing_array: None,
389 });
390 }
391
392 // Simple config-driven without overrides
393 let mut client = AiClient::new(provider)?;
394 if let Some(ref k) = opts.api_key {
395 let override_adapter: Option<Box<dyn ChatApi>> = match provider {
396 Provider::Groq => Some(Box::new(GenericAdapter::new_with_api_key(
397 ProviderConfigs::groq(),
398 Some(k.clone()),
399 )?)),
400 Provider::XaiGrok => Some(Box::new(GenericAdapter::new_with_api_key(
401 ProviderConfigs::xai_grok(),
402 Some(k.clone()),
403 )?)),
404 Provider::Ollama => Some(Box::new(GenericAdapter::new_with_api_key(
405 ProviderConfigs::ollama(),
406 Some(k.clone()),
407 )?)),
408 Provider::DeepSeek => Some(Box::new(GenericAdapter::new_with_api_key(
409 ProviderConfigs::deepseek(),
410 Some(k.clone()),
411 )?)),
412 Provider::Qwen => Some(Box::new(GenericAdapter::new_with_api_key(
413 ProviderConfigs::qwen(),
414 Some(k.clone()),
415 )?)),
416 Provider::BaiduWenxin => Some(Box::new(GenericAdapter::new_with_api_key(
417 ProviderConfigs::baidu_wenxin(),
418 Some(k.clone()),
419 )?)),
420 Provider::TencentHunyuan => Some(Box::new(GenericAdapter::new_with_api_key(
421 ProviderConfigs::tencent_hunyuan(),
422 Some(k.clone()),
423 )?)),
424 Provider::IflytekSpark => Some(Box::new(GenericAdapter::new_with_api_key(
425 ProviderConfigs::iflytek_spark(),
426 Some(k.clone()),
427 )?)),
428 Provider::Moonshot => Some(Box::new(GenericAdapter::new_with_api_key(
429 ProviderConfigs::moonshot(),
430 Some(k.clone()),
431 )?)),
432 Provider::Anthropic => Some(Box::new(GenericAdapter::new_with_api_key(
433 ProviderConfigs::anthropic(),
434 Some(k.clone()),
435 )?)),
436 Provider::AzureOpenAI => Some(Box::new(GenericAdapter::new_with_api_key(
437 ProviderConfigs::azure_openai(),
438 Some(k.clone()),
439 )?)),
440 Provider::HuggingFace => Some(Box::new(GenericAdapter::new_with_api_key(
441 ProviderConfigs::huggingface(),
442 Some(k.clone()),
443 )?)),
444 Provider::TogetherAI => Some(Box::new(GenericAdapter::new_with_api_key(
445 ProviderConfigs::together_ai(),
446 Some(k.clone()),
447 )?)),
448 _ => None,
449 };
450 if let Some(a) = override_adapter {
451 client.adapter = a;
452 }
453 }
454 client.connection_options = Some(opts);
455 Ok(client)
456 }
457
458 pub fn connection_options(&self) -> Option<&ConnectionOptions> {
459 self.connection_options.as_ref()
460 }
461
462 /// Create a new AI client builder
463 ///
464 /// The builder pattern allows more flexible client configuration:
465 /// - Automatic environment variable detection
466 /// - Support for custom base_url and proxy
467 /// - Support for custom timeout and connection pool configuration
468 ///
469 /// # Arguments
470 /// * `provider` - The AI model provider to use
471 ///
472 /// # Returns
473 /// * `AiClientBuilder` - Builder instance
474 ///
475 /// # Example
476 /// ```rust
477 /// use ai_lib::{AiClient, Provider};
478 ///
479 /// // Simplest usage - automatic environment variable detection
480 /// let client = AiClient::builder(Provider::Groq).build()?;
481 ///
482 /// // Custom base_url and proxy
483 /// let client = AiClient::builder(Provider::Groq)
484 /// .with_base_url("https://custom.groq.com")
485 /// .with_proxy(Some("http://proxy.example.com:8080"))
486 /// .build()?;
487 /// # Ok::<(), ai_lib::AiLibError>(())
488 /// ```
489 pub fn builder(provider: Provider) -> AiClientBuilder {
490 AiClientBuilder::new(provider)
491 }
492
493 /// Create AiClient with injected metrics implementation
494 pub fn new_with_metrics(
495 provider: Provider,
496 metrics: Arc<dyn Metrics>,
497 ) -> Result<Self, AiLibError> {
498 let adapter: Box<dyn ChatApi> = match provider {
499 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
500 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
501 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
502 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
503 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
504 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
505 Provider::BaiduWenxin => {
506 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
507 }
508 Provider::TencentHunyuan => {
509 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
510 }
511 Provider::IflytekSpark => {
512 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
513 }
514 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
515 Provider::AzureOpenAI => {
516 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
517 }
518 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
519 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
520 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
521 Provider::Gemini => Box::new(GeminiAdapter::new()?),
522 Provider::Mistral => Box::new(MistralAdapter::new()?),
523 Provider::Cohere => Box::new(CohereAdapter::new()?),
524 };
525
526 Ok(Self {
527 provider,
528 adapter,
529 metrics,
530 connection_options: None,
531 custom_default_chat_model: None,
532 custom_default_multimodal_model: None,
533 #[cfg(feature = "routing_mvp")]
534 routing_array: None,
535 })
536 }
537
538 /// Set metrics implementation on client
539 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
540 self.metrics = metrics;
541 self
542 }
543
544 /// Send chat completion request
545 ///
546 /// # Arguments
547 /// * `request` - Chat completion request
548 ///
549 /// # Returns
550 /// * `Result<ChatCompletionResponse, AiLibError>` - Response on success, error on failure
551 pub async fn chat_completion(
552 &self,
553 request: ChatCompletionRequest,
554 ) -> Result<ChatCompletionResponse, AiLibError> {
555 #[cfg(feature = "routing_mvp")]
556 {
557 // If request.model equals a sentinel like "__route__", pick from routing array
558 if request.model == "__route__" {
559 let _ = self.metrics.incr_counter("routing_mvp.request", 1).await;
560 let mut chosen = self.provider.default_chat_model().to_string();
561 if let Some(arr) = &self.routing_array {
562 let mut arr_clone = arr.clone();
563 if let Some(ep) = arr_clone.select_endpoint() {
564 match crate::provider::utils::health_check(&ep.url).await {
565 Ok(()) => {
566 let _ = self.metrics.incr_counter("routing_mvp.selected", 1).await;
567 chosen = ep.model_name.clone();
568 }
569 Err(_) => {
570 let _ = self.metrics.incr_counter("routing_mvp.health_fail", 1).await;
571 chosen = self.provider.default_chat_model().to_string();
572 let _ = self.metrics.incr_counter("routing_mvp.fallback_default", 1).await;
573 }
574 }
575 } else {
576 let _ = self.metrics.incr_counter("routing_mvp.no_endpoint", 1).await;
577 }
578 } else {
579 let _ = self.metrics.incr_counter("routing_mvp.missing_array", 1).await;
580 }
581 let mut req2 = request;
582 req2.model = chosen;
583 return self.adapter.chat_completion(req2).await;
584 }
585 }
586 self.adapter.chat_completion(request).await
587 }
588
589 /// Streaming chat completion request
590 ///
591 /// # Arguments
592 /// * `request` - Chat completion request
593 ///
594 /// # Returns
595 /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - Stream response on success
596 pub async fn chat_completion_stream(
597 &self,
598 mut request: ChatCompletionRequest,
599 ) -> Result<
600 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
601 AiLibError,
602 > {
603 request.stream = Some(true);
604 self.adapter.chat_completion_stream(request).await
605 }
606
607 /// Streaming chat completion request with cancel control
608 ///
609 /// # Arguments
610 /// * `request` - Chat completion request
611 ///
612 /// # Returns
613 /// * `Result<(impl Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin, CancelHandle), AiLibError>` - Returns streaming response and cancel handle on success
614 pub async fn chat_completion_stream_with_cancel(
615 &self,
616 mut request: ChatCompletionRequest,
617 ) -> Result<
618 (
619 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
620 CancelHandle,
621 ),
622 AiLibError,
623 > {
624 request.stream = Some(true);
625 let stream = self.adapter.chat_completion_stream(request).await?;
626 let (cancel_tx, cancel_rx) = oneshot::channel();
627 let cancel_handle = CancelHandle {
628 sender: Some(cancel_tx),
629 };
630
631 let controlled_stream = ControlledStream::new(stream, cancel_rx);
632 Ok((Box::new(controlled_stream), cancel_handle))
633 }
634
635 /// Batch chat completion requests
636 ///
637 /// # Arguments
638 /// * `requests` - List of chat completion requests
639 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
640 ///
641 /// # Returns
642 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
643 ///
644 /// # Example
645 /// ```rust
646 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
647 /// use ai_lib::types::common::Content;
648 ///
649 /// #[tokio::main]
650 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
651 /// let client = AiClient::new(Provider::Groq)?;
652 ///
653 /// let requests = vec![
654 /// ChatCompletionRequest::new(
655 /// "llama3-8b-8192".to_string(),
656 /// vec![Message {
657 /// role: Role::User,
658 /// content: Content::Text("Hello".to_string()),
659 /// function_call: None,
660 /// }],
661 /// ),
662 /// ChatCompletionRequest::new(
663 /// "llama3-8b-8192".to_string(),
664 /// vec![Message {
665 /// role: Role::User,
666 /// content: Content::Text("How are you?".to_string()),
667 /// function_call: None,
668 /// }],
669 /// ),
670 /// ];
671 ///
672 /// // Limit concurrency to 5
673 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
674 ///
675 /// for (i, response) in responses.iter().enumerate() {
676 /// match response {
677 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
678 /// Err(e) => println!("Request {} failed: {}", i, e),
679 /// }
680 /// }
681 ///
682 /// Ok(())
683 /// }
684 /// ```
685 pub async fn chat_completion_batch(
686 &self,
687 requests: Vec<ChatCompletionRequest>,
688 concurrency_limit: Option<usize>,
689 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
690 self.adapter
691 .chat_completion_batch(requests, concurrency_limit)
692 .await
693 }
694
695 /// Smart batch processing: automatically choose processing strategy based on request count
696 ///
697 /// # Arguments
698 /// * `requests` - List of chat completion requests
699 ///
700 /// # Returns
701 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
702 pub async fn chat_completion_batch_smart(
703 &self,
704 requests: Vec<ChatCompletionRequest>,
705 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
706 // Use sequential processing for small batches, concurrent processing for large batches
707 let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
708 self.chat_completion_batch(requests, concurrency_limit)
709 .await
710 }
711
712 /// Batch chat completion requests
713 ///
714 /// # Arguments
715 /// * `requests` - List of chat completion requests
716 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
717 ///
718 /// # Returns
719 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
720 ///
721 /// # Example
722 /// ```rust
723 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
724 /// use ai_lib::types::common::Content;
725 ///
726 /// #[tokio::main]
727 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
728 /// let client = AiClient::new(Provider::Groq)?;
729 ///
730 /// let requests = vec![
731 /// ChatCompletionRequest::new(
732 /// "llama3-8b-8192".to_string(),
733 /// vec![Message {
734 /// role: Role::User,
735 /// content: Content::Text("Hello".to_string()),
736 /// function_call: None,
737 /// }],
738 /// ),
739 /// ChatCompletionRequest::new(
740 /// "llama3-8b-8192".to_string(),
741 /// vec![Message {
742 /// role: Role::User,
743 /// content: Content::Text("How are you?".to_string()),
744 /// function_call: None,
745 /// }],
746 /// ),
747 /// ];
748 ///
749 /// // Limit concurrency to 5
750 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
751 ///
752 /// for (i, response) in responses.iter().enumerate() {
753 /// match response {
754 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
755 /// Err(e) => println!("Request {} failed: {}", i, e),
756 /// }
757 /// }
758 ///
759 /// Ok(())
760 /// }
761 /// ```
762 ///
763 /// Get list of supported models
764 ///
765 /// # Returns
766 /// * `Result<Vec<String>, AiLibError>` - Returns model list on success, error on failure
767 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
768 self.adapter.list_models().await
769 }
770
771 /// Switch AI model provider
772 ///
773 /// # Arguments
774 /// * `provider` - New provider
775 ///
776 /// # Returns
777 /// * `Result<(), AiLibError>` - Returns () on success, error on failure
778 ///
779 /// # Example
780 /// ```rust
781 /// use ai_lib::{AiClient, Provider};
782 ///
783 /// let mut client = AiClient::new(Provider::Groq)?;
784 /// // Switch from Groq to Groq (demonstrating switch functionality)
785 /// client.switch_provider(Provider::Groq)?;
786 /// # Ok::<(), ai_lib::AiLibError>(())
787 /// ```
788 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
789 let new_adapter: Box<dyn ChatApi> = match provider {
790 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
791 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
792 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
793 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
794 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
795 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
796 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
797 Provider::BaiduWenxin => {
798 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
799 }
800 Provider::TencentHunyuan => {
801 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
802 }
803 Provider::IflytekSpark => {
804 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
805 }
806 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
807 Provider::Gemini => Box::new(GeminiAdapter::new()?),
808 Provider::AzureOpenAI => {
809 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
810 }
811 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
812 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
813 Provider::Mistral => Box::new(MistralAdapter::new()?),
814 Provider::Cohere => Box::new(CohereAdapter::new()?),
815 // Provider::Bedrock => Box::new(BedrockAdapter::new()?),
816 };
817
818 self.provider = provider;
819 self.adapter = new_adapter;
820 Ok(())
821 }
822
823 /// Get current provider
824 pub fn current_provider(&self) -> Provider {
825 self.provider
826 }
827
828 /// Convenience helper: construct a request with the provider's default chat model.
829 /// This does NOT send the request.
830 /// Uses custom default model if set via AiClientBuilder, otherwise uses provider default.
831 pub fn build_simple_request<S: Into<String>>(&self, prompt: S) -> ChatCompletionRequest {
832 let model = self.custom_default_chat_model
833 .clone()
834 .unwrap_or_else(|| self.provider.default_chat_model().to_string());
835
836 ChatCompletionRequest::new(
837 model,
838 vec![crate::types::Message {
839 role: crate::types::Role::User,
840 content: crate::types::common::Content::Text(prompt.into()),
841 function_call: None,
842 }],
843 )
844 }
845
846 /// Convenience helper: construct a request with an explicitly specified chat model.
847 /// This does NOT send the request.
848 pub fn build_simple_request_with_model<S: Into<String>>(
849 &self,
850 prompt: S,
851 model: S
852 ) -> ChatCompletionRequest {
853 ChatCompletionRequest::new(
854 model.into(),
855 vec![crate::types::Message {
856 role: crate::types::Role::User,
857 content: crate::types::common::Content::Text(prompt.into()),
858 function_call: None,
859 }],
860 )
861 }
862
863 /// Convenience helper: construct a request with the provider's default multimodal model.
864 /// This does NOT send the request.
865 /// Uses custom default model if set via AiClientBuilder, otherwise uses provider default.
866 pub fn build_multimodal_request<S: Into<String>>(
867 &self,
868 prompt: S
869 ) -> Result<ChatCompletionRequest, AiLibError> {
870 let model = self.custom_default_multimodal_model
871 .clone()
872 .or_else(|| self.provider.default_multimodal_model().map(|s| s.to_string()))
873 .ok_or_else(|| AiLibError::ConfigurationError(
874 format!("No multimodal model available for provider {:?}", self.provider)
875 ))?;
876
877 Ok(ChatCompletionRequest::new(
878 model,
879 vec![crate::types::Message {
880 role: crate::types::Role::User,
881 content: crate::types::common::Content::Text(prompt.into()),
882 function_call: None,
883 }],
884 ))
885 }
886
887 /// Convenience helper: construct a request with an explicitly specified multimodal model.
888 /// This does NOT send the request.
889 pub fn build_multimodal_request_with_model<S: Into<String>>(
890 &self,
891 prompt: S,
892 model: S
893 ) -> ChatCompletionRequest {
894 ChatCompletionRequest::new(
895 model.into(),
896 vec![crate::types::Message {
897 role: crate::types::Role::User,
898 content: crate::types::common::Content::Text(prompt.into()),
899 function_call: None,
900 }],
901 )
902 }
903
904 /// One-shot helper: create a client for `provider`, send a single user prompt using the
905 /// default chat model, and return plain text content (first choice).
906 pub async fn quick_chat_text<P: Into<String>>(
907 provider: Provider,
908 prompt: P,
909 ) -> Result<String, AiLibError> {
910 let client = AiClient::new(provider)?;
911 let req = client.build_simple_request(prompt.into());
912 let resp = client.chat_completion(req).await?;
913 resp.first_text().map(|s| s.to_string())
914 }
915
916 /// One-shot helper: create a client for `provider`, send a single user prompt using an
917 /// explicitly specified chat model, and return plain text content (first choice).
918 pub async fn quick_chat_text_with_model<P: Into<String>, M: Into<String>>(
919 provider: Provider,
920 prompt: P,
921 model: M,
922 ) -> Result<String, AiLibError> {
923 let client = AiClient::new(provider)?;
924 let req = client.build_simple_request_with_model(prompt.into(), model.into());
925 let resp = client.chat_completion(req).await?;
926 resp.first_text().map(|s| s.to_string())
927 }
928
929 /// One-shot helper: create a client for `provider`, send a single user prompt using the
930 /// default multimodal model, and return plain text content (first choice).
931 pub async fn quick_multimodal_text<P: Into<String>>(
932 provider: Provider,
933 prompt: P,
934 ) -> Result<String, AiLibError> {
935 let client = AiClient::new(provider)?;
936 let req = client.build_multimodal_request(prompt.into())?;
937 let resp = client.chat_completion(req).await?;
938 resp.first_text().map(|s| s.to_string())
939 }
940
941 /// One-shot helper: create a client for `provider`, send a single user prompt using an
942 /// explicitly specified multimodal model, and return plain text content (first choice).
943 pub async fn quick_multimodal_text_with_model<P: Into<String>, M: Into<String>>(
944 provider: Provider,
945 prompt: P,
946 model: M,
947 ) -> Result<String, AiLibError> {
948 let client = AiClient::new(provider)?;
949 let req = client.build_multimodal_request_with_model(prompt.into(), model.into());
950 let resp = client.chat_completion(req).await?;
951 resp.first_text().map(|s| s.to_string())
952 }
953
954 /// One-shot helper with model options: create a client for `provider`, send a single user prompt
955 /// using specified model options, and return plain text content (first choice).
956 pub async fn quick_chat_text_with_options<P: Into<String>>(
957 provider: Provider,
958 prompt: P,
959 options: ModelOptions,
960 ) -> Result<String, AiLibError> {
961 let client = AiClient::new(provider)?;
962
963 // Determine which model to use based on options
964 let model = if let Some(chat_model) = options.chat_model {
965 chat_model
966 } else {
967 provider.default_chat_model().to_string()
968 };
969
970 let req = client.build_simple_request_with_model(prompt.into(), model);
971 let resp = client.chat_completion(req).await?;
972 resp.first_text().map(|s| s.to_string())
973 }
974}
975
976/// Streaming response cancel handle
977pub struct CancelHandle {
978 sender: Option<oneshot::Sender<()>>,
979}
980
981impl CancelHandle {
982 /// Cancel streaming response
983 pub fn cancel(mut self) {
984 if let Some(sender) = self.sender.take() {
985 let _ = sender.send(());
986 }
987 }
988}
989
990/// AI client builder with progressive custom configuration
991///
992/// Usage examples:
993/// ```rust
994/// use ai_lib::{AiClientBuilder, Provider};
995///
996/// // Simplest usage - automatic environment variable detection
997/// let client = AiClientBuilder::new(Provider::Groq).build()?;
998///
999/// // Custom base_url and proxy
1000/// let client = AiClientBuilder::new(Provider::Groq)
1001/// .with_base_url("https://custom.groq.com")
1002/// .with_proxy(Some("http://proxy.example.com:8080"))
1003/// .build()?;
1004///
1005/// // Full custom configuration
1006/// let client = AiClientBuilder::new(Provider::Groq)
1007/// .with_base_url("https://custom.groq.com")
1008/// .with_proxy(Some("http://proxy.example.com:8080"))
1009/// .with_timeout(std::time::Duration::from_secs(60))
1010/// .with_pool_config(32, std::time::Duration::from_secs(90))
1011/// .build()?;
1012/// # Ok::<(), ai_lib::AiLibError>(())
1013/// ```
1014pub struct AiClientBuilder {
1015 provider: Provider,
1016 base_url: Option<String>,
1017 proxy_url: Option<String>,
1018 timeout: Option<std::time::Duration>,
1019 pool_max_idle: Option<usize>,
1020 pool_idle_timeout: Option<std::time::Duration>,
1021 metrics: Option<Arc<dyn Metrics>>,
1022 // Model configuration options
1023 default_chat_model: Option<String>,
1024 default_multimodal_model: Option<String>,
1025 // Resilience configuration
1026 resilience_config: ResilienceConfig,
1027 #[cfg(feature = "routing_mvp")]
1028 routing_array: Option<crate::provider::models::ModelArray>,
1029}
1030
1031impl AiClientBuilder {
1032 /// Create a new builder instance
1033 ///
1034 /// # Arguments
1035 /// * `provider` - The AI model provider to use
1036 ///
1037 /// # Returns
1038 /// * `Self` - Builder instance
1039 pub fn new(provider: Provider) -> Self {
1040 Self {
1041 provider,
1042 base_url: None,
1043 proxy_url: None,
1044 timeout: None,
1045 pool_max_idle: None,
1046 pool_idle_timeout: None,
1047 metrics: None,
1048 default_chat_model: None,
1049 default_multimodal_model: None,
1050 resilience_config: ResilienceConfig::default(),
1051 #[cfg(feature = "routing_mvp")]
1052 routing_array: None,
1053 }
1054 }
1055
1056 /// Check if provider is config-driven (uses GenericAdapter)
1057 fn is_config_driven_provider(provider: Provider) -> bool {
1058 provider.is_config_driven()
1059 }
1060
1061 /// Set custom base URL
1062 ///
1063 /// # Arguments
1064 /// * `base_url` - Custom base URL
1065 ///
1066 /// # Returns
1067 /// * `Self` - Builder instance for method chaining
1068 pub fn with_base_url(mut self, base_url: &str) -> Self {
1069 self.base_url = Some(base_url.to_string());
1070 self
1071 }
1072
1073 /// Set custom proxy URL
1074 ///
1075 /// # Arguments
1076 /// * `proxy_url` - Custom proxy URL, or None to use AI_PROXY_URL environment variable
1077 ///
1078 /// # Returns
1079 /// * `Self` - Builder instance for method chaining
1080 ///
1081 /// # Examples
1082 /// ```rust
1083 /// use ai_lib::{AiClientBuilder, Provider};
1084 ///
1085 /// // Use specific proxy URL
1086 /// let client = AiClientBuilder::new(Provider::Groq)
1087 /// .with_proxy(Some("http://proxy.example.com:8080"))
1088 /// .build()?;
1089 ///
1090 /// // Use AI_PROXY_URL environment variable
1091 /// let client = AiClientBuilder::new(Provider::Groq)
1092 /// .with_proxy(None)
1093 /// .build()?;
1094 /// # Ok::<(), ai_lib::AiLibError>(())
1095 /// ```
1096 pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
1097 self.proxy_url = proxy_url.map(|s| s.to_string());
1098 self
1099 }
1100
1101 /// Explicitly disable proxy usage
1102 ///
1103 /// This method ensures that no proxy will be used, regardless of environment variables.
1104 ///
1105 /// # Returns
1106 /// * `Self` - Builder instance for method chaining
1107 ///
1108 /// # Example
1109 /// ```rust
1110 /// use ai_lib::{AiClientBuilder, Provider};
1111 ///
1112 /// let client = AiClientBuilder::new(Provider::Groq)
1113 /// .build()?;
1114 /// # Ok::<(), ai_lib::AiLibError>(())
1115 /// ```
1116 pub fn without_proxy(mut self) -> Self {
1117 self.proxy_url = Some("".to_string());
1118 self
1119 }
1120
1121 /// Set custom timeout duration
1122 ///
1123 /// # Arguments
1124 /// * `timeout` - Custom timeout duration
1125 ///
1126 /// # Returns
1127 /// * `Self` - Builder instance for method chaining
1128 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
1129 self.timeout = Some(timeout);
1130 self
1131 }
1132
1133 /// Set connection pool configuration
1134 ///
1135 /// # Arguments
1136 /// * `max_idle` - Maximum idle connections per host
1137 /// * `idle_timeout` - Idle connection timeout duration
1138 ///
1139 /// # Returns
1140 /// * `Self` - Builder instance for method chaining
1141 pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
1142 self.pool_max_idle = Some(max_idle);
1143 self.pool_idle_timeout = Some(idle_timeout);
1144 self
1145 }
1146
1147 /// Set custom metrics implementation
1148 ///
1149 /// # Arguments
1150 /// * `metrics` - Custom metrics implementation
1151 ///
1152 /// # Returns
1153 /// * `Self` - Builder instance for method chaining
1154 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
1155 self.metrics = Some(metrics);
1156 self
1157 }
1158
1159 /// Set default chat model for the client
1160 ///
1161 /// # Arguments
1162 /// * `model` - Default chat model name
1163 ///
1164 /// # Returns
1165 /// * `Self` - Builder instance for method chaining
1166 ///
1167 /// # Example
1168 /// ```rust
1169 /// use ai_lib::{AiClientBuilder, Provider};
1170 ///
1171 /// let client = AiClientBuilder::new(Provider::Groq)
1172 /// .with_default_chat_model("llama-3.1-8b-instant")
1173 /// .build()?;
1174 /// # Ok::<(), ai_lib::AiLibError>(())
1175 /// ```
1176 pub fn with_default_chat_model(mut self, model: &str) -> Self {
1177 self.default_chat_model = Some(model.to_string());
1178 self
1179 }
1180
1181 /// Set default multimodal model for the client
1182 ///
1183 /// # Arguments
1184 /// * `model` - Default multimodal model name
1185 ///
1186 /// # Returns
1187 /// * `Self` - Builder instance for method chaining
1188 ///
1189 /// # Example
1190 /// ```rust
1191 /// use ai_lib::{AiClientBuilder, Provider};
1192 ///
1193 /// let client = AiClientBuilder::new(Provider::Groq)
1194 /// .with_default_multimodal_model("llama-3.2-11b-vision")
1195 /// .build()?;
1196 /// # Ok::<(), ai_lib::AiLibError>(())
1197 /// ```
1198 pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
1199 self.default_multimodal_model = Some(model.to_string());
1200 self
1201 }
1202
1203 /// Enable smart defaults for resilience features
1204 ///
1205 /// This method enables reasonable default configurations for circuit breaker,
1206 /// rate limiting, and error handling without requiring detailed configuration.
1207 ///
1208 /// # Returns
1209 /// * `Self` - Builder instance for method chaining
1210 ///
1211 /// # Example
1212 /// ```rust
1213 /// use ai_lib::{AiClientBuilder, Provider};
1214 ///
1215 /// let client = AiClientBuilder::new(Provider::Groq)
1216 /// .with_smart_defaults()
1217 /// .build()?;
1218 /// # Ok::<(), ai_lib::AiLibError>(())
1219 /// ```
1220 pub fn with_smart_defaults(mut self) -> Self {
1221 self.resilience_config = ResilienceConfig::smart_defaults();
1222 self
1223 }
1224
1225 /// Configure for production environment
1226 ///
1227 /// This method applies production-ready configurations for all resilience
1228 /// features with conservative settings for maximum reliability.
1229 ///
1230 /// # Returns
1231 /// * `Self` - Builder instance for method chaining
1232 ///
1233 /// # Example
1234 /// ```rust
1235 /// use ai_lib::{AiClientBuilder, Provider};
1236 ///
1237 /// let client = AiClientBuilder::new(Provider::Groq)
1238 /// .for_production()
1239 /// .build()?;
1240 /// # Ok::<(), ai_lib::AiLibError>(())
1241 /// ```
1242 pub fn for_production(mut self) -> Self {
1243 self.resilience_config = ResilienceConfig::production();
1244 self
1245 }
1246
1247 /// Configure for development environment
1248 ///
1249 /// This method applies development-friendly configurations with more
1250 /// lenient settings for easier debugging and testing.
1251 ///
1252 /// # Returns
1253 /// * `Self` - Builder instance for method chaining
1254 ///
1255 /// # Example
1256 /// ```rust
1257 /// use ai_lib::{AiClientBuilder, Provider};
1258 ///
1259 /// let client = AiClientBuilder::new(Provider::Groq)
1260 /// .for_development()
1261 /// .build()?;
1262 /// # Ok::<(), ai_lib::AiLibError>(())
1263 /// ```
1264 pub fn for_development(mut self) -> Self {
1265 self.resilience_config = ResilienceConfig::development();
1266 self
1267 }
1268
1269 /// Set custom resilience configuration
1270 ///
1271 /// # Arguments
1272 /// * `config` - Custom resilience configuration
1273 ///
1274 /// # Returns
1275 /// * `Self` - Builder instance for method chaining
1276 pub fn with_resilience_config(mut self, config: ResilienceConfig) -> Self {
1277 self.resilience_config = config;
1278 self
1279 }
1280
1281 #[cfg(feature = "routing_mvp")]
1282 /// Provide a `ModelArray` for client-side routing and fallback.
1283 pub fn with_routing_array(
1284 mut self,
1285 array: crate::provider::models::ModelArray,
1286 ) -> Self {
1287 self.routing_array = Some(array);
1288 self
1289 }
1290
1291 /// Build AiClient instance
1292 ///
1293 /// The build process applies configuration in the following priority order:
1294 /// 1. Explicitly set configuration (via with_* methods)
1295 /// 2. Environment variable configuration
1296 /// 3. Default configuration
1297 ///
1298 /// # Returns
1299 /// * `Result<AiClient, AiLibError>` - Returns client instance on success, error on failure
1300 pub fn build(self) -> Result<AiClient, AiLibError> {
1301 // 1. Determine base_url: explicit setting > environment variable > default
1302 let base_url = self.determine_base_url()?;
1303
1304 // 2. Determine proxy_url: explicit setting > environment variable
1305 let proxy_url = self.determine_proxy_url();
1306
1307 // 3. Determine timeout: explicit setting > default
1308 let timeout = self
1309 .timeout
1310 .unwrap_or_else(|| std::time::Duration::from_secs(30));
1311
1312 // 4. Create adapter
1313 let adapter: Box<dyn ChatApi> = if Self::is_config_driven_provider(self.provider) {
1314 // All config-driven providers use the same logic - much cleaner!
1315 // Create custom ProviderConfig (if needed)
1316 let config = self.create_custom_config(base_url)?;
1317 // Create custom HttpTransport (if needed)
1318 let transport = self.create_custom_transport(proxy_url.clone(), timeout)?;
1319 create_generic_adapter(config, transport)?
1320 } else {
1321 // Independent adapters - simple one-liners
1322 match self.provider {
1323 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
1324 Provider::Gemini => Box::new(GeminiAdapter::new()?),
1325 Provider::Mistral => Box::new(MistralAdapter::new()?),
1326 Provider::Cohere => Box::new(CohereAdapter::new()?),
1327 _ => unreachable!("All providers should be handled by now"),
1328 }
1329 };
1330
1331 // 7. Create AiClient
1332 let client = AiClient {
1333 provider: self.provider,
1334 adapter,
1335 metrics: self.metrics.unwrap_or_else(|| Arc::new(NoopMetrics::new())),
1336 connection_options: None,
1337 custom_default_chat_model: self.default_chat_model,
1338 custom_default_multimodal_model: self.default_multimodal_model,
1339 #[cfg(feature = "routing_mvp")]
1340 routing_array: self.routing_array,
1341 };
1342
1343 Ok(client)
1344 }
1345
1346 /// Determine base_url, priority: explicit setting > environment variable > default
1347 fn determine_base_url(&self) -> Result<String, AiLibError> {
1348 // 1. Explicitly set base_url
1349 if let Some(ref base_url) = self.base_url {
1350 return Ok(base_url.clone());
1351 }
1352
1353 // 2. base_url from environment variable
1354 let env_var_name = self.get_base_url_env_var_name();
1355 if let Ok(base_url) = std::env::var(&env_var_name) {
1356 return Ok(base_url);
1357 }
1358
1359 // 3. Use default configuration (only for config-driven providers)
1360 if Self::is_config_driven_provider(self.provider) {
1361 let default_config = self.get_default_provider_config()?;
1362 Ok(default_config.base_url)
1363 } else {
1364 // For independent providers, return a default base URL
1365 match self.provider {
1366 Provider::OpenAI => Ok("https://api.openai.com".to_string()),
1367 Provider::Gemini => Ok("https://generativelanguage.googleapis.com".to_string()),
1368 Provider::Mistral => Ok("https://api.mistral.ai".to_string()),
1369 Provider::Cohere => Ok("https://api.cohere.ai".to_string()),
1370 _ => Err(AiLibError::ConfigurationError(
1371 "Unknown provider for base URL determination".to_string(),
1372 )),
1373 }
1374 }
1375 }
1376
1377 /// Determine proxy_url, priority: explicit setting > environment variable
1378 fn determine_proxy_url(&self) -> Option<String> {
1379 // 1. Explicitly set proxy_url
1380 if let Some(ref proxy_url) = self.proxy_url {
1381 // If proxy_url is empty string, it means explicitly no proxy
1382 if proxy_url.is_empty() {
1383 return None;
1384 }
1385 return Some(proxy_url.clone());
1386 }
1387
1388 // 2. AI_PROXY_URL from environment variable
1389 std::env::var("AI_PROXY_URL").ok()
1390 }
1391
1392 /// Get environment variable name for corresponding provider
1393 fn get_base_url_env_var_name(&self) -> String {
1394 match self.provider {
1395 Provider::Groq => "GROQ_BASE_URL".to_string(),
1396 Provider::XaiGrok => "GROK_BASE_URL".to_string(),
1397 Provider::Ollama => "OLLAMA_BASE_URL".to_string(),
1398 Provider::DeepSeek => "DEEPSEEK_BASE_URL".to_string(),
1399 Provider::Qwen => "DASHSCOPE_BASE_URL".to_string(),
1400 Provider::BaiduWenxin => "BAIDU_WENXIN_BASE_URL".to_string(),
1401 Provider::TencentHunyuan => "TENCENT_HUNYUAN_BASE_URL".to_string(),
1402 Provider::IflytekSpark => "IFLYTEK_BASE_URL".to_string(),
1403 Provider::Moonshot => "MOONSHOT_BASE_URL".to_string(),
1404 Provider::Anthropic => "ANTHROPIC_BASE_URL".to_string(),
1405 Provider::AzureOpenAI => "AZURE_OPENAI_BASE_URL".to_string(),
1406 Provider::HuggingFace => "HUGGINGFACE_BASE_URL".to_string(),
1407 Provider::TogetherAI => "TOGETHER_BASE_URL".to_string(),
1408 // These providers don't support custom base_url
1409 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
1410 "".to_string()
1411 }
1412 }
1413 }
1414
1415 /// Get default provider configuration
1416 fn get_default_provider_config(
1417 &self,
1418 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
1419 self.provider.get_default_config()
1420 }
1421
1422 /// Create custom ProviderConfig
1423 fn create_custom_config(
1424 &self,
1425 base_url: String,
1426 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
1427 let mut config = self.get_default_provider_config()?;
1428 config.base_url = base_url;
1429 Ok(config)
1430 }
1431
1432 /// Create custom HttpTransport
1433 fn create_custom_transport(
1434 &self,
1435 proxy_url: Option<String>,
1436 timeout: std::time::Duration,
1437 ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
1438 // If no custom configuration, return None (use default transport)
1439 if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
1440 return Ok(None);
1441 }
1442
1443 // Create custom HttpTransportConfig
1444 let transport_config = crate::transport::HttpTransportConfig {
1445 timeout,
1446 proxy: proxy_url,
1447 pool_max_idle_per_host: self.pool_max_idle,
1448 pool_idle_timeout: self.pool_idle_timeout,
1449 };
1450
1451 // Create custom HttpTransport
1452 let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
1453 Ok(Some(transport.boxed()))
1454 }
1455}
1456
1457/// Controllable streaming response
1458struct ControlledStream {
1459 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
1460 cancel_rx: Option<oneshot::Receiver<()>>,
1461}
1462
1463impl ControlledStream {
1464 fn new(
1465 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
1466 cancel_rx: oneshot::Receiver<()>,
1467 ) -> Self {
1468 Self {
1469 inner,
1470 cancel_rx: Some(cancel_rx),
1471 }
1472 }
1473}
1474
1475impl Stream for ControlledStream {
1476 type Item = Result<ChatCompletionChunk, AiLibError>;
1477
1478 fn poll_next(
1479 mut self: std::pin::Pin<&mut Self>,
1480 cx: &mut std::task::Context<'_>,
1481 ) -> std::task::Poll<Option<Self::Item>> {
1482 use futures::stream::StreamExt;
1483 use std::task::Poll;
1484
1485 // Check if cancelled
1486 if let Some(ref mut cancel_rx) = self.cancel_rx {
1487 match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
1488 Poll::Ready(_) => {
1489 self.cancel_rx = None;
1490 return Poll::Ready(Some(Err(AiLibError::ProviderError(
1491 "Stream cancelled".to_string(),
1492 ))));
1493 }
1494 Poll::Pending => {}
1495 }
1496 }
1497
1498 // Poll inner stream
1499 self.inner.poll_next_unpin(cx)
1500 }
1501}