ai_lib/
client.rs

1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::provider::{
4    CohereAdapter, GeminiAdapter, GenericAdapter, MistralAdapter, OpenAiAdapter, ProviderConfigs,
5};
6use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
7use futures::stream::Stream;
8use futures::Future;
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12/// AI模型提供商枚举
13#[derive(Debug, Clone, Copy)]
14pub enum Provider {
15    // 驱动配置
16    Groq,
17    XaiGrok,
18    Ollama,
19    DeepSeek,
20    Anthropic,
21    AzureOpenAI,
22    HuggingFace,
23    TogetherAI,
24    // 中国区域提供商(OpenAI 兼容或配置驱动)
25    BaiduWenxin,
26    TencentHunyuan,
27    IflytekSpark,
28    Moonshot,
29    // 特殊适配器
30    OpenAI,
31    Qwen,
32    Gemini,
33    Mistral,
34    Cohere,
35    // Bedrock removed (deferred)
36}
37
38/// 统一AI客户端
39///
40/// 使用示例:
41/// ```rust
42/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
43///
44/// #[tokio::main]
45/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
46///     // 切换模型提供商,只需更改 Provider 的值
47///     let client = AiClient::new(Provider::Groq)?;
48///     
49///     let request = ChatCompletionRequest::new(
50///         "test-model".to_string(),
51///         vec![Message {
52///             role: Role::User,
53///             content: ai_lib::types::common::Content::Text("Hello".to_string()),
54///             function_call: None,
55///         }],
56///     );
57///     
58///     // 注意:这里需要设置GROQ_API_KEY环境变量才能实际调用API
59///     // 可选:设置AI_PROXY_URL环境变量使用代理服务器
60///     // let response = client.chat_completion(request).await?;
61///     
62///     println!("Client created successfully with provider: {:?}", client.current_provider());
63///     println!("Request prepared for model: {}", request.model);
64///     
65///     Ok(())
66/// }
67/// ```
68///
69/// # 代理服务器配置
70///
71/// 通过设置 `AI_PROXY_URL` 环境变量来配置代理服务器:
72///
73/// ```bash
74/// export AI_PROXY_URL=http://proxy.example.com:8080
75/// ```
76///
77/// 支持的代理格式:
78/// - HTTP代理: `http://proxy.example.com:8080`
79/// - HTTPS代理: `https://proxy.example.com:8080`  
80/// - 带认证: `http://user:pass@proxy.example.com:8080`
81pub struct AiClient {
82    provider: Provider,
83    adapter: Box<dyn ChatApi>,
84    metrics: Arc<dyn Metrics>,
85}
86
87impl AiClient {
88    /// 创建新的AI客户端
89    ///
90    /// # Arguments
91    /// * `provider` - 选择要使用的AI模型提供商
92    ///
93    /// # Returns
94    /// * `Result<Self, AiLibError>` - 成功时返回客户端实例,失败时返回错误
95    ///
96    /// # Example
97    /// ```rust
98    /// use ai_lib::{AiClient, Provider};
99    ///
100    /// let client = AiClient::new(Provider::Groq)?;
101    /// # Ok::<(), ai_lib::AiLibError>(())
102    /// ```
103    pub fn new(provider: Provider) -> Result<Self, AiLibError> {
104        let adapter: Box<dyn ChatApi> = match provider {
105            // 使用配置驱动的通用适配器
106            Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
107            Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
108            Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
109            Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
110            Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
111            Provider::BaiduWenxin => Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?),
112            Provider::TencentHunyuan => Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?),
113            Provider::IflytekSpark => Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?),
114            Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
115            Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
116            Provider::AzureOpenAI => {
117                Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
118            }
119            Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
120            Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
121            // 使用独立适配器
122            Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
123            Provider::Gemini => Box::new(GeminiAdapter::new()?),
124            Provider::Mistral => Box::new(MistralAdapter::new()?),
125            Provider::Cohere => Box::new(CohereAdapter::new()?),
126            // Bedrock deferred; not available
127        };
128
129        Ok(Self {
130            provider,
131            adapter,
132            metrics: Arc::new(NoopMetrics::new()),
133        })
134    }
135
136    /// Create AiClient with injected metrics implementation
137    pub fn new_with_metrics(
138        provider: Provider,
139        metrics: Arc<dyn Metrics>,
140    ) -> Result<Self, AiLibError> {
141        let adapter: Box<dyn ChatApi> = match provider {
142            Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
143            Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
144            Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
145            Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
146            Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
147            Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
148            Provider::BaiduWenxin => Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?),
149            Provider::TencentHunyuan => Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?),
150            Provider::IflytekSpark => Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?),
151            Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
152            Provider::AzureOpenAI => {
153                Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
154            }
155            Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
156            Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
157            Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
158            Provider::Gemini => Box::new(GeminiAdapter::new()?),
159            Provider::Mistral => Box::new(MistralAdapter::new()?),
160            Provider::Cohere => Box::new(CohereAdapter::new()?),
161        };
162
163        Ok(Self {
164            provider,
165            adapter,
166            metrics,
167        })
168    }
169
170    /// Set metrics implementation on client
171    pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
172        self.metrics = metrics;
173        self
174    }
175
176    /// 发送聊天完成请求
177    ///
178    /// # Arguments
179    /// * `request` - 聊天完成请求
180    ///
181    /// # Returns
182    /// * `Result<ChatCompletionResponse, AiLibError>` - 成功时返回响应,失败时返回错误
183    pub async fn chat_completion(
184        &self,
185        request: ChatCompletionRequest,
186    ) -> Result<ChatCompletionResponse, AiLibError> {
187        self.adapter.chat_completion(request).await
188    }
189
190    /// 流式聊天完成请求
191    ///
192    /// # Arguments
193    /// * `request` - 聊天完成请求
194    ///
195    /// # Returns
196    /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - 成功时返回流式响应
197    pub async fn chat_completion_stream(
198        &self,
199        mut request: ChatCompletionRequest,
200    ) -> Result<
201        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
202        AiLibError,
203    > {
204        request.stream = Some(true);
205        self.adapter.chat_completion_stream(request).await
206    }
207
208    /// 带取消控制的流式聊天完成请求
209    ///
210    /// # Arguments
211    /// * `request` - 聊天完成请求
212    ///
213    /// # Returns
214    /// * `Result<(impl Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin, CancelHandle), AiLibError>` - 成功时返回流式响应和取消句柄
215    pub async fn chat_completion_stream_with_cancel(
216        &self,
217        mut request: ChatCompletionRequest,
218    ) -> Result<
219        (
220            Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
221            CancelHandle,
222        ),
223        AiLibError,
224    > {
225        request.stream = Some(true);
226        let stream = self.adapter.chat_completion_stream(request).await?;
227        let (cancel_tx, cancel_rx) = oneshot::channel();
228        let cancel_handle = CancelHandle {
229            sender: Some(cancel_tx),
230        };
231        
232        let controlled_stream = ControlledStream::new(stream, cancel_rx);
233        Ok((Box::new(controlled_stream), cancel_handle))
234    }
235
236    /// 批量聊天完成请求
237    ///
238    /// # Arguments
239    /// * `requests` - 聊天完成请求列表
240    /// * `concurrency_limit` - 最大并发请求数(None表示无限制)
241    ///
242    /// # Returns
243    /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - 返回所有请求的响应结果
244    ///
245    /// # Example
246    /// ```rust
247    /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
248    /// use ai_lib::types::common::Content;
249    ///
250    /// #[tokio::main]
251    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
252    ///     let client = AiClient::new(Provider::Groq)?;
253    ///     
254    ///     let requests = vec![
255    ///         ChatCompletionRequest::new(
256    ///             "llama3-8b-8192".to_string(),
257    ///             vec![Message {
258    ///                 role: Role::User,
259    ///                 content: Content::Text("Hello".to_string()),
260    ///                 function_call: None,
261    ///             }],
262    ///         ),
263    ///         ChatCompletionRequest::new(
264    ///             "llama3-8b-8192".to_string(),
265    ///             vec![Message {
266    ///                 role: Role::User,
267    ///                 content: Content::Text("How are you?".to_string()),
268    ///                 function_call: None,
269    ///             }],
270    ///         ),
271    ///     ];
272    ///     
273    ///     // 限制并发数为5
274    ///     let responses = client.chat_completion_batch(requests, Some(5)).await?;
275    ///     
276    ///     for (i, response) in responses.iter().enumerate() {
277    ///         match response {
278    ///             Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
279    ///             Err(e) => println!("Request {} failed: {}", i, e),
280    ///         }
281    ///     }
282    ///     
283    ///     Ok(())
284    /// }
285    /// ```
286    pub async fn chat_completion_batch(
287        &self,
288        requests: Vec<ChatCompletionRequest>,
289        concurrency_limit: Option<usize>,
290    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
291        self.adapter.chat_completion_batch(requests, concurrency_limit).await
292    }
293
294    /// 智能批量处理:根据请求数量自动选择处理策略
295    ///
296    /// # Arguments
297    /// * `requests` - 聊天完成请求列表
298    ///
299    /// # Returns
300    /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - 返回所有请求的响应结果
301    pub async fn chat_completion_batch_smart(
302        &self,
303        requests: Vec<ChatCompletionRequest>,
304    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
305        // 小批量使用顺序处理,大批量使用并发处理
306        let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
307        self.chat_completion_batch(requests, concurrency_limit).await
308    }
309
310    /// 批量聊天完成请求
311    ///
312    /// # Arguments
313    /// * `requests` - 聊天完成请求列表
314    /// * `concurrency_limit` - 最大并发请求数(None表示无限制)
315    ///
316    /// # Returns
317    /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - 返回所有请求的响应结果
318    ///
319    /// # Example
320    /// ```rust
321    /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
322    /// use ai_lib::types::common::Content;
323    ///
324    /// #[tokio::main]
325    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
326    ///     let client = AiClient::new(Provider::Groq)?;
327    ///     
328    ///     let requests = vec![
329    ///         ChatCompletionRequest::new(
330    ///             "llama3-8b-8192".to_string(),
331    ///             vec![Message {
332    ///                 role: Role::User,
333    ///                 content: Content::Text("Hello".to_string()),
334    ///                 function_call: None,
335    ///             }],
336    ///         ),
337    ///         ChatCompletionRequest::new(
338    ///             "llama3-8b-8192".to_string(),
339    ///             vec![Message {
340    ///                 role: Role::User,
341    ///                 content: Content::Text("How are you?".to_string()),
342    ///                 function_call: None,
343    ///             }],
344    ///         ),
345    ///     ];
346    ///     
347    ///     // 限制并发数为5
348    ///     let responses = client.chat_completion_batch(requests, Some(5)).await?;
349    ///     
350    ///     for (i, response) in responses.iter().enumerate() {
351    ///         match response {
352    ///             Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
353    ///             Err(e) => println!("Request {} failed: {}", i, e),
354    ///         }
355    ///     }
356    ///     
357    ///     Ok(())
358    /// }
359    /// 获取支持的模型列表
360    ///
361    /// # Returns
362    /// * `Result<Vec<String>, AiLibError>` - 成功时返回模型列表,失败时返回错误
363    pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
364        self.adapter.list_models().await
365    }
366
367    /// 切换AI模型提供商
368    ///
369    /// # Arguments
370    /// * `provider` - 新的提供商
371    ///
372    /// # Returns
373    /// * `Result<(), AiLibError>` - 成功时返回(),失败时返回错误
374    ///
375    /// # Example
376    /// ```rust
377    /// use ai_lib::{AiClient, Provider};
378    ///
379    /// let mut client = AiClient::new(Provider::Groq)?;
380    /// // 从Groq切换到Groq(演示切换功能)
381    /// client.switch_provider(Provider::Groq)?;
382    /// # Ok::<(), ai_lib::AiLibError>(())
383    /// ```
384    pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
385        let new_adapter: Box<dyn ChatApi> = match provider {
386            Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
387            Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
388            Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
389            Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
390            Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
391            Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
392            Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
393            Provider::BaiduWenxin => Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?),
394            Provider::TencentHunyuan => Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?),
395            Provider::IflytekSpark => Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?),
396            Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
397            Provider::Gemini => Box::new(GeminiAdapter::new()?),
398            Provider::AzureOpenAI => {
399                Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
400            }
401            Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
402            Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
403            Provider::Mistral => Box::new(MistralAdapter::new()?),
404            Provider::Cohere => Box::new(CohereAdapter::new()?),
405            // Provider::Bedrock => Box::new(BedrockAdapter::new()?),
406        };
407
408        self.provider = provider;
409        self.adapter = new_adapter;
410        Ok(())
411    }
412
413    /// 获取当前使用的提供商
414    pub fn current_provider(&self) -> Provider {
415        self.provider
416    }
417}
418
419/// 流式响应取消句柄
420pub struct CancelHandle {
421    sender: Option<oneshot::Sender<()>>,
422}
423
424impl CancelHandle {
425    /// 取消流式响应
426    pub fn cancel(mut self) {
427        if let Some(sender) = self.sender.take() {
428            let _ = sender.send(());
429        }
430    }
431}
432
433/// 可控制的流式响应
434struct ControlledStream {
435    inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
436    cancel_rx: Option<oneshot::Receiver<()>>,
437}
438
439impl ControlledStream {
440    fn new(
441        inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
442        cancel_rx: oneshot::Receiver<()>,
443    ) -> Self {
444        Self {
445            inner,
446            cancel_rx: Some(cancel_rx),
447        }
448    }
449}
450
451impl Stream for ControlledStream {
452    type Item = Result<ChatCompletionChunk, AiLibError>;
453
454    fn poll_next(
455        mut self: std::pin::Pin<&mut Self>,
456        cx: &mut std::task::Context<'_>,
457    ) -> std::task::Poll<Option<Self::Item>> {
458        use futures::stream::StreamExt;
459        use std::task::Poll;
460
461        // 检查是否被取消
462        if let Some(ref mut cancel_rx) = self.cancel_rx {
463            match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
464                Poll::Ready(_) => {
465                    self.cancel_rx = None;
466                    return Poll::Ready(Some(Err(AiLibError::ProviderError(
467                        "Stream cancelled".to_string(),
468                    ))));
469                }
470                Poll::Pending => {}
471            }
472        }
473
474        // 轮询内部流
475        self.inner.poll_next_unpin(cx)
476    }
477}