ai_lib/client.rs
1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::provider::{
4 CohereAdapter, GeminiAdapter, GenericAdapter, MistralAdapter, OpenAiAdapter, ProviderConfigs,
5};
6use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
7use futures::stream::Stream;
8use futures::Future;
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12/// Unified AI client module
13///
14/// AI model provider enumeration
15#[derive(Debug, Clone, Copy)]
16pub enum Provider {
17 // Config-driven providers
18 Groq,
19 XaiGrok,
20 Ollama,
21 DeepSeek,
22 Anthropic,
23 AzureOpenAI,
24 HuggingFace,
25 TogetherAI,
26 // Chinese providers (OpenAI-compatible or config-driven)
27 BaiduWenxin,
28 TencentHunyuan,
29 IflytekSpark,
30 Moonshot,
31 // Independent adapters
32 OpenAI,
33 Qwen,
34 Gemini,
35 Mistral,
36 Cohere,
37 // Bedrock removed (deferred)
38}
39
40/// Unified AI client
41///
42/// Usage example:
43/// ```rust
44/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
45///
46/// #[tokio::main]
47/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
48/// // Switch model provider by changing Provider value
49/// let client = AiClient::new(Provider::Groq)?;
50///
51/// let request = ChatCompletionRequest::new(
52/// "test-model".to_string(),
53/// vec![Message {
54/// role: Role::User,
55/// content: ai_lib::types::common::Content::Text("Hello".to_string()),
56/// function_call: None,
57/// }],
58/// );
59///
60/// // Note: Set GROQ_API_KEY environment variable for actual API calls
61/// // Optional: Set AI_PROXY_URL environment variable to use proxy server
62/// // let response = client.chat_completion(request).await?;
63///
64/// println!("Client created successfully with provider: {:?}", client.current_provider());
65/// println!("Request prepared for model: {}", request.model);
66///
67/// Ok(())
68/// }
69/// ```
70///
71/// # Proxy Configuration
72///
73/// Configure proxy server by setting the `AI_PROXY_URL` environment variable:
74///
75/// ```bash
76/// export AI_PROXY_URL=http://proxy.example.com:8080
77/// ```
78///
79/// Supported proxy formats:
80/// - HTTP proxy: `http://proxy.example.com:8080`
81/// - HTTPS proxy: `https://proxy.example.com:8080`
82/// - With authentication: `http://user:pass@proxy.example.com:8080`
83pub struct AiClient {
84 provider: Provider,
85 adapter: Box<dyn ChatApi>,
86 metrics: Arc<dyn Metrics>,
87}
88
89impl AiClient {
90 /// Create a new AI client
91 ///
92 /// # Arguments
93 /// * `provider` - The AI model provider to use
94 ///
95 /// # Returns
96 /// * `Result<Self, AiLibError>` - Client instance on success, error on failure
97 ///
98 /// # Example
99 /// ```rust
100 /// use ai_lib::{AiClient, Provider};
101 ///
102 /// let client = AiClient::new(Provider::Groq)?;
103 /// # Ok::<(), ai_lib::AiLibError>(())
104 /// ```
105 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
106 // Use the new builder to create client with automatic environment variable detection
107 AiClientBuilder::new(provider).build()
108 }
109
110 /// Create a new AI client builder
111 ///
112 /// The builder pattern allows more flexible client configuration:
113 /// - Automatic environment variable detection
114 /// - Support for custom base_url and proxy
115 /// - Support for custom timeout and connection pool configuration
116 ///
117 /// # Arguments
118 /// * `provider` - The AI model provider to use
119 ///
120 /// # Returns
121 /// * `AiClientBuilder` - Builder instance
122 ///
123 /// # Example
124 /// ```rust
125 /// use ai_lib::{AiClient, Provider};
126 ///
127 /// // Simplest usage - automatic environment variable detection
128 /// let client = AiClient::builder(Provider::Groq).build()?;
129 ///
130 /// // Custom base_url and proxy
131 /// let client = AiClient::builder(Provider::Groq)
132 /// .with_base_url("https://custom.groq.com")
133 /// .with_proxy(Some("http://proxy.example.com:8080"))
134 /// .build()?;
135 /// # Ok::<(), ai_lib::AiLibError>(())
136 /// ```
137 pub fn builder(provider: Provider) -> AiClientBuilder {
138 AiClientBuilder::new(provider)
139 }
140
141 /// Create AiClient with injected metrics implementation
142 pub fn new_with_metrics(
143 provider: Provider,
144 metrics: Arc<dyn Metrics>,
145 ) -> Result<Self, AiLibError> {
146 let adapter: Box<dyn ChatApi> = match provider {
147 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
148 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
149 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
150 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
151 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
152 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
153 Provider::BaiduWenxin => {
154 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
155 }
156 Provider::TencentHunyuan => {
157 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
158 }
159 Provider::IflytekSpark => {
160 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
161 }
162 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
163 Provider::AzureOpenAI => {
164 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
165 }
166 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
167 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
168 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
169 Provider::Gemini => Box::new(GeminiAdapter::new()?),
170 Provider::Mistral => Box::new(MistralAdapter::new()?),
171 Provider::Cohere => Box::new(CohereAdapter::new()?),
172 };
173
174 Ok(Self {
175 provider,
176 adapter,
177 metrics,
178 })
179 }
180
181 /// Set metrics implementation on client
182 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
183 self.metrics = metrics;
184 self
185 }
186
187 /// Send chat completion request
188 ///
189 /// # Arguments
190 /// * `request` - Chat completion request
191 ///
192 /// # Returns
193 /// * `Result<ChatCompletionResponse, AiLibError>` - Response on success, error on failure
194 pub async fn chat_completion(
195 &self,
196 request: ChatCompletionRequest,
197 ) -> Result<ChatCompletionResponse, AiLibError> {
198 self.adapter.chat_completion(request).await
199 }
200
201 /// Streaming chat completion request
202 ///
203 /// # Arguments
204 /// * `request` - Chat completion request
205 ///
206 /// # Returns
207 /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - Stream response on success
208 pub async fn chat_completion_stream(
209 &self,
210 mut request: ChatCompletionRequest,
211 ) -> Result<
212 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
213 AiLibError,
214 > {
215 request.stream = Some(true);
216 self.adapter.chat_completion_stream(request).await
217 }
218
219 /// Streaming chat completion request with cancel control
220 ///
221 /// # Arguments
222 /// * `request` - Chat completion request
223 ///
224 /// # Returns
225 /// * `Result<(impl Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin, CancelHandle), AiLibError>` - Returns streaming response and cancel handle on success
226 pub async fn chat_completion_stream_with_cancel(
227 &self,
228 mut request: ChatCompletionRequest,
229 ) -> Result<
230 (
231 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
232 CancelHandle,
233 ),
234 AiLibError,
235 > {
236 request.stream = Some(true);
237 let stream = self.adapter.chat_completion_stream(request).await?;
238 let (cancel_tx, cancel_rx) = oneshot::channel();
239 let cancel_handle = CancelHandle {
240 sender: Some(cancel_tx),
241 };
242
243 let controlled_stream = ControlledStream::new(stream, cancel_rx);
244 Ok((Box::new(controlled_stream), cancel_handle))
245 }
246
247 /// Batch chat completion requests
248 ///
249 /// # Arguments
250 /// * `requests` - List of chat completion requests
251 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
252 ///
253 /// # Returns
254 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
255 ///
256 /// # Example
257 /// ```rust
258 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
259 /// use ai_lib::types::common::Content;
260 ///
261 /// #[tokio::main]
262 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
263 /// let client = AiClient::new(Provider::Groq)?;
264 ///
265 /// let requests = vec![
266 /// ChatCompletionRequest::new(
267 /// "llama3-8b-8192".to_string(),
268 /// vec![Message {
269 /// role: Role::User,
270 /// content: Content::Text("Hello".to_string()),
271 /// function_call: None,
272 /// }],
273 /// ),
274 /// ChatCompletionRequest::new(
275 /// "llama3-8b-8192".to_string(),
276 /// vec![Message {
277 /// role: Role::User,
278 /// content: Content::Text("How are you?".to_string()),
279 /// function_call: None,
280 /// }],
281 /// ),
282 /// ];
283 ///
284 /// // Limit concurrency to 5
285 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
286 ///
287 /// for (i, response) in responses.iter().enumerate() {
288 /// match response {
289 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
290 /// Err(e) => println!("Request {} failed: {}", i, e),
291 /// }
292 /// }
293 ///
294 /// Ok(())
295 /// }
296 /// ```
297 pub async fn chat_completion_batch(
298 &self,
299 requests: Vec<ChatCompletionRequest>,
300 concurrency_limit: Option<usize>,
301 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
302 self.adapter
303 .chat_completion_batch(requests, concurrency_limit)
304 .await
305 }
306
307 /// Smart batch processing: automatically choose processing strategy based on request count
308 ///
309 /// # Arguments
310 /// * `requests` - List of chat completion requests
311 ///
312 /// # Returns
313 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
314 pub async fn chat_completion_batch_smart(
315 &self,
316 requests: Vec<ChatCompletionRequest>,
317 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
318 // Use sequential processing for small batches, concurrent processing for large batches
319 let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
320 self.chat_completion_batch(requests, concurrency_limit)
321 .await
322 }
323
324 /// Batch chat completion requests
325 ///
326 /// # Arguments
327 /// * `requests` - List of chat completion requests
328 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
329 ///
330 /// # Returns
331 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
332 ///
333 /// # Example
334 /// ```rust
335 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
336 /// use ai_lib::types::common::Content;
337 ///
338 /// #[tokio::main]
339 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
340 /// let client = AiClient::new(Provider::Groq)?;
341 ///
342 /// let requests = vec![
343 /// ChatCompletionRequest::new(
344 /// "llama3-8b-8192".to_string(),
345 /// vec![Message {
346 /// role: Role::User,
347 /// content: Content::Text("Hello".to_string()),
348 /// function_call: None,
349 /// }],
350 /// ),
351 /// ChatCompletionRequest::new(
352 /// "llama3-8b-8192".to_string(),
353 /// vec![Message {
354 /// role: Role::User,
355 /// content: Content::Text("How are you?".to_string()),
356 /// function_call: None,
357 /// }],
358 /// ),
359 /// ];
360 ///
361 /// // Limit concurrency to 5
362 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
363 ///
364 /// for (i, response) in responses.iter().enumerate() {
365 /// match response {
366 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
367 /// Err(e) => println!("Request {} failed: {}", i, e),
368 /// }
369 /// }
370 ///
371 /// Ok(())
372 /// }
373 /// ```
374 ///
375 /// Get list of supported models
376 ///
377 /// # Returns
378 /// * `Result<Vec<String>, AiLibError>` - Returns model list on success, error on failure
379 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
380 self.adapter.list_models().await
381 }
382
383 /// Switch AI model provider
384 ///
385 /// # Arguments
386 /// * `provider` - New provider
387 ///
388 /// # Returns
389 /// * `Result<(), AiLibError>` - Returns () on success, error on failure
390 ///
391 /// # Example
392 /// ```rust
393 /// use ai_lib::{AiClient, Provider};
394 ///
395 /// let mut client = AiClient::new(Provider::Groq)?;
396 /// // Switch from Groq to Groq (demonstrating switch functionality)
397 /// client.switch_provider(Provider::Groq)?;
398 /// # Ok::<(), ai_lib::AiLibError>(())
399 /// ```
400 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
401 let new_adapter: Box<dyn ChatApi> = match provider {
402 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
403 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
404 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
405 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
406 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
407 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
408 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
409 Provider::BaiduWenxin => {
410 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
411 }
412 Provider::TencentHunyuan => {
413 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
414 }
415 Provider::IflytekSpark => {
416 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
417 }
418 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
419 Provider::Gemini => Box::new(GeminiAdapter::new()?),
420 Provider::AzureOpenAI => {
421 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
422 }
423 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
424 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
425 Provider::Mistral => Box::new(MistralAdapter::new()?),
426 Provider::Cohere => Box::new(CohereAdapter::new()?),
427 // Provider::Bedrock => Box::new(BedrockAdapter::new()?),
428 };
429
430 self.provider = provider;
431 self.adapter = new_adapter;
432 Ok(())
433 }
434
435 /// Get current provider
436 pub fn current_provider(&self) -> Provider {
437 self.provider
438 }
439}
440
441/// Streaming response cancel handle
442pub struct CancelHandle {
443 sender: Option<oneshot::Sender<()>>,
444}
445
446impl CancelHandle {
447 /// Cancel streaming response
448 pub fn cancel(mut self) {
449 if let Some(sender) = self.sender.take() {
450 let _ = sender.send(());
451 }
452 }
453}
454
455/// AI client builder with progressive custom configuration
456///
457/// Usage examples:
458/// ```rust
459/// use ai_lib::{AiClientBuilder, Provider};
460///
461/// // Simplest usage - automatic environment variable detection
462/// let client = AiClientBuilder::new(Provider::Groq).build()?;
463///
464/// // Custom base_url and proxy
465/// let client = AiClientBuilder::new(Provider::Groq)
466/// .with_base_url("https://custom.groq.com")
467/// .with_proxy(Some("http://proxy.example.com:8080"))
468/// .build()?;
469///
470/// // Full custom configuration
471/// let client = AiClientBuilder::new(Provider::Groq)
472/// .with_base_url("https://custom.groq.com")
473/// .with_proxy(Some("http://proxy.example.com:8080"))
474/// .with_timeout(std::time::Duration::from_secs(60))
475/// .with_pool_config(32, std::time::Duration::from_secs(90))
476/// .build()?;
477/// # Ok::<(), ai_lib::AiLibError>(())
478/// ```
479pub struct AiClientBuilder {
480 provider: Provider,
481 base_url: Option<String>,
482 proxy_url: Option<String>,
483 timeout: Option<std::time::Duration>,
484 pool_max_idle: Option<usize>,
485 pool_idle_timeout: Option<std::time::Duration>,
486 metrics: Option<Arc<dyn Metrics>>,
487}
488
489impl AiClientBuilder {
490 /// Create a new builder instance
491 ///
492 /// # Arguments
493 /// * `provider` - The AI model provider to use
494 ///
495 /// # Returns
496 /// * `Self` - Builder instance
497 pub fn new(provider: Provider) -> Self {
498 Self {
499 provider,
500 base_url: None,
501 proxy_url: None,
502 timeout: None,
503 pool_max_idle: None,
504 pool_idle_timeout: None,
505 metrics: None,
506 }
507 }
508
509 /// Set custom base URL
510 ///
511 /// # Arguments
512 /// * `base_url` - Custom base URL
513 ///
514 /// # Returns
515 /// * `Self` - Builder instance for method chaining
516 pub fn with_base_url(mut self, base_url: &str) -> Self {
517 self.base_url = Some(base_url.to_string());
518 self
519 }
520
521 /// Set custom proxy URL
522 ///
523 /// # Arguments
524 /// * `proxy_url` - Custom proxy URL, or None to use AI_PROXY_URL environment variable
525 ///
526 /// # Returns
527 /// * `Self` - Builder instance for method chaining
528 ///
529 /// # Examples
530 /// ```rust
531 /// use ai_lib::{AiClientBuilder, Provider};
532 ///
533 /// // Use specific proxy URL
534 /// let client = AiClientBuilder::new(Provider::Groq)
535 /// .with_proxy(Some("http://proxy.example.com:8080"))
536 /// .build()?;
537 ///
538 /// // Use AI_PROXY_URL environment variable
539 /// let client = AiClientBuilder::new(Provider::Groq)
540 /// .with_proxy(None)
541 /// .build()?;
542 /// # Ok::<(), ai_lib::AiLibError>(())
543 /// ```
544 pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
545 self.proxy_url = proxy_url.map(|s| s.to_string());
546 self
547 }
548
549 /// Explicitly disable proxy usage
550 ///
551 /// This method ensures that no proxy will be used, regardless of environment variables.
552 ///
553 /// # Returns
554 /// * `Self` - Builder instance for method chaining
555 ///
556 /// # Example
557 /// ```rust
558 /// use ai_lib::{AiClientBuilder, Provider};
559 ///
560 /// let client = AiClientBuilder::new(Provider::Groq)
561 /// .build()?;
562 /// # Ok::<(), ai_lib::AiLibError>(())
563 /// ```
564 pub fn without_proxy(mut self) -> Self {
565 self.proxy_url = Some("".to_string());
566 self
567 }
568
569 /// Set custom timeout duration
570 ///
571 /// # Arguments
572 /// * `timeout` - Custom timeout duration
573 ///
574 /// # Returns
575 /// * `Self` - Builder instance for method chaining
576 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
577 self.timeout = Some(timeout);
578 self
579 }
580
581 /// Set connection pool configuration
582 ///
583 /// # Arguments
584 /// * `max_idle` - Maximum idle connections per host
585 /// * `idle_timeout` - Idle connection timeout duration
586 ///
587 /// # Returns
588 /// * `Self` - Builder instance for method chaining
589 pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
590 self.pool_max_idle = Some(max_idle);
591 self.pool_idle_timeout = Some(idle_timeout);
592 self
593 }
594
595 /// Set custom metrics implementation
596 ///
597 /// # Arguments
598 /// * `metrics` - Custom metrics implementation
599 ///
600 /// # Returns
601 /// * `Self` - Builder instance for method chaining
602 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
603 self.metrics = Some(metrics);
604 self
605 }
606
607 /// Build AiClient instance
608 ///
609 /// The build process applies configuration in the following priority order:
610 /// 1. Explicitly set configuration (via with_* methods)
611 /// 2. Environment variable configuration
612 /// 3. Default configuration
613 ///
614 /// # Returns
615 /// * `Result<AiClient, AiLibError>` - Returns client instance on success, error on failure
616 pub fn build(self) -> Result<AiClient, AiLibError> {
617 // 1. Determine base_url: explicit setting > environment variable > default
618 let base_url = self.determine_base_url()?;
619
620 // 2. Determine proxy_url: explicit setting > environment variable
621 let proxy_url = self.determine_proxy_url();
622
623 // 3. Determine timeout: explicit setting > default
624 let timeout = self
625 .timeout
626 .unwrap_or_else(|| std::time::Duration::from_secs(30));
627
628 // 4. Create custom ProviderConfig (if needed)
629 let config = self.create_custom_config(base_url)?;
630
631 // 5. Create custom HttpTransport (if needed)
632 let transport = self.create_custom_transport(proxy_url.clone(), timeout)?;
633
634
635
636 // 6. Create adapter
637 let adapter: Box<dyn ChatApi> = match self.provider {
638 // Use config-driven generic adapter
639 Provider::Groq => {
640 if let Some(custom_transport) = transport {
641 Box::new(GenericAdapter::with_transport_ref(
642 config,
643 custom_transport,
644 )?)
645 } else {
646 Box::new(GenericAdapter::new(config)?)
647 }
648 }
649 Provider::XaiGrok => {
650 if let Some(custom_transport) = transport {
651 Box::new(GenericAdapter::with_transport_ref(
652 config,
653 custom_transport,
654 )?)
655 } else {
656 Box::new(GenericAdapter::new(config)?)
657 }
658 }
659 Provider::Ollama => {
660 if let Some(custom_transport) = transport {
661 Box::new(GenericAdapter::with_transport_ref(
662 config,
663 custom_transport,
664 )?)
665 } else {
666 Box::new(GenericAdapter::new(config)?)
667 }
668 }
669 Provider::DeepSeek => {
670 if let Some(custom_transport) = transport {
671 Box::new(GenericAdapter::with_transport_ref(
672 config,
673 custom_transport,
674 )?)
675 } else {
676 Box::new(GenericAdapter::new(config)?)
677 }
678 }
679 Provider::Qwen => {
680 if let Some(custom_transport) = transport {
681 Box::new(GenericAdapter::with_transport_ref(
682 config,
683 custom_transport,
684 )?)
685 } else {
686 Box::new(GenericAdapter::new(config)?)
687 }
688 }
689 Provider::BaiduWenxin => {
690 if let Some(custom_transport) = transport {
691 Box::new(GenericAdapter::with_transport_ref(
692 config,
693 custom_transport,
694 )?)
695 } else {
696 Box::new(GenericAdapter::new(config)?)
697 }
698 }
699 Provider::TencentHunyuan => {
700 if let Some(custom_transport) = transport {
701 Box::new(GenericAdapter::with_transport_ref(
702 config,
703 custom_transport,
704 )?)
705 } else {
706 Box::new(GenericAdapter::new(config)?)
707 }
708 }
709 Provider::IflytekSpark => {
710 if let Some(custom_transport) = transport {
711 Box::new(GenericAdapter::with_transport_ref(
712 config,
713 custom_transport,
714 )?)
715 } else {
716 Box::new(GenericAdapter::new(config)?)
717 }
718 }
719 Provider::Moonshot => {
720 if let Some(custom_transport) = transport {
721 Box::new(GenericAdapter::with_transport_ref(
722 config,
723 custom_transport,
724 )?)
725 } else {
726 Box::new(GenericAdapter::new(config)?)
727 }
728 }
729 Provider::Anthropic => {
730 if let Some(custom_transport) = transport {
731 Box::new(GenericAdapter::with_transport_ref(
732 config,
733 custom_transport,
734 )?)
735 } else {
736 Box::new(GenericAdapter::new(config)?)
737 }
738 }
739 Provider::AzureOpenAI => {
740 if let Some(custom_transport) = transport {
741 Box::new(GenericAdapter::with_transport_ref(
742 config,
743 custom_transport,
744 )?)
745 } else {
746 Box::new(GenericAdapter::new(config)?)
747 }
748 }
749 Provider::HuggingFace => {
750 if let Some(custom_transport) = transport {
751 Box::new(GenericAdapter::with_transport_ref(
752 config,
753 custom_transport,
754 )?)
755 } else {
756 Box::new(GenericAdapter::new(config)?)
757 }
758 }
759 Provider::TogetherAI => {
760 if let Some(custom_transport) = transport {
761 Box::new(GenericAdapter::with_transport_ref(
762 config,
763 custom_transport,
764 )?)
765 } else {
766 Box::new(GenericAdapter::new(config)?)
767 }
768 }
769 // Use independent adapters (these don't support custom configuration)
770 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
771 Provider::Gemini => Box::new(GeminiAdapter::new()?),
772 Provider::Mistral => Box::new(MistralAdapter::new()?),
773 Provider::Cohere => Box::new(CohereAdapter::new()?),
774 };
775
776 // 7. Create AiClient
777 let client = AiClient {
778 provider: self.provider,
779 adapter,
780 metrics: self.metrics.unwrap_or_else(|| Arc::new(NoopMetrics::new())),
781 };
782
783 Ok(client)
784 }
785
786 /// Determine base_url, priority: explicit setting > environment variable > default
787 fn determine_base_url(&self) -> Result<String, AiLibError> {
788 // 1. Explicitly set base_url
789 if let Some(ref base_url) = self.base_url {
790 return Ok(base_url.clone());
791 }
792
793 // 2. base_url from environment variable
794 let env_var_name = self.get_base_url_env_var_name();
795 if let Ok(base_url) = std::env::var(&env_var_name) {
796 return Ok(base_url);
797 }
798
799 // 3. Use default configuration
800 let default_config = self.get_default_provider_config()?;
801 Ok(default_config.base_url)
802 }
803
804 /// Determine proxy_url, priority: explicit setting > environment variable
805 fn determine_proxy_url(&self) -> Option<String> {
806 // 1. Explicitly set proxy_url
807 if let Some(ref proxy_url) = self.proxy_url {
808 // If proxy_url is empty string, it means explicitly no proxy
809 if proxy_url.is_empty() {
810 return None;
811 }
812 return Some(proxy_url.clone());
813 }
814
815 // 2. AI_PROXY_URL from environment variable
816 std::env::var("AI_PROXY_URL").ok()
817 }
818
819 /// Get environment variable name for corresponding provider
820 fn get_base_url_env_var_name(&self) -> String {
821 match self.provider {
822 Provider::Groq => "GROQ_BASE_URL".to_string(),
823 Provider::XaiGrok => "GROK_BASE_URL".to_string(),
824 Provider::Ollama => "OLLAMA_BASE_URL".to_string(),
825 Provider::DeepSeek => "DEEPSEEK_BASE_URL".to_string(),
826 Provider::Qwen => "DASHSCOPE_BASE_URL".to_string(),
827 Provider::BaiduWenxin => "BAIDU_WENXIN_BASE_URL".to_string(),
828 Provider::TencentHunyuan => "TENCENT_HUNYUAN_BASE_URL".to_string(),
829 Provider::IflytekSpark => "IFLYTEK_BASE_URL".to_string(),
830 Provider::Moonshot => "MOONSHOT_BASE_URL".to_string(),
831 Provider::Anthropic => "ANTHROPIC_BASE_URL".to_string(),
832 Provider::AzureOpenAI => "AZURE_OPENAI_BASE_URL".to_string(),
833 Provider::HuggingFace => "HUGGINGFACE_BASE_URL".to_string(),
834 Provider::TogetherAI => "TOGETHER_BASE_URL".to_string(),
835 // These providers don't support custom base_url
836 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
837 "".to_string()
838 }
839 }
840 }
841
842 /// Get default provider configuration
843 fn get_default_provider_config(
844 &self,
845 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
846 match self.provider {
847 Provider::Groq => Ok(ProviderConfigs::groq()),
848 Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
849 Provider::Ollama => Ok(ProviderConfigs::ollama()),
850 Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
851 Provider::Qwen => Ok(ProviderConfigs::qwen()),
852 Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
853 Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
854 Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
855 Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
856 Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
857 Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
858 Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
859 Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
860 // These providers don't support custom configuration
861 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
862 Err(AiLibError::ConfigurationError(
863 "This provider does not support custom configuration".to_string(),
864 ))
865 }
866 }
867 }
868
869 /// Create custom ProviderConfig
870 fn create_custom_config(
871 &self,
872 base_url: String,
873 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
874 let mut config = self.get_default_provider_config()?;
875 config.base_url = base_url;
876 Ok(config)
877 }
878
879 /// Create custom HttpTransport
880 fn create_custom_transport(
881 &self,
882 proxy_url: Option<String>,
883 timeout: std::time::Duration,
884 ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
885 // If no custom configuration, return None (use default transport)
886 if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
887 return Ok(None);
888 }
889
890 // Create custom HttpTransportConfig
891 let transport_config = crate::transport::HttpTransportConfig {
892 timeout,
893 proxy: proxy_url,
894 pool_max_idle_per_host: self.pool_max_idle,
895 pool_idle_timeout: self.pool_idle_timeout,
896 };
897
898 // Create custom HttpTransport
899 let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
900 Ok(Some(transport.boxed()))
901 }
902}
903
904/// Controllable streaming response
905struct ControlledStream {
906 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
907 cancel_rx: Option<oneshot::Receiver<()>>,
908}
909
910impl ControlledStream {
911 fn new(
912 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
913 cancel_rx: oneshot::Receiver<()>,
914 ) -> Self {
915 Self {
916 inner,
917 cancel_rx: Some(cancel_rx),
918 }
919 }
920}
921
922impl Stream for ControlledStream {
923 type Item = Result<ChatCompletionChunk, AiLibError>;
924
925 fn poll_next(
926 mut self: std::pin::Pin<&mut Self>,
927 cx: &mut std::task::Context<'_>,
928 ) -> std::task::Poll<Option<Self::Item>> {
929 use futures::stream::StreamExt;
930 use std::task::Poll;
931
932 // Check if cancelled
933 if let Some(ref mut cancel_rx) = self.cancel_rx {
934 match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
935 Poll::Ready(_) => {
936 self.cancel_rx = None;
937 return Poll::Ready(Some(Err(AiLibError::ProviderError(
938 "Stream cancelled".to_string(),
939 ))));
940 }
941 Poll::Pending => {}
942 }
943 }
944
945 // Poll inner stream
946 self.inner.poll_next_unpin(cx)
947 }
948}