ai_lib/api/
chat.rs

1use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
2use async_trait::async_trait;
3use futures::stream::Stream;
4
5/// Unified chat provider trait exposed to the rest of the crate.
6///
7/// This supersedes the legacy `ChatApi` naming so that downstream consumers
8/// can implement a single trait (`ChatProvider`) whether they are composing
9/// routing strategies or writing bespoke adapters.
10#[async_trait]
11pub trait ChatProvider: Send + Sync {
12    /// Human readable provider name (used in logs/metrics/strategies).
13    fn name(&self) -> &str;
14
15    /// Send chat completion request
16    ///
17    /// # Arguments
18    /// * `request` - Generic chat completion request
19    ///
20    /// # Returns
21    /// * `Result<ChatCompletionResponse, AiLibError>` - Returns response on success, error on failure
22    async fn chat(
23        &self,
24        request: ChatCompletionRequest,
25    ) -> Result<ChatCompletionResponse, AiLibError>;
26
27    /// Streaming chat completion request
28    ///
29    /// # Arguments
30    /// * `request` - Generic chat completion request
31    ///
32    /// # Returns
33    /// * `Result<impl Stream<Item = Result<ChatCompletionChunk, AiLibError>>, AiLibError>` - Returns streaming response on success
34    async fn stream(
35        &self,
36        request: ChatCompletionRequest,
37    ) -> Result<
38        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
39        AiLibError,
40    >;
41
42    /// Get list of supported models
43    ///
44    /// # Returns
45    /// * `Result<Vec<String>, AiLibError>` - Returns model list on success, error on failure
46    async fn list_models(&self) -> Result<Vec<String>, AiLibError>;
47
48    /// Get model information
49    ///
50    /// # Arguments
51    /// * `model_id` - Model ID
52    ///
53    /// # Returns
54    /// * `Result<ModelInfo, AiLibError>` - Returns model information on success, error on failure
55    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError>;
56
57    /// Batch chat completion requests
58    ///
59    /// # Arguments
60    /// * `requests` - Vector of chat completion requests
61    /// * `concurrency_limit` - Optional concurrency limit for concurrent processing
62    ///
63    /// # Returns
64    /// * `Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError>` - Returns vector of results
65    async fn batch(
66        &self,
67        requests: Vec<ChatCompletionRequest>,
68        concurrency_limit: Option<usize>,
69    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
70        batch_utils::process_batch_concurrent(self, requests, concurrency_limit).await
71    }
72}
73
74/// Backwards compatibility alias for the legacy `ChatApi` name.
75pub use ChatProvider as ChatApi;
76
77/// Streaming response data chunk
78#[derive(Debug, Clone)]
79pub struct ChatCompletionChunk {
80    pub id: String,
81    pub object: String,
82    pub created: u64,
83    pub model: String,
84    pub choices: Vec<ChoiceDelta>,
85}
86
87/// Streaming response choice delta
88#[derive(Debug, Clone)]
89pub struct ChoiceDelta {
90    pub index: u32,
91    pub delta: MessageDelta,
92    pub finish_reason: Option<String>,
93}
94
95/// Message delta
96#[derive(Debug, Clone)]
97pub struct MessageDelta {
98    pub role: Option<Role>,
99    pub content: Option<String>,
100}
101
102/// Model information
103///
104/// Model information
105#[derive(Debug, Clone)]
106pub struct ModelInfo {
107    pub id: String,
108    pub object: String,
109    pub created: u64,
110    pub owned_by: String,
111    pub permission: Vec<ModelPermission>,
112}
113
114/// Model permission
115#[derive(Debug, Clone)]
116pub struct ModelPermission {
117    pub id: String,
118    pub object: String,
119    pub created: u64,
120    pub allow_create_engine: bool,
121    pub allow_sampling: bool,
122    pub allow_logprobs: bool,
123    pub allow_search_indices: bool,
124    pub allow_view: bool,
125    pub allow_fine_tuning: bool,
126    pub organization: String,
127    pub group: Option<String>,
128    pub is_blocking: bool,
129}
130
131// Re-export Role type as it's also needed in streaming responses
132use crate::types::Role;
133
134/// Batch processing result containing successful and failed responses
135#[derive(Debug)]
136pub struct BatchResult {
137    pub successful: Vec<ChatCompletionResponse>,
138    pub failed: Vec<(usize, AiLibError)>,
139    pub total_requests: usize,
140    pub total_successful: usize,
141    pub total_failed: usize,
142}
143
144impl BatchResult {
145    /// Create a new batch result
146    pub fn new(total_requests: usize) -> Self {
147        Self {
148            successful: Vec::new(),
149            failed: Vec::new(),
150            total_requests,
151            total_successful: 0,
152            total_failed: 0,
153        }
154    }
155
156    /// Add a successful response
157    pub fn add_success(&mut self, response: ChatCompletionResponse) {
158        self.successful.push(response);
159        self.total_successful += 1;
160    }
161
162    /// Add a failed response with index
163    pub fn add_failure(&mut self, index: usize, error: AiLibError) {
164        self.failed.push((index, error));
165        self.total_failed += 1;
166    }
167
168    /// Check if all requests were successful
169    pub fn all_successful(&self) -> bool {
170        self.total_failed == 0
171    }
172
173    /// Get success rate as a percentage
174    pub fn success_rate(&self) -> f64 {
175        if self.total_requests == 0 {
176            0.0
177        } else {
178            (self.total_successful as f64 / self.total_requests as f64) * 100.0
179        }
180    }
181}
182
183/// Batch processing utility functions
184pub mod batch_utils {
185    use super::*;
186    use futures::stream::{self, StreamExt};
187    use std::sync::Arc;
188    use tokio::sync::Semaphore;
189
190    /// Default implementation for concurrent batch processing
191    pub async fn process_batch_concurrent<T: ChatProvider + ?Sized>(
192        api: &T,
193        requests: Vec<ChatCompletionRequest>,
194        concurrency_limit: Option<usize>,
195    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
196        if requests.is_empty() {
197            return Ok(Vec::new());
198        }
199
200        let semaphore = concurrency_limit.map(|limit| Arc::new(Semaphore::new(limit)));
201
202        let futures = requests.into_iter().enumerate().map(|(index, request)| {
203            let api_ref = api;
204            let semaphore_ref = semaphore.clone();
205
206            async move {
207                // Acquire permit if concurrency limit is set
208                let _permit = if let Some(sem) = &semaphore_ref {
209                    match sem.acquire().await {
210                        Ok(permit) => Some(permit),
211                        Err(_) => {
212                            return (
213                                index,
214                                Err(AiLibError::ProviderError(
215                                    "Failed to acquire semaphore permit".to_string(),
216                                )),
217                            )
218                        }
219                    }
220                } else {
221                    None
222                };
223
224                // Process the request
225                let result = api_ref.chat(request).await;
226
227                // Return result with index for ordering
228                (index, result)
229            }
230        });
231
232        // Execute all futures concurrently
233        let results: Vec<_> = stream::iter(futures)
234            .buffer_unordered(concurrency_limit.unwrap_or(usize::MAX))
235            .collect()
236            .await;
237
238        // Sort results by original index to maintain order
239        let mut sorted_results = Vec::with_capacity(results.len());
240        sorted_results.resize_with(results.len(), || {
241            Err(AiLibError::ProviderError("Placeholder".to_string()))
242        });
243        for (index, result) in results {
244            sorted_results[index] = result;
245        }
246
247        Ok(sorted_results)
248    }
249
250    /// Sequential batch processing implementation
251    pub async fn process_batch_sequential<T: ChatProvider + ?Sized>(
252        api: &T,
253        requests: Vec<ChatCompletionRequest>,
254    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
255        let mut results = Vec::with_capacity(requests.len());
256
257        for request in requests {
258            let result = api.chat(request).await;
259            results.push(result);
260        }
261
262        Ok(results)
263    }
264
265    /// Smart batch processing: automatically choose processing strategy based on request type and size
266    pub async fn process_batch_smart<T: ChatProvider + ?Sized>(
267        api: &T,
268        requests: Vec<ChatCompletionRequest>,
269        concurrency_limit: Option<usize>,
270    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
271        let request_count = requests.len();
272
273        // For small batches, use sequential processing
274        if request_count <= 3 {
275            return process_batch_sequential(api, requests).await;
276        }
277
278        // For larger batches, use concurrent processing
279        process_batch_concurrent(api, requests, concurrency_limit).await
280    }
281}