1use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
2use async_trait::async_trait;
3use futures::stream::Stream;
4
5#[async_trait]
11pub trait ChatProvider: Send + Sync {
12 fn name(&self) -> &str;
14
15 async fn chat(
23 &self,
24 request: ChatCompletionRequest,
25 ) -> Result<ChatCompletionResponse, AiLibError>;
26
27 async fn stream(
35 &self,
36 request: ChatCompletionRequest,
37 ) -> Result<
38 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
39 AiLibError,
40 >;
41
42 async fn list_models(&self) -> Result<Vec<String>, AiLibError>;
47
48 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError>;
56
57 async fn batch(
66 &self,
67 requests: Vec<ChatCompletionRequest>,
68 concurrency_limit: Option<usize>,
69 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
70 batch_utils::process_batch_concurrent(self, requests, concurrency_limit).await
71 }
72}
73
74pub use ChatProvider as ChatApi;
76
77#[derive(Debug, Clone)]
79pub struct ChatCompletionChunk {
80 pub id: String,
81 pub object: String,
82 pub created: u64,
83 pub model: String,
84 pub choices: Vec<ChoiceDelta>,
85}
86
87#[derive(Debug, Clone)]
89pub struct ChoiceDelta {
90 pub index: u32,
91 pub delta: MessageDelta,
92 pub finish_reason: Option<String>,
93}
94
95#[derive(Debug, Clone)]
97pub struct MessageDelta {
98 pub role: Option<Role>,
99 pub content: Option<String>,
100}
101
102#[derive(Debug, Clone)]
106pub struct ModelInfo {
107 pub id: String,
108 pub object: String,
109 pub created: u64,
110 pub owned_by: String,
111 pub permission: Vec<ModelPermission>,
112}
113
114#[derive(Debug, Clone)]
116pub struct ModelPermission {
117 pub id: String,
118 pub object: String,
119 pub created: u64,
120 pub allow_create_engine: bool,
121 pub allow_sampling: bool,
122 pub allow_logprobs: bool,
123 pub allow_search_indices: bool,
124 pub allow_view: bool,
125 pub allow_fine_tuning: bool,
126 pub organization: String,
127 pub group: Option<String>,
128 pub is_blocking: bool,
129}
130
131use crate::types::Role;
133
134#[derive(Debug)]
136pub struct BatchResult {
137 pub successful: Vec<ChatCompletionResponse>,
138 pub failed: Vec<(usize, AiLibError)>,
139 pub total_requests: usize,
140 pub total_successful: usize,
141 pub total_failed: usize,
142}
143
144impl BatchResult {
145 pub fn new(total_requests: usize) -> Self {
147 Self {
148 successful: Vec::new(),
149 failed: Vec::new(),
150 total_requests,
151 total_successful: 0,
152 total_failed: 0,
153 }
154 }
155
156 pub fn add_success(&mut self, response: ChatCompletionResponse) {
158 self.successful.push(response);
159 self.total_successful += 1;
160 }
161
162 pub fn add_failure(&mut self, index: usize, error: AiLibError) {
164 self.failed.push((index, error));
165 self.total_failed += 1;
166 }
167
168 pub fn all_successful(&self) -> bool {
170 self.total_failed == 0
171 }
172
173 pub fn success_rate(&self) -> f64 {
175 if self.total_requests == 0 {
176 0.0
177 } else {
178 (self.total_successful as f64 / self.total_requests as f64) * 100.0
179 }
180 }
181}
182
183pub mod batch_utils {
185 use super::*;
186 use futures::stream::{self, StreamExt};
187 use std::sync::Arc;
188 use tokio::sync::Semaphore;
189
190 pub async fn process_batch_concurrent<T: ChatProvider + ?Sized>(
192 api: &T,
193 requests: Vec<ChatCompletionRequest>,
194 concurrency_limit: Option<usize>,
195 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
196 if requests.is_empty() {
197 return Ok(Vec::new());
198 }
199
200 let semaphore = concurrency_limit.map(|limit| Arc::new(Semaphore::new(limit)));
201
202 let futures = requests.into_iter().enumerate().map(|(index, request)| {
203 let api_ref = api;
204 let semaphore_ref = semaphore.clone();
205
206 async move {
207 let _permit = if let Some(sem) = &semaphore_ref {
209 match sem.acquire().await {
210 Ok(permit) => Some(permit),
211 Err(_) => {
212 return (
213 index,
214 Err(AiLibError::ProviderError(
215 "Failed to acquire semaphore permit".to_string(),
216 )),
217 )
218 }
219 }
220 } else {
221 None
222 };
223
224 let result = api_ref.chat(request).await;
226
227 (index, result)
229 }
230 });
231
232 let results: Vec<_> = stream::iter(futures)
234 .buffer_unordered(concurrency_limit.unwrap_or(usize::MAX))
235 .collect()
236 .await;
237
238 let mut sorted_results = Vec::with_capacity(results.len());
240 sorted_results.resize_with(results.len(), || {
241 Err(AiLibError::ProviderError("Placeholder".to_string()))
242 });
243 for (index, result) in results {
244 sorted_results[index] = result;
245 }
246
247 Ok(sorted_results)
248 }
249
250 pub async fn process_batch_sequential<T: ChatProvider + ?Sized>(
252 api: &T,
253 requests: Vec<ChatCompletionRequest>,
254 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
255 let mut results = Vec::with_capacity(requests.len());
256
257 for request in requests {
258 let result = api.chat(request).await;
259 results.push(result);
260 }
261
262 Ok(results)
263 }
264
265 pub async fn process_batch_smart<T: ChatProvider + ?Sized>(
267 api: &T,
268 requests: Vec<ChatCompletionRequest>,
269 concurrency_limit: Option<usize>,
270 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
271 let request_count = requests.len();
272
273 if request_count <= 3 {
275 return process_batch_sequential(api, requests).await;
276 }
277
278 process_batch_concurrent(api, requests, concurrency_limit).await
280 }
281}