Skip to main content

celers_backend_db/
lib.rs

1//! Database result backend for CeleRS
2//!
3//! This crate provides PostgreSQL and MySQL-based storage for task results and workflow state.
4//!
5//! # Features
6//!
7//! - Task result storage with expiration
8//! - Chord state management (barrier synchronization)
9//! - Atomic counter operations
10//! - SQL-based result queries and analytics
11//! - Support for both PostgreSQL and MySQL
12//!
13//! # Example
14//!
15//! ```ignore
16//! use celers_backend_db::PostgresResultBackend;
17//! use celers_backend_redis::ResultBackend;
18//!
19//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
20//! let mut backend = PostgresResultBackend::new("postgres://localhost/celers").await?;
21//! backend.migrate().await?;
22//!
23//! // Store task result
24//! let meta = TaskMeta::new(task_id, "my_task".to_string());
25//! backend.store_result(task_id, &meta).await?;
26//! # Ok(())
27//! # }
28//! ```
29
30pub mod event_persistence;
31#[cfg(feature = "distributed-locks")]
32pub mod lock;
33pub mod result_store;
34
35pub use event_persistence::{DbEventPersister, DbEventPersisterConfig};
36
37use async_trait::async_trait;
38pub use celers_backend_redis::{
39    BackendError, ChordState, Result, ResultBackend, TaskMeta, TaskResult, TaskTtlConfig,
40};
41use chrono::{DateTime, Utc};
42use serde_json::json;
43use sqlx::{postgres::PgPoolOptions, MySqlPool, PgPool, Row};
44use std::time::Duration;
45use uuid::Uuid;
46
47/// PostgreSQL result backend implementation
48#[derive(Clone)]
49pub struct PostgresResultBackend {
50    pool: PgPool,
51    ttl_config: TaskTtlConfig,
52}
53
54impl PostgresResultBackend {
55    /// Create a new PostgreSQL result backend
56    ///
57    /// # Arguments
58    /// * `database_url` - PostgreSQL connection string (e.g., "postgres://user:pass@localhost/db")
59    pub async fn new(database_url: &str) -> Result<Self> {
60        let pool = PgPoolOptions::new()
61            .max_connections(20)
62            .acquire_timeout(Duration::from_secs(5))
63            .connect(database_url)
64            .await
65            .map_err(|e| {
66                BackendError::Connection(format!("Failed to connect to database: {}", e))
67            })?;
68
69        Ok(Self {
70            pool,
71            ttl_config: TaskTtlConfig::new(),
72        })
73    }
74
75    /// Configure per-task-type TTL
76    pub fn with_ttl_config(mut self, config: TaskTtlConfig) -> Self {
77        self.ttl_config = config;
78        self
79    }
80
81    /// Get the TTL configuration
82    pub fn ttl_config(&self) -> &TaskTtlConfig {
83        &self.ttl_config
84    }
85
86    /// Get a mutable reference to the TTL configuration
87    pub fn ttl_config_mut(&mut self) -> &mut TaskTtlConfig {
88        &mut self.ttl_config
89    }
90
91    /// Run database migrations
92    pub async fn migrate(&self) -> Result<()> {
93        let migration_sql = include_str!("../migrations/001_init_postgres.sql");
94
95        sqlx::query(migration_sql)
96            .execute(&self.pool)
97            .await
98            .map_err(|e| BackendError::Connection(format!("Migration failed: {}", e)))?;
99
100        Ok(())
101    }
102
103    /// Get the underlying connection pool
104    pub fn pool(&self) -> &PgPool {
105        &self.pool
106    }
107
108    /// Clean up expired results (returns number of deleted rows)
109    pub async fn cleanup_expired(&self) -> Result<usize> {
110        let row = sqlx::query("SELECT cleanup_expired_results()")
111            .fetch_one(&self.pool)
112            .await
113            .map_err(|e| {
114                BackendError::Connection(format!("Failed to cleanup expired results: {}", e))
115            })?;
116
117        let count: i32 = row.get(0);
118        Ok(count as usize)
119    }
120}
121
122#[async_trait]
123impl ResultBackend for PostgresResultBackend {
124    async fn store_result(&mut self, task_id: Uuid, meta: &TaskMeta) -> Result<()> {
125        let (result_state, result_data, error_message, retry_count) = match &meta.result {
126            TaskResult::Pending => ("pending", None, None, None),
127            TaskResult::Started => ("started", None, None, None),
128            TaskResult::Success(data) => ("success", Some(data.clone()), None, None),
129            TaskResult::Failure(err) => ("failure", None, Some(err.clone()), None),
130            TaskResult::Revoked => ("revoked", None, None, None),
131            TaskResult::Retry(count) => ("retry", None, None, Some(*count as i32)),
132        };
133
134        sqlx::query(
135            r#"
136            INSERT INTO celers_task_results
137                (task_id, task_name, result_state, result_data, error_message, retry_count,
138                 created_at, started_at, completed_at, worker)
139            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
140            ON CONFLICT (task_id) DO UPDATE SET
141                result_state = EXCLUDED.result_state,
142                result_data = EXCLUDED.result_data,
143                error_message = EXCLUDED.error_message,
144                retry_count = EXCLUDED.retry_count,
145                started_at = EXCLUDED.started_at,
146                completed_at = EXCLUDED.completed_at,
147                worker = EXCLUDED.worker
148            "#,
149        )
150        .bind(task_id)
151        .bind(&meta.task_name)
152        .bind(result_state)
153        .bind(result_data)
154        .bind(error_message)
155        .bind(retry_count)
156        .bind(meta.created_at)
157        .bind(meta.started_at)
158        .bind(meta.completed_at)
159        .bind(&meta.worker)
160        .execute(&self.pool)
161        .await
162        .map_err(|e| BackendError::Connection(format!("Failed to store result: {}", e)))?;
163
164        // Apply per-task TTL if configured
165        if let Some(ttl) = self.ttl_config.get_ttl(&meta.task_name) {
166            self.set_expiration(task_id, ttl).await?;
167        }
168
169        Ok(())
170    }
171
172    async fn get_result(&mut self, task_id: Uuid) -> Result<Option<TaskMeta>> {
173        let row = sqlx::query(
174            r#"
175            SELECT task_id, task_name, result_state, result_data, error_message,
176                   retry_count, created_at, started_at, completed_at, worker
177            FROM celers_task_results
178            WHERE task_id = $1
179            "#,
180        )
181        .bind(task_id)
182        .fetch_optional(&self.pool)
183        .await
184        .map_err(|e| BackendError::Connection(format!("Failed to get result: {}", e)))?;
185
186        match row {
187            Some(row) => {
188                let result_state: String = row.get("result_state");
189                let result_data: Option<serde_json::Value> = row.get("result_data");
190                let error_message: Option<String> = row.get("error_message");
191                let retry_count: Option<i32> = row.get("retry_count");
192
193                let result = match result_state.as_str() {
194                    "pending" => TaskResult::Pending,
195                    "started" => TaskResult::Started,
196                    "success" => TaskResult::Success(result_data.unwrap_or(json!(null))),
197                    "failure" => TaskResult::Failure(error_message.unwrap_or_default()),
198                    "revoked" => TaskResult::Revoked,
199                    "retry" => TaskResult::Retry(retry_count.unwrap_or(0) as u32),
200                    _ => TaskResult::Pending,
201                };
202
203                let meta = TaskMeta {
204                    task_id: row.get("task_id"),
205                    task_name: row.get("task_name"),
206                    result,
207                    created_at: row.get("created_at"),
208                    started_at: row.get("started_at"),
209                    completed_at: row.get("completed_at"),
210                    worker: row.get("worker"),
211                    progress: None,
212                    version: 0,
213                    tags: Vec::new(),
214                    metadata: std::collections::HashMap::new(),
215                    worker_hostname: None,
216                    runtime_ms: None,
217                    memory_bytes: None,
218                    retries: None,
219                    queue: None,
220                };
221
222                Ok(Some(meta))
223            }
224            None => Ok(None),
225        }
226    }
227
228    async fn delete_result(&mut self, task_id: Uuid) -> Result<()> {
229        sqlx::query("DELETE FROM celers_task_results WHERE task_id = $1")
230            .bind(task_id)
231            .execute(&self.pool)
232            .await
233            .map_err(|e| BackendError::Connection(format!("Failed to delete result: {}", e)))?;
234
235        Ok(())
236    }
237
238    async fn set_expiration(&mut self, task_id: Uuid, ttl: Duration) -> Result<()> {
239        let expires_at = Utc::now()
240            + chrono::Duration::from_std(ttl)
241                .map_err(|e| BackendError::Serialization(format!("Invalid TTL duration: {}", e)))?;
242
243        sqlx::query("UPDATE celers_task_results SET expires_at = $1 WHERE task_id = $2")
244            .bind(expires_at)
245            .bind(task_id)
246            .execute(&self.pool)
247            .await
248            .map_err(|e| BackendError::Connection(format!("Failed to set expiration: {}", e)))?;
249
250        Ok(())
251    }
252
253    async fn chord_init(&mut self, state: ChordState) -> Result<()> {
254        let task_ids = serde_json::to_value(&state.task_ids)
255            .map_err(|e| BackendError::Serialization(e.to_string()))?;
256
257        sqlx::query(
258            r#"
259            INSERT INTO celers_chord_state (chord_id, total, completed, callback, task_ids, created_at, timeout_seconds, cancelled, cancellation_reason)
260            VALUES ($1, $2, 0, $3, $4, $5, $6, $7, $8)
261            ON CONFLICT (chord_id) DO UPDATE SET
262                total = EXCLUDED.total,
263                callback = EXCLUDED.callback,
264                task_ids = EXCLUDED.task_ids,
265                timeout_seconds = EXCLUDED.timeout_seconds,
266                cancelled = EXCLUDED.cancelled,
267                cancellation_reason = EXCLUDED.cancellation_reason
268            "#,
269        )
270        .bind(state.chord_id)
271        .bind(state.total as i32)
272        .bind(&state.callback)
273        .bind(task_ids)
274        .bind(state.created_at)
275        .bind(state.timeout.map(|d| d.as_secs() as i64))
276        .bind(state.cancelled)
277        .bind(&state.cancellation_reason)
278        .execute(&self.pool)
279        .await
280        .map_err(|e| BackendError::Connection(format!("Failed to init chord: {}", e)))?;
281
282        Ok(())
283    }
284
285    async fn chord_complete_task(&mut self, chord_id: Uuid) -> Result<usize> {
286        let row = sqlx::query("SELECT chord_increment_counter($1)")
287            .bind(chord_id)
288            .fetch_one(&self.pool)
289            .await
290            .map_err(|e| {
291                BackendError::Connection(format!("Failed to increment chord counter: {}", e))
292            })?;
293
294        let count: i32 = row.get(0);
295        Ok(count as usize)
296    }
297
298    async fn chord_get_state(&mut self, chord_id: Uuid) -> Result<Option<ChordState>> {
299        let row = sqlx::query(
300            r#"
301            SELECT chord_id, total, completed, callback, task_ids, created_at, timeout_seconds, cancelled, cancellation_reason
302            FROM celers_chord_state
303            WHERE chord_id = $1
304            "#,
305        )
306        .bind(chord_id)
307        .fetch_optional(&self.pool)
308        .await
309        .map_err(|e| BackendError::Connection(format!("Failed to get chord state: {}", e)))?;
310
311        match row {
312            Some(row) => {
313                let task_ids_json: serde_json::Value = row.get("task_ids");
314                let task_ids: Vec<Uuid> = serde_json::from_value(task_ids_json)
315                    .map_err(|e| BackendError::Serialization(e.to_string()))?;
316
317                let state = ChordState {
318                    chord_id: row.get("chord_id"),
319                    total: row.get::<i32, _>("total") as usize,
320                    completed: row.get::<i32, _>("completed") as usize,
321                    callback: row.get("callback"),
322                    task_ids,
323                    created_at: row.get("created_at"),
324                    timeout: row
325                        .get::<Option<i64>, _>("timeout_seconds")
326                        .map(|s| std::time::Duration::from_secs(s as u64)),
327                    cancelled: row.get("cancelled"),
328                    cancellation_reason: row.get("cancellation_reason"),
329                    retry_count: 0,
330                    max_retries: None,
331                };
332
333                Ok(Some(state))
334            }
335            None => Ok(None),
336        }
337    }
338
339    // Batch operations using transactions for atomic multi-row operations
340
341    async fn store_results_batch(&mut self, results: &[(Uuid, TaskMeta)]) -> Result<()> {
342        if results.is_empty() {
343            return Ok(());
344        }
345
346        let mut tx =
347            self.pool.begin().await.map_err(|e| {
348                BackendError::Connection(format!("Failed to begin transaction: {}", e))
349            })?;
350
351        for (task_id, meta) in results {
352            let (result_state, result_data, error_message, retry_count) = match &meta.result {
353                TaskResult::Pending => ("pending", None, None, None),
354                TaskResult::Started => ("started", None, None, None),
355                TaskResult::Success(data) => ("success", Some(data.clone()), None, None),
356                TaskResult::Failure(err) => ("failure", None, Some(err.clone()), None),
357                TaskResult::Revoked => ("revoked", None, None, None),
358                TaskResult::Retry(count) => ("retry", None, None, Some(*count as i32)),
359            };
360
361            sqlx::query(
362                r#"
363                INSERT INTO celers_task_results
364                    (task_id, task_name, result_state, result_data, error_message, retry_count,
365                     created_at, started_at, completed_at, worker)
366                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
367                ON CONFLICT (task_id) DO UPDATE SET
368                    result_state = EXCLUDED.result_state,
369                    result_data = EXCLUDED.result_data,
370                    error_message = EXCLUDED.error_message,
371                    retry_count = EXCLUDED.retry_count,
372                    started_at = EXCLUDED.started_at,
373                    completed_at = EXCLUDED.completed_at,
374                    worker = EXCLUDED.worker
375                "#,
376            )
377            .bind(task_id)
378            .bind(&meta.task_name)
379            .bind(result_state)
380            .bind(result_data)
381            .bind(error_message)
382            .bind(retry_count)
383            .bind(meta.created_at)
384            .bind(meta.started_at)
385            .bind(meta.completed_at)
386            .bind(&meta.worker)
387            .execute(&mut *tx)
388            .await
389            .map_err(|e| BackendError::Connection(format!("Failed to store result: {}", e)))?;
390        }
391
392        tx.commit().await.map_err(|e| {
393            BackendError::Connection(format!("Failed to commit transaction: {}", e))
394        })?;
395
396        Ok(())
397    }
398
399    async fn get_results_batch(&mut self, task_ids: &[Uuid]) -> Result<Vec<Option<TaskMeta>>> {
400        if task_ids.is_empty() {
401            return Ok(Vec::new());
402        }
403
404        // PostgreSQL supports = ANY($1) for array queries
405        let rows = sqlx::query(
406            r#"
407            SELECT task_id, task_name, result_state, result_data, error_message,
408                   retry_count, created_at, started_at, completed_at, worker
409            FROM celers_task_results
410            WHERE task_id = ANY($1)
411            "#,
412        )
413        .bind(task_ids)
414        .fetch_all(&self.pool)
415        .await
416        .map_err(|e| BackendError::Connection(format!("Failed to get results: {}", e)))?;
417
418        // Create a HashMap for O(1) lookup
419        let mut results_map = std::collections::HashMap::new();
420        for row in rows {
421            let task_id: Uuid = row.get("task_id");
422            let result_state: String = row.get("result_state");
423            let result_data: Option<serde_json::Value> = row.get("result_data");
424            let error_message: Option<String> = row.get("error_message");
425            let retry_count: Option<i32> = row.get("retry_count");
426
427            let result = match result_state.as_str() {
428                "pending" => TaskResult::Pending,
429                "started" => TaskResult::Started,
430                "success" => TaskResult::Success(result_data.unwrap_or(json!(null))),
431                "failure" => TaskResult::Failure(error_message.unwrap_or_default()),
432                "revoked" => TaskResult::Revoked,
433                "retry" => TaskResult::Retry(retry_count.unwrap_or(0) as u32),
434                _ => TaskResult::Pending,
435            };
436
437            let meta = TaskMeta {
438                task_id: row.get("task_id"),
439                task_name: row.get("task_name"),
440                result,
441                created_at: row.get("created_at"),
442                started_at: row.get("started_at"),
443                completed_at: row.get("completed_at"),
444                worker: row.get("worker"),
445                progress: None,
446                version: 0,
447                tags: Vec::new(),
448                metadata: std::collections::HashMap::new(),
449                worker_hostname: None,
450                runtime_ms: None,
451                memory_bytes: None,
452                retries: None,
453                queue: None,
454            };
455
456            results_map.insert(task_id, meta);
457        }
458
459        // Return results in the same order as input task_ids
460        Ok(task_ids
461            .iter()
462            .map(|id| results_map.get(id).cloned())
463            .collect())
464    }
465
466    async fn delete_results_batch(&mut self, task_ids: &[Uuid]) -> Result<()> {
467        if task_ids.is_empty() {
468            return Ok(());
469        }
470
471        sqlx::query("DELETE FROM celers_task_results WHERE task_id = ANY($1)")
472            .bind(task_ids)
473            .execute(&self.pool)
474            .await
475            .map_err(|e| BackendError::Connection(format!("Failed to delete results: {}", e)))?;
476
477        Ok(())
478    }
479}
480
481/// MySQL result backend implementation
482#[derive(Clone)]
483pub struct MysqlResultBackend {
484    pool: MySqlPool,
485    ttl_config: TaskTtlConfig,
486}
487
488impl MysqlResultBackend {
489    /// Create a new MySQL result backend
490    ///
491    /// # Arguments
492    /// * `database_url` - MySQL connection string (e.g., "mysql://user:pass@localhost/db")
493    pub async fn new(database_url: &str) -> Result<Self> {
494        let pool = sqlx::mysql::MySqlPoolOptions::new()
495            .max_connections(20)
496            .acquire_timeout(Duration::from_secs(5))
497            .connect(database_url)
498            .await
499            .map_err(|e| {
500                BackendError::Connection(format!("Failed to connect to database: {}", e))
501            })?;
502
503        Ok(Self {
504            pool,
505            ttl_config: TaskTtlConfig::new(),
506        })
507    }
508
509    /// Configure per-task-type TTL
510    pub fn with_ttl_config(mut self, config: TaskTtlConfig) -> Self {
511        self.ttl_config = config;
512        self
513    }
514
515    /// Get the TTL configuration
516    pub fn ttl_config(&self) -> &TaskTtlConfig {
517        &self.ttl_config
518    }
519
520    /// Get a mutable reference to the TTL configuration
521    pub fn ttl_config_mut(&mut self) -> &mut TaskTtlConfig {
522        &mut self.ttl_config
523    }
524
525    /// Run database migrations
526    pub async fn migrate(&self) -> Result<()> {
527        let migration_sql = include_str!("../migrations/001_init_mysql.sql");
528
529        // Split and execute MySQL migration (handle DELIMITER sections)
530        let statements: Vec<&str> = migration_sql.split("DELIMITER //").collect();
531
532        // Execute main DDL
533        if let Some(main_sql) = statements.first() {
534            for statement in main_sql.split(';') {
535                let trimmed = statement.trim();
536                if !trimmed.is_empty() && !trimmed.starts_with("--") {
537                    sqlx::query(trimmed)
538                        .execute(&self.pool)
539                        .await
540                        .map_err(|e| {
541                            BackendError::Connection(format!("Migration failed: {}", e))
542                        })?;
543                }
544            }
545        }
546
547        // Execute stored procedures
548        for &proc_section in statements.iter().skip(1) {
549            if let Some(proc_sql) = proc_section.split("DELIMITER ;").next() {
550                let trimmed = proc_sql.trim();
551                if !trimmed.is_empty() {
552                    sqlx::query(trimmed)
553                        .execute(&self.pool)
554                        .await
555                        .map_err(|e| {
556                            BackendError::Connection(format!(
557                                "Stored procedure creation failed: {}",
558                                e
559                            ))
560                        })?;
561                }
562            }
563        }
564
565        Ok(())
566    }
567
568    /// Get the underlying connection pool
569    pub fn pool(&self) -> &MySqlPool {
570        &self.pool
571    }
572}
573
574#[async_trait]
575impl ResultBackend for MysqlResultBackend {
576    async fn store_result(&mut self, task_id: Uuid, meta: &TaskMeta) -> Result<()> {
577        let (result_state, result_data, error_message, retry_count) = match &meta.result {
578            TaskResult::Pending => ("pending", None, None, None),
579            TaskResult::Started => ("started", None, None, None),
580            TaskResult::Success(data) => ("success", Some(data.clone()), None, None),
581            TaskResult::Failure(err) => ("failure", None, Some(err.clone()), None),
582            TaskResult::Revoked => ("revoked", None, None, None),
583            TaskResult::Retry(count) => ("retry", None, None, Some(*count as i32)),
584        };
585
586        let result_data_str =
587            result_data.map(|v| serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()));
588
589        sqlx::query(
590            r#"
591            INSERT INTO celers_task_results
592                (task_id, task_name, result_state, result_data, error_message, retry_count,
593                 created_at, started_at, completed_at, worker)
594            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
595            ON DUPLICATE KEY UPDATE
596                result_state = VALUES(result_state),
597                result_data = VALUES(result_data),
598                error_message = VALUES(error_message),
599                retry_count = VALUES(retry_count),
600                started_at = VALUES(started_at),
601                completed_at = VALUES(completed_at),
602                worker = VALUES(worker)
603            "#,
604        )
605        .bind(task_id.to_string())
606        .bind(&meta.task_name)
607        .bind(result_state)
608        .bind(result_data_str)
609        .bind(error_message)
610        .bind(retry_count)
611        .bind(meta.created_at)
612        .bind(meta.started_at)
613        .bind(meta.completed_at)
614        .bind(&meta.worker)
615        .execute(&self.pool)
616        .await
617        .map_err(|e| BackendError::Connection(format!("Failed to store result: {}", e)))?;
618
619        // Apply per-task TTL if configured
620        if let Some(ttl) = self.ttl_config.get_ttl(&meta.task_name) {
621            self.set_expiration(task_id, ttl).await?;
622        }
623
624        Ok(())
625    }
626
627    async fn get_result(&mut self, task_id: Uuid) -> Result<Option<TaskMeta>> {
628        let row = sqlx::query(
629            r#"
630            SELECT task_id, task_name, result_state, result_data, error_message,
631                   retry_count, created_at, started_at, completed_at, worker
632            FROM celers_task_results
633            WHERE task_id = ?
634            "#,
635        )
636        .bind(task_id.to_string())
637        .fetch_optional(&self.pool)
638        .await
639        .map_err(|e| BackendError::Connection(format!("Failed to get result: {}", e)))?;
640
641        match row {
642            Some(row) => {
643                let task_id_str: String = row.get("task_id");
644                let result_state: String = row.get("result_state");
645                let result_data_str: Option<String> = row.get("result_data");
646                let error_message: Option<String> = row.get("error_message");
647                let retry_count: Option<i32> = row.get("retry_count");
648
649                let result_data = result_data_str.and_then(|s| serde_json::from_str(&s).ok());
650
651                let result = match result_state.as_str() {
652                    "pending" => TaskResult::Pending,
653                    "started" => TaskResult::Started,
654                    "success" => TaskResult::Success(result_data.unwrap_or(json!(null))),
655                    "failure" => TaskResult::Failure(error_message.unwrap_or_default()),
656                    "revoked" => TaskResult::Revoked,
657                    "retry" => TaskResult::Retry(retry_count.unwrap_or(0) as u32),
658                    _ => TaskResult::Pending,
659                };
660
661                let meta = TaskMeta {
662                    task_id: Uuid::parse_str(&task_id_str)
663                        .map_err(|e| BackendError::Serialization(e.to_string()))?,
664                    task_name: row.get("task_name"),
665                    result,
666                    created_at: row.get::<DateTime<Utc>, _>("created_at"),
667                    started_at: row.get("started_at"),
668                    completed_at: row.get("completed_at"),
669                    worker: row.get("worker"),
670                    progress: None,
671                    version: 0,
672                    tags: Vec::new(),
673                    metadata: std::collections::HashMap::new(),
674                    worker_hostname: None,
675                    runtime_ms: None,
676                    memory_bytes: None,
677                    retries: None,
678                    queue: None,
679                };
680
681                Ok(Some(meta))
682            }
683            None => Ok(None),
684        }
685    }
686
687    async fn delete_result(&mut self, task_id: Uuid) -> Result<()> {
688        sqlx::query("DELETE FROM celers_task_results WHERE task_id = ?")
689            .bind(task_id.to_string())
690            .execute(&self.pool)
691            .await
692            .map_err(|e| BackendError::Connection(format!("Failed to delete result: {}", e)))?;
693
694        Ok(())
695    }
696
697    async fn set_expiration(&mut self, task_id: Uuid, ttl: Duration) -> Result<()> {
698        let expires_at = Utc::now()
699            + chrono::Duration::from_std(ttl)
700                .map_err(|e| BackendError::Serialization(format!("Invalid TTL duration: {}", e)))?;
701
702        sqlx::query("UPDATE celers_task_results SET expires_at = ? WHERE task_id = ?")
703            .bind(expires_at)
704            .bind(task_id.to_string())
705            .execute(&self.pool)
706            .await
707            .map_err(|e| BackendError::Connection(format!("Failed to set expiration: {}", e)))?;
708
709        Ok(())
710    }
711
712    async fn chord_init(&mut self, state: ChordState) -> Result<()> {
713        let task_ids = serde_json::to_string(&state.task_ids)
714            .map_err(|e| BackendError::Serialization(e.to_string()))?;
715
716        sqlx::query(
717            r#"
718            INSERT INTO celers_chord_state (chord_id, total, completed, callback, task_ids, created_at, timeout_seconds, cancelled, cancellation_reason)
719            VALUES (?, ?, 0, ?, ?, ?, ?, ?, ?)
720            ON DUPLICATE KEY UPDATE
721                total = VALUES(total),
722                callback = VALUES(callback),
723                task_ids = VALUES(task_ids),
724                timeout_seconds = VALUES(timeout_seconds),
725                cancelled = VALUES(cancelled),
726                cancellation_reason = VALUES(cancellation_reason)
727            "#,
728        )
729        .bind(state.chord_id.to_string())
730        .bind(state.total as i32)
731        .bind(&state.callback)
732        .bind(task_ids)
733        .bind(state.created_at)
734        .bind(state.timeout.map(|d| d.as_secs() as i64))
735        .bind(state.cancelled)
736        .bind(&state.cancellation_reason)
737        .execute(&self.pool)
738        .await
739        .map_err(|e| BackendError::Connection(format!("Failed to init chord: {}", e)))?;
740
741        Ok(())
742    }
743
744    async fn chord_complete_task(&mut self, chord_id: Uuid) -> Result<usize> {
745        // MySQL doesn't support function returns in SELECT, use procedure with OUT parameter
746        // For now, use a simpler UPDATE + SELECT approach
747        sqlx::query("UPDATE celers_chord_state SET completed = completed + 1 WHERE chord_id = ?")
748            .bind(chord_id.to_string())
749            .execute(&self.pool)
750            .await
751            .map_err(|e| {
752                BackendError::Connection(format!("Failed to increment chord counter: {}", e))
753            })?;
754
755        let row = sqlx::query("SELECT completed FROM celers_chord_state WHERE chord_id = ?")
756            .bind(chord_id.to_string())
757            .fetch_one(&self.pool)
758            .await
759            .map_err(|e| BackendError::Connection(format!("Failed to get chord counter: {}", e)))?;
760
761        let count: i32 = row.get("completed");
762        Ok(count as usize)
763    }
764
765    async fn chord_get_state(&mut self, chord_id: Uuid) -> Result<Option<ChordState>> {
766        let row = sqlx::query(
767            r#"
768            SELECT chord_id, total, completed, callback, task_ids, created_at, timeout_seconds, cancelled, cancellation_reason
769            FROM celers_chord_state
770            WHERE chord_id = ?
771            "#,
772        )
773        .bind(chord_id.to_string())
774        .fetch_optional(&self.pool)
775        .await
776        .map_err(|e| BackendError::Connection(format!("Failed to get chord state: {}", e)))?;
777
778        match row {
779            Some(row) => {
780                let chord_id_str: String = row.get("chord_id");
781                let task_ids_str: String = row.get("task_ids");
782                let task_ids: Vec<Uuid> = serde_json::from_str(&task_ids_str)
783                    .map_err(|e| BackendError::Serialization(e.to_string()))?;
784
785                let state = ChordState {
786                    chord_id: Uuid::parse_str(&chord_id_str)
787                        .map_err(|e| BackendError::Serialization(e.to_string()))?,
788                    total: row.get::<i32, _>("total") as usize,
789                    completed: row.get::<i32, _>("completed") as usize,
790                    callback: row.get("callback"),
791                    task_ids,
792                    created_at: row.get("created_at"),
793                    timeout: row
794                        .get::<Option<i64>, _>("timeout_seconds")
795                        .map(|s| std::time::Duration::from_secs(s as u64)),
796                    cancelled: row.get("cancelled"),
797                    cancellation_reason: row.get("cancellation_reason"),
798                    retry_count: 0,
799                    max_retries: None,
800                };
801
802                Ok(Some(state))
803            }
804            None => Ok(None),
805        }
806    }
807
808    // Batch operations using transactions for atomic multi-row operations
809
810    async fn store_results_batch(&mut self, results: &[(Uuid, TaskMeta)]) -> Result<()> {
811        if results.is_empty() {
812            return Ok(());
813        }
814
815        let mut tx =
816            self.pool.begin().await.map_err(|e| {
817                BackendError::Connection(format!("Failed to begin transaction: {}", e))
818            })?;
819
820        for (task_id, meta) in results {
821            let (result_state, result_data, error_message, retry_count) = match &meta.result {
822                TaskResult::Pending => ("pending", None, None, None),
823                TaskResult::Started => ("started", None, None, None),
824                TaskResult::Success(data) => ("success", Some(data.clone()), None, None),
825                TaskResult::Failure(err) => ("failure", None, Some(err.clone()), None),
826                TaskResult::Revoked => ("revoked", None, None, None),
827                TaskResult::Retry(count) => ("retry", None, None, Some(*count as i32)),
828            };
829
830            sqlx::query(
831                r#"
832                INSERT INTO celers_task_results
833                    (task_id, task_name, result_state, result_data, error_message, retry_count,
834                     created_at, started_at, completed_at, worker)
835                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
836                ON DUPLICATE KEY UPDATE
837                    result_state = VALUES(result_state),
838                    result_data = VALUES(result_data),
839                    error_message = VALUES(error_message),
840                    retry_count = VALUES(retry_count),
841                    started_at = VALUES(started_at),
842                    completed_at = VALUES(completed_at),
843                    worker = VALUES(worker)
844                "#,
845            )
846            .bind(task_id)
847            .bind(&meta.task_name)
848            .bind(result_state)
849            .bind(result_data)
850            .bind(error_message)
851            .bind(retry_count)
852            .bind(meta.created_at)
853            .bind(meta.started_at)
854            .bind(meta.completed_at)
855            .bind(&meta.worker)
856            .execute(&mut *tx)
857            .await
858            .map_err(|e| BackendError::Connection(format!("Failed to store result: {}", e)))?;
859        }
860
861        tx.commit().await.map_err(|e| {
862            BackendError::Connection(format!("Failed to commit transaction: {}", e))
863        })?;
864
865        Ok(())
866    }
867
868    async fn get_results_batch(&mut self, task_ids: &[Uuid]) -> Result<Vec<Option<TaskMeta>>> {
869        if task_ids.is_empty() {
870            return Ok(Vec::new());
871        }
872
873        // MySQL requires IN clause with placeholders
874        let placeholders = task_ids.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
875        let query_str = format!(
876            r#"
877            SELECT task_id, task_name, result_state, result_data, error_message,
878                   retry_count, created_at, started_at, completed_at, worker
879            FROM celers_task_results
880            WHERE task_id IN ({})
881            "#,
882            placeholders
883        );
884
885        let mut query = sqlx::query(&query_str);
886        for task_id in task_ids {
887            query = query.bind(task_id);
888        }
889
890        let rows = query
891            .fetch_all(&self.pool)
892            .await
893            .map_err(|e| BackendError::Connection(format!("Failed to get results: {}", e)))?;
894
895        // Create a HashMap for O(1) lookup
896        let mut results_map = std::collections::HashMap::new();
897        for row in rows {
898            let task_id: Uuid = row.get("task_id");
899            let result_state: String = row.get("result_state");
900            let result_data: Option<serde_json::Value> = row.get("result_data");
901            let error_message: Option<String> = row.get("error_message");
902            let retry_count: Option<i32> = row.get("retry_count");
903
904            let result = match result_state.as_str() {
905                "pending" => TaskResult::Pending,
906                "started" => TaskResult::Started,
907                "success" => TaskResult::Success(result_data.unwrap_or(json!(null))),
908                "failure" => TaskResult::Failure(error_message.unwrap_or_default()),
909                "revoked" => TaskResult::Revoked,
910                "retry" => TaskResult::Retry(retry_count.unwrap_or(0) as u32),
911                _ => TaskResult::Pending,
912            };
913
914            let meta = TaskMeta {
915                task_id: row.get("task_id"),
916                task_name: row.get("task_name"),
917                result,
918                created_at: row.get("created_at"),
919                started_at: row.get("started_at"),
920                completed_at: row.get("completed_at"),
921                worker: row.get("worker"),
922                progress: None,
923                version: 0,
924                tags: Vec::new(),
925                metadata: std::collections::HashMap::new(),
926                worker_hostname: None,
927                runtime_ms: None,
928                memory_bytes: None,
929                retries: None,
930                queue: None,
931            };
932
933            results_map.insert(task_id, meta);
934        }
935
936        // Return results in the same order as input task_ids
937        Ok(task_ids
938            .iter()
939            .map(|id| results_map.get(id).cloned())
940            .collect())
941    }
942
943    async fn delete_results_batch(&mut self, task_ids: &[Uuid]) -> Result<()> {
944        if task_ids.is_empty() {
945            return Ok(());
946        }
947
948        // MySQL requires IN clause with placeholders
949        let placeholders = task_ids.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
950        let query_str = format!(
951            "DELETE FROM celers_task_results WHERE task_id IN ({})",
952            placeholders
953        );
954
955        let mut query = sqlx::query(&query_str);
956        for task_id in task_ids {
957            query = query.bind(task_id);
958        }
959
960        query
961            .execute(&self.pool)
962            .await
963            .map_err(|e| BackendError::Connection(format!("Failed to delete results: {}", e)))?;
964
965        Ok(())
966    }
967}
968
969#[cfg(test)]
970mod tests {
971    use super::*;
972
973    #[tokio::test]
974    #[ignore] // Requires PostgreSQL running
975    async fn test_postgres_backend_creation() {
976        let database_url = std::env::var("DATABASE_URL")
977            .unwrap_or_else(|_| "postgres://postgres:postgres@localhost/celers_test".to_string());
978
979        let backend = PostgresResultBackend::new(&database_url).await;
980        assert!(backend.is_ok());
981    }
982
983    #[tokio::test]
984    #[ignore] // Requires MySQL running
985    async fn test_mysql_backend_creation() {
986        let database_url = std::env::var("MYSQL_URL")
987            .unwrap_or_else(|_| "mysql://root:password@localhost/celers_test".to_string());
988
989        let backend = MysqlResultBackend::new(&database_url).await;
990        assert!(backend.is_ok());
991    }
992}