1use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
2use async_trait::async_trait;
3use futures::stream::Stream;
4
5#[async_trait]
12pub trait ChatApi: Send + Sync {
13 async fn chat_completion(
21 &self,
22 request: ChatCompletionRequest,
23 ) -> Result<ChatCompletionResponse, AiLibError>;
24
25 async fn chat_completion_stream(
33 &self,
34 request: ChatCompletionRequest,
35 ) -> Result<
36 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
37 AiLibError,
38 >;
39
40 async fn list_models(&self) -> Result<Vec<String>, AiLibError>;
45
46 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError>;
54
55 async fn chat_completion_batch(
64 &self,
65 requests: Vec<ChatCompletionRequest>,
66 concurrency_limit: Option<usize>,
67 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
68 batch_utils::process_batch_concurrent(self, requests, concurrency_limit).await
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct ChatCompletionChunk {
75 pub id: String,
76 pub object: String,
77 pub created: u64,
78 pub model: String,
79 pub choices: Vec<ChoiceDelta>,
80}
81
82#[derive(Debug, Clone)]
84pub struct ChoiceDelta {
85 pub index: u32,
86 pub delta: MessageDelta,
87 pub finish_reason: Option<String>,
88}
89
90#[derive(Debug, Clone)]
92pub struct MessageDelta {
93 pub role: Option<Role>,
94 pub content: Option<String>,
95}
96
97#[derive(Debug, Clone)]
101pub struct ModelInfo {
102 pub id: String,
103 pub object: String,
104 pub created: u64,
105 pub owned_by: String,
106 pub permission: Vec<ModelPermission>,
107}
108
109#[derive(Debug, Clone)]
111pub struct ModelPermission {
112 pub id: String,
113 pub object: String,
114 pub created: u64,
115 pub allow_create_engine: bool,
116 pub allow_sampling: bool,
117 pub allow_logprobs: bool,
118 pub allow_search_indices: bool,
119 pub allow_view: bool,
120 pub allow_fine_tuning: bool,
121 pub organization: String,
122 pub group: Option<String>,
123 pub is_blocking: bool,
124}
125
126use crate::types::Role;
128
129#[derive(Debug)]
131pub struct BatchResult {
132 pub successful: Vec<ChatCompletionResponse>,
133 pub failed: Vec<(usize, AiLibError)>,
134 pub total_requests: usize,
135 pub total_successful: usize,
136 pub total_failed: usize,
137}
138
139impl BatchResult {
140 pub fn new(total_requests: usize) -> Self {
142 Self {
143 successful: Vec::new(),
144 failed: Vec::new(),
145 total_requests,
146 total_successful: 0,
147 total_failed: 0,
148 }
149 }
150
151 pub fn add_success(&mut self, response: ChatCompletionResponse) {
153 self.successful.push(response);
154 self.total_successful += 1;
155 }
156
157 pub fn add_failure(&mut self, index: usize, error: AiLibError) {
159 self.failed.push((index, error));
160 self.total_failed += 1;
161 }
162
163 pub fn all_successful(&self) -> bool {
165 self.total_failed == 0
166 }
167
168 pub fn success_rate(&self) -> f64 {
170 if self.total_requests == 0 {
171 0.0
172 } else {
173 (self.total_successful as f64 / self.total_requests as f64) * 100.0
174 }
175 }
176}
177
178pub mod batch_utils {
180 use super::*;
181 use futures::stream::{self, StreamExt};
182 use std::sync::Arc;
183 use tokio::sync::Semaphore;
184
185 pub async fn process_batch_concurrent<T: ChatApi + ?Sized>(
187 api: &T,
188 requests: Vec<ChatCompletionRequest>,
189 concurrency_limit: Option<usize>,
190 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
191 if requests.is_empty() {
192 return Ok(Vec::new());
193 }
194
195 let semaphore = concurrency_limit.map(|limit| Arc::new(Semaphore::new(limit)));
196
197 let futures = requests.into_iter().enumerate().map(|(index, request)| {
198 let api_ref = api;
199 let semaphore_ref = semaphore.clone();
200
201 async move {
202 let _permit = if let Some(sem) = &semaphore_ref {
204 match sem.acquire().await {
205 Ok(permit) => Some(permit),
206 Err(_) => {
207 return (
208 index,
209 Err(AiLibError::ProviderError(
210 "Failed to acquire semaphore permit".to_string(),
211 )),
212 )
213 }
214 }
215 } else {
216 None
217 };
218
219 let result = api_ref.chat_completion(request).await;
221
222 (index, result)
224 }
225 });
226
227 let results: Vec<_> = stream::iter(futures)
229 .buffer_unordered(concurrency_limit.unwrap_or(usize::MAX))
230 .collect()
231 .await;
232
233 let mut sorted_results = Vec::with_capacity(results.len());
235 sorted_results.resize_with(results.len(), || {
236 Err(AiLibError::ProviderError("Placeholder".to_string()))
237 });
238 for (index, result) in results {
239 sorted_results[index] = result;
240 }
241
242 Ok(sorted_results)
243 }
244
245 pub async fn process_batch_sequential<T: ChatApi + ?Sized>(
247 api: &T,
248 requests: Vec<ChatCompletionRequest>,
249 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
250 let mut results = Vec::with_capacity(requests.len());
251
252 for request in requests {
253 let result = api.chat_completion(request).await;
254 results.push(result);
255 }
256
257 Ok(results)
258 }
259
260 pub async fn process_batch_smart<T: ChatApi + ?Sized>(
262 api: &T,
263 requests: Vec<ChatCompletionRequest>,
264 concurrency_limit: Option<usize>,
265 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
266 let request_count = requests.len();
267
268 if request_count <= 3 {
270 return process_batch_sequential(api, requests).await;
271 }
272
273 process_batch_concurrent(api, requests, concurrency_limit).await
275 }
276}