ai_lib/client.rs
1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::config::ConnectionOptions;
3use crate::metrics::{Metrics, NoopMetrics};
4use crate::provider::{
5 CohereAdapter, GeminiAdapter, GenericAdapter, MistralAdapter, OpenAiAdapter, ProviderConfigs,
6};
7use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
8use futures::stream::Stream;
9use futures::Future;
10use std::sync::Arc;
11use tokio::sync::oneshot;
12
13/// Unified AI client module
14///
15/// AI model provider enumeration
16#[derive(Debug, Clone, Copy)]
17pub enum Provider {
18 // Config-driven providers
19 Groq,
20 XaiGrok,
21 Ollama,
22 DeepSeek,
23 Anthropic,
24 AzureOpenAI,
25 HuggingFace,
26 TogetherAI,
27 // Chinese providers (OpenAI-compatible or config-driven)
28 BaiduWenxin,
29 TencentHunyuan,
30 IflytekSpark,
31 Moonshot,
32 // Independent adapters
33 OpenAI,
34 Qwen,
35 Gemini,
36 Mistral,
37 Cohere,
38 // Bedrock removed (deferred)
39}
40
41/// Unified AI client
42///
43/// Usage example:
44/// ```rust
45/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
46///
47/// #[tokio::main]
48/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
49/// // Switch model provider by changing Provider value
50/// let client = AiClient::new(Provider::Groq)?;
51///
52/// let request = ChatCompletionRequest::new(
53/// "test-model".to_string(),
54/// vec![Message {
55/// role: Role::User,
56/// content: ai_lib::types::common::Content::Text("Hello".to_string()),
57/// function_call: None,
58/// }],
59/// );
60///
61/// // Note: Set GROQ_API_KEY environment variable for actual API calls
62/// // Optional: Set AI_PROXY_URL environment variable to use proxy server
63/// // let response = client.chat_completion(request).await?;
64///
65/// println!("Client created successfully with provider: {:?}", client.current_provider());
66/// println!("Request prepared for model: {}", request.model);
67///
68/// Ok(())
69/// }
70/// ```
71///
72/// # Proxy Configuration
73///
74/// Configure proxy server by setting the `AI_PROXY_URL` environment variable:
75///
76/// ```bash
77/// export AI_PROXY_URL=http://proxy.example.com:8080
78/// ```
79///
80/// Supported proxy formats:
81/// - HTTP proxy: `http://proxy.example.com:8080`
82/// - HTTPS proxy: `https://proxy.example.com:8080`
83/// - With authentication: `http://user:pass@proxy.example.com:8080`
84pub struct AiClient {
85 provider: Provider,
86 adapter: Box<dyn ChatApi>,
87 metrics: Arc<dyn Metrics>,
88 connection_options: Option<ConnectionOptions>,
89}
90
91impl AiClient {
92 /// Create a new AI client
93 ///
94 /// # Arguments
95 /// * `provider` - The AI model provider to use
96 ///
97 /// # Returns
98 /// * `Result<Self, AiLibError>` - Client instance on success, error on failure
99 ///
100 /// # Example
101 /// ```rust
102 /// use ai_lib::{AiClient, Provider};
103 ///
104 /// let client = AiClient::new(Provider::Groq)?;
105 /// # Ok::<(), ai_lib::AiLibError>(())
106 /// ```
107 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
108 // Use the new builder to create client with automatic environment variable detection
109 let mut c = AiClientBuilder::new(provider).build()?;
110 c.connection_options = None;
111 Ok(c)
112 }
113
114 /// Create client with minimal explicit options (base_url/proxy/timeout). Not all providers
115 /// support overrides; unsupported providers ignore unspecified fields gracefully.
116 pub fn with_options(provider: Provider, opts: ConnectionOptions) -> Result<Self, AiLibError> {
117 let config_driven = matches!(
118 provider,
119 Provider::Groq
120 | Provider::XaiGrok
121 | Provider::Ollama
122 | Provider::DeepSeek
123 | Provider::Qwen
124 | Provider::BaiduWenxin
125 | Provider::TencentHunyuan
126 | Provider::IflytekSpark
127 | Provider::Moonshot
128 | Provider::Anthropic
129 | Provider::AzureOpenAI
130 | Provider::HuggingFace
131 | Provider::TogetherAI
132 );
133 let need_builder = config_driven
134 && (opts.base_url.is_some()
135 || opts.proxy.is_some()
136 || opts.timeout.is_some()
137 || opts.disable_proxy);
138 if need_builder {
139 let mut b = AiClient::builder(provider);
140 if let Some(ref base) = opts.base_url {
141 b = b.with_base_url(base);
142 }
143 if opts.disable_proxy {
144 b = b.without_proxy();
145 } else if let Some(ref proxy) = opts.proxy {
146 if proxy.is_empty() {
147 b = b.without_proxy();
148 } else {
149 b = b.with_proxy(Some(proxy));
150 }
151 }
152 if let Some(t) = opts.timeout {
153 b = b.with_timeout(t);
154 }
155 let mut client = b.build()?;
156 // If api_key override + generic provider path: re-wrap adapter using override
157 if opts.api_key.is_some() {
158 // Only applies to config-driven generic adapter providers
159 let new_adapter: Option<Box<dyn ChatApi>> = match provider {
160 Provider::Groq => Some(Box::new(GenericAdapter::new_with_api_key(
161 ProviderConfigs::groq(),
162 opts.api_key.clone(),
163 )?)),
164 Provider::XaiGrok => Some(Box::new(GenericAdapter::new_with_api_key(
165 ProviderConfigs::xai_grok(),
166 opts.api_key.clone(),
167 )?)),
168 Provider::Ollama => Some(Box::new(GenericAdapter::new_with_api_key(
169 ProviderConfigs::ollama(),
170 opts.api_key.clone(),
171 )?)),
172 Provider::DeepSeek => Some(Box::new(GenericAdapter::new_with_api_key(
173 ProviderConfigs::deepseek(),
174 opts.api_key.clone(),
175 )?)),
176 Provider::Qwen => Some(Box::new(GenericAdapter::new_with_api_key(
177 ProviderConfigs::qwen(),
178 opts.api_key.clone(),
179 )?)),
180 Provider::BaiduWenxin => Some(Box::new(GenericAdapter::new_with_api_key(
181 ProviderConfigs::baidu_wenxin(),
182 opts.api_key.clone(),
183 )?)),
184 Provider::TencentHunyuan => Some(Box::new(GenericAdapter::new_with_api_key(
185 ProviderConfigs::tencent_hunyuan(),
186 opts.api_key.clone(),
187 )?)),
188 Provider::IflytekSpark => Some(Box::new(GenericAdapter::new_with_api_key(
189 ProviderConfigs::iflytek_spark(),
190 opts.api_key.clone(),
191 )?)),
192 Provider::Moonshot => Some(Box::new(GenericAdapter::new_with_api_key(
193 ProviderConfigs::moonshot(),
194 opts.api_key.clone(),
195 )?)),
196 Provider::Anthropic => Some(Box::new(GenericAdapter::new_with_api_key(
197 ProviderConfigs::anthropic(),
198 opts.api_key.clone(),
199 )?)),
200 Provider::AzureOpenAI => Some(Box::new(GenericAdapter::new_with_api_key(
201 ProviderConfigs::azure_openai(),
202 opts.api_key.clone(),
203 )?)),
204 Provider::HuggingFace => Some(Box::new(GenericAdapter::new_with_api_key(
205 ProviderConfigs::huggingface(),
206 opts.api_key.clone(),
207 )?)),
208 Provider::TogetherAI => Some(Box::new(GenericAdapter::new_with_api_key(
209 ProviderConfigs::together_ai(),
210 opts.api_key.clone(),
211 )?)),
212 _ => None,
213 };
214 if let Some(a) = new_adapter {
215 client.adapter = a;
216 }
217 }
218 client.connection_options = Some(opts);
219 return Ok(client);
220 }
221
222 // Independent adapters: OpenAI / Gemini / Mistral / Cohere
223 if matches!(
224 provider,
225 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere
226 ) {
227 let adapter: Box<dyn ChatApi> = match provider {
228 Provider::OpenAI => {
229 if let Some(ref k) = opts.api_key {
230 let inner =
231 OpenAiAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
232 Box::new(inner)
233 } else {
234 let inner = OpenAiAdapter::new()?;
235 Box::new(inner)
236 }
237 }
238 Provider::Gemini => {
239 if let Some(ref k) = opts.api_key {
240 let inner =
241 GeminiAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
242 Box::new(inner)
243 } else {
244 let inner = GeminiAdapter::new()?;
245 Box::new(inner)
246 }
247 }
248 Provider::Mistral => {
249 if opts.api_key.is_some() || opts.base_url.is_some() {
250 let inner = MistralAdapter::new_with_overrides(
251 opts.api_key.clone(),
252 opts.base_url.clone(),
253 )?;
254 Box::new(inner)
255 } else {
256 let inner = MistralAdapter::new()?;
257 Box::new(inner)
258 }
259 }
260 Provider::Cohere => {
261 if let Some(ref k) = opts.api_key {
262 let inner =
263 CohereAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
264 Box::new(inner)
265 } else {
266 let inner = CohereAdapter::new()?;
267 Box::new(inner)
268 }
269 }
270 _ => unreachable!(),
271 };
272 return Ok(AiClient {
273 provider,
274 adapter,
275 metrics: Arc::new(NoopMetrics::new()),
276 connection_options: Some(opts),
277 });
278 }
279
280 // Simple config-driven without overrides
281 let mut client = AiClient::new(provider)?;
282 if let Some(ref k) = opts.api_key {
283 let override_adapter: Option<Box<dyn ChatApi>> = match provider {
284 Provider::Groq => Some(Box::new(GenericAdapter::new_with_api_key(
285 ProviderConfigs::groq(),
286 Some(k.clone()),
287 )?)),
288 Provider::XaiGrok => Some(Box::new(GenericAdapter::new_with_api_key(
289 ProviderConfigs::xai_grok(),
290 Some(k.clone()),
291 )?)),
292 Provider::Ollama => Some(Box::new(GenericAdapter::new_with_api_key(
293 ProviderConfigs::ollama(),
294 Some(k.clone()),
295 )?)),
296 Provider::DeepSeek => Some(Box::new(GenericAdapter::new_with_api_key(
297 ProviderConfigs::deepseek(),
298 Some(k.clone()),
299 )?)),
300 Provider::Qwen => Some(Box::new(GenericAdapter::new_with_api_key(
301 ProviderConfigs::qwen(),
302 Some(k.clone()),
303 )?)),
304 Provider::BaiduWenxin => Some(Box::new(GenericAdapter::new_with_api_key(
305 ProviderConfigs::baidu_wenxin(),
306 Some(k.clone()),
307 )?)),
308 Provider::TencentHunyuan => Some(Box::new(GenericAdapter::new_with_api_key(
309 ProviderConfigs::tencent_hunyuan(),
310 Some(k.clone()),
311 )?)),
312 Provider::IflytekSpark => Some(Box::new(GenericAdapter::new_with_api_key(
313 ProviderConfigs::iflytek_spark(),
314 Some(k.clone()),
315 )?)),
316 Provider::Moonshot => Some(Box::new(GenericAdapter::new_with_api_key(
317 ProviderConfigs::moonshot(),
318 Some(k.clone()),
319 )?)),
320 Provider::Anthropic => Some(Box::new(GenericAdapter::new_with_api_key(
321 ProviderConfigs::anthropic(),
322 Some(k.clone()),
323 )?)),
324 Provider::AzureOpenAI => Some(Box::new(GenericAdapter::new_with_api_key(
325 ProviderConfigs::azure_openai(),
326 Some(k.clone()),
327 )?)),
328 Provider::HuggingFace => Some(Box::new(GenericAdapter::new_with_api_key(
329 ProviderConfigs::huggingface(),
330 Some(k.clone()),
331 )?)),
332 Provider::TogetherAI => Some(Box::new(GenericAdapter::new_with_api_key(
333 ProviderConfigs::together_ai(),
334 Some(k.clone()),
335 )?)),
336 _ => None,
337 };
338 if let Some(a) = override_adapter {
339 client.adapter = a;
340 }
341 }
342 client.connection_options = Some(opts);
343 Ok(client)
344 }
345
346 pub fn connection_options(&self) -> Option<&ConnectionOptions> {
347 self.connection_options.as_ref()
348 }
349
350 /// Create a new AI client builder
351 ///
352 /// The builder pattern allows more flexible client configuration:
353 /// - Automatic environment variable detection
354 /// - Support for custom base_url and proxy
355 /// - Support for custom timeout and connection pool configuration
356 ///
357 /// # Arguments
358 /// * `provider` - The AI model provider to use
359 ///
360 /// # Returns
361 /// * `AiClientBuilder` - Builder instance
362 ///
363 /// # Example
364 /// ```rust
365 /// use ai_lib::{AiClient, Provider};
366 ///
367 /// // Simplest usage - automatic environment variable detection
368 /// let client = AiClient::builder(Provider::Groq).build()?;
369 ///
370 /// // Custom base_url and proxy
371 /// let client = AiClient::builder(Provider::Groq)
372 /// .with_base_url("https://custom.groq.com")
373 /// .with_proxy(Some("http://proxy.example.com:8080"))
374 /// .build()?;
375 /// # Ok::<(), ai_lib::AiLibError>(())
376 /// ```
377 pub fn builder(provider: Provider) -> AiClientBuilder {
378 AiClientBuilder::new(provider)
379 }
380
381 /// Create AiClient with injected metrics implementation
382 pub fn new_with_metrics(
383 provider: Provider,
384 metrics: Arc<dyn Metrics>,
385 ) -> Result<Self, AiLibError> {
386 let adapter: Box<dyn ChatApi> = match provider {
387 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
388 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
389 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
390 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
391 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
392 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
393 Provider::BaiduWenxin => {
394 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
395 }
396 Provider::TencentHunyuan => {
397 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
398 }
399 Provider::IflytekSpark => {
400 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
401 }
402 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
403 Provider::AzureOpenAI => {
404 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
405 }
406 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
407 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
408 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
409 Provider::Gemini => Box::new(GeminiAdapter::new()?),
410 Provider::Mistral => Box::new(MistralAdapter::new()?),
411 Provider::Cohere => Box::new(CohereAdapter::new()?),
412 };
413
414 Ok(Self {
415 provider,
416 adapter,
417 metrics,
418 connection_options: None,
419 })
420 }
421
422 /// Set metrics implementation on client
423 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
424 self.metrics = metrics;
425 self
426 }
427
428 /// Send chat completion request
429 ///
430 /// # Arguments
431 /// * `request` - Chat completion request
432 ///
433 /// # Returns
434 /// * `Result<ChatCompletionResponse, AiLibError>` - Response on success, error on failure
435 pub async fn chat_completion(
436 &self,
437 request: ChatCompletionRequest,
438 ) -> Result<ChatCompletionResponse, AiLibError> {
439 self.adapter.chat_completion(request).await
440 }
441
442 /// Streaming chat completion request
443 ///
444 /// # Arguments
445 /// * `request` - Chat completion request
446 ///
447 /// # Returns
448 /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - Stream response on success
449 pub async fn chat_completion_stream(
450 &self,
451 mut request: ChatCompletionRequest,
452 ) -> Result<
453 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
454 AiLibError,
455 > {
456 request.stream = Some(true);
457 self.adapter.chat_completion_stream(request).await
458 }
459
460 /// Streaming chat completion request with cancel control
461 ///
462 /// # Arguments
463 /// * `request` - Chat completion request
464 ///
465 /// # Returns
466 /// * `Result<(impl Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin, CancelHandle), AiLibError>` - Returns streaming response and cancel handle on success
467 pub async fn chat_completion_stream_with_cancel(
468 &self,
469 mut request: ChatCompletionRequest,
470 ) -> Result<
471 (
472 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
473 CancelHandle,
474 ),
475 AiLibError,
476 > {
477 request.stream = Some(true);
478 let stream = self.adapter.chat_completion_stream(request).await?;
479 let (cancel_tx, cancel_rx) = oneshot::channel();
480 let cancel_handle = CancelHandle {
481 sender: Some(cancel_tx),
482 };
483
484 let controlled_stream = ControlledStream::new(stream, cancel_rx);
485 Ok((Box::new(controlled_stream), cancel_handle))
486 }
487
488 /// Batch chat completion requests
489 ///
490 /// # Arguments
491 /// * `requests` - List of chat completion requests
492 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
493 ///
494 /// # Returns
495 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
496 ///
497 /// # Example
498 /// ```rust
499 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
500 /// use ai_lib::types::common::Content;
501 ///
502 /// #[tokio::main]
503 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
504 /// let client = AiClient::new(Provider::Groq)?;
505 ///
506 /// let requests = vec![
507 /// ChatCompletionRequest::new(
508 /// "llama3-8b-8192".to_string(),
509 /// vec![Message {
510 /// role: Role::User,
511 /// content: Content::Text("Hello".to_string()),
512 /// function_call: None,
513 /// }],
514 /// ),
515 /// ChatCompletionRequest::new(
516 /// "llama3-8b-8192".to_string(),
517 /// vec![Message {
518 /// role: Role::User,
519 /// content: Content::Text("How are you?".to_string()),
520 /// function_call: None,
521 /// }],
522 /// ),
523 /// ];
524 ///
525 /// // Limit concurrency to 5
526 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
527 ///
528 /// for (i, response) in responses.iter().enumerate() {
529 /// match response {
530 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
531 /// Err(e) => println!("Request {} failed: {}", i, e),
532 /// }
533 /// }
534 ///
535 /// Ok(())
536 /// }
537 /// ```
538 pub async fn chat_completion_batch(
539 &self,
540 requests: Vec<ChatCompletionRequest>,
541 concurrency_limit: Option<usize>,
542 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
543 self.adapter
544 .chat_completion_batch(requests, concurrency_limit)
545 .await
546 }
547
548 /// Smart batch processing: automatically choose processing strategy based on request count
549 ///
550 /// # Arguments
551 /// * `requests` - List of chat completion requests
552 ///
553 /// # Returns
554 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
555 pub async fn chat_completion_batch_smart(
556 &self,
557 requests: Vec<ChatCompletionRequest>,
558 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
559 // Use sequential processing for small batches, concurrent processing for large batches
560 let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
561 self.chat_completion_batch(requests, concurrency_limit)
562 .await
563 }
564
565 /// Batch chat completion requests
566 ///
567 /// # Arguments
568 /// * `requests` - List of chat completion requests
569 /// * `concurrency_limit` - Maximum concurrent request count (None means unlimited)
570 ///
571 /// # Returns
572 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns response results for all requests
573 ///
574 /// # Example
575 /// ```rust
576 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
577 /// use ai_lib::types::common::Content;
578 ///
579 /// #[tokio::main]
580 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
581 /// let client = AiClient::new(Provider::Groq)?;
582 ///
583 /// let requests = vec![
584 /// ChatCompletionRequest::new(
585 /// "llama3-8b-8192".to_string(),
586 /// vec![Message {
587 /// role: Role::User,
588 /// content: Content::Text("Hello".to_string()),
589 /// function_call: None,
590 /// }],
591 /// ),
592 /// ChatCompletionRequest::new(
593 /// "llama3-8b-8192".to_string(),
594 /// vec![Message {
595 /// role: Role::User,
596 /// content: Content::Text("How are you?".to_string()),
597 /// function_call: None,
598 /// }],
599 /// ),
600 /// ];
601 ///
602 /// // Limit concurrency to 5
603 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
604 ///
605 /// for (i, response) in responses.iter().enumerate() {
606 /// match response {
607 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
608 /// Err(e) => println!("Request {} failed: {}", i, e),
609 /// }
610 /// }
611 ///
612 /// Ok(())
613 /// }
614 /// ```
615 ///
616 /// Get list of supported models
617 ///
618 /// # Returns
619 /// * `Result<Vec<String>, AiLibError>` - Returns model list on success, error on failure
620 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
621 self.adapter.list_models().await
622 }
623
624 /// Switch AI model provider
625 ///
626 /// # Arguments
627 /// * `provider` - New provider
628 ///
629 /// # Returns
630 /// * `Result<(), AiLibError>` - Returns () on success, error on failure
631 ///
632 /// # Example
633 /// ```rust
634 /// use ai_lib::{AiClient, Provider};
635 ///
636 /// let mut client = AiClient::new(Provider::Groq)?;
637 /// // Switch from Groq to Groq (demonstrating switch functionality)
638 /// client.switch_provider(Provider::Groq)?;
639 /// # Ok::<(), ai_lib::AiLibError>(())
640 /// ```
641 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
642 let new_adapter: Box<dyn ChatApi> = match provider {
643 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
644 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
645 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
646 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
647 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
648 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
649 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
650 Provider::BaiduWenxin => {
651 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
652 }
653 Provider::TencentHunyuan => {
654 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
655 }
656 Provider::IflytekSpark => {
657 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
658 }
659 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
660 Provider::Gemini => Box::new(GeminiAdapter::new()?),
661 Provider::AzureOpenAI => {
662 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
663 }
664 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
665 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
666 Provider::Mistral => Box::new(MistralAdapter::new()?),
667 Provider::Cohere => Box::new(CohereAdapter::new()?),
668 // Provider::Bedrock => Box::new(BedrockAdapter::new()?),
669 };
670
671 self.provider = provider;
672 self.adapter = new_adapter;
673 Ok(())
674 }
675
676 /// Get current provider
677 pub fn current_provider(&self) -> Provider {
678 self.provider
679 }
680}
681
682/// Streaming response cancel handle
683pub struct CancelHandle {
684 sender: Option<oneshot::Sender<()>>,
685}
686
687impl CancelHandle {
688 /// Cancel streaming response
689 pub fn cancel(mut self) {
690 if let Some(sender) = self.sender.take() {
691 let _ = sender.send(());
692 }
693 }
694}
695
696/// AI client builder with progressive custom configuration
697///
698/// Usage examples:
699/// ```rust
700/// use ai_lib::{AiClientBuilder, Provider};
701///
702/// // Simplest usage - automatic environment variable detection
703/// let client = AiClientBuilder::new(Provider::Groq).build()?;
704///
705/// // Custom base_url and proxy
706/// let client = AiClientBuilder::new(Provider::Groq)
707/// .with_base_url("https://custom.groq.com")
708/// .with_proxy(Some("http://proxy.example.com:8080"))
709/// .build()?;
710///
711/// // Full custom configuration
712/// let client = AiClientBuilder::new(Provider::Groq)
713/// .with_base_url("https://custom.groq.com")
714/// .with_proxy(Some("http://proxy.example.com:8080"))
715/// .with_timeout(std::time::Duration::from_secs(60))
716/// .with_pool_config(32, std::time::Duration::from_secs(90))
717/// .build()?;
718/// # Ok::<(), ai_lib::AiLibError>(())
719/// ```
720pub struct AiClientBuilder {
721 provider: Provider,
722 base_url: Option<String>,
723 proxy_url: Option<String>,
724 timeout: Option<std::time::Duration>,
725 pool_max_idle: Option<usize>,
726 pool_idle_timeout: Option<std::time::Duration>,
727 metrics: Option<Arc<dyn Metrics>>,
728}
729
730impl AiClientBuilder {
731 /// Create a new builder instance
732 ///
733 /// # Arguments
734 /// * `provider` - The AI model provider to use
735 ///
736 /// # Returns
737 /// * `Self` - Builder instance
738 pub fn new(provider: Provider) -> Self {
739 Self {
740 provider,
741 base_url: None,
742 proxy_url: None,
743 timeout: None,
744 pool_max_idle: None,
745 pool_idle_timeout: None,
746 metrics: None,
747 }
748 }
749
750 /// Set custom base URL
751 ///
752 /// # Arguments
753 /// * `base_url` - Custom base URL
754 ///
755 /// # Returns
756 /// * `Self` - Builder instance for method chaining
757 pub fn with_base_url(mut self, base_url: &str) -> Self {
758 self.base_url = Some(base_url.to_string());
759 self
760 }
761
762 /// Set custom proxy URL
763 ///
764 /// # Arguments
765 /// * `proxy_url` - Custom proxy URL, or None to use AI_PROXY_URL environment variable
766 ///
767 /// # Returns
768 /// * `Self` - Builder instance for method chaining
769 ///
770 /// # Examples
771 /// ```rust
772 /// use ai_lib::{AiClientBuilder, Provider};
773 ///
774 /// // Use specific proxy URL
775 /// let client = AiClientBuilder::new(Provider::Groq)
776 /// .with_proxy(Some("http://proxy.example.com:8080"))
777 /// .build()?;
778 ///
779 /// // Use AI_PROXY_URL environment variable
780 /// let client = AiClientBuilder::new(Provider::Groq)
781 /// .with_proxy(None)
782 /// .build()?;
783 /// # Ok::<(), ai_lib::AiLibError>(())
784 /// ```
785 pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
786 self.proxy_url = proxy_url.map(|s| s.to_string());
787 self
788 }
789
790 /// Explicitly disable proxy usage
791 ///
792 /// This method ensures that no proxy will be used, regardless of environment variables.
793 ///
794 /// # Returns
795 /// * `Self` - Builder instance for method chaining
796 ///
797 /// # Example
798 /// ```rust
799 /// use ai_lib::{AiClientBuilder, Provider};
800 ///
801 /// let client = AiClientBuilder::new(Provider::Groq)
802 /// .build()?;
803 /// # Ok::<(), ai_lib::AiLibError>(())
804 /// ```
805 pub fn without_proxy(mut self) -> Self {
806 self.proxy_url = Some("".to_string());
807 self
808 }
809
810 /// Set custom timeout duration
811 ///
812 /// # Arguments
813 /// * `timeout` - Custom timeout duration
814 ///
815 /// # Returns
816 /// * `Self` - Builder instance for method chaining
817 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
818 self.timeout = Some(timeout);
819 self
820 }
821
822 /// Set connection pool configuration
823 ///
824 /// # Arguments
825 /// * `max_idle` - Maximum idle connections per host
826 /// * `idle_timeout` - Idle connection timeout duration
827 ///
828 /// # Returns
829 /// * `Self` - Builder instance for method chaining
830 pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
831 self.pool_max_idle = Some(max_idle);
832 self.pool_idle_timeout = Some(idle_timeout);
833 self
834 }
835
836 /// Set custom metrics implementation
837 ///
838 /// # Arguments
839 /// * `metrics` - Custom metrics implementation
840 ///
841 /// # Returns
842 /// * `Self` - Builder instance for method chaining
843 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
844 self.metrics = Some(metrics);
845 self
846 }
847
848 /// Build AiClient instance
849 ///
850 /// The build process applies configuration in the following priority order:
851 /// 1. Explicitly set configuration (via with_* methods)
852 /// 2. Environment variable configuration
853 /// 3. Default configuration
854 ///
855 /// # Returns
856 /// * `Result<AiClient, AiLibError>` - Returns client instance on success, error on failure
857 pub fn build(self) -> Result<AiClient, AiLibError> {
858 // 1. Determine base_url: explicit setting > environment variable > default
859 let base_url = self.determine_base_url()?;
860
861 // 2. Determine proxy_url: explicit setting > environment variable
862 let proxy_url = self.determine_proxy_url();
863
864 // 3. Determine timeout: explicit setting > default
865 let timeout = self
866 .timeout
867 .unwrap_or_else(|| std::time::Duration::from_secs(30));
868
869 // 4. Create custom ProviderConfig (if needed)
870 let config = self.create_custom_config(base_url)?;
871
872 // 5. Create custom HttpTransport (if needed)
873 let transport = self.create_custom_transport(proxy_url.clone(), timeout)?;
874
875 // 6. Create adapter
876 let adapter: Box<dyn ChatApi> = match self.provider {
877 // Use config-driven generic adapter
878 Provider::Groq => {
879 if let Some(custom_transport) = transport {
880 Box::new(GenericAdapter::with_transport_ref(
881 config,
882 custom_transport,
883 )?)
884 } else {
885 Box::new(GenericAdapter::new(config)?)
886 }
887 }
888 Provider::XaiGrok => {
889 if let Some(custom_transport) = transport {
890 Box::new(GenericAdapter::with_transport_ref(
891 config,
892 custom_transport,
893 )?)
894 } else {
895 Box::new(GenericAdapter::new(config)?)
896 }
897 }
898 Provider::Ollama => {
899 if let Some(custom_transport) = transport {
900 Box::new(GenericAdapter::with_transport_ref(
901 config,
902 custom_transport,
903 )?)
904 } else {
905 Box::new(GenericAdapter::new(config)?)
906 }
907 }
908 Provider::DeepSeek => {
909 if let Some(custom_transport) = transport {
910 Box::new(GenericAdapter::with_transport_ref(
911 config,
912 custom_transport,
913 )?)
914 } else {
915 Box::new(GenericAdapter::new(config)?)
916 }
917 }
918 Provider::Qwen => {
919 if let Some(custom_transport) = transport {
920 Box::new(GenericAdapter::with_transport_ref(
921 config,
922 custom_transport,
923 )?)
924 } else {
925 Box::new(GenericAdapter::new(config)?)
926 }
927 }
928 Provider::BaiduWenxin => {
929 if let Some(custom_transport) = transport {
930 Box::new(GenericAdapter::with_transport_ref(
931 config,
932 custom_transport,
933 )?)
934 } else {
935 Box::new(GenericAdapter::new(config)?)
936 }
937 }
938 Provider::TencentHunyuan => {
939 if let Some(custom_transport) = transport {
940 Box::new(GenericAdapter::with_transport_ref(
941 config,
942 custom_transport,
943 )?)
944 } else {
945 Box::new(GenericAdapter::new(config)?)
946 }
947 }
948 Provider::IflytekSpark => {
949 if let Some(custom_transport) = transport {
950 Box::new(GenericAdapter::with_transport_ref(
951 config,
952 custom_transport,
953 )?)
954 } else {
955 Box::new(GenericAdapter::new(config)?)
956 }
957 }
958 Provider::Moonshot => {
959 if let Some(custom_transport) = transport {
960 Box::new(GenericAdapter::with_transport_ref(
961 config,
962 custom_transport,
963 )?)
964 } else {
965 Box::new(GenericAdapter::new(config)?)
966 }
967 }
968 Provider::Anthropic => {
969 if let Some(custom_transport) = transport {
970 Box::new(GenericAdapter::with_transport_ref(
971 config,
972 custom_transport,
973 )?)
974 } else {
975 Box::new(GenericAdapter::new(config)?)
976 }
977 }
978 Provider::AzureOpenAI => {
979 if let Some(custom_transport) = transport {
980 Box::new(GenericAdapter::with_transport_ref(
981 config,
982 custom_transport,
983 )?)
984 } else {
985 Box::new(GenericAdapter::new(config)?)
986 }
987 }
988 Provider::HuggingFace => {
989 if let Some(custom_transport) = transport {
990 Box::new(GenericAdapter::with_transport_ref(
991 config,
992 custom_transport,
993 )?)
994 } else {
995 Box::new(GenericAdapter::new(config)?)
996 }
997 }
998 Provider::TogetherAI => {
999 if let Some(custom_transport) = transport {
1000 Box::new(GenericAdapter::with_transport_ref(
1001 config,
1002 custom_transport,
1003 )?)
1004 } else {
1005 Box::new(GenericAdapter::new(config)?)
1006 }
1007 }
1008 // Use independent adapters (these don't support custom configuration)
1009 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
1010 Provider::Gemini => Box::new(GeminiAdapter::new()?),
1011 Provider::Mistral => Box::new(MistralAdapter::new()?),
1012 Provider::Cohere => Box::new(CohereAdapter::new()?),
1013 };
1014
1015 // 7. Create AiClient
1016 let client = AiClient {
1017 provider: self.provider,
1018 adapter,
1019 metrics: self.metrics.unwrap_or_else(|| Arc::new(NoopMetrics::new())),
1020 connection_options: None,
1021 };
1022
1023 Ok(client)
1024 }
1025
1026 /// Determine base_url, priority: explicit setting > environment variable > default
1027 fn determine_base_url(&self) -> Result<String, AiLibError> {
1028 // 1. Explicitly set base_url
1029 if let Some(ref base_url) = self.base_url {
1030 return Ok(base_url.clone());
1031 }
1032
1033 // 2. base_url from environment variable
1034 let env_var_name = self.get_base_url_env_var_name();
1035 if let Ok(base_url) = std::env::var(&env_var_name) {
1036 return Ok(base_url);
1037 }
1038
1039 // 3. Use default configuration
1040 let default_config = self.get_default_provider_config()?;
1041 Ok(default_config.base_url)
1042 }
1043
1044 /// Determine proxy_url, priority: explicit setting > environment variable
1045 fn determine_proxy_url(&self) -> Option<String> {
1046 // 1. Explicitly set proxy_url
1047 if let Some(ref proxy_url) = self.proxy_url {
1048 // If proxy_url is empty string, it means explicitly no proxy
1049 if proxy_url.is_empty() {
1050 return None;
1051 }
1052 return Some(proxy_url.clone());
1053 }
1054
1055 // 2. AI_PROXY_URL from environment variable
1056 std::env::var("AI_PROXY_URL").ok()
1057 }
1058
1059 /// Get environment variable name for corresponding provider
1060 fn get_base_url_env_var_name(&self) -> String {
1061 match self.provider {
1062 Provider::Groq => "GROQ_BASE_URL".to_string(),
1063 Provider::XaiGrok => "GROK_BASE_URL".to_string(),
1064 Provider::Ollama => "OLLAMA_BASE_URL".to_string(),
1065 Provider::DeepSeek => "DEEPSEEK_BASE_URL".to_string(),
1066 Provider::Qwen => "DASHSCOPE_BASE_URL".to_string(),
1067 Provider::BaiduWenxin => "BAIDU_WENXIN_BASE_URL".to_string(),
1068 Provider::TencentHunyuan => "TENCENT_HUNYUAN_BASE_URL".to_string(),
1069 Provider::IflytekSpark => "IFLYTEK_BASE_URL".to_string(),
1070 Provider::Moonshot => "MOONSHOT_BASE_URL".to_string(),
1071 Provider::Anthropic => "ANTHROPIC_BASE_URL".to_string(),
1072 Provider::AzureOpenAI => "AZURE_OPENAI_BASE_URL".to_string(),
1073 Provider::HuggingFace => "HUGGINGFACE_BASE_URL".to_string(),
1074 Provider::TogetherAI => "TOGETHER_BASE_URL".to_string(),
1075 // These providers don't support custom base_url
1076 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
1077 "".to_string()
1078 }
1079 }
1080 }
1081
1082 /// Get default provider configuration
1083 fn get_default_provider_config(
1084 &self,
1085 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
1086 match self.provider {
1087 Provider::Groq => Ok(ProviderConfigs::groq()),
1088 Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
1089 Provider::Ollama => Ok(ProviderConfigs::ollama()),
1090 Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
1091 Provider::Qwen => Ok(ProviderConfigs::qwen()),
1092 Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
1093 Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
1094 Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
1095 Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
1096 Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
1097 Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
1098 Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
1099 Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
1100 // These providers don't support custom configuration
1101 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
1102 Err(AiLibError::ConfigurationError(
1103 "This provider does not support custom configuration".to_string(),
1104 ))
1105 }
1106 }
1107 }
1108
1109 /// Create custom ProviderConfig
1110 fn create_custom_config(
1111 &self,
1112 base_url: String,
1113 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
1114 let mut config = self.get_default_provider_config()?;
1115 config.base_url = base_url;
1116 Ok(config)
1117 }
1118
1119 /// Create custom HttpTransport
1120 fn create_custom_transport(
1121 &self,
1122 proxy_url: Option<String>,
1123 timeout: std::time::Duration,
1124 ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
1125 // If no custom configuration, return None (use default transport)
1126 if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
1127 return Ok(None);
1128 }
1129
1130 // Create custom HttpTransportConfig
1131 let transport_config = crate::transport::HttpTransportConfig {
1132 timeout,
1133 proxy: proxy_url,
1134 pool_max_idle_per_host: self.pool_max_idle,
1135 pool_idle_timeout: self.pool_idle_timeout,
1136 };
1137
1138 // Create custom HttpTransport
1139 let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
1140 Ok(Some(transport.boxed()))
1141 }
1142}
1143
1144/// Controllable streaming response
1145struct ControlledStream {
1146 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
1147 cancel_rx: Option<oneshot::Receiver<()>>,
1148}
1149
1150impl ControlledStream {
1151 fn new(
1152 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
1153 cancel_rx: oneshot::Receiver<()>,
1154 ) -> Self {
1155 Self {
1156 inner,
1157 cancel_rx: Some(cancel_rx),
1158 }
1159 }
1160}
1161
1162impl Stream for ControlledStream {
1163 type Item = Result<ChatCompletionChunk, AiLibError>;
1164
1165 fn poll_next(
1166 mut self: std::pin::Pin<&mut Self>,
1167 cx: &mut std::task::Context<'_>,
1168 ) -> std::task::Poll<Option<Self::Item>> {
1169 use futures::stream::StreamExt;
1170 use std::task::Poll;
1171
1172 // Check if cancelled
1173 if let Some(ref mut cancel_rx) = self.cancel_rx {
1174 match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
1175 Poll::Ready(_) => {
1176 self.cancel_rx = None;
1177 return Poll::Ready(Some(Err(AiLibError::ProviderError(
1178 "Stream cancelled".to_string(),
1179 ))));
1180 }
1181 Poll::Pending => {}
1182 }
1183 }
1184
1185 // Poll inner stream
1186 self.inner.poll_next_unpin(cx)
1187 }
1188}