Skip to main content

ceres_db/
job_repository.rs

1//! Job repository for PostgreSQL with SELECT FOR UPDATE SKIP LOCKED.
2//!
3//! Implements the [`JobQueue`] trait for persistent job storage with safe
4//! concurrent job claiming using PostgreSQL's row-level locking.
5
6use chrono::{DateTime, Utc};
7use sqlx::{PgPool, Pool, Postgres};
8use uuid::Uuid;
9
10use ceres_core::SyncStats;
11use ceres_core::error::AppError;
12use ceres_core::job::{CreateJobRequest, HarvestJob, JobStatus};
13use ceres_core::job_queue::JobQueue;
14
15/// PostgreSQL implementation of the job queue.
16///
17/// Uses `SELECT FOR UPDATE SKIP LOCKED` for safe concurrent job claiming,
18/// ensuring that multiple workers can process jobs without conflicts.
19#[derive(Clone)]
20pub struct JobRepository {
21    pool: Pool<Postgres>,
22}
23
24impl JobRepository {
25    /// Create a new job repository with the given connection pool.
26    pub fn new(pool: PgPool) -> Self {
27        Self { pool }
28    }
29}
30
31// =============================================================================
32// Helper Types for Database Mapping
33// =============================================================================
34
35/// Helper struct for deserializing job rows from the database.
36#[derive(sqlx::FromRow)]
37struct JobRow {
38    id: Uuid,
39    portal_url: String,
40    portal_name: Option<String>,
41    status: String,
42    created_at: DateTime<Utc>,
43    updated_at: DateTime<Utc>,
44    started_at: Option<DateTime<Utc>>,
45    completed_at: Option<DateTime<Utc>>,
46    retry_count: i32,
47    max_retries: i32,
48    next_retry_at: Option<DateTime<Utc>>,
49    error_message: Option<String>,
50    sync_stats: Option<sqlx::types::Json<SyncStatsJson>>,
51    worker_id: Option<String>,
52    force_full_sync: bool,
53}
54
55/// JSON representation of SyncStats for database storage.
56#[derive(serde::Serialize, serde::Deserialize)]
57struct SyncStatsJson {
58    unchanged: usize,
59    updated: usize,
60    created: usize,
61    failed: usize,
62    /// Number of datasets skipped due to circuit breaker.
63    /// Default to 0 for backwards compatibility with old records.
64    #[serde(default)]
65    skipped: usize,
66}
67
68impl From<&SyncStats> for SyncStatsJson {
69    fn from(stats: &SyncStats) -> Self {
70        Self {
71            unchanged: stats.unchanged,
72            updated: stats.updated,
73            created: stats.created,
74            failed: stats.failed,
75            skipped: stats.skipped,
76        }
77    }
78}
79
80impl From<SyncStatsJson> for SyncStats {
81    fn from(json: SyncStatsJson) -> Self {
82        Self {
83            unchanged: json.unchanged,
84            updated: json.updated,
85            created: json.created,
86            failed: json.failed,
87            skipped: json.skipped,
88        }
89    }
90}
91
92impl From<JobRow> for HarvestJob {
93    fn from(row: JobRow) -> Self {
94        Self {
95            id: row.id,
96            portal_url: row.portal_url,
97            portal_name: row.portal_name,
98            status: row.status.parse().unwrap_or(JobStatus::Pending),
99            created_at: row.created_at,
100            updated_at: row.updated_at,
101            started_at: row.started_at,
102            completed_at: row.completed_at,
103            retry_count: row.retry_count as u32,
104            max_retries: row.max_retries as u32,
105            next_retry_at: row.next_retry_at,
106            error_message: row.error_message,
107            sync_stats: row.sync_stats.map(|j| j.0.into()),
108            worker_id: row.worker_id,
109            force_full_sync: row.force_full_sync,
110        }
111    }
112}
113
114// =============================================================================
115// JobQueue Trait Implementation
116// =============================================================================
117
118impl JobQueue for JobRepository {
119    async fn create_job(&self, request: CreateJobRequest) -> Result<HarvestJob, AppError> {
120        let max_retries = request.max_retries.unwrap_or(3) as i32;
121
122        let row: JobRow = sqlx::query_as(
123            r#"
124            INSERT INTO harvest_jobs (portal_url, portal_name, force_full_sync, max_retries)
125            VALUES ($1, $2, $3, $4)
126            RETURNING *
127            "#,
128        )
129        .bind(&request.portal_url)
130        .bind(&request.portal_name)
131        .bind(request.force_full_sync)
132        .bind(max_retries)
133        .fetch_one(&self.pool)
134        .await?;
135
136        Ok(row.into())
137    }
138
139    async fn claim_job(&self, worker_id: &str) -> Result<Option<HarvestJob>, AppError> {
140        // Use SELECT FOR UPDATE SKIP LOCKED for safe concurrent claiming.
141        // This query claims jobs that are:
142        // 1. Pending with no retry scheduled (new jobs), OR
143        // 2. Pending with retry time passed (retry-ready jobs)
144        //
145        // Priority:
146        // - Non-retry jobs first (next_retry_at IS NULL)
147        // - Then by creation order (oldest first)
148        let row: Option<JobRow> = sqlx::query_as(
149            r#"
150            UPDATE harvest_jobs
151            SET
152                status = 'running',
153                worker_id = $1,
154                started_at = NOW(),
155                updated_at = NOW()
156            WHERE id = (
157                SELECT id FROM harvest_jobs
158                WHERE status = 'pending'
159                  AND (next_retry_at IS NULL OR next_retry_at <= NOW())
160                ORDER BY
161                    next_retry_at NULLS FIRST,
162                    created_at ASC
163                FOR UPDATE SKIP LOCKED
164                LIMIT 1
165            )
166            RETURNING *
167            "#,
168        )
169        .bind(worker_id)
170        .fetch_optional(&self.pool)
171        .await?;
172
173        Ok(row.map(Into::into))
174    }
175
176    async fn complete_job(&self, job_id: Uuid, stats: SyncStats) -> Result<(), AppError> {
177        let stats_json = serde_json::to_value(SyncStatsJson::from(&stats))
178            .map_err(AppError::SerializationError)?;
179
180        sqlx::query(
181            r#"
182            UPDATE harvest_jobs
183            SET
184                status = 'completed',
185                completed_at = NOW(),
186                updated_at = NOW(),
187                sync_stats = $2,
188                error_message = NULL,
189                worker_id = NULL
190            WHERE id = $1
191            "#,
192        )
193        .bind(job_id)
194        .bind(stats_json)
195        .execute(&self.pool)
196        .await?;
197
198        Ok(())
199    }
200
201    async fn fail_job(
202        &self,
203        job_id: Uuid,
204        error: &str,
205        next_retry_at: Option<DateTime<Utc>>,
206    ) -> Result<(), AppError> {
207        // If next_retry_at is provided, reset to pending for retry.
208        // Otherwise, mark as permanently failed.
209        let (new_status, should_increment) = if next_retry_at.is_some() {
210            ("pending", true) // Will retry
211        } else {
212            ("failed", false) // Permanently failed
213        };
214
215        sqlx::query(
216            r#"
217            UPDATE harvest_jobs
218            SET
219                status = $2,
220                error_message = $3,
221                next_retry_at = $4,
222                retry_count = CASE WHEN $5 THEN retry_count + 1 ELSE retry_count END,
223                updated_at = NOW(),
224                completed_at = CASE WHEN $2 = 'failed' THEN NOW() ELSE NULL END,
225                worker_id = NULL,
226                started_at = NULL
227            WHERE id = $1
228            "#,
229        )
230        .bind(job_id)
231        .bind(new_status)
232        .bind(error)
233        .bind(next_retry_at)
234        .bind(should_increment)
235        .execute(&self.pool)
236        .await?;
237
238        Ok(())
239    }
240
241    async fn cancel_job(&self, job_id: Uuid, stats: Option<SyncStats>) -> Result<(), AppError> {
242        let stats_json = stats
243            .as_ref()
244            .map(|s| serde_json::to_value(SyncStatsJson::from(s)))
245            .transpose()
246            .map_err(AppError::SerializationError)?;
247
248        sqlx::query(
249            r#"
250            UPDATE harvest_jobs
251            SET
252                status = 'cancelled',
253                completed_at = NOW(),
254                updated_at = NOW(),
255                sync_stats = COALESCE($2, sync_stats),
256                worker_id = NULL
257            WHERE id = $1
258            "#,
259        )
260        .bind(job_id)
261        .bind(stats_json)
262        .execute(&self.pool)
263        .await?;
264
265        Ok(())
266    }
267
268    async fn get_job(&self, job_id: Uuid) -> Result<Option<HarvestJob>, AppError> {
269        let row: Option<JobRow> = sqlx::query_as("SELECT * FROM harvest_jobs WHERE id = $1")
270            .bind(job_id)
271            .fetch_optional(&self.pool)
272            .await?;
273
274        Ok(row.map(Into::into))
275    }
276
277    async fn list_jobs(
278        &self,
279        status: Option<JobStatus>,
280        limit: usize,
281    ) -> Result<Vec<HarvestJob>, AppError> {
282        let rows: Vec<JobRow> = if let Some(s) = status {
283            sqlx::query_as(
284                r#"
285                SELECT * FROM harvest_jobs
286                WHERE status = $1
287                ORDER BY created_at DESC
288                LIMIT $2
289                "#,
290            )
291            .bind(s.as_str())
292            .bind(limit as i64)
293            .fetch_all(&self.pool)
294            .await?
295        } else {
296            sqlx::query_as(
297                r#"
298                SELECT * FROM harvest_jobs
299                ORDER BY created_at DESC
300                LIMIT $1
301                "#,
302            )
303            .bind(limit as i64)
304            .fetch_all(&self.pool)
305            .await?
306        };
307
308        Ok(rows.into_iter().map(Into::into).collect())
309    }
310
311    async fn release_job(&self, job_id: Uuid) -> Result<(), AppError> {
312        sqlx::query(
313            r#"
314            UPDATE harvest_jobs
315            SET
316                status = 'pending',
317                worker_id = NULL,
318                started_at = NULL,
319                updated_at = NOW()
320            WHERE id = $1 AND status = 'running'
321            "#,
322        )
323        .bind(job_id)
324        .execute(&self.pool)
325        .await?;
326
327        Ok(())
328    }
329
330    async fn release_worker_jobs(&self, worker_id: &str) -> Result<u64, AppError> {
331        let result = sqlx::query(
332            r#"
333            UPDATE harvest_jobs
334            SET
335                status = 'pending',
336                worker_id = NULL,
337                started_at = NULL,
338                updated_at = NOW()
339            WHERE worker_id = $1 AND status = 'running'
340            "#,
341        )
342        .bind(worker_id)
343        .execute(&self.pool)
344        .await?;
345
346        Ok(result.rows_affected())
347    }
348
349    async fn count_by_status(&self, status: JobStatus) -> Result<i64, AppError> {
350        let (count,): (i64,) =
351            sqlx::query_as("SELECT COUNT(*) FROM harvest_jobs WHERE status = $1")
352                .bind(status.as_str())
353                .fetch_one(&self.pool)
354                .await?;
355
356        Ok(count)
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_sync_stats_json_conversion() {
366        let stats = SyncStats {
367            unchanged: 10,
368            updated: 5,
369            created: 3,
370            failed: 2,
371            skipped: 1,
372        };
373
374        let json = SyncStatsJson::from(&stats);
375        assert_eq!(json.unchanged, 10);
376        assert_eq!(json.updated, 5);
377        assert_eq!(json.created, 3);
378        assert_eq!(json.failed, 2);
379        assert_eq!(json.skipped, 1);
380
381        let back: SyncStats = json.into();
382        assert_eq!(back.unchanged, stats.unchanged);
383        assert_eq!(back.updated, stats.updated);
384        assert_eq!(back.created, stats.created);
385        assert_eq!(back.failed, stats.failed);
386        assert_eq!(back.skipped, stats.skipped);
387    }
388
389    #[test]
390    fn test_sync_stats_json_backwards_compatibility() {
391        // Test deserialization of old records without skipped field
392        let old_json = r#"{"unchanged":10,"updated":5,"created":3,"failed":2}"#;
393        let json: SyncStatsJson = serde_json::from_str(old_json).unwrap();
394        assert_eq!(json.skipped, 0); // Should default to 0
395    }
396
397    #[test]
398    fn test_job_status_parsing() {
399        assert_eq!("pending".parse::<JobStatus>(), Ok(JobStatus::Pending));
400        assert_eq!("running".parse::<JobStatus>(), Ok(JobStatus::Running));
401        assert_eq!("completed".parse::<JobStatus>(), Ok(JobStatus::Completed));
402        assert_eq!("failed".parse::<JobStatus>(), Ok(JobStatus::Failed));
403        assert_eq!("cancelled".parse::<JobStatus>(), Ok(JobStatus::Cancelled));
404        assert!("invalid".parse::<JobStatus>().is_err());
405    }
406}