ralph_workflow/checkpoint/
state.rs1use chrono::Local;
7use serde::{Deserialize, Serialize};
8use std::fs;
9use std::io;
10use std::path::Path;
11
12const AGENT_DIR: &str = ".agent";
14
15const CHECKPOINT_FILE: &str = "checkpoint.json";
17
18fn checkpoint_path() -> String {
25 format!("{AGENT_DIR}/{CHECKPOINT_FILE}")
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33pub enum PipelinePhase {
34 Rebase,
36 Planning,
38 Development,
40 Review,
42 Fix,
44 ReviewAgain,
46 CommitMessage,
48 FinalValidation,
50 Complete,
52}
53
54impl std::fmt::Display for PipelinePhase {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 Self::Rebase => write!(f, "Rebase"),
58 Self::Planning => write!(f, "Planning"),
59 Self::Development => write!(f, "Development"),
60 Self::Review => write!(f, "Review"),
61 Self::Fix => write!(f, "Fix"),
62 Self::ReviewAgain => write!(f, "Verification Review"),
63 Self::CommitMessage => write!(f, "Commit Message Generation"),
64 Self::FinalValidation => write!(f, "Final Validation"),
65 Self::Complete => write!(f, "Complete"),
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct PipelineCheckpoint {
77 pub phase: PipelinePhase,
79 pub iteration: u32,
81 pub total_iterations: u32,
83 pub reviewer_pass: u32,
85 pub total_reviewer_passes: u32,
87 pub timestamp: String,
89 pub developer_agent: String,
91 pub reviewer_agent: String,
93}
94
95impl PipelineCheckpoint {
96 pub fn new(
108 phase: PipelinePhase,
109 iteration: u32,
110 total_iterations: u32,
111 reviewer_pass: u32,
112 total_reviewer_passes: u32,
113 developer_agent: &str,
114 reviewer_agent: &str,
115 ) -> Self {
116 Self {
117 phase,
118 iteration,
119 total_iterations,
120 reviewer_pass,
121 total_reviewer_passes,
122 timestamp: timestamp(),
123 developer_agent: developer_agent.to_string(),
124 reviewer_agent: reviewer_agent.to_string(),
125 }
126 }
127
128 pub fn description(&self) -> String {
133 match self.phase {
134 PipelinePhase::Rebase => "Rebase in progress".to_string(),
135 PipelinePhase::Planning => {
136 format!(
137 "Planning phase, iteration {}/{}",
138 self.iteration, self.total_iterations
139 )
140 }
141 PipelinePhase::Development => {
142 format!(
143 "Development iteration {}/{}",
144 self.iteration, self.total_iterations
145 )
146 }
147 PipelinePhase::Review => "Initial review".to_string(),
148 PipelinePhase::Fix => "Applying fixes".to_string(),
149 PipelinePhase::ReviewAgain => {
150 format!(
151 "Verification review {}/{}",
152 self.reviewer_pass, self.total_reviewer_passes
153 )
154 }
155 PipelinePhase::CommitMessage => "Commit message generation".to_string(),
156 PipelinePhase::FinalValidation => "Final validation".to_string(),
157 PipelinePhase::Complete => "Pipeline complete".to_string(),
158 }
159 }
160}
161
162pub fn timestamp() -> String {
164 Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
165}
166
167pub fn save_checkpoint(checkpoint: &PipelineCheckpoint) -> io::Result<()> {
177 let json = serde_json::to_string_pretty(checkpoint).map_err(|e| {
178 io::Error::new(
179 io::ErrorKind::InvalidData,
180 format!("Failed to serialize checkpoint: {e}"),
181 )
182 })?;
183
184 fs::create_dir_all(AGENT_DIR)?;
186
187 let checkpoint_path_str = checkpoint_path();
189 let temp_path = format!("{checkpoint_path_str}.tmp");
190
191 let write_result = fs::write(&temp_path, &json);
193 if write_result.is_err() {
194 let _ = fs::remove_file(&temp_path);
195 return write_result;
196 }
197
198 let rename_result = fs::rename(&temp_path, &checkpoint_path_str);
199 if rename_result.is_err() {
200 let _ = fs::remove_file(&temp_path);
201 return rename_result;
202 }
203
204 Ok(())
205}
206
207pub fn load_checkpoint() -> io::Result<Option<PipelineCheckpoint>> {
218 let checkpoint = checkpoint_path();
219 let path = Path::new(&checkpoint);
220 if !path.exists() {
221 return Ok(None);
222 }
223
224 let content = fs::read_to_string(path)?;
225 let loaded_checkpoint: PipelineCheckpoint = serde_json::from_str(&content).map_err(|e| {
226 io::Error::new(
227 io::ErrorKind::InvalidData,
228 format!("Failed to parse checkpoint: {e}"),
229 )
230 })?;
231
232 Ok(Some(loaded_checkpoint))
233}
234
235pub fn clear_checkpoint() -> io::Result<()> {
244 let checkpoint = checkpoint_path();
245 let path = Path::new(&checkpoint);
246 if path.exists() {
247 fs::remove_file(path)?;
248 }
249 Ok(())
250}
251
252pub fn checkpoint_exists() -> bool {
256 Path::new(&checkpoint_path()).exists()
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use test_helpers::with_temp_cwd;
263
264 #[test]
265 fn test_timestamp_format() {
266 let ts = timestamp();
267 assert!(ts.contains('-'));
268 assert!(ts.contains(':'));
269 assert_eq!(ts.len(), 19);
270 }
271
272 #[test]
273 fn test_pipeline_phase_display() {
274 assert_eq!(format!("{}", PipelinePhase::Rebase), "Rebase");
275 assert_eq!(format!("{}", PipelinePhase::Planning), "Planning");
276 assert_eq!(format!("{}", PipelinePhase::Development), "Development");
277 assert_eq!(format!("{}", PipelinePhase::Review), "Review");
278 assert_eq!(format!("{}", PipelinePhase::Fix), "Fix");
279 assert_eq!(
280 format!("{}", PipelinePhase::ReviewAgain),
281 "Verification Review"
282 );
283 assert_eq!(
284 format!("{}", PipelinePhase::CommitMessage),
285 "Commit Message Generation"
286 );
287 assert_eq!(
288 format!("{}", PipelinePhase::FinalValidation),
289 "Final Validation"
290 );
291 assert_eq!(format!("{}", PipelinePhase::Complete), "Complete");
292 }
293
294 #[test]
295 fn test_checkpoint_new() {
296 let checkpoint =
297 PipelineCheckpoint::new(PipelinePhase::Development, 2, 5, 0, 2, "claude", "codex");
298
299 assert_eq!(checkpoint.phase, PipelinePhase::Development);
300 assert_eq!(checkpoint.iteration, 2);
301 assert_eq!(checkpoint.total_iterations, 5);
302 assert_eq!(checkpoint.reviewer_pass, 0);
303 assert_eq!(checkpoint.total_reviewer_passes, 2);
304 assert_eq!(checkpoint.developer_agent, "claude");
305 assert_eq!(checkpoint.reviewer_agent, "codex");
306 assert!(!checkpoint.timestamp.is_empty());
307 }
308
309 #[test]
310 fn test_checkpoint_description() {
311 let checkpoint =
312 PipelineCheckpoint::new(PipelinePhase::Development, 3, 5, 0, 2, "claude", "codex");
313 assert_eq!(checkpoint.description(), "Development iteration 3/5");
314
315 let checkpoint =
316 PipelineCheckpoint::new(PipelinePhase::ReviewAgain, 5, 5, 2, 3, "claude", "codex");
317 assert_eq!(checkpoint.description(), "Verification review 2/3");
318 }
319
320 #[test]
321 fn test_checkpoint_save_load() {
322 with_temp_cwd(|_dir| {
323 fs::create_dir_all(".agent").unwrap();
324
325 let checkpoint =
326 PipelineCheckpoint::new(PipelinePhase::Review, 5, 5, 1, 2, "claude", "codex");
327
328 save_checkpoint(&checkpoint).unwrap();
329 assert!(checkpoint_exists());
330
331 let loaded = load_checkpoint()
332 .unwrap()
333 .expect("checkpoint should exist after save_checkpoint");
334 assert_eq!(loaded.phase, PipelinePhase::Review);
335 assert_eq!(loaded.iteration, 5);
336 assert_eq!(loaded.developer_agent, "claude");
337 assert_eq!(loaded.reviewer_agent, "codex");
338 });
339 }
340
341 #[test]
342 fn test_checkpoint_clear() {
343 with_temp_cwd(|_dir| {
344 fs::create_dir_all(".agent").unwrap();
345
346 let checkpoint =
347 PipelineCheckpoint::new(PipelinePhase::Development, 1, 5, 0, 2, "claude", "codex");
348
349 save_checkpoint(&checkpoint).unwrap();
350 assert!(checkpoint_exists());
351
352 clear_checkpoint().unwrap();
353 assert!(!checkpoint_exists());
354 });
355 }
356
357 #[test]
358 fn test_load_checkpoint_nonexistent() {
359 with_temp_cwd(|_dir| {
360 fs::create_dir_all(".agent").unwrap();
361
362 let result = load_checkpoint().unwrap();
363 assert!(result.is_none());
364 });
365 }
366
367 #[test]
368 fn test_checkpoint_serialization() {
369 let checkpoint =
370 PipelineCheckpoint::new(PipelinePhase::Fix, 3, 5, 1, 2, "aider", "opencode");
371
372 let json = serde_json::to_string(&checkpoint).unwrap();
373 assert!(json.contains("Fix"));
374 assert!(json.contains("aider"));
375 assert!(json.contains("opencode"));
376
377 let deserialized: PipelineCheckpoint = serde_json::from_str(&json).unwrap();
378 assert_eq!(deserialized.phase, checkpoint.phase);
379 assert_eq!(deserialized.iteration, checkpoint.iteration);
380 }
381}