ai_lib/
client.rs

1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::provider::{GeminiAdapter, GenericAdapter, OpenAiAdapter, ProviderConfigs};
3use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
4use futures::stream::Stream;
5use futures::Future;
6use tokio::sync::oneshot;
7
8/// AI模型提供商枚举
9#[derive(Debug, Clone, Copy)]
10pub enum Provider {
11    Groq,
12    OpenAI,
13    DeepSeek,
14    Gemini,
15    Anthropic,
16    // 特殊适配器
17    // Gemini,  // 需要独立适配器
18}
19
20/// 统一AI客户端
21///
22/// 使用示例:
23/// ```rust
24/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
25///
26/// #[tokio::main]
27/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
28///     // 切换模型提供商,只需更改 Provider 的值
29///     let client = AiClient::new(Provider::Groq)?;
30///     
31///     let request = ChatCompletionRequest::new(
32///         "test-model".to_string(),
33///         vec![Message {
34///             role: Role::User,
35///             content: "Hello".to_string(),
36///         }],
37///     );
38///     
39///     // 注意:这里需要设置GROQ_API_KEY环境变量才能实际调用API
40///     // 可选:设置AI_PROXY_URL环境变量使用代理服务器
41///     // let response = client.chat_completion(request).await?;
42///     
43///     println!("Client created successfully with provider: {:?}", client.current_provider());
44///     println!("Request prepared for model: {}", request.model);
45///     
46///     Ok(())
47/// }
48/// ```
49///
50/// # 代理服务器配置
51///
52/// 通过设置 `AI_PROXY_URL` 环境变量来配置代理服务器:
53///
54/// ```bash
55/// export AI_PROXY_URL=http://proxy.example.com:8080
56/// ```
57///
58/// 支持的代理格式:
59/// - HTTP代理: `http://proxy.example.com:8080`
60/// - HTTPS代理: `https://proxy.example.com:8080`  
61/// - 带认证: `http://user:pass@proxy.example.com:8080`
62pub struct AiClient {
63    provider: Provider,
64    adapter: Box<dyn ChatApi>,
65}
66
67impl AiClient {
68    /// 创建新的AI客户端
69    ///
70    /// # Arguments
71    /// * `provider` - 选择要使用的AI模型提供商
72    ///
73    /// # Returns
74    /// * `Result<Self, AiLibError>` - 成功时返回客户端实例,失败时返回错误
75    ///
76    /// # Example
77    /// ```rust
78    /// use ai_lib::{AiClient, Provider};
79    ///
80    /// let client = AiClient::new(Provider::Groq)?;
81    /// # Ok::<(), ai_lib::AiLibError>(())
82    /// ```
83    pub fn new(provider: Provider) -> Result<Self, AiLibError> {
84        let adapter: Box<dyn ChatApi> = match provider {
85            // 使用配置驱动的通用适配器
86            Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
87            Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
88            Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
89            // 使用独立适配器
90            Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
91            Provider::Gemini => Box::new(GeminiAdapter::new()?),
92        };
93
94        Ok(Self { provider, adapter })
95    }
96
97    /// 发送聊天完成请求
98    ///
99    /// # Arguments
100    /// * `request` - 聊天完成请求
101    ///
102    /// # Returns
103    /// * `Result<ChatCompletionResponse, AiLibError>` - 成功时返回响应,失败时返回错误
104    pub async fn chat_completion(
105        &self,
106        request: ChatCompletionRequest,
107    ) -> Result<ChatCompletionResponse, AiLibError> {
108        self.adapter.chat_completion(request).await
109    }
110
111    /// 流式聊天完成请求
112    ///
113    /// # Arguments
114    /// * `request` - 聊天完成请求
115    ///
116    /// # Returns
117    /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - 成功时返回流式响应
118    pub async fn chat_completion_stream(
119        &self,
120        mut request: ChatCompletionRequest,
121    ) -> Result<
122        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
123        AiLibError,
124    > {
125        request.stream = Some(true);
126        self.adapter.chat_completion_stream(request).await
127    }
128
129    /// 带取消控制的流式聊天完成请求
130    ///
131    /// # Arguments
132    /// * `request` - 聊天完成请求
133    ///
134    /// # Returns
135    /// * `(Stream, CancelHandle)` - 流式响应和取消句柄
136    pub async fn chat_completion_stream_with_cancel(
137        &self,
138        request: ChatCompletionRequest,
139    ) -> Result<
140        (
141            Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
142            CancelHandle,
143        ),
144        AiLibError,
145    > {
146        let (cancel_tx, cancel_rx) = oneshot::channel();
147        let stream = self.chat_completion_stream(request).await?;
148
149        let cancel_handle = CancelHandle {
150            sender: Some(cancel_tx),
151        };
152        let controlled_stream = ControlledStream::new(stream, cancel_rx);
153
154        Ok((Box::new(Box::pin(controlled_stream)), cancel_handle))
155    }
156
157    /// 获取支持的模型列表
158    ///
159    /// # Returns
160    /// * `Result<Vec<String>, AiLibError>` - 成功时返回模型列表,失败时返回错误
161    pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
162        self.adapter.list_models().await
163    }
164
165    /// 切换AI模型提供商
166    ///
167    /// # Arguments
168    /// * `provider` - 新的提供商
169    ///
170    /// # Returns
171    /// * `Result<(), AiLibError>` - 成功时返回(),失败时返回错误
172    ///
173    /// # Example
174    /// ```rust
175    /// use ai_lib::{AiClient, Provider};
176    ///
177    /// let mut client = AiClient::new(Provider::Groq)?;
178    /// // 从Groq切换到Groq(演示切换功能)
179    /// client.switch_provider(Provider::Groq)?;
180    /// # Ok::<(), ai_lib::AiLibError>(())
181    /// ```
182    pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
183        let new_adapter: Box<dyn ChatApi> = match provider {
184            Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
185            Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
186            Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
187            Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
188            Provider::Gemini => Box::new(GeminiAdapter::new()?),
189        };
190
191        self.provider = provider;
192        self.adapter = new_adapter;
193        Ok(())
194    }
195
196    /// 获取当前使用的提供商
197    pub fn current_provider(&self) -> Provider {
198        self.provider
199    }
200}
201
202/// 流式响应取消句柄
203pub struct CancelHandle {
204    sender: Option<oneshot::Sender<()>>,
205}
206
207impl CancelHandle {
208    /// 取消流式响应
209    pub fn cancel(mut self) {
210        if let Some(sender) = self.sender.take() {
211            let _ = sender.send(());
212        }
213    }
214}
215
216/// 可控制的流式响应
217struct ControlledStream {
218    inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
219    cancel_rx: Option<oneshot::Receiver<()>>,
220}
221
222impl ControlledStream {
223    fn new(
224        inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
225        cancel_rx: oneshot::Receiver<()>,
226    ) -> Self {
227        Self {
228            inner,
229            cancel_rx: Some(cancel_rx),
230        }
231    }
232}
233
234impl Stream for ControlledStream {
235    type Item = Result<ChatCompletionChunk, AiLibError>;
236
237    fn poll_next(
238        mut self: std::pin::Pin<&mut Self>,
239        cx: &mut std::task::Context<'_>,
240    ) -> std::task::Poll<Option<Self::Item>> {
241        use futures::stream::StreamExt;
242        use std::task::Poll;
243
244        // 检查是否被取消
245        if let Some(ref mut cancel_rx) = self.cancel_rx {
246            match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
247                Poll::Ready(_) => {
248                    self.cancel_rx = None;
249                    return Poll::Ready(Some(Err(AiLibError::ProviderError(
250                        "Stream cancelled".to_string(),
251                    ))));
252                }
253                Poll::Pending => {}
254            }
255        }
256
257        // 轮询内部流
258        self.inner.poll_next_unpin(cx)
259    }
260}