Skip to main content

jamjet_state/
sqlite.rs

1use crate::backend::{
2    ApiToken, BackendResult, ReclaimResult, StateBackend, StateBackendError, WorkItem, WorkItemId,
3    WorkflowDefinition,
4};
5use crate::event::{Event, EventKind, EventSequence};
6use crate::snapshot::Snapshot;
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use jamjet_core::workflow::{ExecutionId, WorkflowExecution, WorkflowStatus};
10use sqlx::{sqlite::SqliteConnectOptions, Row, SqlitePool};
11use std::str::FromStr;
12use tracing::instrument;
13use uuid::Uuid;
14
15/// SQLite-backed state store for local development.
16///
17/// Run migrations with [`SqliteBackend::migrate`] before first use.
18pub struct SqliteBackend {
19    pool: SqlitePool,
20}
21
22impl SqliteBackend {
23    /// Connect to the SQLite database at `database_url` and return a backend.
24    /// `database_url` is a SQLx-compatible URL, e.g. `sqlite:///path/to/db.sqlite3`.
25    /// The database file is created automatically if it does not exist.
26    pub async fn connect(database_url: &str) -> Result<Self, sqlx::Error> {
27        let opts = SqliteConnectOptions::from_str(database_url)?
28            .create_if_missing(true)
29            // WAL + a busy timeout let the scheduler, worker pool, API, and audit
30            // log share one SQLite file without spurious SQLITE_BUSY errors under
31            // concurrency. (Ignored for in-memory databases.)
32            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
33            .busy_timeout(std::time::Duration::from_secs(5));
34        let pool = SqlitePool::connect_with(opts).await?;
35        Ok(Self { pool })
36    }
37
38    /// Run embedded migrations against the connected database.
39    pub async fn migrate(&self) -> Result<(), sqlx::migrate::MigrateError> {
40        sqlx::migrate!("./migrations").run(&self.pool).await
41    }
42
43    /// Convenience: connect and immediately run migrations.
44    pub async fn open(
45        database_url: &str,
46    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
47        let backend = Self::connect(database_url).await?;
48        backend.migrate().await?;
49        Ok(backend)
50    }
51
52    /// Create a tenant-scoped view of this backend.
53    ///
54    /// All operations on the returned backend are filtered by `tenant_id`,
55    /// ensuring complete data isolation between tenants.
56    pub fn for_tenant(
57        &self,
58        tenant_id: crate::tenant::TenantId,
59    ) -> crate::tenant_scoped::TenantScopedSqliteBackend {
60        crate::tenant_scoped::TenantScopedSqliteBackend::new(self.pool.clone(), tenant_id)
61    }
62
63    /// Get a clone of the underlying connection pool.
64    pub fn pool(&self) -> SqlitePool {
65        self.pool.clone()
66    }
67}
68
69// ── Helpers ───────────────────────────────────────────────────────────────────
70
71pub(crate) fn map_db_err(e: sqlx::Error) -> StateBackendError {
72    StateBackendError::Database(e.to_string())
73}
74
75fn execution_id_str(id: &ExecutionId) -> String {
76    id.0.to_string()
77}
78
79pub(crate) fn parse_execution_id(s: &str) -> BackendResult<ExecutionId> {
80    let uuid = Uuid::parse_str(s)
81        .map_err(|e| StateBackendError::Database(format!("invalid execution_id: {e}")))?;
82    Ok(ExecutionId(uuid))
83}
84
85pub(crate) fn parse_datetime(s: &str) -> BackendResult<DateTime<Utc>> {
86    DateTime::parse_from_rfc3339(s)
87        .map(|dt| dt.with_timezone(&Utc))
88        .map_err(|e| StateBackendError::Database(format!("invalid datetime: {e}")))
89}
90
91fn status_to_str(s: &WorkflowStatus) -> &'static str {
92    match s {
93        WorkflowStatus::Pending => "pending",
94        WorkflowStatus::Running => "running",
95        WorkflowStatus::Paused => "paused",
96        WorkflowStatus::Completed => "completed",
97        WorkflowStatus::Failed => "failed",
98        WorkflowStatus::Cancelled => "cancelled",
99        WorkflowStatus::LimitExceeded => "limit_exceeded",
100    }
101}
102
103fn str_to_status(s: &str) -> BackendResult<WorkflowStatus> {
104    match s {
105        "pending" => Ok(WorkflowStatus::Pending),
106        "running" => Ok(WorkflowStatus::Running),
107        "paused" => Ok(WorkflowStatus::Paused),
108        "completed" => Ok(WorkflowStatus::Completed),
109        "failed" => Ok(WorkflowStatus::Failed),
110        "cancelled" => Ok(WorkflowStatus::Cancelled),
111        "limit_exceeded" => Ok(WorkflowStatus::LimitExceeded),
112        other => Err(StateBackendError::Database(format!(
113            "unknown status: {other}"
114        ))),
115    }
116}
117
118fn row_to_execution(row: &sqlx::sqlite::SqliteRow) -> BackendResult<WorkflowExecution> {
119    let execution_id =
120        parse_execution_id(row.try_get::<&str, _>("execution_id").map_err(map_db_err)?)?;
121    let status = str_to_status(row.try_get::<&str, _>("status").map_err(map_db_err)?)?;
122    let initial_input: serde_json::Value = serde_json::from_str(
123        row.try_get::<&str, _>("initial_input")
124            .map_err(map_db_err)?,
125    )
126    .map_err(StateBackendError::Serialization)?;
127    let current_state: serde_json::Value = serde_json::from_str(
128        row.try_get::<&str, _>("current_state")
129            .map_err(map_db_err)?,
130    )
131    .map_err(StateBackendError::Serialization)?;
132    let started_at = parse_datetime(row.try_get::<&str, _>("started_at").map_err(map_db_err)?)?;
133    let updated_at = parse_datetime(row.try_get::<&str, _>("updated_at").map_err(map_db_err)?)?;
134    let completed_at: Option<DateTime<Utc>> = row
135        .try_get::<Option<&str>, _>("completed_at")
136        .map_err(map_db_err)?
137        .map(parse_datetime)
138        .transpose()?;
139
140    Ok(WorkflowExecution {
141        execution_id,
142        workflow_id: row
143            .try_get::<String, _>("workflow_id")
144            .map_err(map_db_err)?,
145        workflow_version: row
146            .try_get::<String, _>("workflow_version")
147            .map_err(map_db_err)?,
148        status,
149        initial_input,
150        current_state,
151        started_at,
152        updated_at,
153        completed_at,
154        session_type: None,
155    })
156}
157
158fn row_to_event(row: &sqlx::sqlite::SqliteRow) -> BackendResult<Event> {
159    let id = Uuid::parse_str(row.try_get::<&str, _>("id").map_err(map_db_err)?)
160        .map_err(|e| StateBackendError::Database(e.to_string()))?;
161    let execution_id =
162        parse_execution_id(row.try_get::<&str, _>("execution_id").map_err(map_db_err)?)?;
163    let sequence: i64 = row.try_get("sequence").map_err(map_db_err)?;
164    let kind: EventKind =
165        serde_json::from_str(row.try_get::<&str, _>("kind_json").map_err(map_db_err)?)
166            .map_err(StateBackendError::Serialization)?;
167    let created_at = parse_datetime(row.try_get::<&str, _>("created_at").map_err(map_db_err)?)?;
168
169    Ok(Event {
170        id,
171        execution_id,
172        sequence,
173        kind,
174        created_at,
175    })
176}
177
178fn row_to_work_item(row: &sqlx::sqlite::SqliteRow) -> BackendResult<WorkItem> {
179    let id = Uuid::parse_str(row.try_get::<&str, _>("id").map_err(map_db_err)?)
180        .map_err(|e| StateBackendError::Database(e.to_string()))?;
181    let execution_id =
182        parse_execution_id(row.try_get::<&str, _>("execution_id").map_err(map_db_err)?)?;
183    let payload: serde_json::Value =
184        serde_json::from_str(row.try_get::<&str, _>("payload_json").map_err(map_db_err)?)
185            .map_err(StateBackendError::Serialization)?;
186    let lease_expires_at: Option<DateTime<Utc>> = row
187        .try_get::<Option<&str>, _>("lease_expires_at")
188        .map_err(map_db_err)?
189        .map(parse_datetime)
190        .transpose()?;
191    let created_at = parse_datetime(row.try_get::<&str, _>("created_at").map_err(map_db_err)?)?;
192    let attempt: i64 = row.try_get("attempt").map_err(map_db_err)?;
193
194    let max_attempts: i64 = row.try_get("max_attempts").unwrap_or(3);
195
196    let tenant_id: String = row
197        .try_get("tenant_id")
198        .unwrap_or_else(|_| crate::tenant::DEFAULT_TENANT.to_string());
199
200    Ok(WorkItem {
201        id,
202        execution_id,
203        node_id: row.try_get::<String, _>("node_id").map_err(map_db_err)?,
204        queue_type: row.try_get::<String, _>("queue_type").map_err(map_db_err)?,
205        payload,
206        attempt: attempt as u32,
207        max_attempts: max_attempts as u32,
208        created_at,
209        lease_expires_at,
210        worker_id: row
211            .try_get::<Option<String>, _>("worker_id")
212            .map_err(map_db_err)?,
213        tenant_id,
214    })
215}
216
217// ── StateBackend impl ─────────────────────────────────────────────────────────
218
219#[async_trait]
220impl StateBackend for SqliteBackend {
221    // ── Workflow definitions ──────────────────────────────────────────────
222
223    #[instrument(skip(self, def), fields(workflow_id = %def.workflow_id, version = %def.version))]
224    async fn store_workflow(&self, def: WorkflowDefinition) -> BackendResult<()> {
225        let ir_json = serde_json::to_string(&def.ir)?;
226        let created_at = def.created_at.to_rfc3339();
227
228        sqlx::query(
229            r#"INSERT OR REPLACE INTO workflow_definitions (workflow_id, version, ir_json, created_at, tenant_id)
230               VALUES (?, ?, ?, ?, ?)"#,
231        )
232        .bind(&def.workflow_id)
233        .bind(&def.version)
234        .bind(&ir_json)
235        .bind(&created_at)
236        .bind(&def.tenant_id)
237        .execute(&self.pool)
238        .await
239        .map_err(map_db_err)?;
240
241        Ok(())
242    }
243
244    #[instrument(skip(self), fields(workflow_id = workflow_id, version = version))]
245    async fn get_workflow(
246        &self,
247        workflow_id: &str,
248        version: &str,
249    ) -> BackendResult<Option<WorkflowDefinition>> {
250        let row =
251            sqlx::query("SELECT * FROM workflow_definitions WHERE workflow_id = ? AND version = ?")
252                .bind(workflow_id)
253                .bind(version)
254                .fetch_optional(&self.pool)
255                .await
256                .map_err(map_db_err)?;
257
258        let Some(row) = row else { return Ok(None) };
259
260        let ir: serde_json::Value =
261            serde_json::from_str(row.try_get::<&str, _>("ir_json").map_err(map_db_err)?)
262                .map_err(StateBackendError::Serialization)?;
263        let created_at = parse_datetime(row.try_get::<&str, _>("created_at").map_err(map_db_err)?)?;
264
265        let tenant_id: String = row
266            .try_get("tenant_id")
267            .unwrap_or_else(|_| crate::tenant::DEFAULT_TENANT.to_string());
268
269        Ok(Some(WorkflowDefinition {
270            workflow_id: row
271                .try_get::<String, _>("workflow_id")
272                .map_err(map_db_err)?,
273            version: row.try_get::<String, _>("version").map_err(map_db_err)?,
274            ir,
275            created_at,
276            tenant_id,
277        }))
278    }
279
280    // ── Executions ────────────────────────────────────────────────────────
281
282    #[instrument(skip(self, execution), fields(execution_id = %execution.execution_id))]
283    async fn create_execution(&self, execution: WorkflowExecution) -> BackendResult<()> {
284        let id = execution_id_str(&execution.execution_id);
285        let status = status_to_str(&execution.status);
286        let initial_input = serde_json::to_string(&execution.initial_input)?;
287        let current_state = serde_json::to_string(&execution.current_state)?;
288        let started_at = execution.started_at.to_rfc3339();
289        let updated_at = execution.updated_at.to_rfc3339();
290        let completed_at = execution.completed_at.map(|dt| dt.to_rfc3339());
291
292        sqlx::query(
293            r#"INSERT INTO workflow_executions
294               (execution_id, workflow_id, workflow_version, status, initial_input, current_state,
295                started_at, updated_at, completed_at)
296               VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
297        )
298        .bind(&id)
299        .bind(&execution.workflow_id)
300        .bind(&execution.workflow_version)
301        .bind(status)
302        .bind(&initial_input)
303        .bind(&current_state)
304        .bind(&started_at)
305        .bind(&updated_at)
306        .bind(completed_at.as_deref())
307        .execute(&self.pool)
308        .await
309        .map_err(map_db_err)?;
310
311        Ok(())
312    }
313
314    #[instrument(skip(self), fields(execution_id = %id))]
315    async fn get_execution(&self, id: &ExecutionId) -> BackendResult<Option<WorkflowExecution>> {
316        let id_str = execution_id_str(id);
317        let row = sqlx::query("SELECT * FROM workflow_executions WHERE execution_id = ?")
318            .bind(&id_str)
319            .fetch_optional(&self.pool)
320            .await
321            .map_err(map_db_err)?;
322
323        row.map(|r| row_to_execution(&r)).transpose()
324    }
325
326    #[instrument(skip(self), fields(execution_id = %id, status = ?status))]
327    async fn update_execution_status(
328        &self,
329        id: &ExecutionId,
330        status: WorkflowStatus,
331    ) -> BackendResult<()> {
332        let id_str = execution_id_str(id);
333        let status_str = status_to_str(&status);
334        let now = Utc::now().to_rfc3339();
335        let completed_at = if status.is_terminal() {
336            Some(now.clone())
337        } else {
338            None
339        };
340
341        let rows_affected = sqlx::query(
342            "UPDATE workflow_executions SET status = ?, updated_at = ?, completed_at = COALESCE(?, completed_at) WHERE execution_id = ?",
343        )
344        .bind(status_str)
345        .bind(&now)
346        .bind(completed_at.as_deref())
347        .bind(&id_str)
348        .execute(&self.pool)
349        .await
350        .map_err(map_db_err)?
351        .rows_affected();
352
353        if rows_affected == 0 {
354            return Err(StateBackendError::NotFound(id_str));
355        }
356        Ok(())
357    }
358
359    async fn update_execution_current_state(
360        &self,
361        id: &ExecutionId,
362        current_state: &serde_json::Value,
363    ) -> BackendResult<()> {
364        let id_str = execution_id_str(id);
365        let state_str =
366            serde_json::to_string(current_state).map_err(StateBackendError::Serialization)?;
367        let now = Utc::now().to_rfc3339();
368        sqlx::query(
369            "UPDATE workflow_executions SET current_state = ?, updated_at = ? WHERE execution_id = ?",
370        )
371        .bind(&state_str)
372        .bind(&now)
373        .bind(&id_str)
374        .execute(&self.pool)
375        .await
376        .map_err(map_db_err)?;
377        Ok(())
378    }
379
380    async fn patch_append_array(
381        &self,
382        execution_id: &ExecutionId,
383        key: &str,
384        value: serde_json::Value,
385    ) -> BackendResult<()> {
386        let exec = self
387            .get_execution(execution_id)
388            .await?
389            .ok_or_else(|| StateBackendError::NotFound(format!("execution {execution_id}")))?;
390        let mut state = exec.current_state.clone();
391        let arr = state
392            .as_object_mut()
393            .ok_or_else(|| StateBackendError::Database("state is not a JSON object".into()))?
394            .entry(key)
395            .or_insert_with(|| serde_json::json!([]));
396        arr.as_array_mut()
397            .ok_or_else(|| StateBackendError::Database(format!("{key} is not an array")))?
398            .push(value);
399        self.update_execution_current_state(execution_id, &state)
400            .await
401    }
402
403    #[instrument(skip(self))]
404    async fn list_executions(
405        &self,
406        status: Option<WorkflowStatus>,
407        limit: u32,
408        offset: u32,
409    ) -> BackendResult<Vec<WorkflowExecution>> {
410        let rows = match status {
411            Some(s) => {
412                let status_str = status_to_str(&s);
413                sqlx::query(
414                    "SELECT * FROM workflow_executions WHERE status = ? ORDER BY updated_at DESC LIMIT ? OFFSET ?",
415                )
416                .bind(status_str)
417                .bind(limit as i64)
418                .bind(offset as i64)
419                .fetch_all(&self.pool)
420                .await
421                .map_err(map_db_err)?
422            }
423            None => sqlx::query(
424                "SELECT * FROM workflow_executions ORDER BY updated_at DESC LIMIT ? OFFSET ?",
425            )
426            .bind(limit as i64)
427            .bind(offset as i64)
428            .fetch_all(&self.pool)
429            .await
430            .map_err(map_db_err)?,
431        };
432
433        rows.iter().map(row_to_execution).collect()
434    }
435
436    // ── Event log ─────────────────────────────────────────────────────────
437
438    #[instrument(skip(self, event), fields(execution_id = %event.execution_id, seq = event.sequence))]
439    async fn append_event(&self, event: Event) -> BackendResult<EventSequence> {
440        let id = event.id.to_string();
441        let execution_id = execution_id_str(&event.execution_id);
442        let kind_json = serde_json::to_string(&event.kind)?;
443        let created_at = event.created_at.to_rfc3339();
444
445        // Assign the sequence atomically inside a transaction so concurrent
446        // appends to the same execution cannot compute the same number and
447        // collide on UNIQUE(execution_id, sequence). The caller-supplied
448        // sequence is advisory; the database is the source of truth.
449        let mut tx = self.pool.begin().await.map_err(map_db_err)?;
450        let seq_row = sqlx::query(
451            "SELECT COALESCE(MAX(sequence), 0) + 1 AS seq FROM events WHERE execution_id = ?",
452        )
453        .bind(&execution_id)
454        .fetch_one(&mut *tx)
455        .await
456        .map_err(map_db_err)?;
457        let sequence: i64 = seq_row.try_get::<i64, _>("seq").map_err(map_db_err)?;
458
459        sqlx::query(
460            r#"INSERT INTO events (id, execution_id, sequence, kind_json, created_at)
461               VALUES (?, ?, ?, ?, ?)"#,
462        )
463        .bind(&id)
464        .bind(&execution_id)
465        .bind(sequence)
466        .bind(&kind_json)
467        .bind(&created_at)
468        .execute(&mut *tx)
469        .await
470        .map_err(map_db_err)?;
471        tx.commit().await.map_err(map_db_err)?;
472
473        Ok(sequence)
474    }
475
476    #[instrument(skip(self), fields(execution_id = %execution_id))]
477    async fn get_events(&self, execution_id: &ExecutionId) -> BackendResult<Vec<Event>> {
478        let id_str = execution_id_str(execution_id);
479        let rows = sqlx::query("SELECT * FROM events WHERE execution_id = ? ORDER BY sequence ASC")
480            .bind(&id_str)
481            .fetch_all(&self.pool)
482            .await
483            .map_err(map_db_err)?;
484
485        rows.iter().map(row_to_event).collect()
486    }
487
488    #[instrument(skip(self), fields(execution_id = %execution_id, since = since_sequence))]
489    async fn get_events_since(
490        &self,
491        execution_id: &ExecutionId,
492        since_sequence: EventSequence,
493    ) -> BackendResult<Vec<Event>> {
494        let id_str = execution_id_str(execution_id);
495        let rows = sqlx::query(
496            "SELECT * FROM events WHERE execution_id = ? AND sequence > ? ORDER BY sequence ASC",
497        )
498        .bind(&id_str)
499        .bind(since_sequence)
500        .fetch_all(&self.pool)
501        .await
502        .map_err(map_db_err)?;
503
504        rows.iter().map(row_to_event).collect()
505    }
506
507    #[instrument(skip(self), fields(execution_id = %execution_id))]
508    async fn latest_sequence(&self, execution_id: &ExecutionId) -> BackendResult<EventSequence> {
509        let id_str = execution_id_str(execution_id);
510        let row = sqlx::query(
511            "SELECT COALESCE(MAX(sequence), 0) as seq FROM events WHERE execution_id = ?",
512        )
513        .bind(&id_str)
514        .fetch_one(&self.pool)
515        .await
516        .map_err(map_db_err)?;
517
518        Ok(row.try_get::<i64, _>("seq").map_err(map_db_err)?)
519    }
520
521    // ── Snapshots ─────────────────────────────────────────────────────────
522
523    #[instrument(skip(self, snapshot), fields(execution_id = %snapshot.execution_id, at_seq = snapshot.at_sequence))]
524    async fn write_snapshot(&self, snapshot: Snapshot) -> BackendResult<()> {
525        let id = snapshot.id.to_string();
526        let execution_id = execution_id_str(&snapshot.execution_id);
527        let state_json = serde_json::to_string(&snapshot.state)?;
528        let created_at = snapshot.created_at.to_rfc3339();
529
530        sqlx::query(
531            r#"INSERT OR REPLACE INTO snapshots (id, execution_id, at_sequence, state_json, created_at)
532               VALUES (?, ?, ?, ?, ?)"#,
533        )
534        .bind(&id)
535        .bind(&execution_id)
536        .bind(snapshot.at_sequence)
537        .bind(&state_json)
538        .bind(&created_at)
539        .execute(&self.pool)
540        .await
541        .map_err(map_db_err)?;
542
543        Ok(())
544    }
545
546    #[instrument(skip(self), fields(execution_id = %execution_id))]
547    async fn latest_snapshot(&self, execution_id: &ExecutionId) -> BackendResult<Option<Snapshot>> {
548        let id_str = execution_id_str(execution_id);
549        let row = sqlx::query(
550            "SELECT * FROM snapshots WHERE execution_id = ? ORDER BY at_sequence DESC LIMIT 1",
551        )
552        .bind(&id_str)
553        .fetch_optional(&self.pool)
554        .await
555        .map_err(map_db_err)?;
556
557        let Some(row) = row else { return Ok(None) };
558
559        let id = Uuid::parse_str(row.try_get::<&str, _>("id").map_err(map_db_err)?)
560            .map_err(|e| StateBackendError::Database(e.to_string()))?;
561        let execution_id =
562            parse_execution_id(row.try_get::<&str, _>("execution_id").map_err(map_db_err)?)?;
563        let at_sequence: i64 = row.try_get("at_sequence").map_err(map_db_err)?;
564        let state: serde_json::Value =
565            serde_json::from_str(row.try_get::<&str, _>("state_json").map_err(map_db_err)?)
566                .map_err(StateBackendError::Serialization)?;
567        let created_at = parse_datetime(row.try_get::<&str, _>("created_at").map_err(map_db_err)?)?;
568
569        Ok(Some(Snapshot {
570            id,
571            execution_id,
572            at_sequence,
573            state,
574            created_at,
575        }))
576    }
577
578    // ── Work item queue ───────────────────────────────────────────────────
579
580    #[instrument(skip(self, item), fields(execution_id = %item.execution_id, node_id = %item.node_id))]
581    async fn enqueue_work_item(&self, item: WorkItem) -> BackendResult<WorkItemId> {
582        let id = item.id.to_string();
583        let execution_id = execution_id_str(&item.execution_id);
584        let payload_json = serde_json::to_string(&item.payload)?;
585        let created_at = item.created_at.to_rfc3339();
586
587        sqlx::query(
588            r#"INSERT INTO work_items
589               (id, execution_id, node_id, queue_type, payload_json, attempt, max_attempts, status, created_at)
590               VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)"#,
591        )
592        .bind(&id)
593        .bind(&execution_id)
594        .bind(&item.node_id)
595        .bind(&item.queue_type)
596        .bind(&payload_json)
597        .bind(item.attempt as i64)
598        .bind(item.max_attempts as i64)
599        .bind(&created_at)
600        .execute(&self.pool)
601        .await
602        .map_err(map_db_err)?;
603
604        Ok(item.id)
605    }
606
607    #[instrument(skip(self), fields(worker_id = worker_id))]
608    async fn claim_work_item(
609        &self,
610        worker_id: &str,
611        queue_types: &[&str],
612    ) -> BackendResult<Option<WorkItem>> {
613        if queue_types.is_empty() {
614            return Ok(None);
615        }
616
617        // Expire stale leases first
618        let now = Utc::now().to_rfc3339();
619        sqlx::query(
620            "UPDATE work_items SET status = 'pending', worker_id = NULL, lease_expires_at = NULL \
621             WHERE status = 'claimed' AND lease_expires_at < ?",
622        )
623        .bind(&now)
624        .execute(&self.pool)
625        .await
626        .map_err(map_db_err)?;
627
628        // SQLite doesn't support UPDATE ... RETURNING with a subquery easily, so
629        // we use a transaction: SELECT FOR UPDATE equivalent via exclusive transaction.
630        let mut tx = self.pool.begin().await.map_err(map_db_err)?;
631
632        // Build placeholders for queue_types IN clause
633        let placeholders = queue_types
634            .iter()
635            .map(|_| "?")
636            .collect::<Vec<_>>()
637            .join(",");
638        let query_str = format!(
639            "SELECT * FROM work_items WHERE status = 'pending' AND queue_type IN ({}) \
640             AND (retry_after IS NULL OR retry_after <= ?) ORDER BY created_at ASC LIMIT 1",
641            placeholders
642        );
643        let mut q = sqlx::query(&query_str);
644        for qt in queue_types {
645            q = q.bind(*qt);
646        }
647        q = q.bind(&now); // for retry_after <= now check
648        let row = q.fetch_optional(&mut *tx).await.map_err(map_db_err)?;
649
650        let Some(row) = row else {
651            tx.rollback().await.map_err(map_db_err)?;
652            return Ok(None);
653        };
654
655        let item = row_to_work_item(&row)?;
656        let item_id = item.id.to_string();
657        let lease_expires_at = (Utc::now() + chrono::Duration::seconds(30)).to_rfc3339();
658        let claimed_at = Utc::now().to_rfc3339();
659
660        sqlx::query(
661            "UPDATE work_items SET status = 'claimed', worker_id = ?, lease_expires_at = ?, claimed_at = ? WHERE id = ?",
662        )
663        .bind(worker_id)
664        .bind(&lease_expires_at)
665        .bind(&claimed_at)
666        .bind(&item_id)
667        .execute(&mut *tx)
668        .await
669        .map_err(map_db_err)?;
670
671        tx.commit().await.map_err(map_db_err)?;
672
673        // Return item with updated fields
674        let mut claimed = item;
675        claimed.worker_id = Some(worker_id.to_string());
676        claimed.lease_expires_at = Some(
677            DateTime::parse_from_rfc3339(&lease_expires_at)
678                .map(|dt| dt.with_timezone(&Utc))
679                .map_err(|e| StateBackendError::Database(e.to_string()))?,
680        );
681        Ok(Some(claimed))
682    }
683
684    #[instrument(skip(self), fields(item_id = %item_id, worker_id = worker_id))]
685    async fn renew_lease(&self, item_id: WorkItemId, worker_id: &str) -> BackendResult<()> {
686        let lease_expires_at = (Utc::now() + chrono::Duration::seconds(30)).to_rfc3339();
687        let id_str = item_id.to_string();
688
689        let rows_affected = sqlx::query(
690            "UPDATE work_items SET lease_expires_at = ? WHERE id = ? AND worker_id = ? AND status = 'claimed'",
691        )
692        .bind(&lease_expires_at)
693        .bind(&id_str)
694        .bind(worker_id)
695        .execute(&self.pool)
696        .await
697        .map_err(map_db_err)?
698        .rows_affected();
699
700        if rows_affected == 0 {
701            return Err(StateBackendError::NotFound(id_str));
702        }
703        Ok(())
704    }
705
706    #[instrument(skip(self), fields(item_id = %item_id))]
707    async fn complete_work_item(&self, item_id: WorkItemId) -> BackendResult<()> {
708        let id_str = item_id.to_string();
709        let completed_at = Utc::now().to_rfc3339();
710
711        let rows_affected = sqlx::query(
712            "UPDATE work_items SET status = 'completed', completed_at = ?, lease_expires_at = NULL WHERE id = ?",
713        )
714        .bind(&completed_at)
715        .bind(&id_str)
716        .execute(&self.pool)
717        .await
718        .map_err(map_db_err)?
719        .rows_affected();
720
721        if rows_affected == 0 {
722            return Err(StateBackendError::NotFound(id_str));
723        }
724        Ok(())
725    }
726
727    #[instrument(skip(self, error), fields(item_id = %item_id))]
728    async fn fail_work_item(&self, item_id: WorkItemId, error: &str) -> BackendResult<()> {
729        let id_str = item_id.to_string();
730        let _ = error; // logged by caller; stored in event log not here
731
732        let rows_affected = sqlx::query(
733            "UPDATE work_items SET status = 'failed', lease_expires_at = NULL, worker_id = NULL WHERE id = ?",
734        )
735        .bind(&id_str)
736        .execute(&self.pool)
737        .await
738        .map_err(map_db_err)?
739        .rows_affected();
740
741        if rows_affected == 0 {
742            return Err(StateBackendError::NotFound(id_str));
743        }
744        Ok(())
745    }
746
747    #[instrument(skip(self))]
748    async fn reclaim_expired_leases(&self) -> BackendResult<ReclaimResult> {
749        let now = Utc::now().to_rfc3339();
750
751        // Find all claimed items whose lease has expired.
752        let rows = sqlx::query(
753            "SELECT * FROM work_items WHERE status = 'claimed' AND lease_expires_at < ? ORDER BY created_at ASC",
754        )
755        .bind(&now)
756        .fetch_all(&self.pool)
757        .await
758        .map_err(map_db_err)?;
759
760        let mut result = ReclaimResult::default();
761
762        for row in &rows {
763            let item = row_to_work_item(row)?;
764            let new_attempt = item.attempt + 1;
765            let id_str = item.id.to_string();
766
767            if new_attempt >= item.max_attempts {
768                // Exhausted — move to dead-letter (caller emits the event)
769                let dead_lettered_at = Utc::now().to_rfc3339();
770                sqlx::query(
771                    r#"INSERT OR IGNORE INTO dead_letter_items
772                       (id, execution_id, node_id, queue_type, payload_json, attempt, last_error, created_at, dead_lettered_at)
773                       VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
774                )
775                .bind(&id_str)
776                .bind(execution_id_str(&item.execution_id))
777                .bind(&item.node_id)
778                .bind(&item.queue_type)
779                .bind(serde_json::to_string(&item.payload)?)
780                .bind(new_attempt as i64)
781                .bind("lease expired: worker dead")
782                .bind(item.created_at.to_rfc3339())
783                .bind(dead_lettered_at)
784                .execute(&self.pool)
785                .await
786                .map_err(map_db_err)?;
787
788                sqlx::query("UPDATE work_items SET status = 'dead_lettered', attempt = ?, lease_expires_at = NULL, worker_id = NULL WHERE id = ?")
789                    .bind(new_attempt as i64)
790                    .bind(&id_str)
791                    .execute(&self.pool)
792                    .await
793                    .map_err(map_db_err)?;
794
795                let mut exhausted_item = item;
796                exhausted_item.attempt = new_attempt;
797                result.exhausted.push(exhausted_item);
798            } else {
799                // Retryable — reset to pending with incremented attempt.
800                // Apply exponential backoff: 2^attempt seconds.
801                let backoff_secs = 1u64 << new_attempt.min(6); // max 64s
802                let retry_after =
803                    (Utc::now() + chrono::Duration::seconds(backoff_secs as i64)).to_rfc3339();
804
805                sqlx::query(
806                    "UPDATE work_items SET status = 'pending', attempt = ?, worker_id = NULL, lease_expires_at = NULL, retry_after = ? WHERE id = ?",
807                )
808                .bind(new_attempt as i64)
809                .bind(&retry_after)
810                .bind(&id_str)
811                .execute(&self.pool)
812                .await
813                .map_err(map_db_err)?;
814
815                let mut retry_item = item;
816                retry_item.attempt = new_attempt;
817                result.retryable.push(retry_item);
818            }
819        }
820
821        Ok(result)
822    }
823
824    #[instrument(skip(self, last_error), fields(item_id = %item_id))]
825    async fn move_to_dead_letter(
826        &self,
827        item_id: WorkItemId,
828        last_error: &str,
829    ) -> BackendResult<()> {
830        let id_str = item_id.to_string();
831
832        // Load the item first to copy fields.
833        let row = sqlx::query("SELECT * FROM work_items WHERE id = ?")
834            .bind(&id_str)
835            .fetch_optional(&self.pool)
836            .await
837            .map_err(map_db_err)?;
838
839        let Some(row) = row else {
840            return Err(StateBackendError::NotFound(id_str));
841        };
842        let item = row_to_work_item(&row)?;
843        let dead_lettered_at = Utc::now().to_rfc3339();
844
845        sqlx::query(
846            r#"INSERT OR REPLACE INTO dead_letter_items
847               (id, execution_id, node_id, queue_type, payload_json, attempt, last_error, created_at, dead_lettered_at)
848               VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"#,
849        )
850        .bind(&id_str)
851        .bind(execution_id_str(&item.execution_id))
852        .bind(&item.node_id)
853        .bind(&item.queue_type)
854        .bind(serde_json::to_string(&item.payload)?)
855        .bind(item.attempt as i64)
856        .bind(last_error)
857        .bind(item.created_at.to_rfc3339())
858        .bind(dead_lettered_at)
859        .execute(&self.pool)
860        .await
861        .map_err(map_db_err)?;
862
863        sqlx::query("UPDATE work_items SET status = 'dead_lettered', lease_expires_at = NULL, worker_id = NULL WHERE id = ?")
864            .bind(&id_str)
865            .execute(&self.pool)
866            .await
867            .map_err(map_db_err)?;
868
869        Ok(())
870    }
871
872    async fn create_token(&self, name: &str, role: &str) -> BackendResult<(String, ApiToken)> {
873        use rand::Rng;
874        use sha2::{Digest, Sha256};
875
876        // Generate a random 32-byte token, hex-encoded with a prefix.
877        let random_bytes: [u8; 32] = rand::thread_rng().gen();
878        let token = format!(
879            "jj_{}",
880            random_bytes
881                .iter()
882                .map(|b| format!("{b:02x}"))
883                .collect::<String>()
884        );
885        let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
886
887        let id = uuid::Uuid::new_v4().to_string();
888        let now = Utc::now().to_rfc3339();
889
890        sqlx::query(
891            r#"INSERT INTO api_tokens (id, token_hash, name, role, created_at)
892               VALUES (?, ?, ?, ?, ?)"#,
893        )
894        .bind(&id)
895        .bind(&token_hash)
896        .bind(name)
897        .bind(role)
898        .bind(&now)
899        .execute(&self.pool)
900        .await
901        .map_err(map_db_err)?;
902
903        let info = ApiToken {
904            id,
905            name: name.to_string(),
906            role: role.to_string(),
907            created_at: Utc::now(),
908            expires_at: None,
909            tenant_id: crate::tenant::DEFAULT_TENANT.to_string(),
910        };
911        Ok((token, info))
912    }
913
914    async fn validate_token(&self, token: &str) -> BackendResult<Option<ApiToken>> {
915        use sha2::{Digest, Sha256};
916
917        let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
918        let now = Utc::now().to_rfc3339();
919
920        let row = sqlx::query(
921            r#"SELECT id, name, role, created_at, expires_at
922               FROM api_tokens
923               WHERE token_hash = ?
924                 AND revoked_at IS NULL
925                 AND (expires_at IS NULL OR expires_at > ?)"#,
926        )
927        .bind(&token_hash)
928        .bind(&now)
929        .fetch_optional(&self.pool)
930        .await
931        .map_err(map_db_err)?;
932
933        let Some(row) = row else { return Ok(None) };
934
935        // Update last_used_at.
936        let id: String = row.get("id");
937        sqlx::query("UPDATE api_tokens SET last_used_at = ? WHERE id = ?")
938            .bind(&now)
939            .bind(&id)
940            .execute(&self.pool)
941            .await
942            .map_err(map_db_err)?;
943
944        let expires_at: Option<String> = row.get("expires_at");
945        let tenant_id: String = row
946            .try_get("tenant_id")
947            .unwrap_or_else(|_| crate::tenant::DEFAULT_TENANT.to_string());
948        let info = ApiToken {
949            id,
950            name: row.get("name"),
951            role: row.get("role"),
952            created_at: row
953                .get::<String, _>("created_at")
954                .parse::<chrono::DateTime<Utc>>()
955                .unwrap_or_else(|_| Utc::now()),
956            expires_at: expires_at.and_then(|s| s.parse().ok()),
957            tenant_id,
958        };
959        Ok(Some(info))
960    }
961}
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966    use chrono::Utc;
967    use jamjet_core::workflow::{WorkflowExecution, WorkflowStatus};
968    use serde_json::json;
969
970    async fn open_test_db() -> SqliteBackend {
971        let backend = SqliteBackend::open("sqlite::memory:")
972            .await
973            .expect("failed to open in-memory SQLite");
974        backend
975    }
976
977    fn sample_execution() -> WorkflowExecution {
978        let now = Utc::now();
979        WorkflowExecution {
980            execution_id: ExecutionId::new(),
981            workflow_id: "test-wf".to_string(),
982            workflow_version: "1.0.0".to_string(),
983            status: WorkflowStatus::Pending,
984            initial_input: json!({"x": 1}),
985            current_state: json!({}),
986            started_at: now,
987            updated_at: now,
988            completed_at: None,
989            session_type: None,
990        }
991    }
992
993    #[tokio::test]
994    async fn test_create_and_get_execution() {
995        let db = open_test_db().await;
996        let exec = sample_execution();
997        let id = exec.execution_id.clone();
998        db.create_execution(exec).await.unwrap();
999        let fetched = db.get_execution(&id).await.unwrap().unwrap();
1000        assert_eq!(fetched.workflow_id, "test-wf");
1001        assert_eq!(fetched.status, WorkflowStatus::Pending);
1002    }
1003
1004    #[tokio::test]
1005    async fn test_update_status() {
1006        let db = open_test_db().await;
1007        let exec = sample_execution();
1008        let id = exec.execution_id.clone();
1009        db.create_execution(exec).await.unwrap();
1010        db.update_execution_status(&id, WorkflowStatus::Running)
1011            .await
1012            .unwrap();
1013        let fetched = db.get_execution(&id).await.unwrap().unwrap();
1014        assert_eq!(fetched.status, WorkflowStatus::Running);
1015    }
1016
1017    #[tokio::test]
1018    async fn test_list_executions() {
1019        let db = open_test_db().await;
1020        db.create_execution(sample_execution()).await.unwrap();
1021        db.create_execution(sample_execution()).await.unwrap();
1022        let all = db.list_executions(None, 10, 0).await.unwrap();
1023        assert_eq!(all.len(), 2);
1024    }
1025
1026    #[tokio::test]
1027    async fn test_event_log() {
1028        use crate::event::{Event, EventKind};
1029        let db = open_test_db().await;
1030        let exec = sample_execution();
1031        let exec_id = exec.execution_id.clone();
1032        db.create_execution(exec).await.unwrap();
1033
1034        let event = Event::new(
1035            exec_id.clone(),
1036            1,
1037            EventKind::WorkflowStarted {
1038                workflow_id: "test-wf".to_string(),
1039                workflow_version: "1.0.0".to_string(),
1040                initial_input: json!({"x": 1}),
1041            },
1042        );
1043        db.append_event(event).await.unwrap();
1044
1045        let events = db.get_events(&exec_id).await.unwrap();
1046        assert_eq!(events.len(), 1);
1047        assert_eq!(events[0].sequence, 1);
1048
1049        let seq = db.latest_sequence(&exec_id).await.unwrap();
1050        assert_eq!(seq, 1);
1051    }
1052
1053    #[tokio::test]
1054    async fn test_snapshot() {
1055        use crate::snapshot::Snapshot;
1056        let db = open_test_db().await;
1057        let exec = sample_execution();
1058        let exec_id = exec.execution_id.clone();
1059        db.create_execution(exec).await.unwrap();
1060
1061        let snap = Snapshot::new(exec_id.clone(), 5, json!({"nodes_completed": ["a", "b"]}));
1062        db.write_snapshot(snap).await.unwrap();
1063
1064        let loaded = db.latest_snapshot(&exec_id).await.unwrap().unwrap();
1065        assert_eq!(loaded.at_sequence, 5);
1066    }
1067
1068    #[tokio::test]
1069    async fn test_workflow_definition() {
1070        use crate::backend::WorkflowDefinition;
1071        let db = open_test_db().await;
1072
1073        let def = WorkflowDefinition {
1074            workflow_id: "my-wf".to_string(),
1075            version: "1.0.0".to_string(),
1076            ir: json!({"workflow_id": "my-wf", "version": "1.0.0", "nodes": {}}),
1077            created_at: Utc::now(),
1078            tenant_id: crate::tenant::DEFAULT_TENANT.to_string(),
1079        };
1080        db.store_workflow(def).await.unwrap();
1081
1082        let loaded = db.get_workflow("my-wf", "1.0.0").await.unwrap().unwrap();
1083        assert_eq!(loaded.workflow_id, "my-wf");
1084        assert_eq!(loaded.version, "1.0.0");
1085
1086        // Non-existent version returns None
1087        let missing = db.get_workflow("my-wf", "2.0.0").await.unwrap();
1088        assert!(missing.is_none());
1089    }
1090
1091    #[tokio::test]
1092    async fn test_work_item_queue() {
1093        let db = open_test_db().await;
1094        let exec = sample_execution();
1095        let exec_id = exec.execution_id.clone();
1096        db.create_execution(exec).await.unwrap();
1097
1098        let item = WorkItem {
1099            id: Uuid::new_v4(),
1100            execution_id: exec_id.clone(),
1101            node_id: "node-1".to_string(),
1102            queue_type: "default".to_string(),
1103            payload: json!({}),
1104            attempt: 0,
1105            max_attempts: 3,
1106            created_at: Utc::now(),
1107            lease_expires_at: None,
1108            worker_id: None,
1109            tenant_id: crate::tenant::DEFAULT_TENANT.to_string(),
1110        };
1111        let item_id = item.id;
1112        db.enqueue_work_item(item).await.unwrap();
1113
1114        let claimed = db
1115            .claim_work_item("worker-1", &["default"])
1116            .await
1117            .unwrap()
1118            .unwrap();
1119        assert_eq!(claimed.node_id, "node-1");
1120        assert_eq!(claimed.worker_id.as_deref(), Some("worker-1"));
1121
1122        db.complete_work_item(item_id).await.unwrap();
1123
1124        // No more items
1125        let none = db.claim_work_item("worker-1", &["default"]).await.unwrap();
1126        assert!(none.is_none());
1127    }
1128
1129    #[tokio::test]
1130    async fn test_patch_append_array() {
1131        let db = open_test_db().await;
1132        let exec = sample_execution();
1133        let id = exec.execution_id.clone();
1134        db.create_execution(exec).await.unwrap();
1135
1136        db.patch_append_array(
1137            &id,
1138            "agent_tool_events",
1139            json!({"type": "progress", "chunk": 0}),
1140        )
1141        .await
1142        .unwrap();
1143        db.patch_append_array(
1144            &id,
1145            "agent_tool_events",
1146            json!({"type": "progress", "chunk": 1}),
1147        )
1148        .await
1149        .unwrap();
1150
1151        let fetched = db.get_execution(&id).await.unwrap().unwrap();
1152        let events = fetched.current_state["agent_tool_events"]
1153            .as_array()
1154            .unwrap();
1155        assert_eq!(events.len(), 2);
1156        assert_eq!(events[0]["chunk"], 0);
1157        assert_eq!(events[1]["chunk"], 1);
1158    }
1159}