1use super::Store;
22use kernex_core::error::KernexError;
23
24type CheckpointRow = (
25 String,
26 String,
27 String,
28 String,
29 String,
30 String,
31 String,
32 Option<String>,
33 Option<String>,
34 i64,
35 String,
36 String,
37);
38
39#[derive(Debug, Clone)]
41pub struct PhaseCheckpoint {
42 pub id: String,
44 pub run_id: String,
46 pub topology_name: String,
48 pub phase_name: String,
50 pub sender_id: String,
52 pub project: String,
54 pub status: String,
56 pub output: Option<String>,
58 pub error_message: Option<String>,
60 pub attempt: i64,
62 pub created_at: String,
64 pub updated_at: String,
66}
67
68impl Store {
69 #[allow(clippy::too_many_arguments)]
74 pub async fn upsert_phase_checkpoint(
75 &self,
76 run_id: &str,
77 topology_name: &str,
78 phase_name: &str,
79 sender_id: &str,
80 project: &str,
81 status: &str,
82 output: Option<&str>,
83 error_message: Option<&str>,
84 attempt: i64,
85 ) -> Result<(), KernexError> {
86 let id = uuid::Uuid::new_v4().to_string();
87 sqlx::query(
88 "INSERT INTO phase_checkpoints \
89 (id, run_id, topology_name, phase_name, sender_id, project, \
90 status, output, error_message, attempt) \
91 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) \
92 ON CONFLICT (run_id, phase_name) DO UPDATE SET \
93 topology_name = excluded.topology_name, \
94 sender_id = excluded.sender_id, \
95 project = excluded.project, \
96 status = excluded.status, \
97 output = excluded.output, \
98 error_message = excluded.error_message, \
99 attempt = excluded.attempt, \
100 updated_at = datetime('now')",
101 )
102 .bind(&id)
103 .bind(run_id)
104 .bind(topology_name)
105 .bind(phase_name)
106 .bind(sender_id)
107 .bind(project)
108 .bind(status)
109 .bind(output)
110 .bind(error_message)
111 .bind(attempt)
112 .execute(&self.pool)
113 .await
114 .map_err(|e| KernexError::Store(format!("upsert phase checkpoint: {e}")))?;
115 Ok(())
116 }
117
118 pub async fn get_phase_checkpoint(
120 &self,
121 run_id: &str,
122 phase_name: &str,
123 ) -> Result<Option<PhaseCheckpoint>, KernexError> {
124 let row: Option<CheckpointRow> = sqlx::query_as(
125 "SELECT id, run_id, topology_name, phase_name, sender_id, project, \
126 status, output, error_message, attempt, created_at, updated_at \
127 FROM phase_checkpoints WHERE run_id = ? AND phase_name = ?",
128 )
129 .bind(run_id)
130 .bind(phase_name)
131 .fetch_optional(&self.pool)
132 .await
133 .map_err(|e| KernexError::Store(format!("get phase checkpoint: {e}")))?;
134
135 Ok(row.map(
136 |(
137 id,
138 run_id,
139 topology_name,
140 phase_name,
141 sender_id,
142 project,
143 status,
144 output,
145 error_message,
146 attempt,
147 created_at,
148 updated_at,
149 )| PhaseCheckpoint {
150 id,
151 run_id,
152 topology_name,
153 phase_name,
154 sender_id,
155 project,
156 status,
157 output,
158 error_message,
159 attempt,
160 created_at,
161 updated_at,
162 },
163 ))
164 }
165
166 pub async fn get_run_checkpoints(
171 &self,
172 run_id: &str,
173 ) -> Result<Vec<PhaseCheckpoint>, KernexError> {
174 let rows: Vec<CheckpointRow> = sqlx::query_as(
175 "SELECT id, run_id, topology_name, phase_name, sender_id, project, \
176 status, output, error_message, attempt, created_at, updated_at \
177 FROM phase_checkpoints WHERE run_id = ? ORDER BY created_at ASC",
178 )
179 .bind(run_id)
180 .fetch_all(&self.pool)
181 .await
182 .map_err(|e| KernexError::Store(format!("get run checkpoints: {e}")))?;
183
184 Ok(rows
185 .into_iter()
186 .map(
187 |(
188 id,
189 run_id,
190 topology_name,
191 phase_name,
192 sender_id,
193 project,
194 status,
195 output,
196 error_message,
197 attempt,
198 created_at,
199 updated_at,
200 )| PhaseCheckpoint {
201 id,
202 run_id,
203 topology_name,
204 phase_name,
205 sender_id,
206 project,
207 status,
208 output,
209 error_message,
210 attempt,
211 created_at,
212 updated_at,
213 },
214 )
215 .collect())
216 }
217
218 pub async fn clear_run_checkpoints(&self, run_id: &str) -> Result<(), KernexError> {
223 sqlx::query("DELETE FROM phase_checkpoints WHERE run_id = ?")
224 .bind(run_id)
225 .execute(&self.pool)
226 .await
227 .map_err(|e| KernexError::Store(format!("clear run checkpoints: {e}")))?;
228 Ok(())
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use kernex_core::config::MemoryConfig;
236
237 async fn test_store() -> Store {
238 let tmp = std::env::temp_dir().join(format!(
239 "__kernex_checkpoints_test_{}__{}.db",
240 std::process::id(),
241 uuid::Uuid::new_v4()
242 ));
243 let config = MemoryConfig {
244 db_path: tmp.to_str().unwrap().to_string(),
245 ..Default::default()
246 };
247 Store::new(&config).await.unwrap()
248 }
249
250 #[tokio::test]
251 async fn test_upsert_and_get_checkpoint() {
252 let store = test_store().await;
253 let run_id = uuid::Uuid::new_v4().to_string();
254
255 store
256 .upsert_phase_checkpoint(
257 &run_id,
258 "my-pipeline",
259 "phase-1",
260 "user-1",
261 "",
262 "completed",
263 Some("phase output"),
264 None,
265 0,
266 )
267 .await
268 .unwrap();
269
270 let cp = store
271 .get_phase_checkpoint(&run_id, "phase-1")
272 .await
273 .unwrap()
274 .unwrap();
275
276 assert_eq!(cp.run_id, run_id);
277 assert_eq!(cp.topology_name, "my-pipeline");
278 assert_eq!(cp.phase_name, "phase-1");
279 assert_eq!(cp.status, "completed");
280 assert_eq!(cp.output.as_deref(), Some("phase output"));
281 assert!(cp.error_message.is_none());
282 }
283
284 #[tokio::test]
285 async fn test_upsert_updates_existing() {
286 let store = test_store().await;
287 let run_id = uuid::Uuid::new_v4().to_string();
288
289 store
290 .upsert_phase_checkpoint(
291 &run_id,
292 "topo",
293 "phase-a",
294 "user-1",
295 "",
296 "in_progress",
297 None,
298 None,
299 0,
300 )
301 .await
302 .unwrap();
303
304 store
305 .upsert_phase_checkpoint(
306 &run_id,
307 "topo",
308 "phase-a",
309 "user-1",
310 "",
311 "completed",
312 Some("done"),
313 None,
314 0,
315 )
316 .await
317 .unwrap();
318
319 let cp = store
320 .get_phase_checkpoint(&run_id, "phase-a")
321 .await
322 .unwrap()
323 .unwrap();
324
325 assert_eq!(cp.status, "completed");
326 assert_eq!(cp.output.as_deref(), Some("done"));
327 }
328
329 #[tokio::test]
330 async fn test_get_run_checkpoints_ordered() {
331 let store = test_store().await;
332 let run_id = uuid::Uuid::new_v4().to_string();
333
334 for phase in &["phase-1", "phase-2", "phase-3"] {
335 store
336 .upsert_phase_checkpoint(
337 &run_id,
338 "topo",
339 phase,
340 "user-1",
341 "",
342 "completed",
343 None,
344 None,
345 0,
346 )
347 .await
348 .unwrap();
349 }
350
351 let checkpoints = store.get_run_checkpoints(&run_id).await.unwrap();
352 assert_eq!(checkpoints.len(), 3);
353 assert_eq!(checkpoints[0].phase_name, "phase-1");
354 assert_eq!(checkpoints[1].phase_name, "phase-2");
355 assert_eq!(checkpoints[2].phase_name, "phase-3");
356 }
357
358 #[tokio::test]
359 async fn test_clear_run_checkpoints() {
360 let store = test_store().await;
361 let run_id = uuid::Uuid::new_v4().to_string();
362
363 store
364 .upsert_phase_checkpoint(
365 &run_id,
366 "topo",
367 "phase-1",
368 "user-1",
369 "",
370 "completed",
371 None,
372 None,
373 0,
374 )
375 .await
376 .unwrap();
377
378 store.clear_run_checkpoints(&run_id).await.unwrap();
379
380 let checkpoints = store.get_run_checkpoints(&run_id).await.unwrap();
381 assert!(checkpoints.is_empty());
382 }
383
384 #[tokio::test]
385 async fn test_failed_checkpoint_stores_error() {
386 let store = test_store().await;
387 let run_id = uuid::Uuid::new_v4().to_string();
388
389 store
390 .upsert_phase_checkpoint(
391 &run_id,
392 "topo",
393 "phase-1",
394 "user-1",
395 "proj-a",
396 "failed",
397 None,
398 Some("provider timeout"),
399 1,
400 )
401 .await
402 .unwrap();
403
404 let cp = store
405 .get_phase_checkpoint(&run_id, "phase-1")
406 .await
407 .unwrap()
408 .unwrap();
409
410 assert_eq!(cp.status, "failed");
411 assert_eq!(cp.error_message.as_deref(), Some("provider timeout"));
412 assert_eq!(cp.attempt, 1);
413 assert_eq!(cp.project, "proj-a");
414 }
415}