ralph_workflow/git_helpers/
rebase_checkpoint.rs1#![deny(unsafe_code)]
7
8use std::fs;
9use std::io;
10use std::path::Path;
11
12const AGENT_DIR: &str = ".agent";
14
15const REBASE_CHECKPOINT_FILE: &str = "rebase_checkpoint.json";
17
18pub fn rebase_checkpoint_path() -> String {
23 format!("{AGENT_DIR}/{REBASE_CHECKPOINT_FILE}")
24}
25
26pub fn rebase_checkpoint_backup_path() -> String {
31 format!("{AGENT_DIR}/{REBASE_CHECKPOINT_FILE}.bak")
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
36pub enum RebasePhase {
37 NotStarted,
39 PreRebaseCheck,
41 RebaseInProgress,
43 ConflictDetected,
45 ConflictResolutionInProgress,
47 CompletingRebase,
49 RebaseComplete,
51 RebaseAborted,
53}
54
55impl RebasePhase {
56 #[cfg(any(test, feature = "test-utils"))]
70 pub fn max_recovery_attempts(&self) -> u32 {
71 match self {
72 RebasePhase::ConflictResolutionInProgress => 5,
73 RebasePhase::ConflictDetected => 3,
74 RebasePhase::RebaseInProgress => 2,
75 RebasePhase::CompletingRebase => 2,
76 RebasePhase::PreRebaseCheck => 1,
77 _ => 3,
78 }
79 }
80}
81
82#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
87pub struct RebaseCheckpoint {
88 pub phase: RebasePhase,
90 pub upstream_branch: String,
92 pub conflicted_files: Vec<String>,
94 pub resolved_files: Vec<String>,
96 pub error_count: u32,
98 pub last_error: Option<String>,
100 pub timestamp: String,
102 #[serde(default)]
104 pub phase_error_count: u32,
105}
106
107impl Default for RebaseCheckpoint {
108 fn default() -> Self {
109 Self {
110 phase: RebasePhase::NotStarted,
111 upstream_branch: String::new(),
112 conflicted_files: Vec::new(),
113 resolved_files: Vec::new(),
114 error_count: 0,
115 last_error: None,
116 timestamp: chrono::Utc::now().to_rfc3339(),
117 phase_error_count: 0,
118 }
119 }
120}
121
122impl RebaseCheckpoint {
123 pub fn new(upstream_branch: String) -> Self {
125 Self {
126 phase: RebasePhase::NotStarted,
127 upstream_branch,
128 conflicted_files: Vec::new(),
129 resolved_files: Vec::new(),
130 error_count: 0,
131 last_error: None,
132 timestamp: chrono::Utc::now().to_rfc3339(),
133 phase_error_count: 0,
134 }
135 }
136
137 pub fn with_phase(mut self, phase: RebasePhase) -> Self {
141 if self.phase != phase {
143 self.phase_error_count = 0;
144 }
145 self.phase = phase;
146 self.timestamp = chrono::Utc::now().to_rfc3339();
147 self
148 }
149
150 pub fn with_conflicted_file(mut self, file: String) -> Self {
152 if !self.conflicted_files.contains(&file) {
153 self.conflicted_files.push(file);
154 }
155 self
156 }
157
158 pub fn with_resolved_file(mut self, file: String) -> Self {
160 if !self.resolved_files.contains(&file) {
161 self.resolved_files.push(file);
162 }
163 self
164 }
165
166 pub fn with_error(mut self, error: String) -> Self {
170 self.error_count += 1;
171 self.phase_error_count += 1;
172 self.last_error = Some(error);
173 self.timestamp = chrono::Utc::now().to_rfc3339();
174 self
175 }
176
177 pub fn all_conflicts_resolved(&self) -> bool {
179 self.conflicted_files
180 .iter()
181 .all(|f| self.resolved_files.contains(f))
182 }
183
184 pub fn unresolved_conflict_count(&self) -> usize {
186 self.conflicted_files
187 .iter()
188 .filter(|f| !self.resolved_files.contains(f))
189 .count()
190 }
191}
192
193pub fn save_rebase_checkpoint(checkpoint: &RebaseCheckpoint) -> io::Result<()> {
205 let json = serde_json::to_string_pretty(checkpoint).map_err(|e| {
206 io::Error::new(
207 io::ErrorKind::InvalidData,
208 format!("Failed to serialize rebase checkpoint: {e}"),
209 )
210 })?;
211
212 fs::create_dir_all(AGENT_DIR)?;
214
215 let checkpoint_existed = Path::new(&rebase_checkpoint_path()).exists();
217
218 let _ = backup_checkpoint();
220
221 let checkpoint_path_str = rebase_checkpoint_path();
223 let temp_path = format!("{checkpoint_path_str}.tmp");
224
225 let write_result = fs::write(&temp_path, &json);
227 if write_result.is_err() {
228 let _ = fs::remove_file(&temp_path);
229 return write_result;
230 }
231
232 let rename_result = fs::rename(&temp_path, &checkpoint_path_str);
233 if rename_result.is_err() {
234 let _ = fs::remove_file(&temp_path);
235 return rename_result;
236 }
237
238 if !checkpoint_existed {
241 let _ = backup_checkpoint();
242 }
243
244 Ok(())
245}
246
247pub fn load_rebase_checkpoint() -> io::Result<Option<RebaseCheckpoint>> {
260 let checkpoint = rebase_checkpoint_path();
261 let path = Path::new(&checkpoint);
262 if !path.exists() {
263 return Ok(None);
264 }
265
266 let content = fs::read_to_string(path)?;
267 let loaded_checkpoint: RebaseCheckpoint = match serde_json::from_str(&content) {
268 Ok(cp) => cp,
269 Err(e) => {
270 eprintln!("Checkpoint corrupted, attempting restore from backup: {e}");
272 return restore_from_backup();
273 }
274 };
275
276 if let Err(e) = validate_checkpoint(&loaded_checkpoint) {
278 eprintln!("Checkpoint validation failed, attempting restore from backup: {e}");
279 return restore_from_backup();
280 }
281
282 Ok(Some(loaded_checkpoint))
283}
284
285pub fn clear_rebase_checkpoint() -> io::Result<()> {
294 let checkpoint = rebase_checkpoint_path();
295 let path = Path::new(&checkpoint);
296 if path.exists() {
297 fs::remove_file(path)?;
298 }
299 Ok(())
300}
301
302pub fn rebase_checkpoint_exists() -> bool {
306 Path::new(&rebase_checkpoint_path()).exists()
307}
308
309#[cfg(any(test, feature = "test-utils"))]
314pub fn validate_checkpoint(checkpoint: &RebaseCheckpoint) -> io::Result<()> {
315 validate_checkpoint_impl(checkpoint)
316}
317
318#[cfg(not(any(test, feature = "test-utils")))]
323fn validate_checkpoint(checkpoint: &RebaseCheckpoint) -> io::Result<()> {
324 validate_checkpoint_impl(checkpoint)
325}
326
327fn validate_checkpoint_impl(checkpoint: &RebaseCheckpoint) -> io::Result<()> {
329 if checkpoint.phase != RebasePhase::NotStarted && checkpoint.upstream_branch.is_empty() {
331 return Err(io::Error::new(
332 io::ErrorKind::InvalidData,
333 "Checkpoint has empty upstream branch",
334 ));
335 }
336
337 if chrono::DateTime::parse_from_rfc3339(&checkpoint.timestamp).is_err() {
339 return Err(io::Error::new(
340 io::ErrorKind::InvalidData,
341 "Checkpoint has invalid timestamp format",
342 ));
343 }
344
345 for resolved in &checkpoint.resolved_files {
347 if !checkpoint.conflicted_files.contains(resolved) {
348 return Err(io::Error::new(
349 io::ErrorKind::InvalidData,
350 format!(
351 "Resolved file '{}' not found in conflicted files list",
352 resolved
353 ),
354 ));
355 }
356 }
357
358 Ok(())
359}
360
361fn backup_checkpoint() -> io::Result<()> {
369 let checkpoint_path = rebase_checkpoint_path();
370 let backup_path = rebase_checkpoint_backup_path();
371 let checkpoint = Path::new(&checkpoint_path);
372 let backup = Path::new(&backup_path);
373
374 if !checkpoint.exists() {
375 return Ok(());
377 }
378
379 if backup.exists() {
381 fs::remove_file(backup)?;
382 }
383
384 fs::copy(checkpoint, backup)?;
386 Ok(())
387}
388
389fn restore_from_backup() -> io::Result<Option<RebaseCheckpoint>> {
395 let backup_path = rebase_checkpoint_backup_path();
396 let backup = Path::new(&backup_path);
397
398 if !backup.exists() {
399 return Ok(None);
400 }
401
402 let content = fs::read_to_string(backup)?;
403 let checkpoint: RebaseCheckpoint = serde_json::from_str(&content).map_err(|e| {
404 io::Error::new(
405 io::ErrorKind::InvalidData,
406 format!("Failed to parse backup checkpoint: {e}"),
407 )
408 })?;
409
410 validate_checkpoint(&checkpoint)?;
412
413 let checkpoint_path = rebase_checkpoint_path();
415 fs::copy(backup, checkpoint_path)?;
416
417 Ok(Some(checkpoint))
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_rebase_checkpoint_default() {
426 let checkpoint = RebaseCheckpoint::default();
427 assert_eq!(checkpoint.phase, RebasePhase::NotStarted);
428 assert!(checkpoint.upstream_branch.is_empty());
429 assert!(checkpoint.conflicted_files.is_empty());
430 assert!(checkpoint.resolved_files.is_empty());
431 assert_eq!(checkpoint.error_count, 0);
432 assert!(checkpoint.last_error.is_none());
433 }
434
435 #[test]
436 fn test_rebase_checkpoint_new() {
437 let checkpoint = RebaseCheckpoint::new("main".to_string());
438 assert_eq!(checkpoint.phase, RebasePhase::NotStarted);
439 assert_eq!(checkpoint.upstream_branch, "main");
440 }
441
442 #[test]
443 fn test_rebase_checkpoint_with_phase() {
444 let checkpoint =
445 RebaseCheckpoint::new("main".to_string()).with_phase(RebasePhase::RebaseInProgress);
446 assert_eq!(checkpoint.phase, RebasePhase::RebaseInProgress);
447 }
448
449 #[test]
450 fn test_rebase_checkpoint_with_conflicted_file() {
451 let checkpoint = RebaseCheckpoint::new("main".to_string())
452 .with_conflicted_file("file1.txt".to_string())
453 .with_conflicted_file("file2.txt".to_string());
454 assert_eq!(checkpoint.conflicted_files.len(), 2);
455 let checkpoint = checkpoint.with_conflicted_file("file1.txt".to_string());
457 assert_eq!(checkpoint.conflicted_files.len(), 2);
458 }
459
460 #[test]
461 fn test_rebase_checkpoint_with_resolved_file() {
462 let checkpoint = RebaseCheckpoint::new("main".to_string())
463 .with_conflicted_file("file1.txt".to_string())
464 .with_resolved_file("file1.txt".to_string());
465 assert!(checkpoint.resolved_files.contains(&"file1.txt".to_string()));
466 }
467
468 #[test]
469 fn test_rebase_checkpoint_with_error() {
470 let checkpoint =
471 RebaseCheckpoint::new("main".to_string()).with_error("Test error".to_string());
472 assert_eq!(checkpoint.error_count, 1);
473 assert_eq!(checkpoint.last_error, Some("Test error".to_string()));
474 }
475
476 #[test]
477 fn test_rebase_checkpoint_all_conflicts_resolved() {
478 let checkpoint = RebaseCheckpoint::new("main".to_string())
479 .with_conflicted_file("file1.txt".to_string())
480 .with_conflicted_file("file2.txt".to_string())
481 .with_resolved_file("file1.txt".to_string())
482 .with_resolved_file("file2.txt".to_string());
483 assert!(checkpoint.all_conflicts_resolved());
484 }
485
486 #[test]
487 fn test_rebase_checkpoint_unresolved_conflict_count() {
488 let checkpoint = RebaseCheckpoint::new("main".to_string())
489 .with_conflicted_file("file1.txt".to_string())
490 .with_conflicted_file("file2.txt".to_string())
491 .with_resolved_file("file1.txt".to_string());
492 assert_eq!(checkpoint.unresolved_conflict_count(), 1);
493 }
494
495 #[test]
496 fn test_rebase_phase_equality() {
497 assert_eq!(RebasePhase::NotStarted, RebasePhase::NotStarted);
498 assert_ne!(RebasePhase::NotStarted, RebasePhase::RebaseInProgress);
499 }
500
501 #[test]
502 fn test_rebase_checkpoint_path() {
503 let path = rebase_checkpoint_path();
504 assert!(path.contains(".agent"));
505 assert!(path.contains("rebase_checkpoint.json"));
506 }
507
508 #[test]
509 fn test_save_load_rebase_checkpoint() {
510 use test_helpers::with_temp_cwd;
511
512 with_temp_cwd(|_dir| {
513 let checkpoint = RebaseCheckpoint::new("main".to_string())
514 .with_phase(RebasePhase::ConflictDetected)
515 .with_conflicted_file("file1.rs".to_string())
516 .with_conflicted_file("file2.rs".to_string());
517
518 save_rebase_checkpoint(&checkpoint).unwrap();
519 assert!(rebase_checkpoint_exists());
520
521 let loaded = load_rebase_checkpoint()
522 .unwrap()
523 .expect("checkpoint should exist after save");
524 assert_eq!(loaded.phase, RebasePhase::ConflictDetected);
525 assert_eq!(loaded.upstream_branch, "main");
526 assert_eq!(loaded.conflicted_files.len(), 2);
527 });
528 }
529
530 #[test]
531 fn test_clear_rebase_checkpoint() {
532 use test_helpers::with_temp_cwd;
533
534 with_temp_cwd(|_dir| {
535 let checkpoint = RebaseCheckpoint::new("main".to_string());
536 save_rebase_checkpoint(&checkpoint).unwrap();
537 assert!(rebase_checkpoint_exists());
538
539 clear_rebase_checkpoint().unwrap();
540 assert!(!rebase_checkpoint_exists());
541 });
542 }
543
544 #[test]
545 fn test_load_nonexistent_rebase_checkpoint() {
546 use test_helpers::with_temp_cwd;
547
548 with_temp_cwd(|_dir| {
549 let result = load_rebase_checkpoint().unwrap();
550 assert!(result.is_none());
551 assert!(!rebase_checkpoint_exists());
552 });
553 }
554
555 #[test]
556 fn test_rebase_checkpoint_serialization() {
557 let checkpoint = RebaseCheckpoint::new("feature-branch".to_string())
558 .with_phase(RebasePhase::ConflictResolutionInProgress)
559 .with_conflicted_file("src/lib.rs".to_string())
560 .with_resolved_file("src/main.rs".to_string())
561 .with_error("Test error".to_string());
562
563 let json = serde_json::to_string(&checkpoint).unwrap();
564 assert!(json.contains("feature-branch"));
565 assert!(json.contains("src/lib.rs"));
566
567 let deserialized: RebaseCheckpoint = serde_json::from_str(&json).unwrap();
568 assert_eq!(deserialized.phase, checkpoint.phase);
569 assert_eq!(deserialized.upstream_branch, checkpoint.upstream_branch);
570 }
571
572 #[test]
573 fn test_atomic_checkpoint_write() {
574 use test_helpers::with_temp_cwd;
575
576 with_temp_cwd(|_dir| {
577 let checkpoint1 =
579 RebaseCheckpoint::new("main".to_string()).with_phase(RebasePhase::RebaseInProgress);
580
581 save_rebase_checkpoint(&checkpoint1).unwrap();
582
583 assert!(rebase_checkpoint_exists());
585
586 let checkpoint2 = RebaseCheckpoint::new("main".to_string())
588 .with_phase(RebasePhase::RebaseComplete)
589 .with_conflicted_file("test.rs".to_string());
590
591 save_rebase_checkpoint(&checkpoint2).unwrap();
592
593 let loaded = load_rebase_checkpoint()
595 .unwrap()
596 .expect("checkpoint should exist");
597 assert_eq!(loaded.phase, RebasePhase::RebaseComplete);
598 assert_eq!(loaded.conflicted_files.len(), 1);
599 });
600 }
601
602 #[test]
603 fn test_validate_checkpoint_valid() {
604 let checkpoint = RebaseCheckpoint::new("main".to_string())
605 .with_phase(RebasePhase::RebaseInProgress)
606 .with_conflicted_file("file1.rs".to_string())
607 .with_resolved_file("file1.rs".to_string());
608
609 assert!(validate_checkpoint(&checkpoint).is_ok());
610 }
611
612 #[test]
613 fn test_validate_checkpoint_empty_upstream() {
614 let checkpoint = RebaseCheckpoint::new("".to_string()).with_phase(RebasePhase::NotStarted);
616 assert!(validate_checkpoint(&checkpoint).is_ok());
617
618 let checkpoint =
620 RebaseCheckpoint::new("".to_string()).with_phase(RebasePhase::RebaseInProgress);
621 assert!(validate_checkpoint(&checkpoint).is_err());
622 }
623
624 #[test]
625 fn test_validate_checkpoint_invalid_timestamp() {
626 let mut checkpoint = RebaseCheckpoint::new("main".to_string());
627 checkpoint.timestamp = "invalid-timestamp".to_string();
628
629 assert!(validate_checkpoint(&checkpoint).is_err());
630 }
631
632 #[test]
633 fn test_validate_checkpoint_resolved_without_conflicted() {
634 let checkpoint =
635 RebaseCheckpoint::new("main".to_string()).with_resolved_file("file1.rs".to_string());
636
637 assert!(validate_checkpoint(&checkpoint).is_err());
639 }
640
641 #[test]
642 fn test_checkpoint_backup_and_restore() {
643 use test_helpers::with_temp_cwd;
644
645 with_temp_cwd(|_dir| {
646 let checkpoint1 = RebaseCheckpoint::new("main".to_string())
648 .with_phase(RebasePhase::ConflictDetected)
649 .with_conflicted_file("file.rs".to_string());
650
651 save_rebase_checkpoint(&checkpoint1).unwrap();
652
653 let checkpoint_path = rebase_checkpoint_path();
655 let backup_path = rebase_checkpoint_backup_path();
656 assert!(Path::new(&checkpoint_path).exists());
657 assert!(Path::new(&backup_path).exists());
658
659 fs::write(&checkpoint_path, "corrupted data {{{").unwrap();
661
662 let loaded = load_rebase_checkpoint()
664 .unwrap()
665 .expect("should restore from backup");
666
667 assert_eq!(loaded.phase, RebasePhase::ConflictDetected);
668 assert_eq!(loaded.conflicted_files.len(), 1);
669 });
670 }
671
672 #[test]
673 fn test_checkpoint_save_creates_backup() {
674 use test_helpers::with_temp_cwd;
675
676 with_temp_cwd(|_dir| {
677 let checkpoint1 =
679 RebaseCheckpoint::new("main".to_string()).with_phase(RebasePhase::RebaseInProgress);
680 save_rebase_checkpoint(&checkpoint1).unwrap();
681
682 let checkpoint2 =
684 RebaseCheckpoint::new("main".to_string()).with_phase(RebasePhase::RebaseComplete);
685 save_rebase_checkpoint(&checkpoint2).unwrap();
686
687 let backup_path = rebase_checkpoint_backup_path();
689 assert!(Path::new(&backup_path).exists());
690
691 let backup_content = fs::read_to_string(&backup_path).unwrap();
693 let backup_checkpoint: RebaseCheckpoint =
694 serde_json::from_str(&backup_content).unwrap();
695 assert_eq!(backup_checkpoint.phase, RebasePhase::RebaseInProgress);
696 });
697 }
698
699 #[test]
700 fn test_checkpoint_validation_failure_triggers_restore() {
701 use test_helpers::with_temp_cwd;
702
703 with_temp_cwd(|_dir| {
704 let checkpoint1 = RebaseCheckpoint::new("main".to_string())
706 .with_phase(RebasePhase::RebaseInProgress)
707 .with_conflicted_file("file.rs".to_string());
708
709 save_rebase_checkpoint(&checkpoint1).unwrap();
710
711 let checkpoint_path = rebase_checkpoint_path();
713 let corrupted_json = r#"{
714 "phase": "RebaseInProgress",
715 "upstream_branch": "main",
716 "conflicted_files": ["file.rs"],
717 "resolved_files": ["not_in_conflicted.rs"],
718 "error_count": 0,
719 "last_error": null,
720 "timestamp": "2024-01-01T00:00:00Z"
721 }"#;
722 fs::write(&checkpoint_path, corrupted_json).unwrap();
723
724 let loaded = load_rebase_checkpoint()
726 .unwrap()
727 .expect("should restore from backup");
728
729 assert_eq!(loaded.conflicted_files.len(), 1);
730 assert!(!loaded
731 .resolved_files
732 .contains(&"not_in_conflicted.rs".to_string()));
733 });
734 }
735}