1use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
2use async_trait::async_trait;
3use futures::stream::Stream;
4
5#[async_trait]
12pub trait ChatApi: Send + Sync {
13 async fn chat_completion(
21 &self,
22 request: ChatCompletionRequest,
23 ) -> Result<ChatCompletionResponse, AiLibError>;
24
25 async fn chat_completion_stream(
33 &self,
34 request: ChatCompletionRequest,
35 ) -> Result<
36 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
37 AiLibError,
38 >;
39
40 async fn list_models(&self) -> Result<Vec<String>, AiLibError>;
45
46 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError>;
54
55 async fn chat_completion_batch(
66 &self,
67 requests: Vec<ChatCompletionRequest>,
68 concurrency_limit: Option<usize>,
69 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
70 batch_utils::process_batch_concurrent(self, requests, concurrency_limit).await
71 }
72}
73
74#[derive(Debug, Clone)]
78pub struct ChatCompletionChunk {
79 pub id: String,
80 pub object: String,
81 pub created: u64,
82 pub model: String,
83 pub choices: Vec<ChoiceDelta>,
84}
85
86#[derive(Debug, Clone)]
90pub struct ChoiceDelta {
91 pub index: u32,
92 pub delta: MessageDelta,
93 pub finish_reason: Option<String>,
94}
95
96#[derive(Debug, Clone)]
100pub struct MessageDelta {
101 pub role: Option<Role>,
102 pub content: Option<String>,
103}
104
105#[derive(Debug, Clone)]
109pub struct ModelInfo {
110 pub id: String,
111 pub object: String,
112 pub created: u64,
113 pub owned_by: String,
114 pub permission: Vec<ModelPermission>,
115}
116
117#[derive(Debug, Clone)]
121pub struct ModelPermission {
122 pub id: String,
123 pub object: String,
124 pub created: u64,
125 pub allow_create_engine: bool,
126 pub allow_sampling: bool,
127 pub allow_logprobs: bool,
128 pub allow_search_indices: bool,
129 pub allow_view: bool,
130 pub allow_fine_tuning: bool,
131 pub organization: String,
132 pub group: Option<String>,
133 pub is_blocking: bool,
134}
135
136use crate::types::Role;
138
139#[derive(Debug)]
143pub struct BatchResult {
144 pub successful: Vec<ChatCompletionResponse>,
145 pub failed: Vec<(usize, AiLibError)>,
146 pub total_requests: usize,
147 pub total_successful: usize,
148 pub total_failed: usize,
149}
150
151impl BatchResult {
152 pub fn new(total_requests: usize) -> Self {
154 Self {
155 successful: Vec::new(),
156 failed: Vec::new(),
157 total_requests,
158 total_successful: 0,
159 total_failed: 0,
160 }
161 }
162
163 pub fn add_success(&mut self, response: ChatCompletionResponse) {
165 self.successful.push(response);
166 self.total_successful += 1;
167 }
168
169 pub fn add_failure(&mut self, index: usize, error: AiLibError) {
171 self.failed.push((index, error));
172 self.total_failed += 1;
173 }
174
175 pub fn all_successful(&self) -> bool {
177 self.total_failed == 0
178 }
179
180 pub fn success_rate(&self) -> f64 {
182 if self.total_requests == 0 {
183 0.0
184 } else {
185 (self.total_successful as f64 / self.total_requests as f64) * 100.0
186 }
187 }
188}
189
190pub mod batch_utils {
194 use super::*;
195 use futures::stream::{self, StreamExt};
196 use std::sync::Arc;
197 use tokio::sync::Semaphore;
198
199 pub async fn process_batch_concurrent<T: ChatApi + ?Sized>(
203 api: &T,
204 requests: Vec<ChatCompletionRequest>,
205 concurrency_limit: Option<usize>,
206 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
207 if requests.is_empty() {
208 return Ok(Vec::new());
209 }
210
211 let semaphore = concurrency_limit.map(|limit| Arc::new(Semaphore::new(limit)));
212
213 let futures = requests.into_iter().enumerate().map(|(index, request)| {
214 let api_ref = api;
215 let semaphore_ref = semaphore.clone();
216
217 async move {
218 let _permit = if let Some(sem) = &semaphore_ref {
220 match sem.acquire().await {
221 Ok(permit) => Some(permit),
222 Err(_) => return (index, Err(AiLibError::ProviderError("Failed to acquire semaphore permit".to_string()))),
223 }
224 } else {
225 None
226 };
227
228 let result = api_ref.chat_completion(request).await;
230
231 (index, result)
233 }
234 });
235
236 let results: Vec<_> = stream::iter(futures)
238 .buffer_unordered(concurrency_limit.unwrap_or(usize::MAX))
239 .collect()
240 .await;
241
242 let mut sorted_results = Vec::with_capacity(results.len());
244 sorted_results.resize_with(results.len(), || Err(AiLibError::ProviderError("Placeholder".to_string())));
245 for (index, result) in results {
246 sorted_results[index] = result;
247 }
248
249 Ok(sorted_results)
250 }
251
252 pub async fn process_batch_sequential<T: ChatApi + ?Sized>(
256 api: &T,
257 requests: Vec<ChatCompletionRequest>,
258 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
259 let mut results = Vec::with_capacity(requests.len());
260
261 for request in requests {
262 let result = api.chat_completion(request).await;
263 results.push(result);
264 }
265
266 Ok(results)
267 }
268
269 pub async fn process_batch_smart<T: ChatApi + ?Sized>(
273 api: &T,
274 requests: Vec<ChatCompletionRequest>,
275 concurrency_limit: Option<usize>,
276 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
277 let request_count = requests.len();
278
279 if request_count <= 3 {
281 return process_batch_sequential(api, requests).await;
282 }
283
284 process_batch_concurrent(api, requests, concurrency_limit).await
286 }
287}