ai_lib/
client.rs

1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::provider::{
4    CohereAdapter, GeminiAdapter, GenericAdapter, MistralAdapter, OpenAiAdapter, ProviderConfigs,
5};
6use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
7use futures::stream::Stream;
8use futures::Future;
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12/// Unified AI client module
13///
14/// AI model provider enumeration
15#[derive(Debug, Clone, Copy)]
16pub enum Provider {
17    // Config-driven providers
18    Groq,
19    XaiGrok,
20    Ollama,
21    DeepSeek,
22    Anthropic,
23    AzureOpenAI,
24    HuggingFace,
25    TogetherAI,
26    // Chinese providers (OpenAI-compatible or config-driven)
27    BaiduWenxin,
28    TencentHunyuan,
29    IflytekSpark,
30    Moonshot,
31    // Independent adapters
32    OpenAI,
33    Qwen,
34    Gemini,
35    Mistral,
36    Cohere,
37    // Bedrock removed (deferred)
38}
39
40/// Unified AI client
41///
42/// Usage example:
43/// ```rust
44/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
45///
46/// #[tokio::main]
47/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
48///     // Switch model provider by changing Provider value
49///     let client = AiClient::new(Provider::Groq)?;
50///     
51///     let request = ChatCompletionRequest::new(
52///         "test-model".to_string(),
53///         vec![Message {
54///             role: Role::User,
55///             content: ai_lib::types::common::Content::Text("Hello".to_string()),
56///             function_call: None,
57///         }],
58///     );
59///     
60///     // Note: Set GROQ_API_KEY environment variable for actual API calls
61///     // Optional: Set AI_PROXY_URL environment variable to use proxy server
62///     // let response = client.chat_completion(request).await?;
63///     
64///     println!("Client created successfully with provider: {:?}", client.current_provider());
65///     println!("Request prepared for model: {}", request.model);
66///     
67///     Ok(())
68/// }
69/// ```
70///
71/// # Proxy Configuration
72///
73/// Configure proxy server by setting the `AI_PROXY_URL` environment variable:
74///
75/// ```bash
76/// export AI_PROXY_URL=http://proxy.example.com:8080
77/// ```
78///
79/// Supported proxy formats:
80/// - HTTP proxy: `http://proxy.example.com:8080`
81/// - HTTPS proxy: `https://proxy.example.com:8080`  
82/// - With authentication: `http://user:pass@proxy.example.com:8080`
83pub struct AiClient {
84    provider: Provider,
85    adapter: Box<dyn ChatApi>,
86    metrics: Arc<dyn Metrics>,
87}
88
89impl AiClient {
90    /// Create a new AI client
91    ///
92    /// # Arguments
93    /// * `provider` - The AI model provider to use
94    ///
95    /// # Returns
96    /// * `Result<Self, AiLibError>` - Client instance on success, error on failure
97    ///
98    /// # Example
99    /// ```rust
100    /// use ai_lib::{AiClient, Provider};
101    ///
102    /// let client = AiClient::new(Provider::Groq)?;
103    /// # Ok::<(), ai_lib::AiLibError>(())
104    /// ```
105    pub fn new(provider: Provider) -> Result<Self, AiLibError> {
106        // Use the new builder to create client with automatic environment variable detection
107        AiClientBuilder::new(provider).build()
108    }
109
110    /// Create a new AI client builder
111    ///
112    /// The builder pattern allows more flexible client configuration:
113    /// - Automatic environment variable detection
114    /// - Support for custom base_url and proxy
115    /// - Support for custom timeout and connection pool configuration
116    ///
117    /// # Arguments
118    /// * `provider` - The AI model provider to use
119    ///
120    /// # Returns
121    /// * `AiClientBuilder` - Builder instance
122    ///
123    /// # Example
124    /// ```rust
125    /// use ai_lib::{AiClient, Provider};
126    ///
127    /// // Simplest usage - automatic environment variable detection
128    /// let client = AiClient::builder(Provider::Groq).build()?;
129    ///
130    /// // Custom base_url and proxy
131    /// let client = AiClient::builder(Provider::Groq)
132    ///     .with_base_url("https://custom.groq.com")
133    ///     .with_proxy(Some("http://proxy.example.com:8080"))
134    ///     .build()?;
135    /// # Ok::<(), ai_lib::AiLibError>(())
136    /// ```
137    pub fn builder(provider: Provider) -> AiClientBuilder {
138        AiClientBuilder::new(provider)
139    }
140
141    /// Create AiClient with injected metrics implementation
142    pub fn new_with_metrics(
143        provider: Provider,
144        metrics: Arc<dyn Metrics>,
145    ) -> Result<Self, AiLibError> {
146        let adapter: Box<dyn ChatApi> = match provider {
147            Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
148            Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
149            Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
150            Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
151            Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
152            Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
153            Provider::BaiduWenxin => {
154                Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
155            }
156            Provider::TencentHunyuan => {
157                Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
158            }
159            Provider::IflytekSpark => {
160                Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
161            }
162            Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
163            Provider::AzureOpenAI => {
164                Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
165            }
166            Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
167            Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
168            Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
169            Provider::Gemini => Box::new(GeminiAdapter::new()?),
170            Provider::Mistral => Box::new(MistralAdapter::new()?),
171            Provider::Cohere => Box::new(CohereAdapter::new()?),
172        };
173
174        Ok(Self {
175            provider,
176            adapter,
177            metrics,
178        })
179    }
180
181    /// Set metrics implementation on client
182    pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
183        self.metrics = metrics;
184        self
185    }
186
187    /// Send chat completion request
188    ///
189    /// # Arguments
190    /// * `request` - Chat completion request
191    ///
192    /// # Returns
193    /// * `Result<ChatCompletionResponse, AiLibError>` - Response on success, error on failure
194    pub async fn chat_completion(
195        &self,
196        request: ChatCompletionRequest,
197    ) -> Result<ChatCompletionResponse, AiLibError> {
198        self.adapter.chat_completion(request).await
199    }
200
201    /// Streaming chat completion request
202    ///
203    /// # Arguments
204    /// * `request` - Chat completion request
205    ///
206    /// # Returns
207    /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - Stream response on success
208    pub async fn chat_completion_stream(
209        &self,
210        mut request: ChatCompletionRequest,
211    ) -> Result<
212        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
213        AiLibError,
214    > {
215        request.stream = Some(true);
216        self.adapter.chat_completion_stream(request).await
217    }
218
219    /// Streaming chat completion request with cancel control
220    ///
221    /// # Arguments
222    /// * `request` - Chat completion request
223    ///
224    /// # Returns
225    /// * `Result<(impl Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin, CancelHandle), AiLibError>` - Returns streaming response and cancel handle on success
226    pub async fn chat_completion_stream_with_cancel(
227        &self,
228        mut request: ChatCompletionRequest,
229    ) -> Result<
230        (
231            Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
232            CancelHandle,
233        ),
234        AiLibError,
235    > {
236        request.stream = Some(true);
237        let stream = self.adapter.chat_completion_stream(request).await?;
238        let (cancel_tx, cancel_rx) = oneshot::channel();
239        let cancel_handle = CancelHandle {
240            sender: Some(cancel_tx),
241        };
242
243        let controlled_stream = ControlledStream::new(stream, cancel_rx);
244        Ok((Box::new(controlled_stream), cancel_handle))
245    }
246
247    /// Batch chat completion requests
248    ///
249    /// # Arguments
250    /// * `requests` - List of chat completion requests
251    /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
252    ///
253    /// # Returns
254    /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
255    ///
256    /// # Example
257    /// ```rust
258    /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
259    /// use ai_lib::types::common::Content;
260    ///
261    /// #[tokio::main]
262    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
263    ///     let client = AiClient::new(Provider::Groq)?;
264    ///     
265    ///     let requests = vec![
266    ///         ChatCompletionRequest::new(
267    ///             "llama3-8b-8192".to_string(),
268    ///             vec![Message {
269    ///                 role: Role::User,
270    ///                 content: Content::Text("Hello".to_string()),
271    ///                 function_call: None,
272    ///             }],
273    ///         ),
274    ///         ChatCompletionRequest::new(
275    ///             "llama3-8b-8192".to_string(),
276    ///             vec![Message {
277    ///                 role: Role::User,
278    ///                 content: Content::Text("How are you?".to_string()),
279    ///                 function_call: None,
280    ///             }],
281    ///         ),
282    ///     ];
283    ///     
284    ///     // Limit concurrency to 5
285    ///     let responses = client.chat_completion_batch(requests, Some(5)).await?;
286    ///     
287    ///     for (i, response) in responses.iter().enumerate() {
288    ///         match response {
289    ///             Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
290    ///             Err(e) => println!("Request {} failed: {}", i, e),
291    ///         }
292    ///     }
293    ///     
294    ///     Ok(())
295    /// }
296    /// ```
297    pub async fn chat_completion_batch(
298        &self,
299        requests: Vec<ChatCompletionRequest>,
300        concurrency_limit: Option<usize>,
301    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
302        self.adapter
303            .chat_completion_batch(requests, concurrency_limit)
304            .await
305    }
306
307    /// Smart batch processing: automatically choose processing strategy based on request count
308    ///
309    /// # Arguments
310    /// * `requests` - List of chat completion requests
311    ///
312    /// # Returns
313    /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
314    pub async fn chat_completion_batch_smart(
315        &self,
316        requests: Vec<ChatCompletionRequest>,
317    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
318        // Use sequential processing for small batches, concurrent processing for large batches
319        let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
320        self.chat_completion_batch(requests, concurrency_limit)
321            .await
322    }
323
324    /// Batch chat completion requests
325    ///
326    /// # Arguments
327    /// * `requests` - List of chat completion requests
328    /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
329    ///
330    /// # Returns
331    /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
332    ///
333    /// # Example
334    /// ```rust
335    /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
336    /// use ai_lib::types::common::Content;
337    ///
338    /// #[tokio::main]
339    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
340    ///     let client = AiClient::new(Provider::Groq)?;
341    ///     
342    ///     let requests = vec![
343    ///         ChatCompletionRequest::new(
344    ///             "llama3-8b-8192".to_string(),
345    ///             vec![Message {
346    ///                 role: Role::User,
347    ///                 content: Content::Text("Hello".to_string()),
348    ///                 function_call: None,
349    ///             }],
350    ///         ),
351    ///         ChatCompletionRequest::new(
352    ///             "llama3-8b-8192".to_string(),
353    ///             vec![Message {
354    ///                 role: Role::User,
355    ///                 content: Content::Text("How are you?".to_string()),
356    ///                 function_call: None,
357    ///             }],
358    ///         ),
359    ///     ];
360    ///     
361    ///     // Limit concurrency to 5
362    ///     let responses = client.chat_completion_batch(requests, Some(5)).await?;
363    ///     
364    ///     for (i, response) in responses.iter().enumerate() {
365    ///         match response {
366    ///             Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
367    ///             Err(e) => println!("Request {} failed: {}", i, e),
368    ///         }
369    ///     }
370    ///     
371    ///     Ok(())
372    /// }
373    /// ```
374    ///
375    /// Get list of supported models
376    ///
377    /// # Returns
378    /// * `Result<Vec<String>, AiLibError>` - Returns model list on success, error on failure
379    pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
380        self.adapter.list_models().await
381    }
382
383    /// Switch AI model provider
384    ///
385    /// # Arguments
386    /// * `provider` - New provider
387    ///
388    /// # Returns
389    /// * `Result<(), AiLibError>` - Returns () on success, error on failure
390    ///
391    /// # Example
392    /// ```rust
393    /// use ai_lib::{AiClient, Provider};
394    ///
395    /// let mut client = AiClient::new(Provider::Groq)?;
396    /// // Switch from Groq to Groq (demonstrating switch functionality)
397    /// client.switch_provider(Provider::Groq)?;
398    /// # Ok::<(), ai_lib::AiLibError>(())
399    /// ```
400    pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
401        let new_adapter: Box<dyn ChatApi> = match provider {
402            Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
403            Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
404            Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
405            Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
406            Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
407            Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
408            Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
409            Provider::BaiduWenxin => {
410                Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
411            }
412            Provider::TencentHunyuan => {
413                Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
414            }
415            Provider::IflytekSpark => {
416                Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
417            }
418            Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
419            Provider::Gemini => Box::new(GeminiAdapter::new()?),
420            Provider::AzureOpenAI => {
421                Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
422            }
423            Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
424            Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
425            Provider::Mistral => Box::new(MistralAdapter::new()?),
426            Provider::Cohere => Box::new(CohereAdapter::new()?),
427            // Provider::Bedrock => Box::new(BedrockAdapter::new()?),
428        };
429
430        self.provider = provider;
431        self.adapter = new_adapter;
432        Ok(())
433    }
434
435    /// Get current provider
436    pub fn current_provider(&self) -> Provider {
437        self.provider
438    }
439}
440
441/// Streaming response cancel handle
442pub struct CancelHandle {
443    sender: Option<oneshot::Sender<()>>,
444}
445
446impl CancelHandle {
447    /// Cancel streaming response
448    pub fn cancel(mut self) {
449        if let Some(sender) = self.sender.take() {
450            let _ = sender.send(());
451        }
452    }
453}
454
455/// AI client builder with progressive custom configuration
456///
457/// Usage examples:
458/// ```rust
459/// use ai_lib::{AiClientBuilder, Provider};
460///
461/// // Simplest usage - automatic environment variable detection
462/// let client = AiClientBuilder::new(Provider::Groq).build()?;
463///
464/// // Custom base_url and proxy
465/// let client = AiClientBuilder::new(Provider::Groq)
466///     .with_base_url("https://custom.groq.com")
467///     .with_proxy(Some("http://proxy.example.com:8080"))
468///     .build()?;
469///
470/// // Full custom configuration
471/// let client = AiClientBuilder::new(Provider::Groq)
472///     .with_base_url("https://custom.groq.com")
473///     .with_proxy(Some("http://proxy.example.com:8080"))
474///     .with_timeout(std::time::Duration::from_secs(60))
475///     .with_pool_config(32, std::time::Duration::from_secs(90))
476///     .build()?;
477/// # Ok::<(), ai_lib::AiLibError>(())
478/// ```
479pub struct AiClientBuilder {
480    provider: Provider,
481    base_url: Option<String>,
482    proxy_url: Option<String>,
483    timeout: Option<std::time::Duration>,
484    pool_max_idle: Option<usize>,
485    pool_idle_timeout: Option<std::time::Duration>,
486    metrics: Option<Arc<dyn Metrics>>,
487}
488
489impl AiClientBuilder {
490    /// Create a new builder instance
491    ///
492    /// # Arguments
493    /// * `provider` - The AI model provider to use
494    ///
495    /// # Returns
496    /// * `Self` - Builder instance
497    pub fn new(provider: Provider) -> Self {
498        Self {
499            provider,
500            base_url: None,
501            proxy_url: None,
502            timeout: None,
503            pool_max_idle: None,
504            pool_idle_timeout: None,
505            metrics: None,
506        }
507    }
508
509    /// Set custom base URL
510    ///
511    /// # Arguments
512    /// * `base_url` - Custom base URL
513    ///
514    /// # Returns
515    /// * `Self` - Builder instance for method chaining
516    pub fn with_base_url(mut self, base_url: &str) -> Self {
517        self.base_url = Some(base_url.to_string());
518        self
519    }
520
521    /// Set custom proxy URL
522    ///
523    /// # Arguments
524    /// * `proxy_url` - Custom proxy URL, or None to use AI_PROXY_URL environment variable
525    ///
526    /// # Returns
527    /// * `Self` - Builder instance for method chaining
528    ///
529    /// # Examples
530    /// ```rust
531    /// use ai_lib::{AiClientBuilder, Provider};
532    ///
533    /// // Use specific proxy URL
534    /// let client = AiClientBuilder::new(Provider::Groq)
535    ///     .with_proxy(Some("http://proxy.example.com:8080"))
536    ///     .build()?;
537    ///
538    /// // Use AI_PROXY_URL environment variable
539    /// let client = AiClientBuilder::new(Provider::Groq)
540    ///     .with_proxy(None)
541    ///     .build()?;
542    /// # Ok::<(), ai_lib::AiLibError>(())
543    /// ```
544    pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
545        self.proxy_url = proxy_url.map(|s| s.to_string());
546        self
547    }
548
549    /// Explicitly disable proxy usage
550    ///
551    /// This method ensures that no proxy will be used, regardless of environment variables.
552    ///
553    /// # Returns
554    /// * `Self` - Builder instance for method chaining
555    ///
556    /// # Example
557    /// ```rust
558    /// use ai_lib::{AiClientBuilder, Provider};
559    ///
560    /// let client = AiClientBuilder::new(Provider::Groq)
561    ///     .build()?;
562    /// # Ok::<(), ai_lib::AiLibError>(())
563    /// ```
564    pub fn without_proxy(mut self) -> Self {
565        self.proxy_url = Some("".to_string());
566        self
567    }
568
569    /// Set custom timeout duration
570    ///
571    /// # Arguments
572    /// * `timeout` - Custom timeout duration
573    ///
574    /// # Returns
575    /// * `Self` - Builder instance for method chaining
576    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
577        self.timeout = Some(timeout);
578        self
579    }
580
581    /// Set connection pool configuration
582    ///
583    /// # Arguments
584    /// * `max_idle` - Maximum idle connections per host
585    /// * `idle_timeout` - Idle connection timeout duration
586    ///
587    /// # Returns
588    /// * `Self` - Builder instance for method chaining
589    pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
590        self.pool_max_idle = Some(max_idle);
591        self.pool_idle_timeout = Some(idle_timeout);
592        self
593    }
594
595    /// Set custom metrics implementation
596    ///
597    /// # Arguments
598    /// * `metrics` - Custom metrics implementation
599    ///
600    /// # Returns
601    /// * `Self` - Builder instance for method chaining
602    pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
603        self.metrics = Some(metrics);
604        self
605    }
606
607    /// Build AiClient instance
608    ///
609    /// The build process applies configuration in the following priority order:
610    /// 1. Explicitly set configuration (via with_* methods)
611    /// 2. Environment variable configuration
612    /// 3. Default configuration
613    ///
614    /// # Returns
615    /// * `Result<AiClient, AiLibError>` - Returns client instance on success, error on failure
616    pub fn build(self) -> Result<AiClient, AiLibError> {
617        // 1. Determine base_url: explicit setting > environment variable > default
618        let base_url = self.determine_base_url()?;
619
620        // 2. Determine proxy_url: explicit setting > environment variable
621        let proxy_url = self.determine_proxy_url();
622
623        // 3. Determine timeout: explicit setting > default
624        let timeout = self
625            .timeout
626            .unwrap_or_else(|| std::time::Duration::from_secs(30));
627
628        // 4. Create custom ProviderConfig (if needed)
629        let config = self.create_custom_config(base_url)?;
630
631        // 5. Create custom HttpTransport (if needed)
632        let transport = self.create_custom_transport(proxy_url.clone(), timeout)?;
633        
634
635
636        // 6. Create adapter
637        let adapter: Box<dyn ChatApi> = match self.provider {
638            // Use config-driven generic adapter
639            Provider::Groq => {
640                if let Some(custom_transport) = transport {
641                    Box::new(GenericAdapter::with_transport_ref(
642                        config,
643                        custom_transport,
644                    )?)
645                } else {
646                    Box::new(GenericAdapter::new(config)?)
647                }
648            }
649            Provider::XaiGrok => {
650                if let Some(custom_transport) = transport {
651                    Box::new(GenericAdapter::with_transport_ref(
652                        config,
653                        custom_transport,
654                    )?)
655                } else {
656                    Box::new(GenericAdapter::new(config)?)
657                }
658            }
659            Provider::Ollama => {
660                if let Some(custom_transport) = transport {
661                    Box::new(GenericAdapter::with_transport_ref(
662                        config,
663                        custom_transport,
664                    )?)
665                } else {
666                    Box::new(GenericAdapter::new(config)?)
667                }
668            }
669            Provider::DeepSeek => {
670                if let Some(custom_transport) = transport {
671                    Box::new(GenericAdapter::with_transport_ref(
672                        config,
673                        custom_transport,
674                    )?)
675                } else {
676                    Box::new(GenericAdapter::new(config)?)
677                }
678            }
679            Provider::Qwen => {
680                if let Some(custom_transport) = transport {
681                    Box::new(GenericAdapter::with_transport_ref(
682                        config,
683                        custom_transport,
684                    )?)
685                } else {
686                    Box::new(GenericAdapter::new(config)?)
687                }
688            }
689            Provider::BaiduWenxin => {
690                if let Some(custom_transport) = transport {
691                    Box::new(GenericAdapter::with_transport_ref(
692                        config,
693                        custom_transport,
694                    )?)
695                } else {
696                    Box::new(GenericAdapter::new(config)?)
697                }
698            }
699            Provider::TencentHunyuan => {
700                if let Some(custom_transport) = transport {
701                    Box::new(GenericAdapter::with_transport_ref(
702                        config,
703                        custom_transport,
704                    )?)
705                } else {
706                    Box::new(GenericAdapter::new(config)?)
707                }
708            }
709            Provider::IflytekSpark => {
710                if let Some(custom_transport) = transport {
711                    Box::new(GenericAdapter::with_transport_ref(
712                        config,
713                        custom_transport,
714                    )?)
715                } else {
716                    Box::new(GenericAdapter::new(config)?)
717                }
718            }
719            Provider::Moonshot => {
720                if let Some(custom_transport) = transport {
721                    Box::new(GenericAdapter::with_transport_ref(
722                        config,
723                        custom_transport,
724                    )?)
725                } else {
726                    Box::new(GenericAdapter::new(config)?)
727                }
728            }
729            Provider::Anthropic => {
730                if let Some(custom_transport) = transport {
731                    Box::new(GenericAdapter::with_transport_ref(
732                        config,
733                        custom_transport,
734                    )?)
735                } else {
736                    Box::new(GenericAdapter::new(config)?)
737                }
738            }
739            Provider::AzureOpenAI => {
740                if let Some(custom_transport) = transport {
741                    Box::new(GenericAdapter::with_transport_ref(
742                        config,
743                        custom_transport,
744                    )?)
745                } else {
746                    Box::new(GenericAdapter::new(config)?)
747                }
748            }
749            Provider::HuggingFace => {
750                if let Some(custom_transport) = transport {
751                    Box::new(GenericAdapter::with_transport_ref(
752                        config,
753                        custom_transport,
754                    )?)
755                } else {
756                    Box::new(GenericAdapter::new(config)?)
757                }
758            }
759            Provider::TogetherAI => {
760                if let Some(custom_transport) = transport {
761                    Box::new(GenericAdapter::with_transport_ref(
762                        config,
763                        custom_transport,
764                    )?)
765                } else {
766                    Box::new(GenericAdapter::new(config)?)
767                }
768            }
769            // Use independent adapters (these don't support custom configuration)
770            Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
771            Provider::Gemini => Box::new(GeminiAdapter::new()?),
772            Provider::Mistral => Box::new(MistralAdapter::new()?),
773            Provider::Cohere => Box::new(CohereAdapter::new()?),
774        };
775
776        // 7. Create AiClient
777        let client = AiClient {
778            provider: self.provider,
779            adapter,
780            metrics: self.metrics.unwrap_or_else(|| Arc::new(NoopMetrics::new())),
781        };
782
783        Ok(client)
784    }
785
786    /// Determine base_url, priority: explicit setting > environment variable > default
787    fn determine_base_url(&self) -> Result<String, AiLibError> {
788        // 1. Explicitly set base_url
789        if let Some(ref base_url) = self.base_url {
790            return Ok(base_url.clone());
791        }
792
793        // 2. base_url from environment variable
794        let env_var_name = self.get_base_url_env_var_name();
795        if let Ok(base_url) = std::env::var(&env_var_name) {
796            return Ok(base_url);
797        }
798
799        // 3. Use default configuration
800        let default_config = self.get_default_provider_config()?;
801        Ok(default_config.base_url)
802    }
803
804    /// Determine proxy_url, priority: explicit setting > environment variable
805    fn determine_proxy_url(&self) -> Option<String> {
806        // 1. Explicitly set proxy_url
807        if let Some(ref proxy_url) = self.proxy_url {
808            // If proxy_url is empty string, it means explicitly no proxy
809            if proxy_url.is_empty() {
810                return None;
811            }
812            return Some(proxy_url.clone());
813        }
814
815        // 2. AI_PROXY_URL from environment variable
816        std::env::var("AI_PROXY_URL").ok()
817    }
818
819    /// Get environment variable name for corresponding provider
820    fn get_base_url_env_var_name(&self) -> String {
821        match self.provider {
822            Provider::Groq => "GROQ_BASE_URL".to_string(),
823            Provider::XaiGrok => "GROK_BASE_URL".to_string(),
824            Provider::Ollama => "OLLAMA_BASE_URL".to_string(),
825            Provider::DeepSeek => "DEEPSEEK_BASE_URL".to_string(),
826            Provider::Qwen => "DASHSCOPE_BASE_URL".to_string(),
827            Provider::BaiduWenxin => "BAIDU_WENXIN_BASE_URL".to_string(),
828            Provider::TencentHunyuan => "TENCENT_HUNYUAN_BASE_URL".to_string(),
829            Provider::IflytekSpark => "IFLYTEK_BASE_URL".to_string(),
830            Provider::Moonshot => "MOONSHOT_BASE_URL".to_string(),
831            Provider::Anthropic => "ANTHROPIC_BASE_URL".to_string(),
832            Provider::AzureOpenAI => "AZURE_OPENAI_BASE_URL".to_string(),
833            Provider::HuggingFace => "HUGGINGFACE_BASE_URL".to_string(),
834            Provider::TogetherAI => "TOGETHER_BASE_URL".to_string(),
835            // These providers don't support custom base_url
836            Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
837                "".to_string()
838            }
839        }
840    }
841
842    /// Get default provider configuration
843    fn get_default_provider_config(
844        &self,
845    ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
846        match self.provider {
847            Provider::Groq => Ok(ProviderConfigs::groq()),
848            Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
849            Provider::Ollama => Ok(ProviderConfigs::ollama()),
850            Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
851            Provider::Qwen => Ok(ProviderConfigs::qwen()),
852            Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
853            Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
854            Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
855            Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
856            Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
857            Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
858            Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
859            Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
860            // These providers don't support custom configuration
861            Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
862                Err(AiLibError::ConfigurationError(
863                    "This provider does not support custom configuration".to_string(),
864                ))
865            }
866        }
867    }
868
869    /// Create custom ProviderConfig
870    fn create_custom_config(
871        &self,
872        base_url: String,
873    ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
874        let mut config = self.get_default_provider_config()?;
875        config.base_url = base_url;
876        Ok(config)
877    }
878
879    /// Create custom HttpTransport
880    fn create_custom_transport(
881        &self,
882        proxy_url: Option<String>,
883        timeout: std::time::Duration,
884    ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
885        // If no custom configuration, return None (use default transport)
886        if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
887            return Ok(None);
888        }
889
890        // Create custom HttpTransportConfig
891        let transport_config = crate::transport::HttpTransportConfig {
892            timeout,
893            proxy: proxy_url,
894            pool_max_idle_per_host: self.pool_max_idle,
895            pool_idle_timeout: self.pool_idle_timeout,
896        };
897
898        // Create custom HttpTransport
899        let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
900        Ok(Some(transport.boxed()))
901    }
902}
903
904/// Controllable streaming response
905struct ControlledStream {
906    inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
907    cancel_rx: Option<oneshot::Receiver<()>>,
908}
909
910impl ControlledStream {
911    fn new(
912        inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
913        cancel_rx: oneshot::Receiver<()>,
914    ) -> Self {
915        Self {
916            inner,
917            cancel_rx: Some(cancel_rx),
918        }
919    }
920}
921
922impl Stream for ControlledStream {
923    type Item = Result<ChatCompletionChunk, AiLibError>;
924
925    fn poll_next(
926        mut self: std::pin::Pin<&mut Self>,
927        cx: &mut std::task::Context<'_>,
928    ) -> std::task::Poll<Option<Self::Item>> {
929        use futures::stream::StreamExt;
930        use std::task::Poll;
931
932        // Check if cancelled
933        if let Some(ref mut cancel_rx) = self.cancel_rx {
934            match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
935                Poll::Ready(_) => {
936                    self.cancel_rx = None;
937                    return Poll::Ready(Some(Err(AiLibError::ProviderError(
938                        "Stream cancelled".to_string(),
939                    ))));
940                }
941                Poll::Pending => {}
942            }
943        }
944
945        // Poll inner stream
946        self.inner.poll_next_unpin(cx)
947    }
948}