ai_lib/client.rs
1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::provider::{
4 CohereAdapter, GeminiAdapter, GenericAdapter, MistralAdapter, OpenAiAdapter, ProviderConfigs,
5};
6use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
7use futures::stream::Stream;
8use futures::Future;
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12/// Unified AI client module
13///
14/// AI model provider enumeration
15#[derive(Debug, Clone, Copy)]
16pub enum Provider {
17 // Config-driven providers
18 Groq,
19 XaiGrok,
20 Ollama,
21 DeepSeek,
22 Anthropic,
23 AzureOpenAI,
24 HuggingFace,
25 TogetherAI,
26 // Chinese providers (OpenAI-compatible or config-driven)
27 BaiduWenxin,
28 TencentHunyuan,
29 IflytekSpark,
30 Moonshot,
31 // Independent adapters
32 OpenAI,
33 Qwen,
34 Gemini,
35 Mistral,
36 Cohere,
37 // Bedrock removed (deferred)
38}
39
40/// Unified AI client
41///
42/// Usage example:
43/// ```rust
44/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
45///
46/// #[tokio::main]
47/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
48/// // Switch model provider by changing Provider value
49/// let client = AiClient::new(Provider::Groq)?;
50///
51/// let request = ChatCompletionRequest::new(
52/// "test-model".to_string(),
53/// vec![Message {
54/// role: Role::User,
55/// content: ai_lib::types::common::Content::Text("Hello".to_string()),
56/// function_call: None,
57/// }],
58/// );
59///
60/// // Note: Set GROQ_API_KEY environment variable for actual API calls
61/// // Optional: Set AI_PROXY_URL environment variable to use proxy server
62/// // let response = client.chat_completion(request).await?;
63///
64/// println!("Client created successfully with provider: {:?}", client.current_provider());
65/// println!("Request prepared for model: {}", request.model);
66///
67/// Ok(())
68/// }
69/// ```
70///
71/// # Proxy Configuration
72///
73/// Configure proxy server by setting the `AI_PROXY_URL` environment variable:
74///
75/// ```bash
76/// export AI_PROXY_URL=http://proxy.example.com:8080
77/// ```
78///
79/// Supported proxy formats:
80/// - HTTP proxy: `http://proxy.example.com:8080`
81/// - HTTPS proxy: `https://proxy.example.com:8080`
82/// - With authentication: `http://user:pass@proxy.example.com:8080`
83pub struct AiClient {
84 provider: Provider,
85 adapter: Box<dyn ChatApi>,
86 metrics: Arc<dyn Metrics>,
87}
88
89impl AiClient {
90 /// Create a new AI client
91 ///
92 /// # Arguments
93 /// * `provider` - The AI model provider to use
94 ///
95 /// # Returns
96 /// * `Result<Self, AiLibError>` - Client instance on success, error on failure
97 ///
98 /// # Example
99 /// ```rust
100 /// use ai_lib::{AiClient, Provider};
101 ///
102 /// let client = AiClient::new(Provider::Groq)?;
103 /// # Ok::<(), ai_lib::AiLibError>(())
104 /// ```
105 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
106 // Use the new builder to create client with automatic environment variable detection
107 AiClientBuilder::new(provider).build()
108 }
109
110 /// Create a new AI client builder
111 ///
112 /// The builder pattern allows more flexible client configuration:
113 /// - Automatic environment variable detection
114 /// - Support for custom base_url and proxy
115 /// - Support for custom timeout and connection pool configuration
116 ///
117 /// # Arguments
118 /// * `provider` - The AI model provider to use
119 ///
120 /// # Returns
121 /// * `AiClientBuilder` - Builder instance
122 ///
123 /// # Example
124 /// ```rust
125 /// use ai_lib::{AiClient, Provider};
126 ///
127 /// // Simplest usage - automatic environment variable detection
128 /// let client = AiClient::builder(Provider::Groq).build()?;
129 ///
130 /// // Custom base_url and proxy
131 /// let client = AiClient::builder(Provider::Groq)
132 /// .with_base_url("https://custom.groq.com")
133 /// .with_proxy("http://proxy.example.com:8080")
134 /// .build()?;
135 /// # Ok::<(), ai_lib::AiLibError>(())
136 /// ```
137 pub fn builder(provider: Provider) -> AiClientBuilder {
138 AiClientBuilder::new(provider)
139 }
140
141 /// Create AiClient with injected metrics implementation
142 pub fn new_with_metrics(
143 provider: Provider,
144 metrics: Arc<dyn Metrics>,
145 ) -> Result<Self, AiLibError> {
146 let adapter: Box<dyn ChatApi> = match provider {
147 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
148 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
149 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
150 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
151 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
152 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
153 Provider::BaiduWenxin => {
154 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
155 }
156 Provider::TencentHunyuan => {
157 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
158 }
159 Provider::IflytekSpark => {
160 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
161 }
162 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
163 Provider::AzureOpenAI => {
164 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
165 }
166 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
167 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
168 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
169 Provider::Gemini => Box::new(GeminiAdapter::new()?),
170 Provider::Mistral => Box::new(MistralAdapter::new()?),
171 Provider::Cohere => Box::new(CohereAdapter::new()?),
172 };
173
174 Ok(Self {
175 provider,
176 adapter,
177 metrics,
178 })
179 }
180
181 /// Set metrics implementation on client
182 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
183 self.metrics = metrics;
184 self
185 }
186
187 /// Send chat completion request
188 ///
189 /// # Arguments
190 /// * `request` - Chat completion request
191 ///
192 /// # Returns
193 /// * `Result<ChatCompletionResponse, AiLibError>` - Response on success, error on failure
194 pub async fn chat_completion(
195 &self,
196 request: ChatCompletionRequest,
197 ) -> Result<ChatCompletionResponse, AiLibError> {
198 self.adapter.chat_completion(request).await
199 }
200
201 /// Streaming chat completion request
202 ///
203 /// # Arguments
204 /// * `request` - Chat completion request
205 ///
206 /// # Returns
207 /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - Stream response on success
208 pub async fn chat_completion_stream(
209 &self,
210 mut request: ChatCompletionRequest,
211 ) -> Result<
212 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
213 AiLibError,
214 > {
215 request.stream = Some(true);
216 self.adapter.chat_completion_stream(request).await
217 }
218
219 /// Streaming chat completion request with cancel control
220 ///
221 /// # Arguments
222 /// * `request` - Chat completion request
223 ///
224 /// # Returns
225 /// * `Result<(impl Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin, CancelHandle), AiLibError>` - Returns streaming response and cancel handle on success
226 pub async fn chat_completion_stream_with_cancel(
227 &self,
228 mut request: ChatCompletionRequest,
229 ) -> Result<
230 (
231 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
232 CancelHandle,
233 ),
234 AiLibError,
235 > {
236 request.stream = Some(true);
237 let stream = self.adapter.chat_completion_stream(request).await?;
238 let (cancel_tx, cancel_rx) = oneshot::channel();
239 let cancel_handle = CancelHandle {
240 sender: Some(cancel_tx),
241 };
242
243 let controlled_stream = ControlledStream::new(stream, cancel_rx);
244 Ok((Box::new(controlled_stream), cancel_handle))
245 }
246
247 /// Batch chat completion requests
248 ///
249 /// # Arguments
250 /// * `requests` - List of chat completion requests
251 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
252 ///
253 /// # Returns
254 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
255 ///
256 /// # Example
257 /// ```rust
258 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
259 /// use ai_lib::types::common::Content;
260 ///
261 /// #[tokio::main]
262 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
263 /// let client = AiClient::new(Provider::Groq)?;
264 ///
265 /// let requests = vec![
266 /// ChatCompletionRequest::new(
267 /// "llama3-8b-8192".to_string(),
268 /// vec![Message {
269 /// role: Role::User,
270 /// content: Content::Text("Hello".to_string()),
271 /// function_call: None,
272 /// }],
273 /// ),
274 /// ChatCompletionRequest::new(
275 /// "llama3-8b-8192".to_string(),
276 /// vec![Message {
277 /// role: Role::User,
278 /// content: Content::Text("How are you?".to_string()),
279 /// function_call: None,
280 /// }],
281 /// ),
282 /// ];
283 ///
284 /// // Limit concurrency to 5
285 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
286 ///
287 /// for (i, response) in responses.iter().enumerate() {
288 /// match response {
289 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
290 /// Err(e) => println!("Request {} failed: {}", i, e),
291 /// }
292 /// }
293 ///
294 /// Ok(())
295 /// }
296 /// ```
297 pub async fn chat_completion_batch(
298 &self,
299 requests: Vec<ChatCompletionRequest>,
300 concurrency_limit: Option<usize>,
301 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
302 self.adapter
303 .chat_completion_batch(requests, concurrency_limit)
304 .await
305 }
306
307 /// Smart batch processing: automatically choose processing strategy based on request count
308 ///
309 /// # Arguments
310 /// * `requests` - List of chat completion requests
311 ///
312 /// # Returns
313 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
314 pub async fn chat_completion_batch_smart(
315 &self,
316 requests: Vec<ChatCompletionRequest>,
317 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
318 // Use sequential processing for small batches, concurrent processing for large batches
319 let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
320 self.chat_completion_batch(requests, concurrency_limit)
321 .await
322 }
323
324 /// Batch chat completion requests
325 ///
326 /// # Arguments
327 /// * `requests` - List of chat completion requests
328 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
329 ///
330 /// # Returns
331 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
332 ///
333 /// # Example
334 /// ```rust
335 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
336 /// use ai_lib::types::common::Content;
337 ///
338 /// #[tokio::main]
339 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
340 /// let client = AiClient::new(Provider::Groq)?;
341 ///
342 /// let requests = vec![
343 /// ChatCompletionRequest::new(
344 /// "llama3-8b-8192".to_string(),
345 /// vec![Message {
346 /// role: Role::User,
347 /// content: Content::Text("Hello".to_string()),
348 /// function_call: None,
349 /// }],
350 /// ),
351 /// ChatCompletionRequest::new(
352 /// "llama3-8b-8192".to_string(),
353 /// vec![Message {
354 /// role: Role::User,
355 /// content: Content::Text("How are you?".to_string()),
356 /// function_call: None,
357 /// }],
358 /// ),
359 /// ];
360 ///
361 /// // Limit concurrency to 5
362 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
363 ///
364 /// for (i, response) in responses.iter().enumerate() {
365 /// match response {
366 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
367 /// Err(e) => println!("Request {} failed: {}", i, e),
368 /// }
369 /// }
370 ///
371 /// Ok(())
372 /// }
373 /// ```
374 ///
375 /// Get list of supported models
376 ///
377 /// # Returns
378 /// * `Result<Vec<String>, AiLibError>` - Returns model list on success, error on failure
379 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
380 self.adapter.list_models().await
381 }
382
383 /// Switch AI model provider
384 ///
385 /// # Arguments
386 /// * `provider` - New provider
387 ///
388 /// # Returns
389 /// * `Result<(), AiLibError>` - Returns () on success, error on failure
390 ///
391 /// # Example
392 /// ```rust
393 /// use ai_lib::{AiClient, Provider};
394 ///
395 /// let mut client = AiClient::new(Provider::Groq)?;
396 /// // Switch from Groq to Groq (demonstrating switch functionality)
397 /// client.switch_provider(Provider::Groq)?;
398 /// # Ok::<(), ai_lib::AiLibError>(())
399 /// ```
400 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
401 let new_adapter: Box<dyn ChatApi> = match provider {
402 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
403 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
404 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
405 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
406 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
407 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
408 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
409 Provider::BaiduWenxin => {
410 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
411 }
412 Provider::TencentHunyuan => {
413 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
414 }
415 Provider::IflytekSpark => {
416 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
417 }
418 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
419 Provider::Gemini => Box::new(GeminiAdapter::new()?),
420 Provider::AzureOpenAI => {
421 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
422 }
423 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
424 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
425 Provider::Mistral => Box::new(MistralAdapter::new()?),
426 Provider::Cohere => Box::new(CohereAdapter::new()?),
427 // Provider::Bedrock => Box::new(BedrockAdapter::new()?),
428 };
429
430 self.provider = provider;
431 self.adapter = new_adapter;
432 Ok(())
433 }
434
435 /// Get current provider
436 pub fn current_provider(&self) -> Provider {
437 self.provider
438 }
439}
440
441/// Streaming response cancel handle
442pub struct CancelHandle {
443 sender: Option<oneshot::Sender<()>>,
444}
445
446impl CancelHandle {
447 /// Cancel streaming response
448 pub fn cancel(mut self) {
449 if let Some(sender) = self.sender.take() {
450 let _ = sender.send(());
451 }
452 }
453}
454
455/// AI client builder with progressive custom configuration
456///
457/// Usage examples:
458/// ```rust
459/// use ai_lib::{AiClientBuilder, Provider};
460///
461/// // Simplest usage - automatic environment variable detection
462/// let client = AiClientBuilder::new(Provider::Groq).build()?;
463///
464/// // Custom base_url and proxy
465/// let client = AiClientBuilder::new(Provider::Groq)
466/// .with_base_url("https://custom.groq.com")
467/// .with_proxy("http://proxy.example.com:8080")
468/// .build()?;
469///
470/// // Full custom configuration
471/// let client = AiClientBuilder::new(Provider::Groq)
472/// .with_base_url("https://custom.groq.com")
473/// .with_proxy("http://proxy.example.com:8080")
474/// .with_timeout(std::time::Duration::from_secs(60))
475/// .with_pool_config(32, std::time::Duration::from_secs(90))
476/// .build()?;
477/// # Ok::<(), ai_lib::AiLibError>(())
478/// ```
479pub struct AiClientBuilder {
480 provider: Provider,
481 base_url: Option<String>,
482 proxy_url: Option<String>,
483 timeout: Option<std::time::Duration>,
484 pool_max_idle: Option<usize>,
485 pool_idle_timeout: Option<std::time::Duration>,
486 metrics: Option<Arc<dyn Metrics>>,
487}
488
489impl AiClientBuilder {
490 /// Create a new builder instance
491 ///
492 /// # Arguments
493 /// * `provider` - The AI model provider to use
494 ///
495 /// # Returns
496 /// * `Self` - Builder instance
497 pub fn new(provider: Provider) -> Self {
498 Self {
499 provider,
500 base_url: None,
501 proxy_url: None,
502 timeout: None,
503 pool_max_idle: None,
504 pool_idle_timeout: None,
505 metrics: None,
506 }
507 }
508
509 /// Set custom base URL
510 ///
511 /// # Arguments
512 /// * `base_url` - Custom base URL
513 ///
514 /// # Returns
515 /// * `Self` - Builder instance for method chaining
516 pub fn with_base_url(mut self, base_url: &str) -> Self {
517 self.base_url = Some(base_url.to_string());
518 self
519 }
520
521 /// Set custom proxy URL
522 ///
523 /// # Arguments
524 /// * `proxy_url` - Custom proxy URL
525 ///
526 /// # Returns
527 /// * `Self` - Builder instance for method chaining
528 pub fn with_proxy(mut self, proxy_url: &str) -> Self {
529 self.proxy_url = Some(proxy_url.to_string());
530 self
531 }
532
533 /// Set custom timeout duration
534 ///
535 /// # Arguments
536 /// * `timeout` - Custom timeout duration
537 ///
538 /// # Returns
539 /// * `Self` - Builder instance for method chaining
540 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
541 self.timeout = Some(timeout);
542 self
543 }
544
545 /// Set connection pool configuration
546 ///
547 /// # Arguments
548 /// * `max_idle` - Maximum idle connections per host
549 /// * `idle_timeout` - Idle connection timeout duration
550 ///
551 /// # Returns
552 /// * `Self` - Builder instance for method chaining
553 pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
554 self.pool_max_idle = Some(max_idle);
555 self.pool_idle_timeout = Some(idle_timeout);
556 self
557 }
558
559 /// Set custom metrics implementation
560 ///
561 /// # Arguments
562 /// * `metrics` - Custom metrics implementation
563 ///
564 /// # Returns
565 /// * `Self` - Builder instance for method chaining
566 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
567 self.metrics = Some(metrics);
568 self
569 }
570
571 /// Build AiClient instance
572 ///
573 /// The build process applies configuration in the following priority order:
574 /// 1. Explicitly set configuration (via with_* methods)
575 /// 2. Environment variable configuration
576 /// 3. Default configuration
577 ///
578 /// # Returns
579 /// * `Result<AiClient, AiLibError>` - Returns client instance on success, error on failure
580 pub fn build(self) -> Result<AiClient, AiLibError> {
581 // 1. Determine base_url: explicit setting > environment variable > default
582 let base_url = self.determine_base_url()?;
583
584 // 2. Determine proxy_url: explicit setting > environment variable
585 let proxy_url = self.determine_proxy_url();
586
587 // 3. Determine timeout: explicit setting > default
588 let timeout = self
589 .timeout
590 .unwrap_or_else(|| std::time::Duration::from_secs(30));
591
592 // 4. Create custom ProviderConfig (if needed)
593 let config = self.create_custom_config(base_url)?;
594
595 // 5. Create custom HttpTransport (if needed)
596 let transport = self.create_custom_transport(proxy_url, timeout)?;
597
598 // 6. Create adapter
599 let adapter: Box<dyn ChatApi> = match self.provider {
600 // Use config-driven generic adapter
601 Provider::Groq => {
602 if let Some(custom_transport) = transport {
603 Box::new(GenericAdapter::with_transport_ref(
604 config,
605 custom_transport,
606 )?)
607 } else {
608 Box::new(GenericAdapter::new(config)?)
609 }
610 }
611 Provider::XaiGrok => {
612 if let Some(custom_transport) = transport {
613 Box::new(GenericAdapter::with_transport_ref(
614 config,
615 custom_transport,
616 )?)
617 } else {
618 Box::new(GenericAdapter::new(config)?)
619 }
620 }
621 Provider::Ollama => {
622 if let Some(custom_transport) = transport {
623 Box::new(GenericAdapter::with_transport_ref(
624 config,
625 custom_transport,
626 )?)
627 } else {
628 Box::new(GenericAdapter::new(config)?)
629 }
630 }
631 Provider::DeepSeek => {
632 if let Some(custom_transport) = transport {
633 Box::new(GenericAdapter::with_transport_ref(
634 config,
635 custom_transport,
636 )?)
637 } else {
638 Box::new(GenericAdapter::new(config)?)
639 }
640 }
641 Provider::Qwen => {
642 if let Some(custom_transport) = transport {
643 Box::new(GenericAdapter::with_transport_ref(
644 config,
645 custom_transport,
646 )?)
647 } else {
648 Box::new(GenericAdapter::new(config)?)
649 }
650 }
651 Provider::BaiduWenxin => {
652 if let Some(custom_transport) = transport {
653 Box::new(GenericAdapter::with_transport_ref(
654 config,
655 custom_transport,
656 )?)
657 } else {
658 Box::new(GenericAdapter::new(config)?)
659 }
660 }
661 Provider::TencentHunyuan => {
662 if let Some(custom_transport) = transport {
663 Box::new(GenericAdapter::with_transport_ref(
664 config,
665 custom_transport,
666 )?)
667 } else {
668 Box::new(GenericAdapter::new(config)?)
669 }
670 }
671 Provider::IflytekSpark => {
672 if let Some(custom_transport) = transport {
673 Box::new(GenericAdapter::with_transport_ref(
674 config,
675 custom_transport,
676 )?)
677 } else {
678 Box::new(GenericAdapter::new(config)?)
679 }
680 }
681 Provider::Moonshot => {
682 if let Some(custom_transport) = transport {
683 Box::new(GenericAdapter::with_transport_ref(
684 config,
685 custom_transport,
686 )?)
687 } else {
688 Box::new(GenericAdapter::new(config)?)
689 }
690 }
691 Provider::Anthropic => {
692 if let Some(custom_transport) = transport {
693 Box::new(GenericAdapter::with_transport_ref(
694 config,
695 custom_transport,
696 )?)
697 } else {
698 Box::new(GenericAdapter::new(config)?)
699 }
700 }
701 Provider::AzureOpenAI => {
702 if let Some(custom_transport) = transport {
703 Box::new(GenericAdapter::with_transport_ref(
704 config,
705 custom_transport,
706 )?)
707 } else {
708 Box::new(GenericAdapter::new(config)?)
709 }
710 }
711 Provider::HuggingFace => {
712 if let Some(custom_transport) = transport {
713 Box::new(GenericAdapter::with_transport_ref(
714 config,
715 custom_transport,
716 )?)
717 } else {
718 Box::new(GenericAdapter::new(config)?)
719 }
720 }
721 Provider::TogetherAI => {
722 if let Some(custom_transport) = transport {
723 Box::new(GenericAdapter::with_transport_ref(
724 config,
725 custom_transport,
726 )?)
727 } else {
728 Box::new(GenericAdapter::new(config)?)
729 }
730 }
731 // Use independent adapters (these don't support custom configuration)
732 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
733 Provider::Gemini => Box::new(GeminiAdapter::new()?),
734 Provider::Mistral => Box::new(MistralAdapter::new()?),
735 Provider::Cohere => Box::new(CohereAdapter::new()?),
736 };
737
738 // 7. Create AiClient
739 let client = AiClient {
740 provider: self.provider,
741 adapter,
742 metrics: self.metrics.unwrap_or_else(|| Arc::new(NoopMetrics::new())),
743 };
744
745 Ok(client)
746 }
747
748 /// Determine base_url, priority: explicit setting > environment variable > default
749 fn determine_base_url(&self) -> Result<String, AiLibError> {
750 // 1. Explicitly set base_url
751 if let Some(ref base_url) = self.base_url {
752 return Ok(base_url.clone());
753 }
754
755 // 2. base_url from environment variable
756 let env_var_name = self.get_base_url_env_var_name();
757 if let Ok(base_url) = std::env::var(&env_var_name) {
758 return Ok(base_url);
759 }
760
761 // 3. Use default configuration
762 let default_config = self.get_default_provider_config()?;
763 Ok(default_config.base_url)
764 }
765
766 /// Determine proxy_url, priority: explicit setting > environment variable
767 fn determine_proxy_url(&self) -> Option<String> {
768 // 1. Explicitly set proxy_url
769 if let Some(ref proxy_url) = self.proxy_url {
770 return Some(proxy_url.clone());
771 }
772
773 // 2. AI_PROXY_URL from environment variable
774 std::env::var("AI_PROXY_URL").ok()
775 }
776
777 /// Get environment variable name for corresponding provider
778 fn get_base_url_env_var_name(&self) -> String {
779 match self.provider {
780 Provider::Groq => "GROQ_BASE_URL".to_string(),
781 Provider::XaiGrok => "GROK_BASE_URL".to_string(),
782 Provider::Ollama => "OLLAMA_BASE_URL".to_string(),
783 Provider::DeepSeek => "DEEPSEEK_BASE_URL".to_string(),
784 Provider::Qwen => "DASHSCOPE_BASE_URL".to_string(),
785 Provider::BaiduWenxin => "BAIDU_WENXIN_BASE_URL".to_string(),
786 Provider::TencentHunyuan => "TENCENT_HUNYUAN_BASE_URL".to_string(),
787 Provider::IflytekSpark => "IFLYTEK_BASE_URL".to_string(),
788 Provider::Moonshot => "MOONSHOT_BASE_URL".to_string(),
789 Provider::Anthropic => "ANTHROPIC_BASE_URL".to_string(),
790 Provider::AzureOpenAI => "AZURE_OPENAI_BASE_URL".to_string(),
791 Provider::HuggingFace => "HUGGINGFACE_BASE_URL".to_string(),
792 Provider::TogetherAI => "TOGETHER_BASE_URL".to_string(),
793 // These providers don't support custom base_url
794 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
795 "".to_string()
796 }
797 }
798 }
799
800 /// Get default provider configuration
801 fn get_default_provider_config(
802 &self,
803 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
804 match self.provider {
805 Provider::Groq => Ok(ProviderConfigs::groq()),
806 Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
807 Provider::Ollama => Ok(ProviderConfigs::ollama()),
808 Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
809 Provider::Qwen => Ok(ProviderConfigs::qwen()),
810 Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
811 Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
812 Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
813 Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
814 Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
815 Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
816 Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
817 Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
818 // These providers don't support custom configuration
819 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
820 Err(AiLibError::ConfigurationError(
821 "This provider does not support custom configuration".to_string(),
822 ))
823 }
824 }
825 }
826
827 /// Create custom ProviderConfig
828 fn create_custom_config(
829 &self,
830 base_url: String,
831 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
832 let mut config = self.get_default_provider_config()?;
833 config.base_url = base_url;
834 Ok(config)
835 }
836
837 /// Create custom HttpTransport
838 fn create_custom_transport(
839 &self,
840 proxy_url: Option<String>,
841 timeout: std::time::Duration,
842 ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
843 // If no custom configuration, return None (use default transport)
844 if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
845 return Ok(None);
846 }
847
848 // Create custom HttpTransportConfig
849 let transport_config = crate::transport::HttpTransportConfig {
850 timeout,
851 proxy: proxy_url,
852 pool_max_idle_per_host: self.pool_max_idle,
853 pool_idle_timeout: self.pool_idle_timeout,
854 };
855
856 // Create custom HttpTransport
857 let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
858 Ok(Some(transport.boxed()))
859 }
860}
861
862/// Controllable streaming response
863struct ControlledStream {
864 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
865 cancel_rx: Option<oneshot::Receiver<()>>,
866}
867
868impl ControlledStream {
869 fn new(
870 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
871 cancel_rx: oneshot::Receiver<()>,
872 ) -> Self {
873 Self {
874 inner,
875 cancel_rx: Some(cancel_rx),
876 }
877 }
878}
879
880impl Stream for ControlledStream {
881 type Item = Result<ChatCompletionChunk, AiLibError>;
882
883 fn poll_next(
884 mut self: std::pin::Pin<&mut Self>,
885 cx: &mut std::task::Context<'_>,
886 ) -> std::task::Poll<Option<Self::Item>> {
887 use futures::stream::StreamExt;
888 use std::task::Poll;
889
890 // Check if cancelled
891 if let Some(ref mut cancel_rx) = self.cancel_rx {
892 match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
893 Poll::Ready(_) => {
894 self.cancel_rx = None;
895 return Poll::Ready(Some(Err(AiLibError::ProviderError(
896 "Stream cancelled".to_string(),
897 ))));
898 }
899 Poll::Pending => {}
900 }
901 }
902
903 // Poll inner stream
904 self.inner.poll_next_unpin(cx)
905 }
906}