1use crate::error::EngineError;
6use crate::file_store::FileStore;
7use crate::job::*;
8use crate::queue::JobQueue;
9use crate::webhook::{WebhookDelivery, WebhookQueue};
10use std::sync::Arc;
11
12pub struct BatchEngine<Q: JobQueue, W: WebhookQueue> {
22 pub queue: Arc<Q>,
23 pub file_store: FileStore,
24 pub webhook_queue: Arc<W>,
25 pub global_webhook_urls: Vec<String>,
27 pub webhook_signing_secret: Option<String>,
29}
30
31impl<Q: JobQueue, W: WebhookQueue> BatchEngine<Q, W> {
32 pub async fn submit(&self, submission: BatchSubmission) -> Result<BatchJob, EngineError> {
37 self.file_store
38 .get_meta(&submission.input_file_id)
39 .await
40 .map_err(|e| EngineError::Backend(e.to_string()))?
41 .ok_or_else(|| EngineError::FileNotFound(submission.input_file_id.clone()))?;
42
43 const DEFAULT_MAX_RETRIES: u8 = 3;
44
45 let epoch = crate::db::epoch_secs();
46 let now = crate::db::format_epoch_iso8601(epoch);
47 let batch_id = BatchId::new();
48 let total = submission.items.len() as u32;
49
50 let job = BatchJob {
51 id: batch_id.clone(),
52 status: BatchStatus::Queued,
53 execution_mode: submission.execution_mode.clone(),
54 priority: submission.priority,
55 key_id: submission.key_id,
56 input_file_id: submission.input_file_id,
57 webhook_url: submission.webhook_url.clone(),
58 metadata: submission.metadata,
59 request_counts: RequestCounts {
60 total,
61 ..Default::default()
62 },
63 created_at: now.clone(),
64 started_at: None,
65 completed_at: None,
66 expires_at: crate::db::epoch_plus_hours_iso8601(epoch, 24),
68 };
69
70 let items: Vec<BatchItem> = submission
71 .items
72 .into_iter()
73 .map(|si| BatchItem {
74 id: ItemId::new(),
75 batch_id: batch_id.clone(),
76 custom_id: si.custom_id,
77 status: ItemStatus::Pending,
78 request: BatchItemRequest {
79 model: si.model,
80 body: si.body,
81 source_format: si.source_format,
82 },
83 result: None,
84 attempts: 0,
85 max_retries: DEFAULT_MAX_RETRIES,
86 last_error: None,
87 next_retry_at: None,
88 lease_id: None,
89 lease_expires_at: None,
90 idempotency_key: None,
91 created_at: now.clone(),
92 completed_at: None,
93 })
94 .collect();
95
96 self.queue
97 .enqueue(&job, &items)
98 .await
99 .map_err(EngineError::Queue)?;
100
101 self.fire_webhook(
102 &batch_id,
103 "batch.queued",
104 serde_json::json!({
105 "batch_id": batch_id.0,
106 "total_items": total,
107 "execution_mode": job.execution_mode.as_str(),
108 }),
109 None,
110 )
111 .await;
112
113 Ok(job)
114 }
115
116 pub async fn get(&self, id: &BatchId) -> Result<Option<BatchJob>, EngineError> {
118 self.queue.get(id).await.map_err(EngineError::Queue)
119 }
120
121 pub async fn list(
123 &self,
124 key_id: Option<i64>,
125 cursor: Option<&str>,
126 limit: u32,
127 ) -> Result<Vec<BatchJob>, EngineError> {
128 self.queue
129 .list(key_id, cursor, limit)
130 .await
131 .map_err(EngineError::Queue)
132 }
133
134 pub async fn cancel(&self, id: &BatchId) -> Result<BatchJob, EngineError> {
136 let job = self.queue.cancel(id).await.map_err(EngineError::Queue)?;
137
138 if job.status == BatchStatus::Cancelled {
139 self.fire_webhook(
140 id,
141 "batch.cancelled",
142 serde_json::json!({ "batch_id": id.0 }),
143 job.webhook_url.as_deref(),
144 )
145 .await;
146 }
147
148 Ok(job)
149 }
150
151 pub async fn get_items(&self, id: &BatchId) -> Result<Vec<BatchItem>, EngineError> {
153 self.queue.get_items(id).await.map_err(EngineError::Queue)
154 }
155
156 async fn fire_webhook(
160 &self,
161 batch_id: &BatchId,
162 event_type: &str,
163 payload: serde_json::Value,
164 batch_webhook_url: Option<&str>,
165 ) {
166 const DEFAULT_MAX_RETRIES: u8 = 3;
167
168 let event_id = format!("evt_{}", uuid::Uuid::new_v4());
169
170 let mut urls: Vec<(String, Option<String>)> = self
171 .global_webhook_urls
172 .iter()
173 .map(|u| (u.clone(), self.webhook_signing_secret.clone()))
174 .collect();
175
176 if let Some(url) = batch_webhook_url {
177 urls.push((url.to_string(), self.webhook_signing_secret.clone()));
178 }
179
180 let full_payload = serde_json::json!({
181 "event_id": event_id,
182 "event_type": event_type,
183 "data": payload,
184 });
185
186 for (url, secret) in urls {
187 let delivery = WebhookDelivery {
188 delivery_id: format!("whd_{}", uuid::Uuid::new_v4()),
189 event_id: event_id.clone(),
190 batch_id: batch_id.0.clone(),
191 url,
192 payload: full_payload.clone(),
193 signing_secret: secret,
194 attempts: 0,
195 max_retries: DEFAULT_MAX_RETRIES,
196 next_retry_at: None,
197 };
198 if let Err(e) = self.webhook_queue.enqueue(delivery).await {
199 tracing::error!(error = %e, "failed to enqueue webhook delivery");
200 }
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::db::init_batch_engine_tables;
209 use crate::file_store::FileStore;
210 use crate::queue::sqlite::SqliteQueue;
211 use crate::webhook::sqlite::SqliteWebhookQueue;
212 use rusqlite::Connection;
213 use std::sync::Arc;
214 use std::sync::Mutex;
215
216 async fn test_engine() -> BatchEngine<SqliteQueue, SqliteWebhookQueue> {
217 let conn = Connection::open_in_memory().unwrap();
218 init_batch_engine_tables(&conn).unwrap();
219 let db = Arc::new(Mutex::new(conn));
220
221 BatchEngine {
222 queue: Arc::new(SqliteQueue::new(db.clone())),
223 file_store: FileStore::new(db.clone()),
224 webhook_queue: Arc::new(SqliteWebhookQueue::new(db)),
225 global_webhook_urls: vec![],
226 webhook_signing_secret: None,
227 }
228 }
229
230 #[tokio::test]
231 async fn submit_and_get() {
232 let engine = test_engine().await;
233
234 engine
236 .file_store
237 .insert("file-sub1", None, None, b"test", 2)
238 .await
239 .unwrap();
240
241 let job = engine
242 .submit(BatchSubmission {
243 items: vec![
244 SubmissionItem {
245 custom_id: "req-1".into(),
246 model: "gpt-4o".into(),
247 body: serde_json::json!({}),
248 source_format: SourceFormat::OpenAI,
249 },
250 SubmissionItem {
251 custom_id: "req-2".into(),
252 model: "gpt-4o".into(),
253 body: serde_json::json!({}),
254 source_format: SourceFormat::OpenAI,
255 },
256 ],
257 execution_mode: ExecutionMode::ProxyNative,
258 input_file_id: "file-sub1".into(),
259 key_id: Some(42),
260 webhook_url: None,
261 metadata: None,
262 priority: 0,
263 })
264 .await
265 .unwrap();
266
267 assert_eq!(job.status, BatchStatus::Queued);
268 assert_eq!(job.request_counts.total, 2);
269 assert_eq!(job.key_id, Some(42));
270
271 let fetched = engine.get(&job.id).await.unwrap().unwrap();
272 assert_eq!(fetched.id, job.id);
273 }
274
275 #[tokio::test]
276 async fn submit_missing_file() {
277 let engine = test_engine().await;
278 let result = engine
279 .submit(BatchSubmission {
280 items: vec![],
281 execution_mode: ExecutionMode::ProxyNative,
282 input_file_id: "file-nope".into(),
283 key_id: None,
284 webhook_url: None,
285 metadata: None,
286 priority: 0,
287 })
288 .await;
289 assert!(result.is_err());
290 }
291
292 #[tokio::test]
293 async fn cancel_job() {
294 let engine = test_engine().await;
295 engine
296 .file_store
297 .insert("file-cancel", None, None, b"test", 1)
298 .await
299 .unwrap();
300
301 let job = engine
302 .submit(BatchSubmission {
303 items: vec![SubmissionItem {
304 custom_id: "r1".into(),
305 model: "gpt-4o".into(),
306 body: serde_json::json!({}),
307 source_format: SourceFormat::OpenAI,
308 }],
309 execution_mode: ExecutionMode::ProxyNative,
310 input_file_id: "file-cancel".into(),
311 key_id: None,
312 webhook_url: None,
313 metadata: None,
314 priority: 0,
315 })
316 .await
317 .unwrap();
318
319 let cancelled = engine.cancel(&job.id).await.unwrap();
320 assert_eq!(cancelled.status, BatchStatus::Cancelled);
321 }
322}