1pub mod event_persistence;
31#[cfg(feature = "distributed-locks")]
32pub mod lock;
33pub mod result_store;
34
35pub use event_persistence::{DbEventPersister, DbEventPersisterConfig};
36
37use async_trait::async_trait;
38pub use celers_backend_redis::{
39 BackendError, ChordState, Result, ResultBackend, TaskMeta, TaskResult, TaskTtlConfig,
40};
41use chrono::{DateTime, Utc};
42use serde_json::json;
43use sqlx::{postgres::PgPoolOptions, MySqlPool, PgPool, Row};
44use std::time::Duration;
45use uuid::Uuid;
46
47#[derive(Clone)]
49pub struct PostgresResultBackend {
50 pool: PgPool,
51 ttl_config: TaskTtlConfig,
52}
53
54impl PostgresResultBackend {
55 pub async fn new(database_url: &str) -> Result<Self> {
60 let pool = PgPoolOptions::new()
61 .max_connections(20)
62 .acquire_timeout(Duration::from_secs(5))
63 .connect(database_url)
64 .await
65 .map_err(|e| {
66 BackendError::Connection(format!("Failed to connect to database: {}", e))
67 })?;
68
69 Ok(Self {
70 pool,
71 ttl_config: TaskTtlConfig::new(),
72 })
73 }
74
75 pub fn with_ttl_config(mut self, config: TaskTtlConfig) -> Self {
77 self.ttl_config = config;
78 self
79 }
80
81 pub fn ttl_config(&self) -> &TaskTtlConfig {
83 &self.ttl_config
84 }
85
86 pub fn ttl_config_mut(&mut self) -> &mut TaskTtlConfig {
88 &mut self.ttl_config
89 }
90
91 pub async fn migrate(&self) -> Result<()> {
93 let migration_sql = include_str!("../migrations/001_init_postgres.sql");
94
95 sqlx::query(migration_sql)
96 .execute(&self.pool)
97 .await
98 .map_err(|e| BackendError::Connection(format!("Migration failed: {}", e)))?;
99
100 Ok(())
101 }
102
103 pub fn pool(&self) -> &PgPool {
105 &self.pool
106 }
107
108 pub async fn cleanup_expired(&self) -> Result<usize> {
110 let row = sqlx::query("SELECT cleanup_expired_results()")
111 .fetch_one(&self.pool)
112 .await
113 .map_err(|e| {
114 BackendError::Connection(format!("Failed to cleanup expired results: {}", e))
115 })?;
116
117 let count: i32 = row.get(0);
118 Ok(count as usize)
119 }
120}
121
122#[async_trait]
123impl ResultBackend for PostgresResultBackend {
124 async fn store_result(&mut self, task_id: Uuid, meta: &TaskMeta) -> Result<()> {
125 let (result_state, result_data, error_message, retry_count) = match &meta.result {
126 TaskResult::Pending => ("pending", None, None, None),
127 TaskResult::Started => ("started", None, None, None),
128 TaskResult::Success(data) => ("success", Some(data.clone()), None, None),
129 TaskResult::Failure(err) => ("failure", None, Some(err.clone()), None),
130 TaskResult::Revoked => ("revoked", None, None, None),
131 TaskResult::Retry(count) => ("retry", None, None, Some(*count as i32)),
132 };
133
134 sqlx::query(
135 r#"
136 INSERT INTO celers_task_results
137 (task_id, task_name, result_state, result_data, error_message, retry_count,
138 created_at, started_at, completed_at, worker)
139 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
140 ON CONFLICT (task_id) DO UPDATE SET
141 result_state = EXCLUDED.result_state,
142 result_data = EXCLUDED.result_data,
143 error_message = EXCLUDED.error_message,
144 retry_count = EXCLUDED.retry_count,
145 started_at = EXCLUDED.started_at,
146 completed_at = EXCLUDED.completed_at,
147 worker = EXCLUDED.worker
148 "#,
149 )
150 .bind(task_id)
151 .bind(&meta.task_name)
152 .bind(result_state)
153 .bind(result_data)
154 .bind(error_message)
155 .bind(retry_count)
156 .bind(meta.created_at)
157 .bind(meta.started_at)
158 .bind(meta.completed_at)
159 .bind(&meta.worker)
160 .execute(&self.pool)
161 .await
162 .map_err(|e| BackendError::Connection(format!("Failed to store result: {}", e)))?;
163
164 if let Some(ttl) = self.ttl_config.get_ttl(&meta.task_name) {
166 self.set_expiration(task_id, ttl).await?;
167 }
168
169 Ok(())
170 }
171
172 async fn get_result(&mut self, task_id: Uuid) -> Result<Option<TaskMeta>> {
173 let row = sqlx::query(
174 r#"
175 SELECT task_id, task_name, result_state, result_data, error_message,
176 retry_count, created_at, started_at, completed_at, worker
177 FROM celers_task_results
178 WHERE task_id = $1
179 "#,
180 )
181 .bind(task_id)
182 .fetch_optional(&self.pool)
183 .await
184 .map_err(|e| BackendError::Connection(format!("Failed to get result: {}", e)))?;
185
186 match row {
187 Some(row) => {
188 let result_state: String = row.get("result_state");
189 let result_data: Option<serde_json::Value> = row.get("result_data");
190 let error_message: Option<String> = row.get("error_message");
191 let retry_count: Option<i32> = row.get("retry_count");
192
193 let result = match result_state.as_str() {
194 "pending" => TaskResult::Pending,
195 "started" => TaskResult::Started,
196 "success" => TaskResult::Success(result_data.unwrap_or(json!(null))),
197 "failure" => TaskResult::Failure(error_message.unwrap_or_default()),
198 "revoked" => TaskResult::Revoked,
199 "retry" => TaskResult::Retry(retry_count.unwrap_or(0) as u32),
200 _ => TaskResult::Pending,
201 };
202
203 let meta = TaskMeta {
204 task_id: row.get("task_id"),
205 task_name: row.get("task_name"),
206 result,
207 created_at: row.get("created_at"),
208 started_at: row.get("started_at"),
209 completed_at: row.get("completed_at"),
210 worker: row.get("worker"),
211 progress: None,
212 version: 0,
213 tags: Vec::new(),
214 metadata: std::collections::HashMap::new(),
215 worker_hostname: None,
216 runtime_ms: None,
217 memory_bytes: None,
218 retries: None,
219 queue: None,
220 };
221
222 Ok(Some(meta))
223 }
224 None => Ok(None),
225 }
226 }
227
228 async fn delete_result(&mut self, task_id: Uuid) -> Result<()> {
229 sqlx::query("DELETE FROM celers_task_results WHERE task_id = $1")
230 .bind(task_id)
231 .execute(&self.pool)
232 .await
233 .map_err(|e| BackendError::Connection(format!("Failed to delete result: {}", e)))?;
234
235 Ok(())
236 }
237
238 async fn set_expiration(&mut self, task_id: Uuid, ttl: Duration) -> Result<()> {
239 let expires_at = Utc::now()
240 + chrono::Duration::from_std(ttl)
241 .map_err(|e| BackendError::Serialization(format!("Invalid TTL duration: {}", e)))?;
242
243 sqlx::query("UPDATE celers_task_results SET expires_at = $1 WHERE task_id = $2")
244 .bind(expires_at)
245 .bind(task_id)
246 .execute(&self.pool)
247 .await
248 .map_err(|e| BackendError::Connection(format!("Failed to set expiration: {}", e)))?;
249
250 Ok(())
251 }
252
253 async fn chord_init(&mut self, state: ChordState) -> Result<()> {
254 let task_ids = serde_json::to_value(&state.task_ids)
255 .map_err(|e| BackendError::Serialization(e.to_string()))?;
256
257 sqlx::query(
258 r#"
259 INSERT INTO celers_chord_state (chord_id, total, completed, callback, task_ids, created_at, timeout_seconds, cancelled, cancellation_reason)
260 VALUES ($1, $2, 0, $3, $4, $5, $6, $7, $8)
261 ON CONFLICT (chord_id) DO UPDATE SET
262 total = EXCLUDED.total,
263 callback = EXCLUDED.callback,
264 task_ids = EXCLUDED.task_ids,
265 timeout_seconds = EXCLUDED.timeout_seconds,
266 cancelled = EXCLUDED.cancelled,
267 cancellation_reason = EXCLUDED.cancellation_reason
268 "#,
269 )
270 .bind(state.chord_id)
271 .bind(state.total as i32)
272 .bind(&state.callback)
273 .bind(task_ids)
274 .bind(state.created_at)
275 .bind(state.timeout.map(|d| d.as_secs() as i64))
276 .bind(state.cancelled)
277 .bind(&state.cancellation_reason)
278 .execute(&self.pool)
279 .await
280 .map_err(|e| BackendError::Connection(format!("Failed to init chord: {}", e)))?;
281
282 Ok(())
283 }
284
285 async fn chord_complete_task(&mut self, chord_id: Uuid) -> Result<usize> {
286 let row = sqlx::query("SELECT chord_increment_counter($1)")
287 .bind(chord_id)
288 .fetch_one(&self.pool)
289 .await
290 .map_err(|e| {
291 BackendError::Connection(format!("Failed to increment chord counter: {}", e))
292 })?;
293
294 let count: i32 = row.get(0);
295 Ok(count as usize)
296 }
297
298 async fn chord_get_state(&mut self, chord_id: Uuid) -> Result<Option<ChordState>> {
299 let row = sqlx::query(
300 r#"
301 SELECT chord_id, total, completed, callback, task_ids, created_at, timeout_seconds, cancelled, cancellation_reason
302 FROM celers_chord_state
303 WHERE chord_id = $1
304 "#,
305 )
306 .bind(chord_id)
307 .fetch_optional(&self.pool)
308 .await
309 .map_err(|e| BackendError::Connection(format!("Failed to get chord state: {}", e)))?;
310
311 match row {
312 Some(row) => {
313 let task_ids_json: serde_json::Value = row.get("task_ids");
314 let task_ids: Vec<Uuid> = serde_json::from_value(task_ids_json)
315 .map_err(|e| BackendError::Serialization(e.to_string()))?;
316
317 let state = ChordState {
318 chord_id: row.get("chord_id"),
319 total: row.get::<i32, _>("total") as usize,
320 completed: row.get::<i32, _>("completed") as usize,
321 callback: row.get("callback"),
322 task_ids,
323 created_at: row.get("created_at"),
324 timeout: row
325 .get::<Option<i64>, _>("timeout_seconds")
326 .map(|s| std::time::Duration::from_secs(s as u64)),
327 cancelled: row.get("cancelled"),
328 cancellation_reason: row.get("cancellation_reason"),
329 retry_count: 0,
330 max_retries: None,
331 };
332
333 Ok(Some(state))
334 }
335 None => Ok(None),
336 }
337 }
338
339 async fn store_results_batch(&mut self, results: &[(Uuid, TaskMeta)]) -> Result<()> {
342 if results.is_empty() {
343 return Ok(());
344 }
345
346 let mut tx =
347 self.pool.begin().await.map_err(|e| {
348 BackendError::Connection(format!("Failed to begin transaction: {}", e))
349 })?;
350
351 for (task_id, meta) in results {
352 let (result_state, result_data, error_message, retry_count) = match &meta.result {
353 TaskResult::Pending => ("pending", None, None, None),
354 TaskResult::Started => ("started", None, None, None),
355 TaskResult::Success(data) => ("success", Some(data.clone()), None, None),
356 TaskResult::Failure(err) => ("failure", None, Some(err.clone()), None),
357 TaskResult::Revoked => ("revoked", None, None, None),
358 TaskResult::Retry(count) => ("retry", None, None, Some(*count as i32)),
359 };
360
361 sqlx::query(
362 r#"
363 INSERT INTO celers_task_results
364 (task_id, task_name, result_state, result_data, error_message, retry_count,
365 created_at, started_at, completed_at, worker)
366 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
367 ON CONFLICT (task_id) DO UPDATE SET
368 result_state = EXCLUDED.result_state,
369 result_data = EXCLUDED.result_data,
370 error_message = EXCLUDED.error_message,
371 retry_count = EXCLUDED.retry_count,
372 started_at = EXCLUDED.started_at,
373 completed_at = EXCLUDED.completed_at,
374 worker = EXCLUDED.worker
375 "#,
376 )
377 .bind(task_id)
378 .bind(&meta.task_name)
379 .bind(result_state)
380 .bind(result_data)
381 .bind(error_message)
382 .bind(retry_count)
383 .bind(meta.created_at)
384 .bind(meta.started_at)
385 .bind(meta.completed_at)
386 .bind(&meta.worker)
387 .execute(&mut *tx)
388 .await
389 .map_err(|e| BackendError::Connection(format!("Failed to store result: {}", e)))?;
390 }
391
392 tx.commit().await.map_err(|e| {
393 BackendError::Connection(format!("Failed to commit transaction: {}", e))
394 })?;
395
396 Ok(())
397 }
398
399 async fn get_results_batch(&mut self, task_ids: &[Uuid]) -> Result<Vec<Option<TaskMeta>>> {
400 if task_ids.is_empty() {
401 return Ok(Vec::new());
402 }
403
404 let rows = sqlx::query(
406 r#"
407 SELECT task_id, task_name, result_state, result_data, error_message,
408 retry_count, created_at, started_at, completed_at, worker
409 FROM celers_task_results
410 WHERE task_id = ANY($1)
411 "#,
412 )
413 .bind(task_ids)
414 .fetch_all(&self.pool)
415 .await
416 .map_err(|e| BackendError::Connection(format!("Failed to get results: {}", e)))?;
417
418 let mut results_map = std::collections::HashMap::new();
420 for row in rows {
421 let task_id: Uuid = row.get("task_id");
422 let result_state: String = row.get("result_state");
423 let result_data: Option<serde_json::Value> = row.get("result_data");
424 let error_message: Option<String> = row.get("error_message");
425 let retry_count: Option<i32> = row.get("retry_count");
426
427 let result = match result_state.as_str() {
428 "pending" => TaskResult::Pending,
429 "started" => TaskResult::Started,
430 "success" => TaskResult::Success(result_data.unwrap_or(json!(null))),
431 "failure" => TaskResult::Failure(error_message.unwrap_or_default()),
432 "revoked" => TaskResult::Revoked,
433 "retry" => TaskResult::Retry(retry_count.unwrap_or(0) as u32),
434 _ => TaskResult::Pending,
435 };
436
437 let meta = TaskMeta {
438 task_id: row.get("task_id"),
439 task_name: row.get("task_name"),
440 result,
441 created_at: row.get("created_at"),
442 started_at: row.get("started_at"),
443 completed_at: row.get("completed_at"),
444 worker: row.get("worker"),
445 progress: None,
446 version: 0,
447 tags: Vec::new(),
448 metadata: std::collections::HashMap::new(),
449 worker_hostname: None,
450 runtime_ms: None,
451 memory_bytes: None,
452 retries: None,
453 queue: None,
454 };
455
456 results_map.insert(task_id, meta);
457 }
458
459 Ok(task_ids
461 .iter()
462 .map(|id| results_map.get(id).cloned())
463 .collect())
464 }
465
466 async fn delete_results_batch(&mut self, task_ids: &[Uuid]) -> Result<()> {
467 if task_ids.is_empty() {
468 return Ok(());
469 }
470
471 sqlx::query("DELETE FROM celers_task_results WHERE task_id = ANY($1)")
472 .bind(task_ids)
473 .execute(&self.pool)
474 .await
475 .map_err(|e| BackendError::Connection(format!("Failed to delete results: {}", e)))?;
476
477 Ok(())
478 }
479}
480
481#[derive(Clone)]
483pub struct MysqlResultBackend {
484 pool: MySqlPool,
485 ttl_config: TaskTtlConfig,
486}
487
488impl MysqlResultBackend {
489 pub async fn new(database_url: &str) -> Result<Self> {
494 let pool = sqlx::mysql::MySqlPoolOptions::new()
495 .max_connections(20)
496 .acquire_timeout(Duration::from_secs(5))
497 .connect(database_url)
498 .await
499 .map_err(|e| {
500 BackendError::Connection(format!("Failed to connect to database: {}", e))
501 })?;
502
503 Ok(Self {
504 pool,
505 ttl_config: TaskTtlConfig::new(),
506 })
507 }
508
509 pub fn with_ttl_config(mut self, config: TaskTtlConfig) -> Self {
511 self.ttl_config = config;
512 self
513 }
514
515 pub fn ttl_config(&self) -> &TaskTtlConfig {
517 &self.ttl_config
518 }
519
520 pub fn ttl_config_mut(&mut self) -> &mut TaskTtlConfig {
522 &mut self.ttl_config
523 }
524
525 pub async fn migrate(&self) -> Result<()> {
527 let migration_sql = include_str!("../migrations/001_init_mysql.sql");
528
529 let statements: Vec<&str> = migration_sql.split("DELIMITER //").collect();
531
532 if let Some(main_sql) = statements.first() {
534 for statement in main_sql.split(';') {
535 let trimmed = statement.trim();
536 if !trimmed.is_empty() && !trimmed.starts_with("--") {
537 sqlx::query(trimmed)
538 .execute(&self.pool)
539 .await
540 .map_err(|e| {
541 BackendError::Connection(format!("Migration failed: {}", e))
542 })?;
543 }
544 }
545 }
546
547 for &proc_section in statements.iter().skip(1) {
549 if let Some(proc_sql) = proc_section.split("DELIMITER ;").next() {
550 let trimmed = proc_sql.trim();
551 if !trimmed.is_empty() {
552 sqlx::query(trimmed)
553 .execute(&self.pool)
554 .await
555 .map_err(|e| {
556 BackendError::Connection(format!(
557 "Stored procedure creation failed: {}",
558 e
559 ))
560 })?;
561 }
562 }
563 }
564
565 Ok(())
566 }
567
568 pub fn pool(&self) -> &MySqlPool {
570 &self.pool
571 }
572}
573
574#[async_trait]
575impl ResultBackend for MysqlResultBackend {
576 async fn store_result(&mut self, task_id: Uuid, meta: &TaskMeta) -> Result<()> {
577 let (result_state, result_data, error_message, retry_count) = match &meta.result {
578 TaskResult::Pending => ("pending", None, None, None),
579 TaskResult::Started => ("started", None, None, None),
580 TaskResult::Success(data) => ("success", Some(data.clone()), None, None),
581 TaskResult::Failure(err) => ("failure", None, Some(err.clone()), None),
582 TaskResult::Revoked => ("revoked", None, None, None),
583 TaskResult::Retry(count) => ("retry", None, None, Some(*count as i32)),
584 };
585
586 let result_data_str =
587 result_data.map(|v| serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()));
588
589 sqlx::query(
590 r#"
591 INSERT INTO celers_task_results
592 (task_id, task_name, result_state, result_data, error_message, retry_count,
593 created_at, started_at, completed_at, worker)
594 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
595 ON DUPLICATE KEY UPDATE
596 result_state = VALUES(result_state),
597 result_data = VALUES(result_data),
598 error_message = VALUES(error_message),
599 retry_count = VALUES(retry_count),
600 started_at = VALUES(started_at),
601 completed_at = VALUES(completed_at),
602 worker = VALUES(worker)
603 "#,
604 )
605 .bind(task_id.to_string())
606 .bind(&meta.task_name)
607 .bind(result_state)
608 .bind(result_data_str)
609 .bind(error_message)
610 .bind(retry_count)
611 .bind(meta.created_at)
612 .bind(meta.started_at)
613 .bind(meta.completed_at)
614 .bind(&meta.worker)
615 .execute(&self.pool)
616 .await
617 .map_err(|e| BackendError::Connection(format!("Failed to store result: {}", e)))?;
618
619 if let Some(ttl) = self.ttl_config.get_ttl(&meta.task_name) {
621 self.set_expiration(task_id, ttl).await?;
622 }
623
624 Ok(())
625 }
626
627 async fn get_result(&mut self, task_id: Uuid) -> Result<Option<TaskMeta>> {
628 let row = sqlx::query(
629 r#"
630 SELECT task_id, task_name, result_state, result_data, error_message,
631 retry_count, created_at, started_at, completed_at, worker
632 FROM celers_task_results
633 WHERE task_id = ?
634 "#,
635 )
636 .bind(task_id.to_string())
637 .fetch_optional(&self.pool)
638 .await
639 .map_err(|e| BackendError::Connection(format!("Failed to get result: {}", e)))?;
640
641 match row {
642 Some(row) => {
643 let task_id_str: String = row.get("task_id");
644 let result_state: String = row.get("result_state");
645 let result_data_str: Option<String> = row.get("result_data");
646 let error_message: Option<String> = row.get("error_message");
647 let retry_count: Option<i32> = row.get("retry_count");
648
649 let result_data = result_data_str.and_then(|s| serde_json::from_str(&s).ok());
650
651 let result = match result_state.as_str() {
652 "pending" => TaskResult::Pending,
653 "started" => TaskResult::Started,
654 "success" => TaskResult::Success(result_data.unwrap_or(json!(null))),
655 "failure" => TaskResult::Failure(error_message.unwrap_or_default()),
656 "revoked" => TaskResult::Revoked,
657 "retry" => TaskResult::Retry(retry_count.unwrap_or(0) as u32),
658 _ => TaskResult::Pending,
659 };
660
661 let meta = TaskMeta {
662 task_id: Uuid::parse_str(&task_id_str)
663 .map_err(|e| BackendError::Serialization(e.to_string()))?,
664 task_name: row.get("task_name"),
665 result,
666 created_at: row.get::<DateTime<Utc>, _>("created_at"),
667 started_at: row.get("started_at"),
668 completed_at: row.get("completed_at"),
669 worker: row.get("worker"),
670 progress: None,
671 version: 0,
672 tags: Vec::new(),
673 metadata: std::collections::HashMap::new(),
674 worker_hostname: None,
675 runtime_ms: None,
676 memory_bytes: None,
677 retries: None,
678 queue: None,
679 };
680
681 Ok(Some(meta))
682 }
683 None => Ok(None),
684 }
685 }
686
687 async fn delete_result(&mut self, task_id: Uuid) -> Result<()> {
688 sqlx::query("DELETE FROM celers_task_results WHERE task_id = ?")
689 .bind(task_id.to_string())
690 .execute(&self.pool)
691 .await
692 .map_err(|e| BackendError::Connection(format!("Failed to delete result: {}", e)))?;
693
694 Ok(())
695 }
696
697 async fn set_expiration(&mut self, task_id: Uuid, ttl: Duration) -> Result<()> {
698 let expires_at = Utc::now()
699 + chrono::Duration::from_std(ttl)
700 .map_err(|e| BackendError::Serialization(format!("Invalid TTL duration: {}", e)))?;
701
702 sqlx::query("UPDATE celers_task_results SET expires_at = ? WHERE task_id = ?")
703 .bind(expires_at)
704 .bind(task_id.to_string())
705 .execute(&self.pool)
706 .await
707 .map_err(|e| BackendError::Connection(format!("Failed to set expiration: {}", e)))?;
708
709 Ok(())
710 }
711
712 async fn chord_init(&mut self, state: ChordState) -> Result<()> {
713 let task_ids = serde_json::to_string(&state.task_ids)
714 .map_err(|e| BackendError::Serialization(e.to_string()))?;
715
716 sqlx::query(
717 r#"
718 INSERT INTO celers_chord_state (chord_id, total, completed, callback, task_ids, created_at, timeout_seconds, cancelled, cancellation_reason)
719 VALUES (?, ?, 0, ?, ?, ?, ?, ?, ?)
720 ON DUPLICATE KEY UPDATE
721 total = VALUES(total),
722 callback = VALUES(callback),
723 task_ids = VALUES(task_ids),
724 timeout_seconds = VALUES(timeout_seconds),
725 cancelled = VALUES(cancelled),
726 cancellation_reason = VALUES(cancellation_reason)
727 "#,
728 )
729 .bind(state.chord_id.to_string())
730 .bind(state.total as i32)
731 .bind(&state.callback)
732 .bind(task_ids)
733 .bind(state.created_at)
734 .bind(state.timeout.map(|d| d.as_secs() as i64))
735 .bind(state.cancelled)
736 .bind(&state.cancellation_reason)
737 .execute(&self.pool)
738 .await
739 .map_err(|e| BackendError::Connection(format!("Failed to init chord: {}", e)))?;
740
741 Ok(())
742 }
743
744 async fn chord_complete_task(&mut self, chord_id: Uuid) -> Result<usize> {
745 sqlx::query("UPDATE celers_chord_state SET completed = completed + 1 WHERE chord_id = ?")
748 .bind(chord_id.to_string())
749 .execute(&self.pool)
750 .await
751 .map_err(|e| {
752 BackendError::Connection(format!("Failed to increment chord counter: {}", e))
753 })?;
754
755 let row = sqlx::query("SELECT completed FROM celers_chord_state WHERE chord_id = ?")
756 .bind(chord_id.to_string())
757 .fetch_one(&self.pool)
758 .await
759 .map_err(|e| BackendError::Connection(format!("Failed to get chord counter: {}", e)))?;
760
761 let count: i32 = row.get("completed");
762 Ok(count as usize)
763 }
764
765 async fn chord_get_state(&mut self, chord_id: Uuid) -> Result<Option<ChordState>> {
766 let row = sqlx::query(
767 r#"
768 SELECT chord_id, total, completed, callback, task_ids, created_at, timeout_seconds, cancelled, cancellation_reason
769 FROM celers_chord_state
770 WHERE chord_id = ?
771 "#,
772 )
773 .bind(chord_id.to_string())
774 .fetch_optional(&self.pool)
775 .await
776 .map_err(|e| BackendError::Connection(format!("Failed to get chord state: {}", e)))?;
777
778 match row {
779 Some(row) => {
780 let chord_id_str: String = row.get("chord_id");
781 let task_ids_str: String = row.get("task_ids");
782 let task_ids: Vec<Uuid> = serde_json::from_str(&task_ids_str)
783 .map_err(|e| BackendError::Serialization(e.to_string()))?;
784
785 let state = ChordState {
786 chord_id: Uuid::parse_str(&chord_id_str)
787 .map_err(|e| BackendError::Serialization(e.to_string()))?,
788 total: row.get::<i32, _>("total") as usize,
789 completed: row.get::<i32, _>("completed") as usize,
790 callback: row.get("callback"),
791 task_ids,
792 created_at: row.get("created_at"),
793 timeout: row
794 .get::<Option<i64>, _>("timeout_seconds")
795 .map(|s| std::time::Duration::from_secs(s as u64)),
796 cancelled: row.get("cancelled"),
797 cancellation_reason: row.get("cancellation_reason"),
798 retry_count: 0,
799 max_retries: None,
800 };
801
802 Ok(Some(state))
803 }
804 None => Ok(None),
805 }
806 }
807
808 async fn store_results_batch(&mut self, results: &[(Uuid, TaskMeta)]) -> Result<()> {
811 if results.is_empty() {
812 return Ok(());
813 }
814
815 let mut tx =
816 self.pool.begin().await.map_err(|e| {
817 BackendError::Connection(format!("Failed to begin transaction: {}", e))
818 })?;
819
820 for (task_id, meta) in results {
821 let (result_state, result_data, error_message, retry_count) = match &meta.result {
822 TaskResult::Pending => ("pending", None, None, None),
823 TaskResult::Started => ("started", None, None, None),
824 TaskResult::Success(data) => ("success", Some(data.clone()), None, None),
825 TaskResult::Failure(err) => ("failure", None, Some(err.clone()), None),
826 TaskResult::Revoked => ("revoked", None, None, None),
827 TaskResult::Retry(count) => ("retry", None, None, Some(*count as i32)),
828 };
829
830 sqlx::query(
831 r#"
832 INSERT INTO celers_task_results
833 (task_id, task_name, result_state, result_data, error_message, retry_count,
834 created_at, started_at, completed_at, worker)
835 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
836 ON DUPLICATE KEY UPDATE
837 result_state = VALUES(result_state),
838 result_data = VALUES(result_data),
839 error_message = VALUES(error_message),
840 retry_count = VALUES(retry_count),
841 started_at = VALUES(started_at),
842 completed_at = VALUES(completed_at),
843 worker = VALUES(worker)
844 "#,
845 )
846 .bind(task_id)
847 .bind(&meta.task_name)
848 .bind(result_state)
849 .bind(result_data)
850 .bind(error_message)
851 .bind(retry_count)
852 .bind(meta.created_at)
853 .bind(meta.started_at)
854 .bind(meta.completed_at)
855 .bind(&meta.worker)
856 .execute(&mut *tx)
857 .await
858 .map_err(|e| BackendError::Connection(format!("Failed to store result: {}", e)))?;
859 }
860
861 tx.commit().await.map_err(|e| {
862 BackendError::Connection(format!("Failed to commit transaction: {}", e))
863 })?;
864
865 Ok(())
866 }
867
868 async fn get_results_batch(&mut self, task_ids: &[Uuid]) -> Result<Vec<Option<TaskMeta>>> {
869 if task_ids.is_empty() {
870 return Ok(Vec::new());
871 }
872
873 let placeholders = task_ids.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
875 let query_str = format!(
876 r#"
877 SELECT task_id, task_name, result_state, result_data, error_message,
878 retry_count, created_at, started_at, completed_at, worker
879 FROM celers_task_results
880 WHERE task_id IN ({})
881 "#,
882 placeholders
883 );
884
885 let mut query = sqlx::query(&query_str);
886 for task_id in task_ids {
887 query = query.bind(task_id);
888 }
889
890 let rows = query
891 .fetch_all(&self.pool)
892 .await
893 .map_err(|e| BackendError::Connection(format!("Failed to get results: {}", e)))?;
894
895 let mut results_map = std::collections::HashMap::new();
897 for row in rows {
898 let task_id: Uuid = row.get("task_id");
899 let result_state: String = row.get("result_state");
900 let result_data: Option<serde_json::Value> = row.get("result_data");
901 let error_message: Option<String> = row.get("error_message");
902 let retry_count: Option<i32> = row.get("retry_count");
903
904 let result = match result_state.as_str() {
905 "pending" => TaskResult::Pending,
906 "started" => TaskResult::Started,
907 "success" => TaskResult::Success(result_data.unwrap_or(json!(null))),
908 "failure" => TaskResult::Failure(error_message.unwrap_or_default()),
909 "revoked" => TaskResult::Revoked,
910 "retry" => TaskResult::Retry(retry_count.unwrap_or(0) as u32),
911 _ => TaskResult::Pending,
912 };
913
914 let meta = TaskMeta {
915 task_id: row.get("task_id"),
916 task_name: row.get("task_name"),
917 result,
918 created_at: row.get("created_at"),
919 started_at: row.get("started_at"),
920 completed_at: row.get("completed_at"),
921 worker: row.get("worker"),
922 progress: None,
923 version: 0,
924 tags: Vec::new(),
925 metadata: std::collections::HashMap::new(),
926 worker_hostname: None,
927 runtime_ms: None,
928 memory_bytes: None,
929 retries: None,
930 queue: None,
931 };
932
933 results_map.insert(task_id, meta);
934 }
935
936 Ok(task_ids
938 .iter()
939 .map(|id| results_map.get(id).cloned())
940 .collect())
941 }
942
943 async fn delete_results_batch(&mut self, task_ids: &[Uuid]) -> Result<()> {
944 if task_ids.is_empty() {
945 return Ok(());
946 }
947
948 let placeholders = task_ids.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
950 let query_str = format!(
951 "DELETE FROM celers_task_results WHERE task_id IN ({})",
952 placeholders
953 );
954
955 let mut query = sqlx::query(&query_str);
956 for task_id in task_ids {
957 query = query.bind(task_id);
958 }
959
960 query
961 .execute(&self.pool)
962 .await
963 .map_err(|e| BackendError::Connection(format!("Failed to delete results: {}", e)))?;
964
965 Ok(())
966 }
967}
968
969#[cfg(test)]
970mod tests {
971 use super::*;
972
973 #[tokio::test]
974 #[ignore] async fn test_postgres_backend_creation() {
976 let database_url = std::env::var("DATABASE_URL")
977 .unwrap_or_else(|_| "postgres://postgres:postgres@localhost/celers_test".to_string());
978
979 let backend = PostgresResultBackend::new(&database_url).await;
980 assert!(backend.is_ok());
981 }
982
983 #[tokio::test]
984 #[ignore] async fn test_mysql_backend_creation() {
986 let database_url = std::env::var("MYSQL_URL")
987 .unwrap_or_else(|_| "mysql://root:password@localhost/celers_test".to_string());
988
989 let backend = MysqlResultBackend::new(&database_url).await;
990 assert!(backend.is_ok());
991 }
992}