ai_lib/client.rs
1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::provider::{
4 CohereAdapter, GeminiAdapter, GenericAdapter, MistralAdapter, OpenAiAdapter, ProviderConfigs,
5};
6use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
7use futures::stream::Stream;
8use futures::Future;
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12/// AI模型提供商枚举
13#[derive(Debug, Clone, Copy)]
14pub enum Provider {
15 // 驱动配置
16 Groq,
17 XaiGrok,
18 Ollama,
19 DeepSeek,
20 Anthropic,
21 AzureOpenAI,
22 HuggingFace,
23 TogetherAI,
24 // 中国区域提供商(OpenAI 兼容或配置驱动)
25 BaiduWenxin,
26 TencentHunyuan,
27 IflytekSpark,
28 Moonshot,
29 // 特殊适配器
30 OpenAI,
31 Qwen,
32 Gemini,
33 Mistral,
34 Cohere,
35 // Bedrock removed (deferred)
36}
37
38/// 统一AI客户端
39///
40/// 使用示例:
41/// ```rust
42/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
43///
44/// #[tokio::main]
45/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
46/// // 切换模型提供商,只需更改 Provider 的值
47/// let client = AiClient::new(Provider::Groq)?;
48///
49/// let request = ChatCompletionRequest::new(
50/// "test-model".to_string(),
51/// vec![Message {
52/// role: Role::User,
53/// content: ai_lib::types::common::Content::Text("Hello".to_string()),
54/// function_call: None,
55/// }],
56/// );
57///
58/// // 注意:这里需要设置GROQ_API_KEY环境变量才能实际调用API
59/// // 可选:设置AI_PROXY_URL环境变量使用代理服务器
60/// // let response = client.chat_completion(request).await?;
61///
62/// println!("Client created successfully with provider: {:?}", client.current_provider());
63/// println!("Request prepared for model: {}", request.model);
64///
65/// Ok(())
66/// }
67/// ```
68///
69/// # 代理服务器配置
70///
71/// 通过设置 `AI_PROXY_URL` 环境变量来配置代理服务器:
72///
73/// ```bash
74/// export AI_PROXY_URL=http://proxy.example.com:8080
75/// ```
76///
77/// 支持的代理格式:
78/// - HTTP代理: `http://proxy.example.com:8080`
79/// - HTTPS代理: `https://proxy.example.com:8080`
80/// - 带认证: `http://user:pass@proxy.example.com:8080`
81pub struct AiClient {
82 provider: Provider,
83 adapter: Box<dyn ChatApi>,
84 metrics: Arc<dyn Metrics>,
85}
86
87impl AiClient {
88 /// 创建新的AI客户端
89 ///
90 /// # Arguments
91 /// * `provider` - 选择要使用的AI模型提供商
92 ///
93 /// # Returns
94 /// * `Result<Self, AiLibError>` - 成功时返回客户端实例,失败时返回错误
95 ///
96 /// # Example
97 /// ```rust
98 /// use ai_lib::{AiClient, Provider};
99 ///
100 /// let client = AiClient::new(Provider::Groq)?;
101 /// # Ok::<(), ai_lib::AiLibError>(())
102 /// ```
103 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
104 let adapter: Box<dyn ChatApi> = match provider {
105 // 使用配置驱动的通用适配器
106 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
107 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
108 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
109 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
110 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
111 Provider::BaiduWenxin => Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?),
112 Provider::TencentHunyuan => Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?),
113 Provider::IflytekSpark => Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?),
114 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
115 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
116 Provider::AzureOpenAI => {
117 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
118 }
119 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
120 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
121 // 使用独立适配器
122 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
123 Provider::Gemini => Box::new(GeminiAdapter::new()?),
124 Provider::Mistral => Box::new(MistralAdapter::new()?),
125 Provider::Cohere => Box::new(CohereAdapter::new()?),
126 // Bedrock deferred; not available
127 };
128
129 Ok(Self {
130 provider,
131 adapter,
132 metrics: Arc::new(NoopMetrics::new()),
133 })
134 }
135
136 /// Create AiClient with injected metrics implementation
137 pub fn new_with_metrics(
138 provider: Provider,
139 metrics: Arc<dyn Metrics>,
140 ) -> Result<Self, AiLibError> {
141 let adapter: Box<dyn ChatApi> = match provider {
142 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
143 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
144 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
145 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
146 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
147 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
148 Provider::BaiduWenxin => Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?),
149 Provider::TencentHunyuan => Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?),
150 Provider::IflytekSpark => Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?),
151 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
152 Provider::AzureOpenAI => {
153 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
154 }
155 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
156 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
157 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
158 Provider::Gemini => Box::new(GeminiAdapter::new()?),
159 Provider::Mistral => Box::new(MistralAdapter::new()?),
160 Provider::Cohere => Box::new(CohereAdapter::new()?),
161 };
162
163 Ok(Self {
164 provider,
165 adapter,
166 metrics,
167 })
168 }
169
170 /// Set metrics implementation on client
171 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
172 self.metrics = metrics;
173 self
174 }
175
176 /// 发送聊天完成请求
177 ///
178 /// # Arguments
179 /// * `request` - 聊天完成请求
180 ///
181 /// # Returns
182 /// * `Result<ChatCompletionResponse, AiLibError>` - 成功时返回响应,失败时返回错误
183 pub async fn chat_completion(
184 &self,
185 request: ChatCompletionRequest,
186 ) -> Result<ChatCompletionResponse, AiLibError> {
187 self.adapter.chat_completion(request).await
188 }
189
190 /// 流式聊天完成请求
191 ///
192 /// # Arguments
193 /// * `request` - 聊天完成请求
194 ///
195 /// # Returns
196 /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - 成功时返回流式响应
197 pub async fn chat_completion_stream(
198 &self,
199 mut request: ChatCompletionRequest,
200 ) -> Result<
201 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
202 AiLibError,
203 > {
204 request.stream = Some(true);
205 self.adapter.chat_completion_stream(request).await
206 }
207
208 /// 带取消控制的流式聊天完成请求
209 ///
210 /// # Arguments
211 /// * `request` - 聊天完成请求
212 ///
213 /// # Returns
214 /// * `Result<(impl Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin, CancelHandle), AiLibError>` - 成功时返回流式响应和取消句柄
215 pub async fn chat_completion_stream_with_cancel(
216 &self,
217 mut request: ChatCompletionRequest,
218 ) -> Result<
219 (
220 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
221 CancelHandle,
222 ),
223 AiLibError,
224 > {
225 request.stream = Some(true);
226 let stream = self.adapter.chat_completion_stream(request).await?;
227 let (cancel_tx, cancel_rx) = oneshot::channel();
228 let cancel_handle = CancelHandle {
229 sender: Some(cancel_tx),
230 };
231
232 let controlled_stream = ControlledStream::new(stream, cancel_rx);
233 Ok((Box::new(controlled_stream), cancel_handle))
234 }
235
236 /// 批量聊天完成请求
237 ///
238 /// # Arguments
239 /// * `requests` - 聊天完成请求列表
240 /// * `concurrency_limit` - 最大并发请求数(None表示无限制)
241 ///
242 /// # Returns
243 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - 返回所有请求的响应结果
244 ///
245 /// # Example
246 /// ```rust
247 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
248 /// use ai_lib::types::common::Content;
249 ///
250 /// #[tokio::main]
251 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
252 /// let client = AiClient::new(Provider::Groq)?;
253 ///
254 /// let requests = vec![
255 /// ChatCompletionRequest::new(
256 /// "llama3-8b-8192".to_string(),
257 /// vec![Message {
258 /// role: Role::User,
259 /// content: Content::Text("Hello".to_string()),
260 /// function_call: None,
261 /// }],
262 /// ),
263 /// ChatCompletionRequest::new(
264 /// "llama3-8b-8192".to_string(),
265 /// vec![Message {
266 /// role: Role::User,
267 /// content: Content::Text("How are you?".to_string()),
268 /// function_call: None,
269 /// }],
270 /// ),
271 /// ];
272 ///
273 /// // 限制并发数为5
274 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
275 ///
276 /// for (i, response) in responses.iter().enumerate() {
277 /// match response {
278 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
279 /// Err(e) => println!("Request {} failed: {}", i, e),
280 /// }
281 /// }
282 ///
283 /// Ok(())
284 /// }
285 /// ```
286 pub async fn chat_completion_batch(
287 &self,
288 requests: Vec<ChatCompletionRequest>,
289 concurrency_limit: Option<usize>,
290 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
291 self.adapter.chat_completion_batch(requests, concurrency_limit).await
292 }
293
294 /// 智能批量处理:根据请求数量自动选择处理策略
295 ///
296 /// # Arguments
297 /// * `requests` - 聊天完成请求列表
298 ///
299 /// # Returns
300 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - 返回所有请求的响应结果
301 pub async fn chat_completion_batch_smart(
302 &self,
303 requests: Vec<ChatCompletionRequest>,
304 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
305 // 小批量使用顺序处理,大批量使用并发处理
306 let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
307 self.chat_completion_batch(requests, concurrency_limit).await
308 }
309
310 /// 批量聊天完成请求
311 ///
312 /// # Arguments
313 /// * `requests` - 聊天完成请求列表
314 /// * `concurrency_limit` - 最大并发请求数(None表示无限制)
315 ///
316 /// # Returns
317 /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - 返回所有请求的响应结果
318 ///
319 /// # Example
320 /// ```rust
321 /// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
322 /// use ai_lib::types::common::Content;
323 ///
324 /// #[tokio::main]
325 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
326 /// let client = AiClient::new(Provider::Groq)?;
327 ///
328 /// let requests = vec![
329 /// ChatCompletionRequest::new(
330 /// "llama3-8b-8192".to_string(),
331 /// vec![Message {
332 /// role: Role::User,
333 /// content: Content::Text("Hello".to_string()),
334 /// function_call: None,
335 /// }],
336 /// ),
337 /// ChatCompletionRequest::new(
338 /// "llama3-8b-8192".to_string(),
339 /// vec![Message {
340 /// role: Role::User,
341 /// content: Content::Text("How are you?".to_string()),
342 /// function_call: None,
343 /// }],
344 /// ),
345 /// ];
346 ///
347 /// // 限制并发数为5
348 /// let responses = client.chat_completion_batch(requests, Some(5)).await?;
349 ///
350 /// for (i, response) in responses.iter().enumerate() {
351 /// match response {
352 /// Ok(resp) => println!("Request {}: {}", i, resp.choices[0].message.content.as_text()),
353 /// Err(e) => println!("Request {} failed: {}", i, e),
354 /// }
355 /// }
356 ///
357 /// Ok(())
358 /// }
359 /// 获取支持的模型列表
360 ///
361 /// # Returns
362 /// * `Result<Vec<String>, AiLibError>` - 成功时返回模型列表,失败时返回错误
363 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
364 self.adapter.list_models().await
365 }
366
367 /// 切换AI模型提供商
368 ///
369 /// # Arguments
370 /// * `provider` - 新的提供商
371 ///
372 /// # Returns
373 /// * `Result<(), AiLibError>` - 成功时返回(),失败时返回错误
374 ///
375 /// # Example
376 /// ```rust
377 /// use ai_lib::{AiClient, Provider};
378 ///
379 /// let mut client = AiClient::new(Provider::Groq)?;
380 /// // 从Groq切换到Groq(演示切换功能)
381 /// client.switch_provider(Provider::Groq)?;
382 /// # Ok::<(), ai_lib::AiLibError>(())
383 /// ```
384 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
385 let new_adapter: Box<dyn ChatApi> = match provider {
386 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
387 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
388 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
389 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
390 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
391 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
392 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
393 Provider::BaiduWenxin => Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?),
394 Provider::TencentHunyuan => Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?),
395 Provider::IflytekSpark => Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?),
396 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
397 Provider::Gemini => Box::new(GeminiAdapter::new()?),
398 Provider::AzureOpenAI => {
399 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
400 }
401 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
402 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
403 Provider::Mistral => Box::new(MistralAdapter::new()?),
404 Provider::Cohere => Box::new(CohereAdapter::new()?),
405 // Provider::Bedrock => Box::new(BedrockAdapter::new()?),
406 };
407
408 self.provider = provider;
409 self.adapter = new_adapter;
410 Ok(())
411 }
412
413 /// 获取当前使用的提供商
414 pub fn current_provider(&self) -> Provider {
415 self.provider
416 }
417}
418
419/// 流式响应取消句柄
420pub struct CancelHandle {
421 sender: Option<oneshot::Sender<()>>,
422}
423
424impl CancelHandle {
425 /// 取消流式响应
426 pub fn cancel(mut self) {
427 if let Some(sender) = self.sender.take() {
428 let _ = sender.send(());
429 }
430 }
431}
432
433/// 可控制的流式响应
434struct ControlledStream {
435 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
436 cancel_rx: Option<oneshot::Receiver<()>>,
437}
438
439impl ControlledStream {
440 fn new(
441 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
442 cancel_rx: oneshot::Receiver<()>,
443 ) -> Self {
444 Self {
445 inner,
446 cancel_rx: Some(cancel_rx),
447 }
448 }
449}
450
451impl Stream for ControlledStream {
452 type Item = Result<ChatCompletionChunk, AiLibError>;
453
454 fn poll_next(
455 mut self: std::pin::Pin<&mut Self>,
456 cx: &mut std::task::Context<'_>,
457 ) -> std::task::Poll<Option<Self::Item>> {
458 use futures::stream::StreamExt;
459 use std::task::Poll;
460
461 // 检查是否被取消
462 if let Some(ref mut cancel_rx) = self.cancel_rx {
463 match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
464 Poll::Ready(_) => {
465 self.cancel_rx = None;
466 return Poll::Ready(Some(Err(AiLibError::ProviderError(
467 "Stream cancelled".to_string(),
468 ))));
469 }
470 Poll::Pending => {}
471 }
472 }
473
474 // 轮询内部流
475 self.inner.poll_next_unpin(cx)
476 }
477}