Skip to main content

jamjet_state/
sqlite.rs

1use crate::backend::{
2    ApiToken, BackendResult, ReclaimResult, StateBackend, StateBackendError, WorkItem, WorkItemId,
3    WorkflowDefinition,
4};
5use crate::event::{Event, EventKind, EventSequence};
6use crate::snapshot::Snapshot;
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use jamjet_core::workflow::{ExecutionId, WorkflowExecution, WorkflowStatus};
10use sqlx::{sqlite::SqliteConnectOptions, Row, SqlitePool};
11use std::str::FromStr;
12use tracing::instrument;
13use uuid::Uuid;
14
15/// SQLite-backed state store for local development.
16///
17/// Run migrations with [`SqliteBackend::migrate`] before first use.
18pub struct SqliteBackend {
19    pool: SqlitePool,
20}
21
22impl SqliteBackend {
23    /// Connect to the SQLite database at `database_url` and return a backend.
24    /// `database_url` is a SQLx-compatible URL, e.g. `sqlite:///path/to/db.sqlite3`.
25    /// The database file is created automatically if it does not exist.
26    pub async fn connect(database_url: &str) -> Result<Self, sqlx::Error> {
27        let opts = SqliteConnectOptions::from_str(database_url)?.create_if_missing(true);
28        let pool = SqlitePool::connect_with(opts).await?;
29        Ok(Self { pool })
30    }
31
32    /// Run embedded migrations against the connected database.
33    pub async fn migrate(&self) -> Result<(), sqlx::migrate::MigrateError> {
34        sqlx::migrate!("./migrations").run(&self.pool).await
35    }
36
37    /// Convenience: connect and immediately run migrations.
38    pub async fn open(
39        database_url: &str,
40    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
41        let backend = Self::connect(database_url).await?;
42        backend.migrate().await?;
43        Ok(backend)
44    }
45
46    /// Create a tenant-scoped view of this backend.
47    ///
48    /// All operations on the returned backend are filtered by `tenant_id`,
49    /// ensuring complete data isolation between tenants.
50    pub fn for_tenant(
51        &self,
52        tenant_id: crate::tenant::TenantId,
53    ) -> crate::tenant_scoped::TenantScopedSqliteBackend {
54        crate::tenant_scoped::TenantScopedSqliteBackend::new(self.pool.clone(), tenant_id)
55    }
56
57    /// Get a clone of the underlying connection pool.
58    pub fn pool(&self) -> SqlitePool {
59        self.pool.clone()
60    }
61}
62
63// ── Helpers ───────────────────────────────────────────────────────────────────
64
65pub(crate) fn map_db_err(e: sqlx::Error) -> StateBackendError {
66    StateBackendError::Database(e.to_string())
67}
68
69fn execution_id_str(id: &ExecutionId) -> String {
70    id.0.to_string()
71}
72
73pub(crate) fn parse_execution_id(s: &str) -> BackendResult<ExecutionId> {
74    let uuid = Uuid::parse_str(s)
75        .map_err(|e| StateBackendError::Database(format!("invalid execution_id: {e}")))?;
76    Ok(ExecutionId(uuid))
77}
78
79pub(crate) fn parse_datetime(s: &str) -> BackendResult<DateTime<Utc>> {
80    DateTime::parse_from_rfc3339(s)
81        .map(|dt| dt.with_timezone(&Utc))
82        .map_err(|e| StateBackendError::Database(format!("invalid datetime: {e}")))
83}
84
85fn status_to_str(s: &WorkflowStatus) -> &'static str {
86    match s {
87        WorkflowStatus::Pending => "pending",
88        WorkflowStatus::Running => "running",
89        WorkflowStatus::Paused => "paused",
90        WorkflowStatus::Completed => "completed",
91        WorkflowStatus::Failed => "failed",
92        WorkflowStatus::Cancelled => "cancelled",
93        WorkflowStatus::LimitExceeded => "limit_exceeded",
94    }
95}
96
97fn str_to_status(s: &str) -> BackendResult<WorkflowStatus> {
98    match s {
99        "pending" => Ok(WorkflowStatus::Pending),
100        "running" => Ok(WorkflowStatus::Running),
101        "paused" => Ok(WorkflowStatus::Paused),
102        "completed" => Ok(WorkflowStatus::Completed),
103        "failed" => Ok(WorkflowStatus::Failed),
104        "cancelled" => Ok(WorkflowStatus::Cancelled),
105        "limit_exceeded" => Ok(WorkflowStatus::LimitExceeded),
106        other => Err(StateBackendError::Database(format!(
107            "unknown status: {other}"
108        ))),
109    }
110}
111
112fn row_to_execution(row: &sqlx::sqlite::SqliteRow) -> BackendResult<WorkflowExecution> {
113    let execution_id =
114        parse_execution_id(row.try_get::<&str, _>("execution_id").map_err(map_db_err)?)?;
115    let status = str_to_status(row.try_get::<&str, _>("status").map_err(map_db_err)?)?;
116    let initial_input: serde_json::Value = serde_json::from_str(
117        row.try_get::<&str, _>("initial_input")
118            .map_err(map_db_err)?,
119    )
120    .map_err(StateBackendError::Serialization)?;
121    let current_state: serde_json::Value = serde_json::from_str(
122        row.try_get::<&str, _>("current_state")
123            .map_err(map_db_err)?,
124    )
125    .map_err(StateBackendError::Serialization)?;
126    let started_at = parse_datetime(row.try_get::<&str, _>("started_at").map_err(map_db_err)?)?;
127    let updated_at = parse_datetime(row.try_get::<&str, _>("updated_at").map_err(map_db_err)?)?;
128    let completed_at: Option<DateTime<Utc>> = row
129        .try_get::<Option<&str>, _>("completed_at")
130        .map_err(map_db_err)?
131        .map(parse_datetime)
132        .transpose()?;
133
134    Ok(WorkflowExecution {
135        execution_id,
136        workflow_id: row
137            .try_get::<String, _>("workflow_id")
138            .map_err(map_db_err)?,
139        workflow_version: row
140            .try_get::<String, _>("workflow_version")
141            .map_err(map_db_err)?,
142        status,
143        initial_input,
144        current_state,
145        started_at,
146        updated_at,
147        completed_at,
148        session_type: None,
149    })
150}
151
152fn row_to_event(row: &sqlx::sqlite::SqliteRow) -> BackendResult<Event> {
153    let id = Uuid::parse_str(row.try_get::<&str, _>("id").map_err(map_db_err)?)
154        .map_err(|e| StateBackendError::Database(e.to_string()))?;
155    let execution_id =
156        parse_execution_id(row.try_get::<&str, _>("execution_id").map_err(map_db_err)?)?;
157    let sequence: i64 = row.try_get("sequence").map_err(map_db_err)?;
158    let kind: EventKind =
159        serde_json::from_str(row.try_get::<&str, _>("kind_json").map_err(map_db_err)?)
160            .map_err(StateBackendError::Serialization)?;
161    let created_at = parse_datetime(row.try_get::<&str, _>("created_at").map_err(map_db_err)?)?;
162
163    Ok(Event {
164        id,
165        execution_id,
166        sequence,
167        kind,
168        created_at,
169    })
170}
171
172fn row_to_work_item(row: &sqlx::sqlite::SqliteRow) -> BackendResult<WorkItem> {
173    let id = Uuid::parse_str(row.try_get::<&str, _>("id").map_err(map_db_err)?)
174        .map_err(|e| StateBackendError::Database(e.to_string()))?;
175    let execution_id =
176        parse_execution_id(row.try_get::<&str, _>("execution_id").map_err(map_db_err)?)?;
177    let payload: serde_json::Value =
178        serde_json::from_str(row.try_get::<&str, _>("payload_json").map_err(map_db_err)?)
179            .map_err(StateBackendError::Serialization)?;
180    let lease_expires_at: Option<DateTime<Utc>> = row
181        .try_get::<Option<&str>, _>("lease_expires_at")
182        .map_err(map_db_err)?
183        .map(parse_datetime)
184        .transpose()?;
185    let created_at = parse_datetime(row.try_get::<&str, _>("created_at").map_err(map_db_err)?)?;
186    let attempt: i64 = row.try_get("attempt").map_err(map_db_err)?;
187
188    let max_attempts: i64 = row.try_get("max_attempts").unwrap_or(3);
189
190    let tenant_id: String = row
191        .try_get("tenant_id")
192        .unwrap_or_else(|_| crate::tenant::DEFAULT_TENANT.to_string());
193
194    Ok(WorkItem {
195        id,
196        execution_id,
197        node_id: row.try_get::<String, _>("node_id").map_err(map_db_err)?,
198        queue_type: row.try_get::<String, _>("queue_type").map_err(map_db_err)?,
199        payload,
200        attempt: attempt as u32,
201        max_attempts: max_attempts as u32,
202        created_at,
203        lease_expires_at,
204        worker_id: row
205            .try_get::<Option<String>, _>("worker_id")
206            .map_err(map_db_err)?,
207        tenant_id,
208    })
209}
210
211// ── StateBackend impl ─────────────────────────────────────────────────────────
212
213#[async_trait]
214impl StateBackend for SqliteBackend {
215    // ── Workflow definitions ──────────────────────────────────────────────
216
217    #[instrument(skip(self, def), fields(workflow_id = %def.workflow_id, version = %def.version))]
218    async fn store_workflow(&self, def: WorkflowDefinition) -> BackendResult<()> {
219        let ir_json = serde_json::to_string(&def.ir)?;
220        let created_at = def.created_at.to_rfc3339();
221
222        sqlx::query(
223            r#"INSERT OR REPLACE INTO workflow_definitions (workflow_id, version, ir_json, created_at, tenant_id)
224               VALUES (?, ?, ?, ?, ?)"#,
225        )
226        .bind(&def.workflow_id)
227        .bind(&def.version)
228        .bind(&ir_json)
229        .bind(&created_at)
230        .bind(&def.tenant_id)
231        .execute(&self.pool)
232        .await
233        .map_err(map_db_err)?;
234
235        Ok(())
236    }
237
238    #[instrument(skip(self), fields(workflow_id = workflow_id, version = version))]
239    async fn get_workflow(
240        &self,
241        workflow_id: &str,
242        version: &str,
243    ) -> BackendResult<Option<WorkflowDefinition>> {
244        let row =
245            sqlx::query("SELECT * FROM workflow_definitions WHERE workflow_id = ? AND version = ?")
246                .bind(workflow_id)
247                .bind(version)
248                .fetch_optional(&self.pool)
249                .await
250                .map_err(map_db_err)?;
251
252        let Some(row) = row else { return Ok(None) };
253
254        let ir: serde_json::Value =
255            serde_json::from_str(row.try_get::<&str, _>("ir_json").map_err(map_db_err)?)
256                .map_err(StateBackendError::Serialization)?;
257        let created_at = parse_datetime(row.try_get::<&str, _>("created_at").map_err(map_db_err)?)?;
258
259        let tenant_id: String = row
260            .try_get("tenant_id")
261            .unwrap_or_else(|_| crate::tenant::DEFAULT_TENANT.to_string());
262
263        Ok(Some(WorkflowDefinition {
264            workflow_id: row
265                .try_get::<String, _>("workflow_id")
266                .map_err(map_db_err)?,
267            version: row.try_get::<String, _>("version").map_err(map_db_err)?,
268            ir,
269            created_at,
270            tenant_id,
271        }))
272    }
273
274    // ── Executions ────────────────────────────────────────────────────────
275
276    #[instrument(skip(self, execution), fields(execution_id = %execution.execution_id))]
277    async fn create_execution(&self, execution: WorkflowExecution) -> BackendResult<()> {
278        let id = execution_id_str(&execution.execution_id);
279        let status = status_to_str(&execution.status);
280        let initial_input = serde_json::to_string(&execution.initial_input)?;
281        let current_state = serde_json::to_string(&execution.current_state)?;
282        let started_at = execution.started_at.to_rfc3339();
283        let updated_at = execution.updated_at.to_rfc3339();
284        let completed_at = execution.completed_at.map(|dt| dt.to_rfc3339());
285
286        sqlx::query(
287            r#"INSERT INTO workflow_executions
288               (execution_id, workflow_id, workflow_version, status, initial_input, current_state,
289                started_at, updated_at, completed_at)
290               VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
291        )
292        .bind(&id)
293        .bind(&execution.workflow_id)
294        .bind(&execution.workflow_version)
295        .bind(status)
296        .bind(&initial_input)
297        .bind(&current_state)
298        .bind(&started_at)
299        .bind(&updated_at)
300        .bind(completed_at.as_deref())
301        .execute(&self.pool)
302        .await
303        .map_err(map_db_err)?;
304
305        Ok(())
306    }
307
308    #[instrument(skip(self), fields(execution_id = %id))]
309    async fn get_execution(&self, id: &ExecutionId) -> BackendResult<Option<WorkflowExecution>> {
310        let id_str = execution_id_str(id);
311        let row = sqlx::query("SELECT * FROM workflow_executions WHERE execution_id = ?")
312            .bind(&id_str)
313            .fetch_optional(&self.pool)
314            .await
315            .map_err(map_db_err)?;
316
317        row.map(|r| row_to_execution(&r)).transpose()
318    }
319
320    #[instrument(skip(self), fields(execution_id = %id, status = ?status))]
321    async fn update_execution_status(
322        &self,
323        id: &ExecutionId,
324        status: WorkflowStatus,
325    ) -> BackendResult<()> {
326        let id_str = execution_id_str(id);
327        let status_str = status_to_str(&status);
328        let now = Utc::now().to_rfc3339();
329        let completed_at = if status.is_terminal() {
330            Some(now.clone())
331        } else {
332            None
333        };
334
335        let rows_affected = sqlx::query(
336            "UPDATE workflow_executions SET status = ?, updated_at = ?, completed_at = COALESCE(?, completed_at) WHERE execution_id = ?",
337        )
338        .bind(status_str)
339        .bind(&now)
340        .bind(completed_at.as_deref())
341        .bind(&id_str)
342        .execute(&self.pool)
343        .await
344        .map_err(map_db_err)?
345        .rows_affected();
346
347        if rows_affected == 0 {
348            return Err(StateBackendError::NotFound(id_str));
349        }
350        Ok(())
351    }
352
353    async fn update_execution_current_state(
354        &self,
355        id: &ExecutionId,
356        current_state: &serde_json::Value,
357    ) -> BackendResult<()> {
358        let id_str = execution_id_str(id);
359        let state_str =
360            serde_json::to_string(current_state).map_err(StateBackendError::Serialization)?;
361        let now = Utc::now().to_rfc3339();
362        sqlx::query(
363            "UPDATE workflow_executions SET current_state = ?, updated_at = ? WHERE execution_id = ?",
364        )
365        .bind(&state_str)
366        .bind(&now)
367        .bind(&id_str)
368        .execute(&self.pool)
369        .await
370        .map_err(map_db_err)?;
371        Ok(())
372    }
373
374    async fn patch_append_array(
375        &self,
376        execution_id: &ExecutionId,
377        key: &str,
378        value: serde_json::Value,
379    ) -> BackendResult<()> {
380        let exec = self
381            .get_execution(execution_id)
382            .await?
383            .ok_or_else(|| StateBackendError::NotFound(format!("execution {execution_id}")))?;
384        let mut state = exec.current_state.clone();
385        let arr = state
386            .as_object_mut()
387            .ok_or_else(|| StateBackendError::Database("state is not a JSON object".into()))?
388            .entry(key)
389            .or_insert_with(|| serde_json::json!([]));
390        arr.as_array_mut()
391            .ok_or_else(|| {
392                StateBackendError::Database(format!("{key} is not an array"))
393            })?
394            .push(value);
395        self.update_execution_current_state(execution_id, &state)
396            .await
397    }
398
399    #[instrument(skip(self))]
400    async fn list_executions(
401        &self,
402        status: Option<WorkflowStatus>,
403        limit: u32,
404        offset: u32,
405    ) -> BackendResult<Vec<WorkflowExecution>> {
406        let rows = match status {
407            Some(s) => {
408                let status_str = status_to_str(&s);
409                sqlx::query(
410                    "SELECT * FROM workflow_executions WHERE status = ? ORDER BY updated_at DESC LIMIT ? OFFSET ?",
411                )
412                .bind(status_str)
413                .bind(limit as i64)
414                .bind(offset as i64)
415                .fetch_all(&self.pool)
416                .await
417                .map_err(map_db_err)?
418            }
419            None => sqlx::query(
420                "SELECT * FROM workflow_executions ORDER BY updated_at DESC LIMIT ? OFFSET ?",
421            )
422            .bind(limit as i64)
423            .bind(offset as i64)
424            .fetch_all(&self.pool)
425            .await
426            .map_err(map_db_err)?,
427        };
428
429        rows.iter().map(row_to_execution).collect()
430    }
431
432    // ── Event log ─────────────────────────────────────────────────────────
433
434    #[instrument(skip(self, event), fields(execution_id = %event.execution_id, seq = event.sequence))]
435    async fn append_event(&self, event: Event) -> BackendResult<EventSequence> {
436        let id = event.id.to_string();
437        let execution_id = execution_id_str(&event.execution_id);
438        let kind_json = serde_json::to_string(&event.kind)?;
439        let created_at = event.created_at.to_rfc3339();
440
441        sqlx::query(
442            r#"INSERT INTO events (id, execution_id, sequence, kind_json, created_at)
443               VALUES (?, ?, ?, ?, ?)"#,
444        )
445        .bind(&id)
446        .bind(&execution_id)
447        .bind(event.sequence)
448        .bind(&kind_json)
449        .bind(&created_at)
450        .execute(&self.pool)
451        .await
452        .map_err(map_db_err)?;
453
454        Ok(event.sequence)
455    }
456
457    #[instrument(skip(self), fields(execution_id = %execution_id))]
458    async fn get_events(&self, execution_id: &ExecutionId) -> BackendResult<Vec<Event>> {
459        let id_str = execution_id_str(execution_id);
460        let rows = sqlx::query("SELECT * FROM events WHERE execution_id = ? ORDER BY sequence ASC")
461            .bind(&id_str)
462            .fetch_all(&self.pool)
463            .await
464            .map_err(map_db_err)?;
465
466        rows.iter().map(row_to_event).collect()
467    }
468
469    #[instrument(skip(self), fields(execution_id = %execution_id, since = since_sequence))]
470    async fn get_events_since(
471        &self,
472        execution_id: &ExecutionId,
473        since_sequence: EventSequence,
474    ) -> BackendResult<Vec<Event>> {
475        let id_str = execution_id_str(execution_id);
476        let rows = sqlx::query(
477            "SELECT * FROM events WHERE execution_id = ? AND sequence > ? ORDER BY sequence ASC",
478        )
479        .bind(&id_str)
480        .bind(since_sequence)
481        .fetch_all(&self.pool)
482        .await
483        .map_err(map_db_err)?;
484
485        rows.iter().map(row_to_event).collect()
486    }
487
488    #[instrument(skip(self), fields(execution_id = %execution_id))]
489    async fn latest_sequence(&self, execution_id: &ExecutionId) -> BackendResult<EventSequence> {
490        let id_str = execution_id_str(execution_id);
491        let row = sqlx::query(
492            "SELECT COALESCE(MAX(sequence), 0) as seq FROM events WHERE execution_id = ?",
493        )
494        .bind(&id_str)
495        .fetch_one(&self.pool)
496        .await
497        .map_err(map_db_err)?;
498
499        Ok(row.try_get::<i64, _>("seq").map_err(map_db_err)?)
500    }
501
502    // ── Snapshots ─────────────────────────────────────────────────────────
503
504    #[instrument(skip(self, snapshot), fields(execution_id = %snapshot.execution_id, at_seq = snapshot.at_sequence))]
505    async fn write_snapshot(&self, snapshot: Snapshot) -> BackendResult<()> {
506        let id = snapshot.id.to_string();
507        let execution_id = execution_id_str(&snapshot.execution_id);
508        let state_json = serde_json::to_string(&snapshot.state)?;
509        let created_at = snapshot.created_at.to_rfc3339();
510
511        sqlx::query(
512            r#"INSERT OR REPLACE INTO snapshots (id, execution_id, at_sequence, state_json, created_at)
513               VALUES (?, ?, ?, ?, ?)"#,
514        )
515        .bind(&id)
516        .bind(&execution_id)
517        .bind(snapshot.at_sequence)
518        .bind(&state_json)
519        .bind(&created_at)
520        .execute(&self.pool)
521        .await
522        .map_err(map_db_err)?;
523
524        Ok(())
525    }
526
527    #[instrument(skip(self), fields(execution_id = %execution_id))]
528    async fn latest_snapshot(&self, execution_id: &ExecutionId) -> BackendResult<Option<Snapshot>> {
529        let id_str = execution_id_str(execution_id);
530        let row = sqlx::query(
531            "SELECT * FROM snapshots WHERE execution_id = ? ORDER BY at_sequence DESC LIMIT 1",
532        )
533        .bind(&id_str)
534        .fetch_optional(&self.pool)
535        .await
536        .map_err(map_db_err)?;
537
538        let Some(row) = row else { return Ok(None) };
539
540        let id = Uuid::parse_str(row.try_get::<&str, _>("id").map_err(map_db_err)?)
541            .map_err(|e| StateBackendError::Database(e.to_string()))?;
542        let execution_id =
543            parse_execution_id(row.try_get::<&str, _>("execution_id").map_err(map_db_err)?)?;
544        let at_sequence: i64 = row.try_get("at_sequence").map_err(map_db_err)?;
545        let state: serde_json::Value =
546            serde_json::from_str(row.try_get::<&str, _>("state_json").map_err(map_db_err)?)
547                .map_err(StateBackendError::Serialization)?;
548        let created_at = parse_datetime(row.try_get::<&str, _>("created_at").map_err(map_db_err)?)?;
549
550        Ok(Some(Snapshot {
551            id,
552            execution_id,
553            at_sequence,
554            state,
555            created_at,
556        }))
557    }
558
559    // ── Work item queue ───────────────────────────────────────────────────
560
561    #[instrument(skip(self, item), fields(execution_id = %item.execution_id, node_id = %item.node_id))]
562    async fn enqueue_work_item(&self, item: WorkItem) -> BackendResult<WorkItemId> {
563        let id = item.id.to_string();
564        let execution_id = execution_id_str(&item.execution_id);
565        let payload_json = serde_json::to_string(&item.payload)?;
566        let created_at = item.created_at.to_rfc3339();
567
568        sqlx::query(
569            r#"INSERT INTO work_items
570               (id, execution_id, node_id, queue_type, payload_json, attempt, max_attempts, status, created_at)
571               VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)"#,
572        )
573        .bind(&id)
574        .bind(&execution_id)
575        .bind(&item.node_id)
576        .bind(&item.queue_type)
577        .bind(&payload_json)
578        .bind(item.attempt as i64)
579        .bind(item.max_attempts as i64)
580        .bind(&created_at)
581        .execute(&self.pool)
582        .await
583        .map_err(map_db_err)?;
584
585        Ok(item.id)
586    }
587
588    #[instrument(skip(self), fields(worker_id = worker_id))]
589    async fn claim_work_item(
590        &self,
591        worker_id: &str,
592        queue_types: &[&str],
593    ) -> BackendResult<Option<WorkItem>> {
594        if queue_types.is_empty() {
595            return Ok(None);
596        }
597
598        // Expire stale leases first
599        let now = Utc::now().to_rfc3339();
600        sqlx::query(
601            "UPDATE work_items SET status = 'pending', worker_id = NULL, lease_expires_at = NULL \
602             WHERE status = 'claimed' AND lease_expires_at < ?",
603        )
604        .bind(&now)
605        .execute(&self.pool)
606        .await
607        .map_err(map_db_err)?;
608
609        // SQLite doesn't support UPDATE ... RETURNING with a subquery easily, so
610        // we use a transaction: SELECT FOR UPDATE equivalent via exclusive transaction.
611        let mut tx = self.pool.begin().await.map_err(map_db_err)?;
612
613        // Build placeholders for queue_types IN clause
614        let placeholders = queue_types
615            .iter()
616            .map(|_| "?")
617            .collect::<Vec<_>>()
618            .join(",");
619        let query_str = format!(
620            "SELECT * FROM work_items WHERE status = 'pending' AND queue_type IN ({}) \
621             AND (retry_after IS NULL OR retry_after <= ?) ORDER BY created_at ASC LIMIT 1",
622            placeholders
623        );
624        let mut q = sqlx::query(&query_str);
625        for qt in queue_types {
626            q = q.bind(*qt);
627        }
628        q = q.bind(&now); // for retry_after <= now check
629        let row = q.fetch_optional(&mut *tx).await.map_err(map_db_err)?;
630
631        let Some(row) = row else {
632            tx.rollback().await.map_err(map_db_err)?;
633            return Ok(None);
634        };
635
636        let item = row_to_work_item(&row)?;
637        let item_id = item.id.to_string();
638        let lease_expires_at = (Utc::now() + chrono::Duration::seconds(30)).to_rfc3339();
639        let claimed_at = Utc::now().to_rfc3339();
640
641        sqlx::query(
642            "UPDATE work_items SET status = 'claimed', worker_id = ?, lease_expires_at = ?, claimed_at = ? WHERE id = ?",
643        )
644        .bind(worker_id)
645        .bind(&lease_expires_at)
646        .bind(&claimed_at)
647        .bind(&item_id)
648        .execute(&mut *tx)
649        .await
650        .map_err(map_db_err)?;
651
652        tx.commit().await.map_err(map_db_err)?;
653
654        // Return item with updated fields
655        let mut claimed = item;
656        claimed.worker_id = Some(worker_id.to_string());
657        claimed.lease_expires_at = Some(
658            DateTime::parse_from_rfc3339(&lease_expires_at)
659                .map(|dt| dt.with_timezone(&Utc))
660                .map_err(|e| StateBackendError::Database(e.to_string()))?,
661        );
662        Ok(Some(claimed))
663    }
664
665    #[instrument(skip(self), fields(item_id = %item_id, worker_id = worker_id))]
666    async fn renew_lease(&self, item_id: WorkItemId, worker_id: &str) -> BackendResult<()> {
667        let lease_expires_at = (Utc::now() + chrono::Duration::seconds(30)).to_rfc3339();
668        let id_str = item_id.to_string();
669
670        let rows_affected = sqlx::query(
671            "UPDATE work_items SET lease_expires_at = ? WHERE id = ? AND worker_id = ? AND status = 'claimed'",
672        )
673        .bind(&lease_expires_at)
674        .bind(&id_str)
675        .bind(worker_id)
676        .execute(&self.pool)
677        .await
678        .map_err(map_db_err)?
679        .rows_affected();
680
681        if rows_affected == 0 {
682            return Err(StateBackendError::NotFound(id_str));
683        }
684        Ok(())
685    }
686
687    #[instrument(skip(self), fields(item_id = %item_id))]
688    async fn complete_work_item(&self, item_id: WorkItemId) -> BackendResult<()> {
689        let id_str = item_id.to_string();
690        let completed_at = Utc::now().to_rfc3339();
691
692        let rows_affected = sqlx::query(
693            "UPDATE work_items SET status = 'completed', completed_at = ?, lease_expires_at = NULL WHERE id = ?",
694        )
695        .bind(&completed_at)
696        .bind(&id_str)
697        .execute(&self.pool)
698        .await
699        .map_err(map_db_err)?
700        .rows_affected();
701
702        if rows_affected == 0 {
703            return Err(StateBackendError::NotFound(id_str));
704        }
705        Ok(())
706    }
707
708    #[instrument(skip(self, error), fields(item_id = %item_id))]
709    async fn fail_work_item(&self, item_id: WorkItemId, error: &str) -> BackendResult<()> {
710        let id_str = item_id.to_string();
711        let _ = error; // logged by caller; stored in event log not here
712
713        let rows_affected = sqlx::query(
714            "UPDATE work_items SET status = 'failed', lease_expires_at = NULL, worker_id = NULL WHERE id = ?",
715        )
716        .bind(&id_str)
717        .execute(&self.pool)
718        .await
719        .map_err(map_db_err)?
720        .rows_affected();
721
722        if rows_affected == 0 {
723            return Err(StateBackendError::NotFound(id_str));
724        }
725        Ok(())
726    }
727
728    #[instrument(skip(self))]
729    async fn reclaim_expired_leases(&self) -> BackendResult<ReclaimResult> {
730        let now = Utc::now().to_rfc3339();
731
732        // Find all claimed items whose lease has expired.
733        let rows = sqlx::query(
734            "SELECT * FROM work_items WHERE status = 'claimed' AND lease_expires_at < ? ORDER BY created_at ASC",
735        )
736        .bind(&now)
737        .fetch_all(&self.pool)
738        .await
739        .map_err(map_db_err)?;
740
741        let mut result = ReclaimResult::default();
742
743        for row in &rows {
744            let item = row_to_work_item(row)?;
745            let new_attempt = item.attempt + 1;
746            let id_str = item.id.to_string();
747
748            if new_attempt >= item.max_attempts {
749                // Exhausted — move to dead-letter (caller emits the event)
750                let dead_lettered_at = Utc::now().to_rfc3339();
751                sqlx::query(
752                    r#"INSERT OR IGNORE INTO dead_letter_items
753                       (id, execution_id, node_id, queue_type, payload_json, attempt, last_error, created_at, dead_lettered_at)
754                       VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
755                )
756                .bind(&id_str)
757                .bind(execution_id_str(&item.execution_id))
758                .bind(&item.node_id)
759                .bind(&item.queue_type)
760                .bind(serde_json::to_string(&item.payload)?)
761                .bind(new_attempt as i64)
762                .bind("lease expired: worker dead")
763                .bind(item.created_at.to_rfc3339())
764                .bind(dead_lettered_at)
765                .execute(&self.pool)
766                .await
767                .map_err(map_db_err)?;
768
769                sqlx::query("UPDATE work_items SET status = 'dead_lettered', attempt = ?, lease_expires_at = NULL, worker_id = NULL WHERE id = ?")
770                    .bind(new_attempt as i64)
771                    .bind(&id_str)
772                    .execute(&self.pool)
773                    .await
774                    .map_err(map_db_err)?;
775
776                let mut exhausted_item = item;
777                exhausted_item.attempt = new_attempt;
778                result.exhausted.push(exhausted_item);
779            } else {
780                // Retryable — reset to pending with incremented attempt.
781                // Apply exponential backoff: 2^attempt seconds.
782                let backoff_secs = 1u64 << new_attempt.min(6); // max 64s
783                let retry_after =
784                    (Utc::now() + chrono::Duration::seconds(backoff_secs as i64)).to_rfc3339();
785
786                sqlx::query(
787                    "UPDATE work_items SET status = 'pending', attempt = ?, worker_id = NULL, lease_expires_at = NULL, retry_after = ? WHERE id = ?",
788                )
789                .bind(new_attempt as i64)
790                .bind(&retry_after)
791                .bind(&id_str)
792                .execute(&self.pool)
793                .await
794                .map_err(map_db_err)?;
795
796                let mut retry_item = item;
797                retry_item.attempt = new_attempt;
798                result.retryable.push(retry_item);
799            }
800        }
801
802        Ok(result)
803    }
804
805    #[instrument(skip(self, last_error), fields(item_id = %item_id))]
806    async fn move_to_dead_letter(
807        &self,
808        item_id: WorkItemId,
809        last_error: &str,
810    ) -> BackendResult<()> {
811        let id_str = item_id.to_string();
812
813        // Load the item first to copy fields.
814        let row = sqlx::query("SELECT * FROM work_items WHERE id = ?")
815            .bind(&id_str)
816            .fetch_optional(&self.pool)
817            .await
818            .map_err(map_db_err)?;
819
820        let Some(row) = row else {
821            return Err(StateBackendError::NotFound(id_str));
822        };
823        let item = row_to_work_item(&row)?;
824        let dead_lettered_at = Utc::now().to_rfc3339();
825
826        sqlx::query(
827            r#"INSERT OR REPLACE INTO dead_letter_items
828               (id, execution_id, node_id, queue_type, payload_json, attempt, last_error, created_at, dead_lettered_at)
829               VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
830        )
831        .bind(&id_str)
832        .bind(execution_id_str(&item.execution_id))
833        .bind(&item.node_id)
834        .bind(&item.queue_type)
835        .bind(serde_json::to_string(&item.payload)?)
836        .bind(item.attempt as i64)
837        .bind(last_error)
838        .bind(item.created_at.to_rfc3339())
839        .bind(dead_lettered_at)
840        .execute(&self.pool)
841        .await
842        .map_err(map_db_err)?;
843
844        sqlx::query("UPDATE work_items SET status = 'dead_lettered', lease_expires_at = NULL, worker_id = NULL WHERE id = ?")
845            .bind(&id_str)
846            .execute(&self.pool)
847            .await
848            .map_err(map_db_err)?;
849
850        Ok(())
851    }
852
853    async fn create_token(&self, name: &str, role: &str) -> BackendResult<(String, ApiToken)> {
854        use rand::Rng;
855        use sha2::{Digest, Sha256};
856
857        // Generate a random 32-byte token, hex-encoded with a prefix.
858        let random_bytes: [u8; 32] = rand::thread_rng().gen();
859        let token = format!(
860            "jj_{}",
861            random_bytes
862                .iter()
863                .map(|b| format!("{b:02x}"))
864                .collect::<String>()
865        );
866        let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
867
868        let id = uuid::Uuid::new_v4().to_string();
869        let now = Utc::now().to_rfc3339();
870
871        sqlx::query(
872            r#"INSERT INTO api_tokens (id, token_hash, name, role, created_at)
873               VALUES (?, ?, ?, ?, ?)"#,
874        )
875        .bind(&id)
876        .bind(&token_hash)
877        .bind(name)
878        .bind(role)
879        .bind(&now)
880        .execute(&self.pool)
881        .await
882        .map_err(map_db_err)?;
883
884        let info = ApiToken {
885            id,
886            name: name.to_string(),
887            role: role.to_string(),
888            created_at: Utc::now(),
889            expires_at: None,
890            tenant_id: crate::tenant::DEFAULT_TENANT.to_string(),
891        };
892        Ok((token, info))
893    }
894
895    async fn validate_token(&self, token: &str) -> BackendResult<Option<ApiToken>> {
896        use sha2::{Digest, Sha256};
897
898        let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
899        let now = Utc::now().to_rfc3339();
900
901        let row = sqlx::query(
902            r#"SELECT id, name, role, created_at, expires_at
903               FROM api_tokens
904               WHERE token_hash = ?
905                 AND revoked_at IS NULL
906                 AND (expires_at IS NULL OR expires_at > ?)"#,
907        )
908        .bind(&token_hash)
909        .bind(&now)
910        .fetch_optional(&self.pool)
911        .await
912        .map_err(map_db_err)?;
913
914        let Some(row) = row else { return Ok(None) };
915
916        // Update last_used_at.
917        let id: String = row.get("id");
918        sqlx::query("UPDATE api_tokens SET last_used_at = ? WHERE id = ?")
919            .bind(&now)
920            .bind(&id)
921            .execute(&self.pool)
922            .await
923            .map_err(map_db_err)?;
924
925        let expires_at: Option<String> = row.get("expires_at");
926        let tenant_id: String = row
927            .try_get("tenant_id")
928            .unwrap_or_else(|_| crate::tenant::DEFAULT_TENANT.to_string());
929        let info = ApiToken {
930            id,
931            name: row.get("name"),
932            role: row.get("role"),
933            created_at: row
934                .get::<String, _>("created_at")
935                .parse::<chrono::DateTime<Utc>>()
936                .unwrap_or_else(|_| Utc::now()),
937            expires_at: expires_at.and_then(|s| s.parse().ok()),
938            tenant_id,
939        };
940        Ok(Some(info))
941    }
942}
943
944#[cfg(test)]
945mod tests {
946    use super::*;
947    use chrono::Utc;
948    use jamjet_core::workflow::{WorkflowExecution, WorkflowStatus};
949    use serde_json::json;
950
951    async fn open_test_db() -> SqliteBackend {
952        let backend = SqliteBackend::open("sqlite::memory:")
953            .await
954            .expect("failed to open in-memory SQLite");
955        backend
956    }
957
958    fn sample_execution() -> WorkflowExecution {
959        let now = Utc::now();
960        WorkflowExecution {
961            execution_id: ExecutionId::new(),
962            workflow_id: "test-wf".to_string(),
963            workflow_version: "1.0.0".to_string(),
964            status: WorkflowStatus::Pending,
965            initial_input: json!({"x": 1}),
966            current_state: json!({}),
967            started_at: now,
968            updated_at: now,
969            completed_at: None,
970            session_type: None,
971        }
972    }
973
974    #[tokio::test]
975    async fn test_create_and_get_execution() {
976        let db = open_test_db().await;
977        let exec = sample_execution();
978        let id = exec.execution_id.clone();
979        db.create_execution(exec).await.unwrap();
980        let fetched = db.get_execution(&id).await.unwrap().unwrap();
981        assert_eq!(fetched.workflow_id, "test-wf");
982        assert_eq!(fetched.status, WorkflowStatus::Pending);
983    }
984
985    #[tokio::test]
986    async fn test_update_status() {
987        let db = open_test_db().await;
988        let exec = sample_execution();
989        let id = exec.execution_id.clone();
990        db.create_execution(exec).await.unwrap();
991        db.update_execution_status(&id, WorkflowStatus::Running)
992            .await
993            .unwrap();
994        let fetched = db.get_execution(&id).await.unwrap().unwrap();
995        assert_eq!(fetched.status, WorkflowStatus::Running);
996    }
997
998    #[tokio::test]
999    async fn test_list_executions() {
1000        let db = open_test_db().await;
1001        db.create_execution(sample_execution()).await.unwrap();
1002        db.create_execution(sample_execution()).await.unwrap();
1003        let all = db.list_executions(None, 10, 0).await.unwrap();
1004        assert_eq!(all.len(), 2);
1005    }
1006
1007    #[tokio::test]
1008    async fn test_event_log() {
1009        use crate::event::{Event, EventKind};
1010        let db = open_test_db().await;
1011        let exec = sample_execution();
1012        let exec_id = exec.execution_id.clone();
1013        db.create_execution(exec).await.unwrap();
1014
1015        let event = Event::new(
1016            exec_id.clone(),
1017            1,
1018            EventKind::WorkflowStarted {
1019                workflow_id: "test-wf".to_string(),
1020                workflow_version: "1.0.0".to_string(),
1021                initial_input: json!({"x": 1}),
1022            },
1023        );
1024        db.append_event(event).await.unwrap();
1025
1026        let events = db.get_events(&exec_id).await.unwrap();
1027        assert_eq!(events.len(), 1);
1028        assert_eq!(events[0].sequence, 1);
1029
1030        let seq = db.latest_sequence(&exec_id).await.unwrap();
1031        assert_eq!(seq, 1);
1032    }
1033
1034    #[tokio::test]
1035    async fn test_snapshot() {
1036        use crate::snapshot::Snapshot;
1037        let db = open_test_db().await;
1038        let exec = sample_execution();
1039        let exec_id = exec.execution_id.clone();
1040        db.create_execution(exec).await.unwrap();
1041
1042        let snap = Snapshot::new(exec_id.clone(), 5, json!({"nodes_completed": ["a", "b"]}));
1043        db.write_snapshot(snap).await.unwrap();
1044
1045        let loaded = db.latest_snapshot(&exec_id).await.unwrap().unwrap();
1046        assert_eq!(loaded.at_sequence, 5);
1047    }
1048
1049    #[tokio::test]
1050    async fn test_workflow_definition() {
1051        use crate::backend::WorkflowDefinition;
1052        let db = open_test_db().await;
1053
1054        let def = WorkflowDefinition {
1055            workflow_id: "my-wf".to_string(),
1056            version: "1.0.0".to_string(),
1057            ir: json!({"workflow_id": "my-wf", "version": "1.0.0", "nodes": {}}),
1058            created_at: Utc::now(),
1059            tenant_id: crate::tenant::DEFAULT_TENANT.to_string(),
1060        };
1061        db.store_workflow(def).await.unwrap();
1062
1063        let loaded = db.get_workflow("my-wf", "1.0.0").await.unwrap().unwrap();
1064        assert_eq!(loaded.workflow_id, "my-wf");
1065        assert_eq!(loaded.version, "1.0.0");
1066
1067        // Non-existent version returns None
1068        let missing = db.get_workflow("my-wf", "2.0.0").await.unwrap();
1069        assert!(missing.is_none());
1070    }
1071
1072    #[tokio::test]
1073    async fn test_work_item_queue() {
1074        let db = open_test_db().await;
1075        let exec = sample_execution();
1076        let exec_id = exec.execution_id.clone();
1077        db.create_execution(exec).await.unwrap();
1078
1079        let item = WorkItem {
1080            id: Uuid::new_v4(),
1081            execution_id: exec_id.clone(),
1082            node_id: "node-1".to_string(),
1083            queue_type: "default".to_string(),
1084            payload: json!({}),
1085            attempt: 0,
1086            max_attempts: 3,
1087            created_at: Utc::now(),
1088            lease_expires_at: None,
1089            worker_id: None,
1090            tenant_id: crate::tenant::DEFAULT_TENANT.to_string(),
1091        };
1092        let item_id = item.id;
1093        db.enqueue_work_item(item).await.unwrap();
1094
1095        let claimed = db
1096            .claim_work_item("worker-1", &["default"])
1097            .await
1098            .unwrap()
1099            .unwrap();
1100        assert_eq!(claimed.node_id, "node-1");
1101        assert_eq!(claimed.worker_id.as_deref(), Some("worker-1"));
1102
1103        db.complete_work_item(item_id).await.unwrap();
1104
1105        // No more items
1106        let none = db.claim_work_item("worker-1", &["default"]).await.unwrap();
1107        assert!(none.is_none());
1108    }
1109
1110    #[tokio::test]
1111    async fn test_patch_append_array() {
1112        let db = open_test_db().await;
1113        let exec = sample_execution();
1114        let id = exec.execution_id.clone();
1115        db.create_execution(exec).await.unwrap();
1116
1117        db.patch_append_array(&id, "agent_tool_events", json!({"type": "progress", "chunk": 0}))
1118            .await
1119            .unwrap();
1120        db.patch_append_array(&id, "agent_tool_events", json!({"type": "progress", "chunk": 1}))
1121            .await
1122            .unwrap();
1123
1124        let fetched = db.get_execution(&id).await.unwrap().unwrap();
1125        let events = fetched.current_state["agent_tool_events"]
1126            .as_array()
1127            .unwrap();
1128        assert_eq!(events.len(), 2);
1129        assert_eq!(events[0]["chunk"], 0);
1130        assert_eq!(events[1]["chunk"], 1);
1131    }
1132}