Skip to main content

anyllm_batch_engine/
engine.rs

1// crates/batch_engine/src/engine.rs
2//! BatchEngine: the main entry point for batch operations.
3//! Thin facade over JobQueue, FileStore, and WebhookQueue.
4
5use crate::error::EngineError;
6use crate::file_store::FileStore;
7use crate::job::*;
8use crate::queue::JobQueue;
9use crate::webhook::{WebhookDelivery, WebhookQueue};
10use std::sync::Arc;
11
12/// The main batch engine. Holds references to queue, file store, and webhook queue.
13///
14/// Each concern is separated into its own abstraction:
15/// - `queue` owns job/item lifecycle and worker lease management.
16/// - `file_store` owns JSONL file storage (referenced by job input_file_id).
17/// - `webhook_queue` owns durable outbound webhook delivery with retry.
18///
19/// The three abstractions share a single SQLite connection via `Arc<Mutex<Connection>>`
20/// so they can participate in the same ACID transaction without an external coordinator.
21pub struct BatchEngine<Q: JobQueue, W: WebhookQueue> {
22    pub queue: Arc<Q>,
23    pub file_store: FileStore,
24    pub webhook_queue: Arc<W>,
25    /// Global webhook URLs notified on every batch event. Supplemented by per-batch URLs.
26    pub global_webhook_urls: Vec<String>,
27    /// Optional HMAC signing secret for webhook payloads (X-Signature-256 header).
28    pub webhook_signing_secret: Option<String>,
29}
30
31impl<Q: JobQueue, W: WebhookQueue> BatchEngine<Q, W> {
32    /// Validate the submission, create a `BatchJob` and its `BatchItem` list, enqueue both,
33    /// and fire a `batch.queued` webhook. Returns the created job on success.
34    ///
35    /// Fails if the referenced `input_file_id` does not exist in the file store.
36    pub async fn submit(&self, submission: BatchSubmission) -> Result<BatchJob, EngineError> {
37        self.file_store
38            .get_meta(&submission.input_file_id)
39            .await
40            .map_err(|e| EngineError::Backend(e.to_string()))?
41            .ok_or_else(|| EngineError::FileNotFound(submission.input_file_id.clone()))?;
42
43        const DEFAULT_MAX_RETRIES: u8 = 3;
44
45        let epoch = crate::db::epoch_secs();
46        let now = crate::db::format_epoch_iso8601(epoch);
47        let batch_id = BatchId::new();
48        let total = submission.items.len() as u32;
49
50        let job = BatchJob {
51            id: batch_id.clone(),
52            status: BatchStatus::Queued,
53            execution_mode: submission.execution_mode.clone(),
54            priority: submission.priority,
55            key_id: submission.key_id,
56            input_file_id: submission.input_file_id,
57            webhook_url: submission.webhook_url.clone(),
58            metadata: submission.metadata,
59            request_counts: RequestCounts {
60                total,
61                ..Default::default()
62            },
63            created_at: now.clone(),
64            started_at: None,
65            completed_at: None,
66            // 24h TTL matches the Anthropic batch API contract (results expire after 24h).
67            expires_at: crate::db::epoch_plus_hours_iso8601(epoch, 24),
68        };
69
70        let items: Vec<BatchItem> = submission
71            .items
72            .into_iter()
73            .map(|si| BatchItem {
74                id: ItemId::new(),
75                batch_id: batch_id.clone(),
76                custom_id: si.custom_id,
77                status: ItemStatus::Pending,
78                request: BatchItemRequest {
79                    model: si.model,
80                    body: si.body,
81                    source_format: si.source_format,
82                },
83                result: None,
84                attempts: 0,
85                max_retries: DEFAULT_MAX_RETRIES,
86                last_error: None,
87                next_retry_at: None,
88                lease_id: None,
89                lease_expires_at: None,
90                idempotency_key: None,
91                created_at: now.clone(),
92                completed_at: None,
93            })
94            .collect();
95
96        self.queue
97            .enqueue(&job, &items)
98            .await
99            .map_err(EngineError::Queue)?;
100
101        self.fire_webhook(
102            &batch_id,
103            "batch.queued",
104            serde_json::json!({
105                "batch_id": batch_id.0,
106                "total_items": total,
107                "execution_mode": job.execution_mode.as_str(),
108            }),
109            None,
110        )
111        .await;
112
113        Ok(job)
114    }
115
116    /// Get a batch job by ID.
117    pub async fn get(&self, id: &BatchId) -> Result<Option<BatchJob>, EngineError> {
118        self.queue.get(id).await.map_err(EngineError::Queue)
119    }
120
121    /// List batch jobs.
122    pub async fn list(
123        &self,
124        key_id: Option<i64>,
125        cursor: Option<&str>,
126        limit: u32,
127    ) -> Result<Vec<BatchJob>, EngineError> {
128        self.queue
129            .list(key_id, cursor, limit)
130            .await
131            .map_err(EngineError::Queue)
132    }
133
134    /// Cancel a batch job.
135    pub async fn cancel(&self, id: &BatchId) -> Result<BatchJob, EngineError> {
136        let job = self.queue.cancel(id).await.map_err(EngineError::Queue)?;
137
138        if job.status == BatchStatus::Cancelled {
139            self.fire_webhook(
140                id,
141                "batch.cancelled",
142                serde_json::json!({ "batch_id": id.0 }),
143                job.webhook_url.as_deref(),
144            )
145            .await;
146        }
147
148        Ok(job)
149    }
150
151    /// Get items for a batch (used for result retrieval).
152    pub async fn get_items(&self, id: &BatchId) -> Result<Vec<BatchItem>, EngineError> {
153        self.queue.get_items(id).await.map_err(EngineError::Queue)
154    }
155
156    /// Fire a webhook to all configured URLs.
157    /// `batch_webhook_url`: per-batch URL for terminal events; callers pass it from the job
158    /// they already hold to avoid an extra database round-trip.
159    async fn fire_webhook(
160        &self,
161        batch_id: &BatchId,
162        event_type: &str,
163        payload: serde_json::Value,
164        batch_webhook_url: Option<&str>,
165    ) {
166        const DEFAULT_MAX_RETRIES: u8 = 3;
167
168        let event_id = format!("evt_{}", uuid::Uuid::new_v4());
169
170        let mut urls: Vec<(String, Option<String>)> = self
171            .global_webhook_urls
172            .iter()
173            .map(|u| (u.clone(), self.webhook_signing_secret.clone()))
174            .collect();
175
176        if let Some(url) = batch_webhook_url {
177            urls.push((url.to_string(), self.webhook_signing_secret.clone()));
178        }
179
180        let full_payload = serde_json::json!({
181            "event_id": event_id,
182            "event_type": event_type,
183            "data": payload,
184        });
185
186        for (url, secret) in urls {
187            let delivery = WebhookDelivery {
188                delivery_id: format!("whd_{}", uuid::Uuid::new_v4()),
189                event_id: event_id.clone(),
190                batch_id: batch_id.0.clone(),
191                url,
192                payload: full_payload.clone(),
193                signing_secret: secret,
194                attempts: 0,
195                max_retries: DEFAULT_MAX_RETRIES,
196                next_retry_at: None,
197            };
198            if let Err(e) = self.webhook_queue.enqueue(delivery).await {
199                tracing::error!(error = %e, "failed to enqueue webhook delivery");
200            }
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::db::init_batch_engine_tables;
209    use crate::file_store::FileStore;
210    use crate::queue::sqlite::SqliteQueue;
211    use crate::webhook::sqlite::SqliteWebhookQueue;
212    use rusqlite::Connection;
213    use std::sync::Arc;
214    use std::sync::Mutex;
215
216    async fn test_engine() -> BatchEngine<SqliteQueue, SqliteWebhookQueue> {
217        let conn = Connection::open_in_memory().unwrap();
218        init_batch_engine_tables(&conn).unwrap();
219        let db = Arc::new(Mutex::new(conn));
220
221        BatchEngine {
222            queue: Arc::new(SqliteQueue::new(db.clone())),
223            file_store: FileStore::new(db.clone()),
224            webhook_queue: Arc::new(SqliteWebhookQueue::new(db)),
225            global_webhook_urls: vec![],
226            webhook_signing_secret: None,
227        }
228    }
229
230    #[tokio::test]
231    async fn submit_and_get() {
232        let engine = test_engine().await;
233
234        // Upload a file first.
235        engine
236            .file_store
237            .insert("file-sub1", None, None, b"test", 2)
238            .await
239            .unwrap();
240
241        let job = engine
242            .submit(BatchSubmission {
243                items: vec![
244                    SubmissionItem {
245                        custom_id: "req-1".into(),
246                        model: "gpt-4o".into(),
247                        body: serde_json::json!({}),
248                        source_format: SourceFormat::OpenAI,
249                    },
250                    SubmissionItem {
251                        custom_id: "req-2".into(),
252                        model: "gpt-4o".into(),
253                        body: serde_json::json!({}),
254                        source_format: SourceFormat::OpenAI,
255                    },
256                ],
257                execution_mode: ExecutionMode::ProxyNative,
258                input_file_id: "file-sub1".into(),
259                key_id: Some(42),
260                webhook_url: None,
261                metadata: None,
262                priority: 0,
263            })
264            .await
265            .unwrap();
266
267        assert_eq!(job.status, BatchStatus::Queued);
268        assert_eq!(job.request_counts.total, 2);
269        assert_eq!(job.key_id, Some(42));
270
271        let fetched = engine.get(&job.id).await.unwrap().unwrap();
272        assert_eq!(fetched.id, job.id);
273    }
274
275    #[tokio::test]
276    async fn submit_missing_file() {
277        let engine = test_engine().await;
278        let result = engine
279            .submit(BatchSubmission {
280                items: vec![],
281                execution_mode: ExecutionMode::ProxyNative,
282                input_file_id: "file-nope".into(),
283                key_id: None,
284                webhook_url: None,
285                metadata: None,
286                priority: 0,
287            })
288            .await;
289        assert!(result.is_err());
290    }
291
292    #[tokio::test]
293    async fn cancel_job() {
294        let engine = test_engine().await;
295        engine
296            .file_store
297            .insert("file-cancel", None, None, b"test", 1)
298            .await
299            .unwrap();
300
301        let job = engine
302            .submit(BatchSubmission {
303                items: vec![SubmissionItem {
304                    custom_id: "r1".into(),
305                    model: "gpt-4o".into(),
306                    body: serde_json::json!({}),
307                    source_format: SourceFormat::OpenAI,
308                }],
309                execution_mode: ExecutionMode::ProxyNative,
310                input_file_id: "file-cancel".into(),
311                key_id: None,
312                webhook_url: None,
313                metadata: None,
314                priority: 0,
315            })
316            .await
317            .unwrap();
318
319        let cancelled = engine.cancel(&job.id).await.unwrap();
320        assert_eq!(cancelled.status, BatchStatus::Cancelled);
321    }
322}