ai_lib/client.rs
1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::config::{ConnectionOptions, ResilienceConfig};
3use crate::metrics::{Metrics, NoopMetrics};
4use crate::provider::{
5 classification::ProviderClassification,
6 CohereAdapter, GeminiAdapter, GenericAdapter, MistralAdapter, OpenAiAdapter, ProviderConfigs,
7};
8use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
9use futures::stream::Stream;
10use futures::Future;
11use std::sync::Arc;
12use tokio::sync::oneshot;
13
14/// Model configuration options for explicit model selection
15#[derive(Debug, Clone)]
16pub struct ModelOptions {
17 pub chat_model: Option<String>,
18 pub multimodal_model: Option<String>,
19 pub fallback_models: Vec<String>,
20 pub auto_discovery: bool,
21}
22
23impl ModelOptions {
24 /// Create default model options
25 pub fn default() -> Self {
26 Self {
27 chat_model: None,
28 multimodal_model: None,
29 fallback_models: Vec::new(),
30 auto_discovery: true,
31 }
32 }
33
34 /// Set chat model
35 pub fn with_chat_model(mut self, model: &str) -> Self {
36 self.chat_model = Some(model.to_string());
37 self
38 }
39
40 /// Set multimodal model
41 pub fn with_multimodal_model(mut self, model: &str) -> Self {
42 self.multimodal_model = Some(model.to_string());
43 self
44 }
45
46 /// Set fallback models
47 pub fn with_fallback_models(mut self, models: Vec<&str>) -> Self {
48 self.fallback_models = models.into_iter().map(|s| s.to_string()).collect();
49 self
50 }
51
52 /// Enable or disable auto discovery
53 pub fn with_auto_discovery(mut self, enabled: bool) -> Self {
54 self.auto_discovery = enabled;
55 self
56 }
57}
58
59/// Helper function to create GenericAdapter with optional custom transport
60fn create_generic_adapter(
61 config: crate::provider::config::ProviderConfig,
62 transport: Option<crate::transport::DynHttpTransportRef>,
63) -> Result<Box<dyn ChatApi>, AiLibError> {
64 if let Some(custom_transport) = transport {
65 Ok(Box::new(GenericAdapter::with_transport_ref(config, custom_transport)?))
66 } else {
67 Ok(Box::new(GenericAdapter::new(config)?))
68 }
69}
70
71/// Unified AI client module
72///
73/// AI model provider enumeration
74#[derive(Debug, Clone, Copy, PartialEq)]
75pub enum Provider {
76 // Config-driven providers
77 Groq,
78 XaiGrok,
79 Ollama,
80 DeepSeek,
81 Anthropic,
82 AzureOpenAI,
83 HuggingFace,
84 TogetherAI,
85 // Chinese providers (OpenAI-compatible or config-driven)
86 BaiduWenxin,
87 TencentHunyuan,
88 IflytekSpark,
89 Moonshot,
90 // Independent adapters
91 OpenAI,
92 Qwen,
93 Gemini,
94 Mistral,
95 Cohere,
96 // Bedrock removed (deferred)
97}
98
99impl Provider {
100 /// Get the provider's preferred default chat model.
101 /// These should mirror the values used inside `ProviderConfigs`.
102 pub fn default_chat_model(&self) -> &'static str {
103 match self {
104 Provider::Groq => "llama-3.1-8b-instant",
105 Provider::XaiGrok => "grok-beta",
106 Provider::Ollama => "llama3-8b",
107 Provider::DeepSeek => "deepseek-chat",
108 Provider::Anthropic => "claude-3-5-sonnet-20241022",
109 Provider::AzureOpenAI => "gpt-35-turbo",
110 Provider::HuggingFace => "microsoft/DialoGPT-medium",
111 Provider::TogetherAI => "meta-llama/Llama-3-8b-chat-hf",
112 Provider::BaiduWenxin => "ernie-3.5",
113 Provider::TencentHunyuan => "hunyuan-standard",
114 Provider::IflytekSpark => "spark-v3.0",
115 Provider::Moonshot => "moonshot-v1-8k",
116 Provider::OpenAI => "gpt-3.5-turbo",
117 Provider::Qwen => "qwen-turbo",
118 Provider::Gemini => "gemini-pro", // widely used text/chat model
119 Provider::Mistral => "mistral-small", // generic default
120 Provider::Cohere => "command-r", // chat-capable model
121 }
122 }
123
124 /// Get the provider's preferred multimodal model (if any).
125 pub fn default_multimodal_model(&self) -> Option<&'static str> {
126 match self {
127 Provider::OpenAI => Some("gpt-4o"),
128 Provider::AzureOpenAI => Some("gpt-4o"),
129 Provider::Anthropic => Some("claude-3-5-sonnet-20241022"),
130 Provider::Groq => None, // No multimodal model currently available
131 Provider::Gemini => Some("gemini-1.5-flash"),
132 Provider::Cohere => Some("command-r-plus"),
133 // Others presently have no clearly documented multimodal endpoint or are not yet wired.
134 _ => None,
135 }
136 }
137}
138
139/// Unified AI client
140///
141/// Usage example:
142/// ```rust
143/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
144///
145/// #[tokio::main]
146/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
147/// // Switch model provider by changing Provider value
148/// let client = AiClient::new(Provider::Groq)?;
149///
150/// let request = ChatCompletionRequest::new(
151/// "test-model".to_string(),
152/// vec![Message {
153/// role: Role::User,
154/// content: ai_lib::types::common::Content::Text("Hello".to_string()),
155/// function_call: None,
156/// }],
157/// );
158///
159/// // Note: Set GROQ_API_KEY environment variable for actual API calls
160/// // Optional: Set AI_PROXY_URL environment variable to use proxy server
161/// // let response = client.chat_completion(request).await?;
162///
163/// println!("Client created successfully with provider: {:?}", client.current_provider());
164/// println!("Request prepared for model: {}", request.model);
165///
166/// Ok(())
167/// }
168/// ```
169///
170/// # Proxy Configuration
171///
172/// Configure proxy server by setting the `AI_PROXY_URL` environment variable:
173///
174/// ```bash
175/// export AI_PROXY_URL=http://proxy.example.com:8080
176/// ```
177///
178/// Supported proxy formats:
179/// - HTTP proxy: `http://proxy.example.com:8080`
180/// - HTTPS proxy: `https://proxy.example.com:8080`
181/// - With authentication: `http://user:pass@proxy.example.com:8080`
182pub struct AiClient {
183 provider: Provider,
184 adapter: Box<dyn ChatApi>,
185 metrics: Arc<dyn Metrics>,
186 connection_options: Option<ConnectionOptions>,
187 // Custom default models (override provider defaults)
188 custom_default_chat_model: Option<String>,
189 custom_default_multimodal_model: Option<String>,
190}
191
192impl AiClient {
193 /// Create a new AI client
194 ///
195 /// # Arguments
196 /// * `provider` - The AI model provider to use
197 ///
198 /// # Returns
199 /// * `Result<Self, AiLibError>` - Client instance on success, error on failure
200 ///
201 /// # Example
202 /// ```rust
203 /// use ai_lib::{AiClient, Provider};
204 ///
205 /// let client = AiClient::new(Provider::Groq)?;
206 /// # Ok::<(), ai_lib::AiLibError>(())
207 /// ```
208 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
209 // Use the new builder to create client with automatic environment variable detection
210 let mut c = AiClientBuilder::new(provider).build()?;
211 c.connection_options = None;
212 Ok(c)
213 }
214
215 /// Create client with minimal explicit options (base_url/proxy/timeout). Not all providers
216 /// support overrides; unsupported providers ignore unspecified fields gracefully.
217 pub fn with_options(provider: Provider, opts: ConnectionOptions) -> Result<Self, AiLibError> {
218 let config_driven = provider.is_config_driven();
219 let need_builder = config_driven
220 && (opts.base_url.is_some()
221 || opts.proxy.is_some()
222 || opts.timeout.is_some()
223 || opts.disable_proxy);
224 if need_builder {
225 let mut b = AiClient::builder(provider);
226 if let Some(ref base) = opts.base_url {
227 b = b.with_base_url(base);
228 }
229 if opts.disable_proxy {
230 b = b.without_proxy();
231 } else if let Some(ref proxy) = opts.proxy {
232 if proxy.is_empty() {
233 b = b.without_proxy();
234 } else {
235 b = b.with_proxy(Some(proxy));
236 }
237 }
238 if let Some(t) = opts.timeout {
239 b = b.with_timeout(t);
240 }
241 let mut client = b.build()?;
242 // If api_key override + generic provider path: re-wrap adapter using override
243 if opts.api_key.is_some() {
244 // Only applies to config-driven generic adapter providers
245 let new_adapter: Option<Box<dyn ChatApi>> = match provider {
246 Provider::Groq => Some(Box::new(GenericAdapter::new_with_api_key(
247 ProviderConfigs::groq(),
248 opts.api_key.clone(),
249 )?)),
250 Provider::XaiGrok => Some(Box::new(GenericAdapter::new_with_api_key(
251 ProviderConfigs::xai_grok(),
252 opts.api_key.clone(),
253 )?)),
254 Provider::Ollama => Some(Box::new(GenericAdapter::new_with_api_key(
255 ProviderConfigs::ollama(),
256 opts.api_key.clone(),
257 )?)),
258 Provider::DeepSeek => Some(Box::new(GenericAdapter::new_with_api_key(
259 ProviderConfigs::deepseek(),
260 opts.api_key.clone(),
261 )?)),
262 Provider::Qwen => Some(Box::new(GenericAdapter::new_with_api_key(
263 ProviderConfigs::qwen(),
264 opts.api_key.clone(),
265 )?)),
266 Provider::BaiduWenxin => Some(Box::new(GenericAdapter::new_with_api_key(
267 ProviderConfigs::baidu_wenxin(),
268 opts.api_key.clone(),
269 )?)),
270 Provider::TencentHunyuan => Some(Box::new(GenericAdapter::new_with_api_key(
271 ProviderConfigs::tencent_hunyuan(),
272 opts.api_key.clone(),
273 )?)),
274 Provider::IflytekSpark => Some(Box::new(GenericAdapter::new_with_api_key(
275 ProviderConfigs::iflytek_spark(),
276 opts.api_key.clone(),
277 )?)),
278 Provider::Moonshot => Some(Box::new(GenericAdapter::new_with_api_key(
279 ProviderConfigs::moonshot(),
280 opts.api_key.clone(),
281 )?)),
282 Provider::Anthropic => Some(Box::new(GenericAdapter::new_with_api_key(
283 ProviderConfigs::anthropic(),
284 opts.api_key.clone(),
285 )?)),
286 Provider::AzureOpenAI => Some(Box::new(GenericAdapter::new_with_api_key(
287 ProviderConfigs::azure_openai(),
288 opts.api_key.clone(),
289 )?)),
290 Provider::HuggingFace => Some(Box::new(GenericAdapter::new_with_api_key(
291 ProviderConfigs::huggingface(),
292 opts.api_key.clone(),
293 )?)),
294 Provider::TogetherAI => Some(Box::new(GenericAdapter::new_with_api_key(
295 ProviderConfigs::together_ai(),
296 opts.api_key.clone(),
297 )?)),
298 _ => None,
299 };
300 if let Some(a) = new_adapter {
301 client.adapter = a;
302 }
303 }
304 client.connection_options = Some(opts);
305 return Ok(client);
306 }
307
308 // Independent adapters: OpenAI / Gemini / Mistral / Cohere
309 if provider.is_independent() {
310 let adapter: Box<dyn ChatApi> = match provider {
311 Provider::OpenAI => {
312 if let Some(ref k) = opts.api_key {
313 let inner =
314 OpenAiAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
315 Box::new(inner)
316 } else {
317 let inner = OpenAiAdapter::new()?;
318 Box::new(inner)
319 }
320 }
321 Provider::Gemini => {
322 if let Some(ref k) = opts.api_key {
323 let inner =
324 GeminiAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
325 Box::new(inner)
326 } else {
327 let inner = GeminiAdapter::new()?;
328 Box::new(inner)
329 }
330 }
331 Provider::Mistral => {
332 if opts.api_key.is_some() || opts.base_url.is_some() {
333 let inner = MistralAdapter::new_with_overrides(
334 opts.api_key.clone(),
335 opts.base_url.clone(),
336 )?;
337 Box::new(inner)
338 } else {
339 let inner = MistralAdapter::new()?;
340 Box::new(inner)
341 }
342 }
343 Provider::Cohere => {
344 if let Some(ref k) = opts.api_key {
345 let inner =
346 CohereAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
347 Box::new(inner)
348 } else {
349 let inner = CohereAdapter::new()?;
350 Box::new(inner)
351 }
352 }
353 _ => unreachable!(),
354 };
355 return Ok(AiClient {
356 provider,
357 adapter,
358 metrics: Arc::new(NoopMetrics::new()),
359 connection_options: Some(opts),
360 custom_default_chat_model: None,
361 custom_default_multimodal_model: None,
362 });
363 }
364
365 // Simple config-driven without overrides
366 let mut client = AiClient::new(provider)?;
367 if let Some(ref k) = opts.api_key {
368 let override_adapter: Option<Box<dyn ChatApi>> = match provider {
369 Provider::Groq => Some(Box::new(GenericAdapter::new_with_api_key(
370 ProviderConfigs::groq(),
371 Some(k.clone()),
372 )?)),
373 Provider::XaiGrok => Some(Box::new(GenericAdapter::new_with_api_key(
374 ProviderConfigs::xai_grok(),
375 Some(k.clone()),
376 )?)),
377 Provider::Ollama => Some(Box::new(GenericAdapter::new_with_api_key(
378 ProviderConfigs::ollama(),
379 Some(k.clone()),
380 )?)),
381 Provider::DeepSeek => Some(Box::new(GenericAdapter::new_with_api_key(
382 ProviderConfigs::deepseek(),
383 Some(k.clone()),
384 )?)),
385 Provider::Qwen => Some(Box::new(GenericAdapter::new_with_api_key(
386 ProviderConfigs::qwen(),
387 Some(k.clone()),
388 )?)),
389 Provider::BaiduWenxin => Some(Box::new(GenericAdapter::new_with_api_key(
390 ProviderConfigs::baidu_wenxin(),
391 Some(k.clone()),
392 )?)),
393 Provider::TencentHunyuan => Some(Box::new(GenericAdapter::new_with_api_key(
394 ProviderConfigs::tencent_hunyuan(),
395 Some(k.clone()),
396 )?)),
397 Provider::IflytekSpark => Some(Box::new(GenericAdapter::new_with_api_key(
398 ProviderConfigs::iflytek_spark(),
399 Some(k.clone()),
400 )?)),
401 Provider::Moonshot => Some(Box::new(GenericAdapter::new_with_api_key(
402 ProviderConfigs::moonshot(),
403 Some(k.clone()),
404 )?)),
405 Provider::Anthropic => Some(Box::new(GenericAdapter::new_with_api_key(
406 ProviderConfigs::anthropic(),
407 Some(k.clone()),
408 )?)),
409 Provider::AzureOpenAI => Some(Box::new(GenericAdapter::new_with_api_key(
410 ProviderConfigs::azure_openai(),
411 Some(k.clone()),
412 )?)),
413 Provider::HuggingFace => Some(Box::new(GenericAdapter::new_with_api_key(
414 ProviderConfigs::huggingface(),
415 Some(k.clone()),
416 )?)),
417 Provider::TogetherAI => Some(Box::new(GenericAdapter::new_with_api_key(
418 ProviderConfigs::together_ai(),
419 Some(k.clone()),
420 )?)),
421 _ => None,
422 };
423 if let Some(a) = override_adapter {
424 client.adapter = a;
425 }
426 }
427 client.connection_options = Some(opts);
428 Ok(client)
429 }
430
431 pub fn connection_options(&self) -> Option<&ConnectionOptions> {
432 self.connection_options.as_ref()
433 }
434
435 /// Create a new AI client builder
436 ///
437 /// The builder pattern allows more flexible client configuration:
438 /// - Automatic environment variable detection
439 /// - Support for custom base_url and proxy
440 /// - Support for custom timeout and connection pool configuration
441 ///
442 /// # Arguments
443 /// * `provider` - The AI model provider to use
444 ///
445 /// # Returns
446 /// * `AiClientBuilder` - Builder instance
447 ///
448 /// # Example
449 /// ```rust
450 /// use ai_lib::{AiClient, Provider};
451 ///
452 /// // Simplest usage - automatic environment variable detection
453 /// let client = AiClient::builder(Provider::Groq).build()?;
454 ///
455 /// // Custom base_url and proxy
456 /// let client = AiClient::builder(Provider::Groq)
457 /// .with_base_url("https://custom.groq.com")
458 /// .with_proxy(Some("http://proxy.example.com:8080"))
459 /// .build()?;
460 /// # Ok::<(), ai_lib::AiLibError>(())
461 /// ```
462 pub fn builder(provider: Provider) -> AiClientBuilder {
463 AiClientBuilder::new(provider)
464 }
465
466 /// Create AiClient with injected metrics implementation
467 pub fn new_with_metrics(
468 provider: Provider,
469 metrics: Arc<dyn Metrics>,
470 ) -> Result<Self, AiLibError> {
471 let adapter: Box<dyn ChatApi> = match provider {
472 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
473 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
474 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
475 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
476 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
477 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
478 Provider::BaiduWenxin => {
479 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
480 }
481 Provider::TencentHunyuan => {
482 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
483 }
484 Provider::IflytekSpark => {
485 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
486 }
487 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
488 Provider::AzureOpenAI => {
489 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
490 }
491 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
492 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
493 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
494 Provider::Gemini => Box::new(GeminiAdapter::new()?),
495 Provider::Mistral => Box::new(MistralAdapter::new()?),
496 Provider::Cohere => Box::new(CohereAdapter::new()?),
497 };
498
499 Ok(Self {
500 provider,
501 adapter,
502 metrics,
503 connection_options: None,
504 custom_default_chat_model: None,
505 custom_default_multimodal_model: None,
506 })
507 }
508
509 /// Set metrics implementation on client
510 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
511 self.metrics = metrics;
512 self
513 }
514
515 /// Send chat completion request
516 ///
517 /// # Arguments
518 /// * `request` - Chat completion request
519 ///
520 /// # Returns
521 /// * `Result<ChatCompletionResponse, AiLibError>` - Response on success, error on failure
522 pub async fn chat_completion(
523 &self,
524 request: ChatCompletionRequest,
525 ) -> Result<ChatCompletionResponse, AiLibError> {
526 self.adapter.chat_completion(request).await
527 }
528
529 /// Streaming chat completion request
530 ///
531 /// # Arguments
532 /// * `request` - Chat completion request
533 ///
534 /// # Returns
535 /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - Stream response on success
536 pub async fn chat_completion_stream(
537 &self,
538 mut request: ChatCompletionRequest,
539 ) -> Result<
540 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
541 AiLibError,
542 > {
543 request.stream = Some(true);
544 self.adapter.chat_completion_stream(request).await
545 }
546
547 /// Streaming chat completion request with cancel control
548 ///
549 /// # Arguments
550 /// * `request` - Chat completion request
551 ///
552 /// # Returns
553 /// * `Result<(impl Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin, CancelHandle), AiLibError>` - Returns streaming response and cancel handle on success
554 pub async fn chat_completion_stream_with_cancel(
555 &self,
556 mut request: ChatCompletionRequest,
557 ) -> Result<
558 (
559 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
560 CancelHandle,
561 ),
562 AiLibError,
563 > {
564 request.stream = Some(true);
565 let stream = self.adapter.chat_completion_stream(request).await?;
566 let (cancel_tx, cancel_rx) = oneshot::channel();
567 let cancel_handle = CancelHandle {
568 sender: Some(cancel_tx),
569 };
570
571 let controlled_stream = ControlledStream::new(stream, cancel_rx);
572 Ok((Box::new(controlled_stream), cancel_handle))
573 }
574
575 /// Batch chat completion requests
576 ///
577 /// # Arguments
578 /// * `requests` - List of chat completion requests
579 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
580 ///
581 /// # Returns
582 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
583 ///
584 /// # Example
585 /// ```rust
586 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
587 /// use ai_lib::types::common::Content;
588 ///
589 /// #[tokio::main]
590 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
591 /// let client = AiClient::new(Provider::Groq)?;
592 ///
593 /// let requests = vec![
594 /// ChatCompletionRequest::new(
595 /// "llama3-8b-8192".to_string(),
596 /// vec![Message {
597 /// role: Role::User,
598 /// content: Content::Text("Hello".to_string()),
599 /// function_call: None,
600 /// }],
601 /// ),
602 /// ChatCompletionRequest::new(
603 /// "llama3-8b-8192".to_string(),
604 /// vec![Message {
605 /// role: Role::User,
606 /// content: Content::Text("How are you?".to_string()),
607 /// function_call: None,
608 /// }],
609 /// ),
610 /// ];
611 ///
612 /// // Limit concurrency to 5
613 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
614 ///
615 /// for (i, response) in responses.iter().enumerate() {
616 /// match response {
617 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
618 /// Err(e) => println!("Request {} failed: {}", i, e),
619 /// }
620 /// }
621 ///
622 /// Ok(())
623 /// }
624 /// ```
625 pub async fn chat_completion_batch(
626 &self,
627 requests: Vec<ChatCompletionRequest>,
628 concurrency_limit: Option<usize>,
629 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
630 self.adapter
631 .chat_completion_batch(requests, concurrency_limit)
632 .await
633 }
634
635 /// Smart batch processing: automatically choose processing strategy based on request count
636 ///
637 /// # Arguments
638 /// * `requests` - List of chat completion requests
639 ///
640 /// # Returns
641 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
642 pub async fn chat_completion_batch_smart(
643 &self,
644 requests: Vec<ChatCompletionRequest>,
645 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
646 // Use sequential processing for small batches, concurrent processing for large batches
647 let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
648 self.chat_completion_batch(requests, concurrency_limit)
649 .await
650 }
651
652 /// Batch chat completion requests
653 ///
654 /// # Arguments
655 /// * `requests` - List of chat completion requests
656 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
657 ///
658 /// # Returns
659 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
660 ///
661 /// # Example
662 /// ```rust
663 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
664 /// use ai_lib::types::common::Content;
665 ///
666 /// #[tokio::main]
667 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
668 /// let client = AiClient::new(Provider::Groq)?;
669 ///
670 /// let requests = vec![
671 /// ChatCompletionRequest::new(
672 /// "llama3-8b-8192".to_string(),
673 /// vec![Message {
674 /// role: Role::User,
675 /// content: Content::Text("Hello".to_string()),
676 /// function_call: None,
677 /// }],
678 /// ),
679 /// ChatCompletionRequest::new(
680 /// "llama3-8b-8192".to_string(),
681 /// vec![Message {
682 /// role: Role::User,
683 /// content: Content::Text("How are you?".to_string()),
684 /// function_call: None,
685 /// }],
686 /// ),
687 /// ];
688 ///
689 /// // Limit concurrency to 5
690 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
691 ///
692 /// for (i, response) in responses.iter().enumerate() {
693 /// match response {
694 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
695 /// Err(e) => println!("Request {} failed: {}", i, e),
696 /// }
697 /// }
698 ///
699 /// Ok(())
700 /// }
701 /// ```
702 ///
703 /// Get list of supported models
704 ///
705 /// # Returns
706 /// * `Result<Vec<String>, AiLibError>` - Returns model list on success, error on failure
707 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
708 self.adapter.list_models().await
709 }
710
711 /// Switch AI model provider
712 ///
713 /// # Arguments
714 /// * `provider` - New provider
715 ///
716 /// # Returns
717 /// * `Result<(), AiLibError>` - Returns () on success, error on failure
718 ///
719 /// # Example
720 /// ```rust
721 /// use ai_lib::{AiClient, Provider};
722 ///
723 /// let mut client = AiClient::new(Provider::Groq)?;
724 /// // Switch from Groq to Groq (demonstrating switch functionality)
725 /// client.switch_provider(Provider::Groq)?;
726 /// # Ok::<(), ai_lib::AiLibError>(())
727 /// ```
728 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
729 let new_adapter: Box<dyn ChatApi> = match provider {
730 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
731 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
732 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
733 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
734 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
735 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
736 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
737 Provider::BaiduWenxin => {
738 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
739 }
740 Provider::TencentHunyuan => {
741 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
742 }
743 Provider::IflytekSpark => {
744 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
745 }
746 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
747 Provider::Gemini => Box::new(GeminiAdapter::new()?),
748 Provider::AzureOpenAI => {
749 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
750 }
751 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
752 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
753 Provider::Mistral => Box::new(MistralAdapter::new()?),
754 Provider::Cohere => Box::new(CohereAdapter::new()?),
755 // Provider::Bedrock => Box::new(BedrockAdapter::new()?),
756 };
757
758 self.provider = provider;
759 self.adapter = new_adapter;
760 Ok(())
761 }
762
763 /// Get current provider
764 pub fn current_provider(&self) -> Provider {
765 self.provider
766 }
767
768 /// Convenience helper: construct a request with the provider's default chat model.
769 /// This does NOT send the request.
770 /// Uses custom default model if set via AiClientBuilder, otherwise uses provider default.
771 pub fn build_simple_request<S: Into<String>>(&self, prompt: S) -> ChatCompletionRequest {
772 let model = self.custom_default_chat_model
773 .clone()
774 .unwrap_or_else(|| self.provider.default_chat_model().to_string());
775
776 ChatCompletionRequest::new(
777 model,
778 vec![crate::types::Message {
779 role: crate::types::Role::User,
780 content: crate::types::common::Content::Text(prompt.into()),
781 function_call: None,
782 }],
783 )
784 }
785
786 /// Convenience helper: construct a request with an explicitly specified chat model.
787 /// This does NOT send the request.
788 pub fn build_simple_request_with_model<S: Into<String>>(
789 &self,
790 prompt: S,
791 model: S
792 ) -> ChatCompletionRequest {
793 ChatCompletionRequest::new(
794 model.into(),
795 vec![crate::types::Message {
796 role: crate::types::Role::User,
797 content: crate::types::common::Content::Text(prompt.into()),
798 function_call: None,
799 }],
800 )
801 }
802
803 /// Convenience helper: construct a request with the provider's default multimodal model.
804 /// This does NOT send the request.
805 /// Uses custom default model if set via AiClientBuilder, otherwise uses provider default.
806 pub fn build_multimodal_request<S: Into<String>>(
807 &self,
808 prompt: S
809 ) -> Result<ChatCompletionRequest, AiLibError> {
810 let model = self.custom_default_multimodal_model
811 .clone()
812 .or_else(|| self.provider.default_multimodal_model().map(|s| s.to_string()))
813 .ok_or_else(|| AiLibError::ConfigurationError(
814 format!("No multimodal model available for provider {:?}", self.provider)
815 ))?;
816
817 Ok(ChatCompletionRequest::new(
818 model,
819 vec![crate::types::Message {
820 role: crate::types::Role::User,
821 content: crate::types::common::Content::Text(prompt.into()),
822 function_call: None,
823 }],
824 ))
825 }
826
827 /// Convenience helper: construct a request with an explicitly specified multimodal model.
828 /// This does NOT send the request.
829 pub fn build_multimodal_request_with_model<S: Into<String>>(
830 &self,
831 prompt: S,
832 model: S
833 ) -> ChatCompletionRequest {
834 ChatCompletionRequest::new(
835 model.into(),
836 vec![crate::types::Message {
837 role: crate::types::Role::User,
838 content: crate::types::common::Content::Text(prompt.into()),
839 function_call: None,
840 }],
841 )
842 }
843
844 /// One-shot helper: create a client for `provider`, send a single user prompt using the
845 /// default chat model, and return plain text content (first choice).
846 pub async fn quick_chat_text<P: Into<String>>(
847 provider: Provider,
848 prompt: P,
849 ) -> Result<String, AiLibError> {
850 let client = AiClient::new(provider)?;
851 let req = client.build_simple_request(prompt.into());
852 let resp = client.chat_completion(req).await?;
853 resp.first_text().map(|s| s.to_string())
854 }
855
856 /// One-shot helper: create a client for `provider`, send a single user prompt using an
857 /// explicitly specified chat model, and return plain text content (first choice).
858 pub async fn quick_chat_text_with_model<P: Into<String>, M: Into<String>>(
859 provider: Provider,
860 prompt: P,
861 model: M,
862 ) -> Result<String, AiLibError> {
863 let client = AiClient::new(provider)?;
864 let req = client.build_simple_request_with_model(prompt.into(), model.into());
865 let resp = client.chat_completion(req).await?;
866 resp.first_text().map(|s| s.to_string())
867 }
868
869 /// One-shot helper: create a client for `provider`, send a single user prompt using the
870 /// default multimodal model, and return plain text content (first choice).
871 pub async fn quick_multimodal_text<P: Into<String>>(
872 provider: Provider,
873 prompt: P,
874 ) -> Result<String, AiLibError> {
875 let client = AiClient::new(provider)?;
876 let req = client.build_multimodal_request(prompt.into())?;
877 let resp = client.chat_completion(req).await?;
878 resp.first_text().map(|s| s.to_string())
879 }
880
881 /// One-shot helper: create a client for `provider`, send a single user prompt using an
882 /// explicitly specified multimodal model, and return plain text content (first choice).
883 pub async fn quick_multimodal_text_with_model<P: Into<String>, M: Into<String>>(
884 provider: Provider,
885 prompt: P,
886 model: M,
887 ) -> Result<String, AiLibError> {
888 let client = AiClient::new(provider)?;
889 let req = client.build_multimodal_request_with_model(prompt.into(), model.into());
890 let resp = client.chat_completion(req).await?;
891 resp.first_text().map(|s| s.to_string())
892 }
893
894 /// One-shot helper with model options: create a client for `provider`, send a single user prompt
895 /// using specified model options, and return plain text content (first choice).
896 pub async fn quick_chat_text_with_options<P: Into<String>>(
897 provider: Provider,
898 prompt: P,
899 options: ModelOptions,
900 ) -> Result<String, AiLibError> {
901 let client = AiClient::new(provider)?;
902
903 // Determine which model to use based on options
904 let model = if let Some(chat_model) = options.chat_model {
905 chat_model
906 } else {
907 provider.default_chat_model().to_string()
908 };
909
910 let req = client.build_simple_request_with_model(prompt.into(), model);
911 let resp = client.chat_completion(req).await?;
912 resp.first_text().map(|s| s.to_string())
913 }
914}
915
916/// Streaming response cancel handle
917pub struct CancelHandle {
918 sender: Option<oneshot::Sender<()>>,
919}
920
921impl CancelHandle {
922 /// Cancel streaming response
923 pub fn cancel(mut self) {
924 if let Some(sender) = self.sender.take() {
925 let _ = sender.send(());
926 }
927 }
928}
929
930/// AI client builder with progressive custom configuration
931///
932/// Usage examples:
933/// ```rust
934/// use ai_lib::{AiClientBuilder, Provider};
935///
936/// // Simplest usage - automatic environment variable detection
937/// let client = AiClientBuilder::new(Provider::Groq).build()?;
938///
939/// // Custom base_url and proxy
940/// let client = AiClientBuilder::new(Provider::Groq)
941/// .with_base_url("https://custom.groq.com")
942/// .with_proxy(Some("http://proxy.example.com:8080"))
943/// .build()?;
944///
945/// // Full custom configuration
946/// let client = AiClientBuilder::new(Provider::Groq)
947/// .with_base_url("https://custom.groq.com")
948/// .with_proxy(Some("http://proxy.example.com:8080"))
949/// .with_timeout(std::time::Duration::from_secs(60))
950/// .with_pool_config(32, std::time::Duration::from_secs(90))
951/// .build()?;
952/// # Ok::<(), ai_lib::AiLibError>(())
953/// ```
954pub struct AiClientBuilder {
955 provider: Provider,
956 base_url: Option<String>,
957 proxy_url: Option<String>,
958 timeout: Option<std::time::Duration>,
959 pool_max_idle: Option<usize>,
960 pool_idle_timeout: Option<std::time::Duration>,
961 metrics: Option<Arc<dyn Metrics>>,
962 // Model configuration options
963 default_chat_model: Option<String>,
964 default_multimodal_model: Option<String>,
965 // Resilience configuration
966 resilience_config: ResilienceConfig,
967}
968
969impl AiClientBuilder {
970 /// Create a new builder instance
971 ///
972 /// # Arguments
973 /// * `provider` - The AI model provider to use
974 ///
975 /// # Returns
976 /// * `Self` - Builder instance
977 pub fn new(provider: Provider) -> Self {
978 Self {
979 provider,
980 base_url: None,
981 proxy_url: None,
982 timeout: None,
983 pool_max_idle: None,
984 pool_idle_timeout: None,
985 metrics: None,
986 default_chat_model: None,
987 default_multimodal_model: None,
988 resilience_config: ResilienceConfig::default(),
989 }
990 }
991
992 /// Check if provider is config-driven (uses GenericAdapter)
993 fn is_config_driven_provider(provider: Provider) -> bool {
994 provider.is_config_driven()
995 }
996
997 /// Set custom base URL
998 ///
999 /// # Arguments
1000 /// * `base_url` - Custom base URL
1001 ///
1002 /// # Returns
1003 /// * `Self` - Builder instance for method chaining
1004 pub fn with_base_url(mut self, base_url: &str) -> Self {
1005 self.base_url = Some(base_url.to_string());
1006 self
1007 }
1008
1009 /// Set custom proxy URL
1010 ///
1011 /// # Arguments
1012 /// * `proxy_url` - Custom proxy URL, or None to use AI_PROXY_URL environment variable
1013 ///
1014 /// # Returns
1015 /// * `Self` - Builder instance for method chaining
1016 ///
1017 /// # Examples
1018 /// ```rust
1019 /// use ai_lib::{AiClientBuilder, Provider};
1020 ///
1021 /// // Use specific proxy URL
1022 /// let client = AiClientBuilder::new(Provider::Groq)
1023 /// .with_proxy(Some("http://proxy.example.com:8080"))
1024 /// .build()?;
1025 ///
1026 /// // Use AI_PROXY_URL environment variable
1027 /// let client = AiClientBuilder::new(Provider::Groq)
1028 /// .with_proxy(None)
1029 /// .build()?;
1030 /// # Ok::<(), ai_lib::AiLibError>(())
1031 /// ```
1032 pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
1033 self.proxy_url = proxy_url.map(|s| s.to_string());
1034 self
1035 }
1036
1037 /// Explicitly disable proxy usage
1038 ///
1039 /// This method ensures that no proxy will be used, regardless of environment variables.
1040 ///
1041 /// # Returns
1042 /// * `Self` - Builder instance for method chaining
1043 ///
1044 /// # Example
1045 /// ```rust
1046 /// use ai_lib::{AiClientBuilder, Provider};
1047 ///
1048 /// let client = AiClientBuilder::new(Provider::Groq)
1049 /// .build()?;
1050 /// # Ok::<(), ai_lib::AiLibError>(())
1051 /// ```
1052 pub fn without_proxy(mut self) -> Self {
1053 self.proxy_url = Some("".to_string());
1054 self
1055 }
1056
1057 /// Set custom timeout duration
1058 ///
1059 /// # Arguments
1060 /// * `timeout` - Custom timeout duration
1061 ///
1062 /// # Returns
1063 /// * `Self` - Builder instance for method chaining
1064 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
1065 self.timeout = Some(timeout);
1066 self
1067 }
1068
1069 /// Set connection pool configuration
1070 ///
1071 /// # Arguments
1072 /// * `max_idle` - Maximum idle connections per host
1073 /// * `idle_timeout` - Idle connection timeout duration
1074 ///
1075 /// # Returns
1076 /// * `Self` - Builder instance for method chaining
1077 pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
1078 self.pool_max_idle = Some(max_idle);
1079 self.pool_idle_timeout = Some(idle_timeout);
1080 self
1081 }
1082
1083 /// Set custom metrics implementation
1084 ///
1085 /// # Arguments
1086 /// * `metrics` - Custom metrics implementation
1087 ///
1088 /// # Returns
1089 /// * `Self` - Builder instance for method chaining
1090 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
1091 self.metrics = Some(metrics);
1092 self
1093 }
1094
1095 /// Set default chat model for the client
1096 ///
1097 /// # Arguments
1098 /// * `model` - Default chat model name
1099 ///
1100 /// # Returns
1101 /// * `Self` - Builder instance for method chaining
1102 ///
1103 /// # Example
1104 /// ```rust
1105 /// use ai_lib::{AiClientBuilder, Provider};
1106 ///
1107 /// let client = AiClientBuilder::new(Provider::Groq)
1108 /// .with_default_chat_model("llama-3.1-8b-instant")
1109 /// .build()?;
1110 /// # Ok::<(), ai_lib::AiLibError>(())
1111 /// ```
1112 pub fn with_default_chat_model(mut self, model: &str) -> Self {
1113 self.default_chat_model = Some(model.to_string());
1114 self
1115 }
1116
1117 /// Set default multimodal model for the client
1118 ///
1119 /// # Arguments
1120 /// * `model` - Default multimodal model name
1121 ///
1122 /// # Returns
1123 /// * `Self` - Builder instance for method chaining
1124 ///
1125 /// # Example
1126 /// ```rust
1127 /// use ai_lib::{AiClientBuilder, Provider};
1128 ///
1129 /// let client = AiClientBuilder::new(Provider::Groq)
1130 /// .with_default_multimodal_model("llama-3.2-11b-vision")
1131 /// .build()?;
1132 /// # Ok::<(), ai_lib::AiLibError>(())
1133 /// ```
1134 pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
1135 self.default_multimodal_model = Some(model.to_string());
1136 self
1137 }
1138
1139 /// Enable smart defaults for resilience features
1140 ///
1141 /// This method enables reasonable default configurations for circuit breaker,
1142 /// rate limiting, and error handling without requiring detailed configuration.
1143 ///
1144 /// # Returns
1145 /// * `Self` - Builder instance for method chaining
1146 ///
1147 /// # Example
1148 /// ```rust
1149 /// use ai_lib::{AiClientBuilder, Provider};
1150 ///
1151 /// let client = AiClientBuilder::new(Provider::Groq)
1152 /// .with_smart_defaults()
1153 /// .build()?;
1154 /// # Ok::<(), ai_lib::AiLibError>(())
1155 /// ```
1156 pub fn with_smart_defaults(mut self) -> Self {
1157 self.resilience_config = ResilienceConfig::smart_defaults();
1158 self
1159 }
1160
1161 /// Configure for production environment
1162 ///
1163 /// This method applies production-ready configurations for all resilience
1164 /// features with conservative settings for maximum reliability.
1165 ///
1166 /// # Returns
1167 /// * `Self` - Builder instance for method chaining
1168 ///
1169 /// # Example
1170 /// ```rust
1171 /// use ai_lib::{AiClientBuilder, Provider};
1172 ///
1173 /// let client = AiClientBuilder::new(Provider::Groq)
1174 /// .for_production()
1175 /// .build()?;
1176 /// # Ok::<(), ai_lib::AiLibError>(())
1177 /// ```
1178 pub fn for_production(mut self) -> Self {
1179 self.resilience_config = ResilienceConfig::production();
1180 self
1181 }
1182
1183 /// Configure for development environment
1184 ///
1185 /// This method applies development-friendly configurations with more
1186 /// lenient settings for easier debugging and testing.
1187 ///
1188 /// # Returns
1189 /// * `Self` - Builder instance for method chaining
1190 ///
1191 /// # Example
1192 /// ```rust
1193 /// use ai_lib::{AiClientBuilder, Provider};
1194 ///
1195 /// let client = AiClientBuilder::new(Provider::Groq)
1196 /// .for_development()
1197 /// .build()?;
1198 /// # Ok::<(), ai_lib::AiLibError>(())
1199 /// ```
1200 pub fn for_development(mut self) -> Self {
1201 self.resilience_config = ResilienceConfig::development();
1202 self
1203 }
1204
1205 /// Set custom resilience configuration
1206 ///
1207 /// # Arguments
1208 /// * `config` - Custom resilience configuration
1209 ///
1210 /// # Returns
1211 /// * `Self` - Builder instance for method chaining
1212 pub fn with_resilience_config(mut self, config: ResilienceConfig) -> Self {
1213 self.resilience_config = config;
1214 self
1215 }
1216
1217 /// Build AiClient instance
1218 ///
1219 /// The build process applies configuration in the following priority order:
1220 /// 1. Explicitly set configuration (via with_* methods)
1221 /// 2. Environment variable configuration
1222 /// 3. Default configuration
1223 ///
1224 /// # Returns
1225 /// * `Result<AiClient, AiLibError>` - Returns client instance on success, error on failure
1226 pub fn build(self) -> Result<AiClient, AiLibError> {
1227 // 1. Determine base_url: explicit setting > environment variable > default
1228 let base_url = self.determine_base_url()?;
1229
1230 // 2. Determine proxy_url: explicit setting > environment variable
1231 let proxy_url = self.determine_proxy_url();
1232
1233 // 3. Determine timeout: explicit setting > default
1234 let timeout = self
1235 .timeout
1236 .unwrap_or_else(|| std::time::Duration::from_secs(30));
1237
1238 // 4. Create custom ProviderConfig (if needed)
1239 let config = self.create_custom_config(base_url)?;
1240
1241 // 5. Create custom HttpTransport (if needed)
1242 let transport = self.create_custom_transport(proxy_url.clone(), timeout)?;
1243
1244 // 6. Create adapter
1245 let adapter: Box<dyn ChatApi> = if Self::is_config_driven_provider(self.provider) {
1246 // All config-driven providers use the same logic - much cleaner!
1247 create_generic_adapter(config, transport)?
1248 } else {
1249 // Independent adapters - simple one-liners
1250 match self.provider {
1251 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
1252 Provider::Gemini => Box::new(GeminiAdapter::new()?),
1253 Provider::Mistral => Box::new(MistralAdapter::new()?),
1254 Provider::Cohere => Box::new(CohereAdapter::new()?),
1255 _ => unreachable!("All providers should be handled by now"),
1256 }
1257 };
1258
1259 // 7. Create AiClient
1260 let client = AiClient {
1261 provider: self.provider,
1262 adapter,
1263 metrics: self.metrics.unwrap_or_else(|| Arc::new(NoopMetrics::new())),
1264 connection_options: None,
1265 custom_default_chat_model: self.default_chat_model,
1266 custom_default_multimodal_model: self.default_multimodal_model,
1267 };
1268
1269 Ok(client)
1270 }
1271
1272 /// Determine base_url, priority: explicit setting > environment variable > default
1273 fn determine_base_url(&self) -> Result<String, AiLibError> {
1274 // 1. Explicitly set base_url
1275 if let Some(ref base_url) = self.base_url {
1276 return Ok(base_url.clone());
1277 }
1278
1279 // 2. base_url from environment variable
1280 let env_var_name = self.get_base_url_env_var_name();
1281 if let Ok(base_url) = std::env::var(&env_var_name) {
1282 return Ok(base_url);
1283 }
1284
1285 // 3. Use default configuration
1286 let default_config = self.get_default_provider_config()?;
1287 Ok(default_config.base_url)
1288 }
1289
1290 /// Determine proxy_url, priority: explicit setting > environment variable
1291 fn determine_proxy_url(&self) -> Option<String> {
1292 // 1. Explicitly set proxy_url
1293 if let Some(ref proxy_url) = self.proxy_url {
1294 // If proxy_url is empty string, it means explicitly no proxy
1295 if proxy_url.is_empty() {
1296 return None;
1297 }
1298 return Some(proxy_url.clone());
1299 }
1300
1301 // 2. AI_PROXY_URL from environment variable
1302 std::env::var("AI_PROXY_URL").ok()
1303 }
1304
1305 /// Get environment variable name for corresponding provider
1306 fn get_base_url_env_var_name(&self) -> String {
1307 match self.provider {
1308 Provider::Groq => "GROQ_BASE_URL".to_string(),
1309 Provider::XaiGrok => "GROK_BASE_URL".to_string(),
1310 Provider::Ollama => "OLLAMA_BASE_URL".to_string(),
1311 Provider::DeepSeek => "DEEPSEEK_BASE_URL".to_string(),
1312 Provider::Qwen => "DASHSCOPE_BASE_URL".to_string(),
1313 Provider::BaiduWenxin => "BAIDU_WENXIN_BASE_URL".to_string(),
1314 Provider::TencentHunyuan => "TENCENT_HUNYUAN_BASE_URL".to_string(),
1315 Provider::IflytekSpark => "IFLYTEK_BASE_URL".to_string(),
1316 Provider::Moonshot => "MOONSHOT_BASE_URL".to_string(),
1317 Provider::Anthropic => "ANTHROPIC_BASE_URL".to_string(),
1318 Provider::AzureOpenAI => "AZURE_OPENAI_BASE_URL".to_string(),
1319 Provider::HuggingFace => "HUGGINGFACE_BASE_URL".to_string(),
1320 Provider::TogetherAI => "TOGETHER_BASE_URL".to_string(),
1321 // These providers don't support custom base_url
1322 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
1323 "".to_string()
1324 }
1325 }
1326 }
1327
1328 /// Get default provider configuration
1329 fn get_default_provider_config(
1330 &self,
1331 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
1332 self.provider.get_default_config()
1333 }
1334
1335 /// Create custom ProviderConfig
1336 fn create_custom_config(
1337 &self,
1338 base_url: String,
1339 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
1340 let mut config = self.get_default_provider_config()?;
1341 config.base_url = base_url;
1342 Ok(config)
1343 }
1344
1345 /// Create custom HttpTransport
1346 fn create_custom_transport(
1347 &self,
1348 proxy_url: Option<String>,
1349 timeout: std::time::Duration,
1350 ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
1351 // If no custom configuration, return None (use default transport)
1352 if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
1353 return Ok(None);
1354 }
1355
1356 // Create custom HttpTransportConfig
1357 let transport_config = crate::transport::HttpTransportConfig {
1358 timeout,
1359 proxy: proxy_url,
1360 pool_max_idle_per_host: self.pool_max_idle,
1361 pool_idle_timeout: self.pool_idle_timeout,
1362 };
1363
1364 // Create custom HttpTransport
1365 let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
1366 Ok(Some(transport.boxed()))
1367 }
1368}
1369
1370/// Controllable streaming response
1371struct ControlledStream {
1372 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
1373 cancel_rx: Option<oneshot::Receiver<()>>,
1374}
1375
1376impl ControlledStream {
1377 fn new(
1378 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
1379 cancel_rx: oneshot::Receiver<()>,
1380 ) -> Self {
1381 Self {
1382 inner,
1383 cancel_rx: Some(cancel_rx),
1384 }
1385 }
1386}
1387
1388impl Stream for ControlledStream {
1389 type Item = Result<ChatCompletionChunk, AiLibError>;
1390
1391 fn poll_next(
1392 mut self: std::pin::Pin<&mut Self>,
1393 cx: &mut std::task::Context<'_>,
1394 ) -> std::task::Poll<Option<Self::Item>> {
1395 use futures::stream::StreamExt;
1396 use std::task::Poll;
1397
1398 // Check if cancelled
1399 if let Some(ref mut cancel_rx) = self.cancel_rx {
1400 match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
1401 Poll::Ready(_) => {
1402 self.cancel_rx = None;
1403 return Poll::Ready(Some(Err(AiLibError::ProviderError(
1404 "Stream cancelled".to_string(),
1405 ))));
1406 }
1407 Poll::Pending => {}
1408 }
1409 }
1410
1411 // Poll inner stream
1412 self.inner.poll_next_unpin(cx)
1413 }
1414}