ai_lib/api/
chat.rs

1use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
2use async_trait::async_trait;
3use futures::stream::Stream;
4
5/// 通用的聊天API接口,定义所有AI服务的核心能力
6///
7/// Generic chat API interface
8///
9/// This trait defines the core capabilities that all AI services should have,
10/// without depending on any specific model implementation details
11#[async_trait]
12pub trait ChatApi: Send + Sync {
13    /// Send chat completion request
14    ///
15    /// # Arguments
16    /// * `request` - Generic chat completion request
17    ///
18    /// # Returns
19    /// * `Result<ChatCompletionResponse, AiLibError>` - Returns response on success, error on failure
20    async fn chat_completion(
21        &self,
22        request: ChatCompletionRequest,
23    ) -> Result<ChatCompletionResponse, AiLibError>;
24
25    /// Streaming chat completion request
26    ///
27    /// # Arguments
28    /// * `request` - Generic chat completion request
29    ///
30    /// # Returns
31    /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - Returns streaming response on success
32    async fn chat_completion_stream(
33        &self,
34        request: ChatCompletionRequest,
35    ) -> Result<
36        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
37        AiLibError,
38    >;
39
40    /// Get list of supported models
41    ///
42    /// # Returns
43    /// * `Result<Vec<String>, AiLibError>` - Returns model list on success, error on failure
44    async fn list_models(&self) -> Result<Vec<String>, AiLibError>;
45
46    /// Get model information
47    ///
48    /// # Arguments
49    /// * `model_id` - Model ID
50    ///
51    /// # Returns
52    /// * `Result<ModelInfo, AiLibError>` - Returns model information on success, error on failure
53    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError>;
54
55    /// 批处理聊天完成请求
56    ///
57    /// Batch chat completion requests
58    ///
59    /// # Arguments
60    /// * `requests` - Vector of chat completion requests
61    /// * `concurrency_limit` - Optional concurrency limit for concurrent processing
62    ///
63    /// # Returns
64    /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns vector of results
65    async fn chat_completion_batch(
66        &self,
67        requests: Vec<ChatCompletionRequest>,
68        concurrency_limit: Option<usize>,
69    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
70        batch_utils::process_batch_concurrent(self, requests, concurrency_limit).await
71    }
72}
73
74/// 流式响应的数据块
75///
76/// Streaming response data chunk
77#[derive(Debug, Clone)]
78pub struct ChatCompletionChunk {
79    pub id: String,
80    pub object: String,
81    pub created: u64,
82    pub model: String,
83    pub choices: Vec<ChoiceDelta>,
84}
85
86/// 流式响应的选择项增量
87///
88/// Streaming response choice delta
89#[derive(Debug, Clone)]
90pub struct ChoiceDelta {
91    pub index: u32,
92    pub delta: MessageDelta,
93    pub finish_reason: Option<String>,
94}
95
96/// 消息增量
97///
98/// Message delta
99#[derive(Debug, Clone)]
100pub struct MessageDelta {
101    pub role: Option<Role>,
102    pub content: Option<String>,
103}
104
105/// 模型信息
106///
107/// Model information
108#[derive(Debug, Clone)]
109pub struct ModelInfo {
110    pub id: String,
111    pub object: String,
112    pub created: u64,
113    pub owned_by: String,
114    pub permission: Vec<ModelPermission>,
115}
116
117/// 模型权限
118///
119/// Model permission
120#[derive(Debug, Clone)]
121pub struct ModelPermission {
122    pub id: String,
123    pub object: String,
124    pub created: u64,
125    pub allow_create_engine: bool,
126    pub allow_sampling: bool,
127    pub allow_logprobs: bool,
128    pub allow_search_indices: bool,
129    pub allow_view: bool,
130    pub allow_fine_tuning: bool,
131    pub organization: String,
132    pub group: Option<String>,
133    pub is_blocking: bool,
134}
135
136// Re-export Role type as it's also needed in streaming responses
137use crate::types::Role;
138
139/// 批处理结果,包含成功和失败的响应
140///
141/// Batch processing result containing successful and failed responses
142#[derive(Debug)]
143pub struct BatchResult {
144    pub successful: Vec<ChatCompletionResponse>,
145    pub failed: Vec<(usize, AiLibError)>,
146    pub total_requests: usize,
147    pub total_successful: usize,
148    pub total_failed: usize,
149}
150
151impl BatchResult {
152    /// Create a new batch result
153    pub fn new(total_requests: usize) -> Self {
154        Self {
155            successful: Vec::new(),
156            failed: Vec::new(),
157            total_requests,
158            total_successful: 0,
159            total_failed: 0,
160        }
161    }
162
163    /// Add a successful response
164    pub fn add_success(&mut self, response: ChatCompletionResponse) {
165        self.successful.push(response);
166        self.total_successful += 1;
167    }
168
169    /// Add a failed response with index
170    pub fn add_failure(&mut self, index: usize, error: AiLibError) {
171        self.failed.push((index, error));
172        self.total_failed += 1;
173    }
174
175    /// Check if all requests were successful
176    pub fn all_successful(&self) -> bool {
177        self.total_failed == 0
178    }
179
180    /// Get success rate as a percentage
181    pub fn success_rate(&self) -> f64 {
182        if self.total_requests == 0 {
183            0.0
184        } else {
185            (self.total_successful as f64 / self.total_requests as f64) * 100.0
186        }
187    }
188}
189
190/// 批处理工具函数
191///
192/// Batch processing utility functions
193pub mod batch_utils {
194    use super::*;
195    use futures::stream::{self, StreamExt};
196    use std::sync::Arc;
197    use tokio::sync::Semaphore;
198
199    /// 并发处理批处理请求的默认实现
200    ///
201    /// Default implementation for concurrent batch processing
202    pub async fn process_batch_concurrent<T: ChatApi + ?Sized>(
203        api: &T,
204        requests: Vec<ChatCompletionRequest>,
205        concurrency_limit: Option<usize>,
206    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
207        if requests.is_empty() {
208            return Ok(Vec::new());
209        }
210
211        let semaphore = concurrency_limit.map(|limit| Arc::new(Semaphore::new(limit)));
212        
213        let futures = requests.into_iter().enumerate().map(|(index, request)| {
214            let api_ref = api;
215            let semaphore_ref = semaphore.clone();
216            
217            async move {
218                // Acquire permit if concurrency limit is set
219                let _permit = if let Some(sem) = &semaphore_ref {
220                    match sem.acquire().await {
221                        Ok(permit) => Some(permit),
222                        Err(_) => return (index, Err(AiLibError::ProviderError("Failed to acquire semaphore permit".to_string()))),
223                    }
224                } else {
225                    None
226                };
227
228                // Process the request
229                let result = api_ref.chat_completion(request).await;
230                
231                // Return result with index for ordering
232                (index, result)
233            }
234        });
235
236        // Execute all futures concurrently
237        let results: Vec<_> = stream::iter(futures)
238            .buffer_unordered(concurrency_limit.unwrap_or(usize::MAX))
239            .collect()
240            .await;
241
242        // Sort results by original index to maintain order
243        let mut sorted_results = Vec::with_capacity(results.len());
244        sorted_results.resize_with(results.len(), || Err(AiLibError::ProviderError("Placeholder".to_string())));
245        for (index, result) in results {
246            sorted_results[index] = result;
247        }
248
249        Ok(sorted_results)
250    }
251
252    /// 顺序处理批处理请求的实现
253    ///
254    /// Sequential batch processing implementation
255    pub async fn process_batch_sequential<T: ChatApi + ?Sized>(
256        api: &T,
257        requests: Vec<ChatCompletionRequest>,
258    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
259        let mut results = Vec::with_capacity(requests.len());
260        
261        for request in requests {
262            let result = api.chat_completion(request).await;
263            results.push(result);
264        }
265        
266        Ok(results)
267    }
268
269    /// 智能批处理:根据请求类型和大小自动选择处理策略
270    ///
271    /// Smart batch processing: automatically choose processing strategy based on request type and size
272    pub async fn process_batch_smart<T: ChatApi + ?Sized>(
273        api: &T,
274        requests: Vec<ChatCompletionRequest>,
275        concurrency_limit: Option<usize>,
276    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
277        let request_count = requests.len();
278        
279        // For small batches, use sequential processing
280        if request_count <= 3 {
281            return process_batch_sequential(api, requests).await;
282        }
283        
284        // For larger batches, use concurrent processing
285        process_batch_concurrent(api, requests, concurrency_limit).await
286    }
287}