1use super::Store;
22use crate::error::MemoryError;
23
24type CheckpointRow = (
25 String,
26 String,
27 String,
28 String,
29 String,
30 String,
31 String,
32 Option<String>,
33 Option<String>,
34 i64,
35 String,
36 String,
37);
38
39#[derive(Debug, Clone)]
41pub struct PhaseCheckpoint {
42 pub id: String,
44 pub run_id: String,
46 pub topology_name: String,
48 pub phase_name: String,
50 pub sender_id: String,
52 pub project: String,
54 pub status: String,
56 pub output: Option<String>,
58 pub error_message: Option<String>,
60 pub attempt: i64,
62 pub created_at: String,
64 pub updated_at: String,
66}
67
68impl Store {
69 #[allow(clippy::too_many_arguments)]
74 pub async fn upsert_phase_checkpoint(
75 &self,
76 run_id: &str,
77 topology_name: &str,
78 phase_name: &str,
79 sender_id: &str,
80 project: &str,
81 status: &str,
82 output: Option<&str>,
83 error_message: Option<&str>,
84 attempt: i64,
85 ) -> Result<(), MemoryError> {
86 let id = uuid::Uuid::new_v4().to_string();
87 sqlx::query(
88 "INSERT INTO phase_checkpoints \
89 (id, run_id, topology_name, phase_name, sender_id, project, \
90 status, output, error_message, attempt) \
91 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) \
92 ON CONFLICT (run_id, phase_name) DO UPDATE SET \
93 topology_name = excluded.topology_name, \
94 sender_id = excluded.sender_id, \
95 project = excluded.project, \
96 status = excluded.status, \
97 output = excluded.output, \
98 error_message = excluded.error_message, \
99 attempt = excluded.attempt, \
100 updated_at = datetime('now')",
101 )
102 .bind(&id)
103 .bind(run_id)
104 .bind(topology_name)
105 .bind(phase_name)
106 .bind(sender_id)
107 .bind(project)
108 .bind(status)
109 .bind(output)
110 .bind(error_message)
111 .bind(attempt)
112 .execute(&self.pool)
113 .await
114 .map_err(|e| MemoryError::sqlite("upsert phase checkpoint", e))?;
115 Ok(())
116 }
117
118 pub async fn get_phase_checkpoint(
120 &self,
121 run_id: &str,
122 phase_name: &str,
123 ) -> Result<Option<PhaseCheckpoint>, MemoryError> {
124 let row: Option<CheckpointRow> = sqlx::query_as(
125 "SELECT id, run_id, topology_name, phase_name, sender_id, project, \
126 status, output, error_message, attempt, created_at, updated_at \
127 FROM phase_checkpoints WHERE run_id = ? AND phase_name = ?",
128 )
129 .bind(run_id)
130 .bind(phase_name)
131 .fetch_optional(&self.pool)
132 .await
133 .map_err(|e| MemoryError::sqlite("get phase checkpoint", e))?;
134
135 Ok(row.map(
136 |(
137 id,
138 run_id,
139 topology_name,
140 phase_name,
141 sender_id,
142 project,
143 status,
144 output,
145 error_message,
146 attempt,
147 created_at,
148 updated_at,
149 )| PhaseCheckpoint {
150 id,
151 run_id,
152 topology_name,
153 phase_name,
154 sender_id,
155 project,
156 status,
157 output,
158 error_message,
159 attempt,
160 created_at,
161 updated_at,
162 },
163 ))
164 }
165
166 pub async fn get_run_checkpoints(
171 &self,
172 run_id: &str,
173 ) -> Result<Vec<PhaseCheckpoint>, MemoryError> {
174 let rows: Vec<CheckpointRow> = sqlx::query_as(
175 "SELECT id, run_id, topology_name, phase_name, sender_id, project, \
176 status, output, error_message, attempt, created_at, updated_at \
177 FROM phase_checkpoints WHERE run_id = ? ORDER BY created_at ASC",
178 )
179 .bind(run_id)
180 .fetch_all(&self.pool)
181 .await
182 .map_err(|e| MemoryError::sqlite("get run checkpoints", e))?;
183
184 Ok(rows
185 .into_iter()
186 .map(
187 |(
188 id,
189 run_id,
190 topology_name,
191 phase_name,
192 sender_id,
193 project,
194 status,
195 output,
196 error_message,
197 attempt,
198 created_at,
199 updated_at,
200 )| PhaseCheckpoint {
201 id,
202 run_id,
203 topology_name,
204 phase_name,
205 sender_id,
206 project,
207 status,
208 output,
209 error_message,
210 attempt,
211 created_at,
212 updated_at,
213 },
214 )
215 .collect())
216 }
217
218 pub async fn clear_run_checkpoints(&self, run_id: &str) -> Result<(), MemoryError> {
223 sqlx::query("DELETE FROM phase_checkpoints WHERE run_id = ?")
224 .bind(run_id)
225 .execute(&self.pool)
226 .await
227 .map_err(|e| MemoryError::sqlite("clear run checkpoints", e))?;
228 Ok(())
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use kernex_core::config::MemoryConfig;
236
237 async fn test_store() -> (Store, tempfile::TempDir) {
238 let tmp_dir = tempfile::TempDir::new().unwrap();
239 let db_path = tmp_dir.path().join("checkpoints.db");
240 let config = MemoryConfig {
241 db_path: db_path.to_str().unwrap().to_string(),
242 ..Default::default()
243 };
244 let store = Store::new(&config).await.unwrap();
245 (store, tmp_dir)
246 }
247
248 #[tokio::test]
249 async fn test_upsert_and_get_checkpoint() {
250 let (store, _tmp_dir) = test_store().await;
251 let run_id = uuid::Uuid::new_v4().to_string();
252
253 store
254 .upsert_phase_checkpoint(
255 &run_id,
256 "my-pipeline",
257 "phase-1",
258 "user-1",
259 "",
260 "completed",
261 Some("phase output"),
262 None,
263 0,
264 )
265 .await
266 .unwrap();
267
268 let cp = store
269 .get_phase_checkpoint(&run_id, "phase-1")
270 .await
271 .unwrap()
272 .unwrap();
273
274 assert_eq!(cp.run_id, run_id);
275 assert_eq!(cp.topology_name, "my-pipeline");
276 assert_eq!(cp.phase_name, "phase-1");
277 assert_eq!(cp.status, "completed");
278 assert_eq!(cp.output.as_deref(), Some("phase output"));
279 assert!(cp.error_message.is_none());
280 }
281
282 #[tokio::test]
283 async fn test_upsert_updates_existing() {
284 let (store, _tmp_dir) = test_store().await;
285 let run_id = uuid::Uuid::new_v4().to_string();
286
287 store
288 .upsert_phase_checkpoint(
289 &run_id,
290 "topo",
291 "phase-a",
292 "user-1",
293 "",
294 "in_progress",
295 None,
296 None,
297 0,
298 )
299 .await
300 .unwrap();
301
302 store
303 .upsert_phase_checkpoint(
304 &run_id,
305 "topo",
306 "phase-a",
307 "user-1",
308 "",
309 "completed",
310 Some("done"),
311 None,
312 0,
313 )
314 .await
315 .unwrap();
316
317 let cp = store
318 .get_phase_checkpoint(&run_id, "phase-a")
319 .await
320 .unwrap()
321 .unwrap();
322
323 assert_eq!(cp.status, "completed");
324 assert_eq!(cp.output.as_deref(), Some("done"));
325 }
326
327 #[tokio::test]
328 async fn test_get_run_checkpoints_ordered() {
329 let (store, _tmp_dir) = test_store().await;
330 let run_id = uuid::Uuid::new_v4().to_string();
331
332 for phase in &["phase-1", "phase-2", "phase-3"] {
333 store
334 .upsert_phase_checkpoint(
335 &run_id,
336 "topo",
337 phase,
338 "user-1",
339 "",
340 "completed",
341 None,
342 None,
343 0,
344 )
345 .await
346 .unwrap();
347 }
348
349 let checkpoints = store.get_run_checkpoints(&run_id).await.unwrap();
350 assert_eq!(checkpoints.len(), 3);
351 assert_eq!(checkpoints[0].phase_name, "phase-1");
352 assert_eq!(checkpoints[1].phase_name, "phase-2");
353 assert_eq!(checkpoints[2].phase_name, "phase-3");
354 }
355
356 #[tokio::test]
357 async fn test_clear_run_checkpoints() {
358 let (store, _tmp_dir) = test_store().await;
359 let run_id = uuid::Uuid::new_v4().to_string();
360
361 store
362 .upsert_phase_checkpoint(
363 &run_id,
364 "topo",
365 "phase-1",
366 "user-1",
367 "",
368 "completed",
369 None,
370 None,
371 0,
372 )
373 .await
374 .unwrap();
375
376 store.clear_run_checkpoints(&run_id).await.unwrap();
377
378 let checkpoints = store.get_run_checkpoints(&run_id).await.unwrap();
379 assert!(checkpoints.is_empty());
380 }
381
382 #[tokio::test]
383 async fn test_failed_checkpoint_stores_error() {
384 let (store, _tmp_dir) = test_store().await;
385 let run_id = uuid::Uuid::new_v4().to_string();
386
387 store
388 .upsert_phase_checkpoint(
389 &run_id,
390 "topo",
391 "phase-1",
392 "user-1",
393 "proj-a",
394 "failed",
395 None,
396 Some("provider timeout"),
397 1,
398 )
399 .await
400 .unwrap();
401
402 let cp = store
403 .get_phase_checkpoint(&run_id, "phase-1")
404 .await
405 .unwrap()
406 .unwrap();
407
408 assert_eq!(cp.status, "failed");
409 assert_eq!(cp.error_message.as_deref(), Some("provider timeout"));
410 assert_eq!(cp.attempt, 1);
411 assert_eq!(cp.project, "proj-a");
412 }
413}