Skip to main content

kernex_memory/store/
checkpoints.rs

1//! Phase checkpoint storage for resumable pipeline runs.
2//!
3//! Each checkpoint records the status and output of one pipeline phase
4//! within a run. A `run_id` (UUID) groups all phases belonging to the
5//! same pipeline execution, allowing a failed run to be resumed from
6//! the last completed phase.
7//!
8//! # Lifecycle
9//!
10//! ```text
11//! upsert_phase_checkpoint(run_id, phase, "pending",     None,   None)
12//! upsert_phase_checkpoint(run_id, phase, "in_progress", None,   None)
13//! upsert_phase_checkpoint(run_id, phase, "completed",   output, None)
14//! // or
15//! upsert_phase_checkpoint(run_id, phase, "failed",      None,   Some(error))
16//! ```
17//!
18//! To resume a run, call `get_run_checkpoints` and skip phases whose
19//! `status` is `"completed"`.
20
21use super::Store;
22use kernex_core::error::KernexError;
23
24type CheckpointRow = (
25    String,
26    String,
27    String,
28    String,
29    String,
30    String,
31    String,
32    Option<String>,
33    Option<String>,
34    i64,
35    String,
36    String,
37);
38
39/// A recorded snapshot of one pipeline phase within a run.
40#[derive(Debug, Clone)]
41pub struct PhaseCheckpoint {
42    /// Unique row identifier.
43    pub id: String,
44    /// UUID identifying the enclosing pipeline run.
45    pub run_id: String,
46    /// Name of the topology (e.g. `"my-pipeline"`).
47    pub topology_name: String,
48    /// Name of the phase within the topology.
49    pub phase_name: String,
50    /// Agent / sender identifier that owns this run.
51    pub sender_id: String,
52    /// Project scope. Empty string = no project.
53    pub project: String,
54    /// One of `"pending"`, `"in_progress"`, `"completed"`, `"failed"`.
55    pub status: String,
56    /// Phase output text. `None` until the phase completes successfully.
57    pub output: Option<String>,
58    /// Error detail. Set only when `status` is `"failed"`.
59    pub error_message: Option<String>,
60    /// How many attempts have been made for this phase (0-indexed).
61    pub attempt: i64,
62    /// ISO-8601 creation timestamp.
63    pub created_at: String,
64    /// ISO-8601 last-update timestamp.
65    pub updated_at: String,
66}
67
68impl Store {
69    /// Create or update a checkpoint for one phase within a run.
70    ///
71    /// Uses `INSERT OR REPLACE` keyed on `(run_id, phase_name)`, so calling
72    /// this multiple times as the phase progresses is safe and idempotent.
73    #[allow(clippy::too_many_arguments)]
74    pub async fn upsert_phase_checkpoint(
75        &self,
76        run_id: &str,
77        topology_name: &str,
78        phase_name: &str,
79        sender_id: &str,
80        project: &str,
81        status: &str,
82        output: Option<&str>,
83        error_message: Option<&str>,
84        attempt: i64,
85    ) -> Result<(), KernexError> {
86        let id = uuid::Uuid::new_v4().to_string();
87        sqlx::query(
88            "INSERT INTO phase_checkpoints \
89             (id, run_id, topology_name, phase_name, sender_id, project, \
90              status, output, error_message, attempt) \
91             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) \
92             ON CONFLICT (run_id, phase_name) DO UPDATE SET \
93               topology_name = excluded.topology_name, \
94               sender_id     = excluded.sender_id, \
95               project       = excluded.project, \
96               status        = excluded.status, \
97               output        = excluded.output, \
98               error_message = excluded.error_message, \
99               attempt       = excluded.attempt, \
100               updated_at    = datetime('now')",
101        )
102        .bind(&id)
103        .bind(run_id)
104        .bind(topology_name)
105        .bind(phase_name)
106        .bind(sender_id)
107        .bind(project)
108        .bind(status)
109        .bind(output)
110        .bind(error_message)
111        .bind(attempt)
112        .execute(&self.pool)
113        .await
114        .map_err(|e| KernexError::Store(format!("upsert phase checkpoint: {e}")))?;
115        Ok(())
116    }
117
118    /// Fetch the checkpoint for a specific phase within a run.
119    pub async fn get_phase_checkpoint(
120        &self,
121        run_id: &str,
122        phase_name: &str,
123    ) -> Result<Option<PhaseCheckpoint>, KernexError> {
124        let row: Option<CheckpointRow> = sqlx::query_as(
125            "SELECT id, run_id, topology_name, phase_name, sender_id, project, \
126                    status, output, error_message, attempt, created_at, updated_at \
127             FROM phase_checkpoints WHERE run_id = ? AND phase_name = ?",
128        )
129        .bind(run_id)
130        .bind(phase_name)
131        .fetch_optional(&self.pool)
132        .await
133        .map_err(|e| KernexError::Store(format!("get phase checkpoint: {e}")))?;
134
135        Ok(row.map(
136            |(
137                id,
138                run_id,
139                topology_name,
140                phase_name,
141                sender_id,
142                project,
143                status,
144                output,
145                error_message,
146                attempt,
147                created_at,
148                updated_at,
149            )| PhaseCheckpoint {
150                id,
151                run_id,
152                topology_name,
153                phase_name,
154                sender_id,
155                project,
156                status,
157                output,
158                error_message,
159                attempt,
160                created_at,
161                updated_at,
162            },
163        ))
164    }
165
166    /// Fetch all phase checkpoints for a run, ordered by creation time.
167    ///
168    /// Use this to inspect which phases have already completed when resuming
169    /// a failed run.
170    pub async fn get_run_checkpoints(
171        &self,
172        run_id: &str,
173    ) -> Result<Vec<PhaseCheckpoint>, KernexError> {
174        let rows: Vec<CheckpointRow> = sqlx::query_as(
175            "SELECT id, run_id, topology_name, phase_name, sender_id, project, \
176                    status, output, error_message, attempt, created_at, updated_at \
177             FROM phase_checkpoints WHERE run_id = ? ORDER BY created_at ASC",
178        )
179        .bind(run_id)
180        .fetch_all(&self.pool)
181        .await
182        .map_err(|e| KernexError::Store(format!("get run checkpoints: {e}")))?;
183
184        Ok(rows
185            .into_iter()
186            .map(
187                |(
188                    id,
189                    run_id,
190                    topology_name,
191                    phase_name,
192                    sender_id,
193                    project,
194                    status,
195                    output,
196                    error_message,
197                    attempt,
198                    created_at,
199                    updated_at,
200                )| PhaseCheckpoint {
201                    id,
202                    run_id,
203                    topology_name,
204                    phase_name,
205                    sender_id,
206                    project,
207                    status,
208                    output,
209                    error_message,
210                    attempt,
211                    created_at,
212                    updated_at,
213                },
214            )
215            .collect())
216    }
217
218    /// Delete all checkpoints for a run.
219    ///
220    /// Call this after a pipeline run completes successfully to reclaim space,
221    /// or before re-running a pipeline from scratch.
222    pub async fn clear_run_checkpoints(&self, run_id: &str) -> Result<(), KernexError> {
223        sqlx::query("DELETE FROM phase_checkpoints WHERE run_id = ?")
224            .bind(run_id)
225            .execute(&self.pool)
226            .await
227            .map_err(|e| KernexError::Store(format!("clear run checkpoints: {e}")))?;
228        Ok(())
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use kernex_core::config::MemoryConfig;
236
237    async fn test_store() -> Store {
238        let tmp = std::env::temp_dir().join(format!(
239            "__kernex_checkpoints_test_{}__{}.db",
240            std::process::id(),
241            uuid::Uuid::new_v4()
242        ));
243        let config = MemoryConfig {
244            db_path: tmp.to_str().unwrap().to_string(),
245            ..Default::default()
246        };
247        Store::new(&config).await.unwrap()
248    }
249
250    #[tokio::test]
251    async fn test_upsert_and_get_checkpoint() {
252        let store = test_store().await;
253        let run_id = uuid::Uuid::new_v4().to_string();
254
255        store
256            .upsert_phase_checkpoint(
257                &run_id,
258                "my-pipeline",
259                "phase-1",
260                "user-1",
261                "",
262                "completed",
263                Some("phase output"),
264                None,
265                0,
266            )
267            .await
268            .unwrap();
269
270        let cp = store
271            .get_phase_checkpoint(&run_id, "phase-1")
272            .await
273            .unwrap()
274            .unwrap();
275
276        assert_eq!(cp.run_id, run_id);
277        assert_eq!(cp.topology_name, "my-pipeline");
278        assert_eq!(cp.phase_name, "phase-1");
279        assert_eq!(cp.status, "completed");
280        assert_eq!(cp.output.as_deref(), Some("phase output"));
281        assert!(cp.error_message.is_none());
282    }
283
284    #[tokio::test]
285    async fn test_upsert_updates_existing() {
286        let store = test_store().await;
287        let run_id = uuid::Uuid::new_v4().to_string();
288
289        store
290            .upsert_phase_checkpoint(
291                &run_id,
292                "topo",
293                "phase-a",
294                "user-1",
295                "",
296                "in_progress",
297                None,
298                None,
299                0,
300            )
301            .await
302            .unwrap();
303
304        store
305            .upsert_phase_checkpoint(
306                &run_id,
307                "topo",
308                "phase-a",
309                "user-1",
310                "",
311                "completed",
312                Some("done"),
313                None,
314                0,
315            )
316            .await
317            .unwrap();
318
319        let cp = store
320            .get_phase_checkpoint(&run_id, "phase-a")
321            .await
322            .unwrap()
323            .unwrap();
324
325        assert_eq!(cp.status, "completed");
326        assert_eq!(cp.output.as_deref(), Some("done"));
327    }
328
329    #[tokio::test]
330    async fn test_get_run_checkpoints_ordered() {
331        let store = test_store().await;
332        let run_id = uuid::Uuid::new_v4().to_string();
333
334        for phase in &["phase-1", "phase-2", "phase-3"] {
335            store
336                .upsert_phase_checkpoint(
337                    &run_id,
338                    "topo",
339                    phase,
340                    "user-1",
341                    "",
342                    "completed",
343                    None,
344                    None,
345                    0,
346                )
347                .await
348                .unwrap();
349        }
350
351        let checkpoints = store.get_run_checkpoints(&run_id).await.unwrap();
352        assert_eq!(checkpoints.len(), 3);
353        assert_eq!(checkpoints[0].phase_name, "phase-1");
354        assert_eq!(checkpoints[1].phase_name, "phase-2");
355        assert_eq!(checkpoints[2].phase_name, "phase-3");
356    }
357
358    #[tokio::test]
359    async fn test_clear_run_checkpoints() {
360        let store = test_store().await;
361        let run_id = uuid::Uuid::new_v4().to_string();
362
363        store
364            .upsert_phase_checkpoint(
365                &run_id,
366                "topo",
367                "phase-1",
368                "user-1",
369                "",
370                "completed",
371                None,
372                None,
373                0,
374            )
375            .await
376            .unwrap();
377
378        store.clear_run_checkpoints(&run_id).await.unwrap();
379
380        let checkpoints = store.get_run_checkpoints(&run_id).await.unwrap();
381        assert!(checkpoints.is_empty());
382    }
383
384    #[tokio::test]
385    async fn test_failed_checkpoint_stores_error() {
386        let store = test_store().await;
387        let run_id = uuid::Uuid::new_v4().to_string();
388
389        store
390            .upsert_phase_checkpoint(
391                &run_id,
392                "topo",
393                "phase-1",
394                "user-1",
395                "proj-a",
396                "failed",
397                None,
398                Some("provider timeout"),
399                1,
400            )
401            .await
402            .unwrap();
403
404        let cp = store
405            .get_phase_checkpoint(&run_id, "phase-1")
406            .await
407            .unwrap()
408            .unwrap();
409
410        assert_eq!(cp.status, "failed");
411        assert_eq!(cp.error_message.as_deref(), Some("provider timeout"));
412        assert_eq!(cp.attempt, 1);
413        assert_eq!(cp.project, "proj-a");
414    }
415}