Skip to main content

kernex_memory/store/
checkpoints.rs

1//! Phase checkpoint storage for resumable pipeline runs.
2//!
3//! Each checkpoint records the status and output of one pipeline phase
4//! within a run. A `run_id` (UUID) groups all phases belonging to the
5//! same pipeline execution, allowing a failed run to be resumed from
6//! the last completed phase.
7//!
8//! # Lifecycle
9//!
10//! ```text
11//! upsert_phase_checkpoint(run_id, phase, "pending",     None,   None)
12//! upsert_phase_checkpoint(run_id, phase, "in_progress", None,   None)
13//! upsert_phase_checkpoint(run_id, phase, "completed",   output, None)
14//! // or
15//! upsert_phase_checkpoint(run_id, phase, "failed",      None,   Some(error))
16//! ```
17//!
18//! To resume a run, call `get_run_checkpoints` and skip phases whose
19//! `status` is `"completed"`.
20
21use super::Store;
22use crate::error::MemoryError;
23
24type CheckpointRow = (
25    String,
26    String,
27    String,
28    String,
29    String,
30    String,
31    String,
32    Option<String>,
33    Option<String>,
34    i64,
35    String,
36    String,
37);
38
39/// A recorded snapshot of one pipeline phase within a run.
40#[derive(Debug, Clone)]
41pub struct PhaseCheckpoint {
42    /// Unique row identifier.
43    pub id: String,
44    /// UUID identifying the enclosing pipeline run.
45    pub run_id: String,
46    /// Name of the topology (e.g. `"my-pipeline"`).
47    pub topology_name: String,
48    /// Name of the phase within the topology.
49    pub phase_name: String,
50    /// Agent / sender identifier that owns this run.
51    pub sender_id: String,
52    /// Project scope. Empty string = no project.
53    pub project: String,
54    /// One of `"pending"`, `"in_progress"`, `"completed"`, `"failed"`.
55    pub status: String,
56    /// Phase output text. `None` until the phase completes successfully.
57    pub output: Option<String>,
58    /// Error detail. Set only when `status` is `"failed"`.
59    pub error_message: Option<String>,
60    /// How many attempts have been made for this phase (0-indexed).
61    pub attempt: i64,
62    /// ISO-8601 creation timestamp.
63    pub created_at: String,
64    /// ISO-8601 last-update timestamp.
65    pub updated_at: String,
66}
67
68impl Store {
69    /// Create or update a checkpoint for one phase within a run.
70    ///
71    /// Uses `INSERT OR REPLACE` keyed on `(run_id, phase_name)`, so calling
72    /// this multiple times as the phase progresses is safe and idempotent.
73    #[allow(clippy::too_many_arguments)]
74    pub async fn upsert_phase_checkpoint(
75        &self,
76        run_id: &str,
77        topology_name: &str,
78        phase_name: &str,
79        sender_id: &str,
80        project: &str,
81        status: &str,
82        output: Option<&str>,
83        error_message: Option<&str>,
84        attempt: i64,
85    ) -> Result<(), MemoryError> {
86        let id = uuid::Uuid::new_v4().to_string();
87        sqlx::query(
88            "INSERT INTO phase_checkpoints \
89             (id, run_id, topology_name, phase_name, sender_id, project, \
90              status, output, error_message, attempt) \
91             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) \
92             ON CONFLICT (run_id, phase_name) DO UPDATE SET \
93               topology_name = excluded.topology_name, \
94               sender_id     = excluded.sender_id, \
95               project       = excluded.project, \
96               status        = excluded.status, \
97               output        = excluded.output, \
98               error_message = excluded.error_message, \
99               attempt       = excluded.attempt, \
100               updated_at    = datetime('now')",
101        )
102        .bind(&id)
103        .bind(run_id)
104        .bind(topology_name)
105        .bind(phase_name)
106        .bind(sender_id)
107        .bind(project)
108        .bind(status)
109        .bind(output)
110        .bind(error_message)
111        .bind(attempt)
112        .execute(&self.pool)
113        .await
114        .map_err(|e| MemoryError::sqlite("upsert phase checkpoint", e))?;
115        Ok(())
116    }
117
118    /// Fetch the checkpoint for a specific phase within a run.
119    pub async fn get_phase_checkpoint(
120        &self,
121        run_id: &str,
122        phase_name: &str,
123    ) -> Result<Option<PhaseCheckpoint>, MemoryError> {
124        let row: Option<CheckpointRow> = sqlx::query_as(
125            "SELECT id, run_id, topology_name, phase_name, sender_id, project, \
126                    status, output, error_message, attempt, created_at, updated_at \
127             FROM phase_checkpoints WHERE run_id = ? AND phase_name = ?",
128        )
129        .bind(run_id)
130        .bind(phase_name)
131        .fetch_optional(&self.pool)
132        .await
133        .map_err(|e| MemoryError::sqlite("get phase checkpoint", e))?;
134
135        Ok(row.map(
136            |(
137                id,
138                run_id,
139                topology_name,
140                phase_name,
141                sender_id,
142                project,
143                status,
144                output,
145                error_message,
146                attempt,
147                created_at,
148                updated_at,
149            )| PhaseCheckpoint {
150                id,
151                run_id,
152                topology_name,
153                phase_name,
154                sender_id,
155                project,
156                status,
157                output,
158                error_message,
159                attempt,
160                created_at,
161                updated_at,
162            },
163        ))
164    }
165
166    /// Fetch all phase checkpoints for a run, ordered by creation time.
167    ///
168    /// Use this to inspect which phases have already completed when resuming
169    /// a failed run.
170    pub async fn get_run_checkpoints(
171        &self,
172        run_id: &str,
173    ) -> Result<Vec<PhaseCheckpoint>, MemoryError> {
174        let rows: Vec<CheckpointRow> = sqlx::query_as(
175            "SELECT id, run_id, topology_name, phase_name, sender_id, project, \
176                    status, output, error_message, attempt, created_at, updated_at \
177             FROM phase_checkpoints WHERE run_id = ? ORDER BY created_at ASC",
178        )
179        .bind(run_id)
180        .fetch_all(&self.pool)
181        .await
182        .map_err(|e| MemoryError::sqlite("get run checkpoints", e))?;
183
184        Ok(rows
185            .into_iter()
186            .map(
187                |(
188                    id,
189                    run_id,
190                    topology_name,
191                    phase_name,
192                    sender_id,
193                    project,
194                    status,
195                    output,
196                    error_message,
197                    attempt,
198                    created_at,
199                    updated_at,
200                )| PhaseCheckpoint {
201                    id,
202                    run_id,
203                    topology_name,
204                    phase_name,
205                    sender_id,
206                    project,
207                    status,
208                    output,
209                    error_message,
210                    attempt,
211                    created_at,
212                    updated_at,
213                },
214            )
215            .collect())
216    }
217
218    /// Delete all checkpoints for a run.
219    ///
220    /// Call this after a pipeline run completes successfully to reclaim space,
221    /// or before re-running a pipeline from scratch.
222    pub async fn clear_run_checkpoints(&self, run_id: &str) -> Result<(), MemoryError> {
223        sqlx::query("DELETE FROM phase_checkpoints WHERE run_id = ?")
224            .bind(run_id)
225            .execute(&self.pool)
226            .await
227            .map_err(|e| MemoryError::sqlite("clear run checkpoints", e))?;
228        Ok(())
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use kernex_core::config::MemoryConfig;
236
237    async fn test_store() -> (Store, tempfile::TempDir) {
238        let tmp_dir = tempfile::TempDir::new().unwrap();
239        let db_path = tmp_dir.path().join("checkpoints.db");
240        let config = MemoryConfig {
241            db_path: db_path.to_str().unwrap().to_string(),
242            ..Default::default()
243        };
244        let store = Store::new(&config).await.unwrap();
245        (store, tmp_dir)
246    }
247
248    #[tokio::test]
249    async fn test_upsert_and_get_checkpoint() {
250        let (store, _tmp_dir) = test_store().await;
251        let run_id = uuid::Uuid::new_v4().to_string();
252
253        store
254            .upsert_phase_checkpoint(
255                &run_id,
256                "my-pipeline",
257                "phase-1",
258                "user-1",
259                "",
260                "completed",
261                Some("phase output"),
262                None,
263                0,
264            )
265            .await
266            .unwrap();
267
268        let cp = store
269            .get_phase_checkpoint(&run_id, "phase-1")
270            .await
271            .unwrap()
272            .unwrap();
273
274        assert_eq!(cp.run_id, run_id);
275        assert_eq!(cp.topology_name, "my-pipeline");
276        assert_eq!(cp.phase_name, "phase-1");
277        assert_eq!(cp.status, "completed");
278        assert_eq!(cp.output.as_deref(), Some("phase output"));
279        assert!(cp.error_message.is_none());
280    }
281
282    #[tokio::test]
283    async fn test_upsert_updates_existing() {
284        let (store, _tmp_dir) = test_store().await;
285        let run_id = uuid::Uuid::new_v4().to_string();
286
287        store
288            .upsert_phase_checkpoint(
289                &run_id,
290                "topo",
291                "alpha",
292                "user-1",
293                "",
294                "in_progress",
295                None,
296                None,
297                0,
298            )
299            .await
300            .unwrap();
301
302        store
303            .upsert_phase_checkpoint(
304                &run_id,
305                "topo",
306                "alpha",
307                "user-1",
308                "",
309                "completed",
310                Some("done"),
311                None,
312                0,
313            )
314            .await
315            .unwrap();
316
317        let cp = store
318            .get_phase_checkpoint(&run_id, "alpha")
319            .await
320            .unwrap()
321            .unwrap();
322
323        assert_eq!(cp.status, "completed");
324        assert_eq!(cp.output.as_deref(), Some("done"));
325    }
326
327    #[tokio::test]
328    async fn test_get_run_checkpoints_ordered() {
329        let (store, _tmp_dir) = test_store().await;
330        let run_id = uuid::Uuid::new_v4().to_string();
331
332        for phase in &["phase-1", "phase-2", "phase-3"] {
333            store
334                .upsert_phase_checkpoint(
335                    &run_id,
336                    "topo",
337                    phase,
338                    "user-1",
339                    "",
340                    "completed",
341                    None,
342                    None,
343                    0,
344                )
345                .await
346                .unwrap();
347        }
348
349        let checkpoints = store.get_run_checkpoints(&run_id).await.unwrap();
350        assert_eq!(checkpoints.len(), 3);
351        assert_eq!(checkpoints[0].phase_name, "phase-1");
352        assert_eq!(checkpoints[1].phase_name, "phase-2");
353        assert_eq!(checkpoints[2].phase_name, "phase-3");
354    }
355
356    #[tokio::test]
357    async fn test_clear_run_checkpoints() {
358        let (store, _tmp_dir) = test_store().await;
359        let run_id = uuid::Uuid::new_v4().to_string();
360
361        store
362            .upsert_phase_checkpoint(
363                &run_id,
364                "topo",
365                "phase-1",
366                "user-1",
367                "",
368                "completed",
369                None,
370                None,
371                0,
372            )
373            .await
374            .unwrap();
375
376        store.clear_run_checkpoints(&run_id).await.unwrap();
377
378        let checkpoints = store.get_run_checkpoints(&run_id).await.unwrap();
379        assert!(checkpoints.is_empty());
380    }
381
382    #[tokio::test]
383    async fn test_failed_checkpoint_stores_error() {
384        let (store, _tmp_dir) = test_store().await;
385        let run_id = uuid::Uuid::new_v4().to_string();
386
387        store
388            .upsert_phase_checkpoint(
389                &run_id,
390                "topo",
391                "phase-1",
392                "user-1",
393                "proj-a",
394                "failed",
395                None,
396                Some("provider timeout"),
397                1,
398            )
399            .await
400            .unwrap();
401
402        let cp = store
403            .get_phase_checkpoint(&run_id, "phase-1")
404            .await
405            .unwrap()
406            .unwrap();
407
408        assert_eq!(cp.status, "failed");
409        assert_eq!(cp.error_message.as_deref(), Some("provider timeout"));
410        assert_eq!(cp.attempt, 1);
411        assert_eq!(cp.project, "proj-a");
412    }
413}