#![deny(unsafe_code)]
use crate::workspace::Workspace;
use std::fs;
use std::io;
use std::path::Path;
const AGENT_DIR: &str = ".agent";
const REBASE_CHECKPOINT_FILE: &str = "rebase_checkpoint.json";
#[must_use]
pub fn rebase_checkpoint_path() -> String {
format!("{AGENT_DIR}/{REBASE_CHECKPOINT_FILE}")
}
#[must_use]
pub fn rebase_checkpoint_backup_path() -> String {
format!("{AGENT_DIR}/{REBASE_CHECKPOINT_FILE}.bak")
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum RebasePhase {
NotStarted,
PreRebaseCheck,
RebaseInProgress,
ConflictDetected,
ConflictResolutionInProgress,
CompletingRebase,
RebaseComplete,
RebaseAborted,
}
impl RebasePhase {
#[cfg(any(test, feature = "test-utils"))]
#[must_use]
pub const fn max_recovery_attempts(&self) -> u32 {
match self {
Self::ConflictResolutionInProgress => 5,
Self::RebaseInProgress | Self::CompletingRebase => 2,
Self::PreRebaseCheck => 1,
Self::ConflictDetected
| Self::NotStarted
| Self::RebaseComplete
| Self::RebaseAborted => 3,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RebaseCheckpoint {
pub phase: RebasePhase,
pub upstream_branch: String,
pub conflicted_files: Vec<String>,
pub resolved_files: Vec<String>,
pub error_count: u32,
pub last_error: Option<String>,
pub timestamp: String,
#[serde(default)]
pub phase_error_count: u32,
}
impl Default for RebaseCheckpoint {
fn default() -> Self {
Self {
phase: RebasePhase::NotStarted,
upstream_branch: String::new(),
conflicted_files: Vec::new(),
resolved_files: Vec::new(),
error_count: 0,
last_error: None,
timestamp: chrono::Utc::now().to_rfc3339(),
phase_error_count: 0,
}
}
}
impl RebaseCheckpoint {
#[must_use]
pub fn new(upstream_branch: String) -> Self {
Self {
phase: RebasePhase::NotStarted,
upstream_branch,
conflicted_files: Vec::new(),
resolved_files: Vec::new(),
error_count: 0,
last_error: None,
timestamp: chrono::Utc::now().to_rfc3339(),
phase_error_count: 0,
}
}
#[must_use]
pub fn with_phase(self, phase: RebasePhase) -> Self {
let phase_error_count = if self.phase != phase {
0
} else {
self.phase_error_count
};
Self {
phase,
phase_error_count,
timestamp: chrono::Utc::now().to_rfc3339(),
..self
}
}
#[must_use]
pub fn with_conflicted_file(self, file: String) -> Self {
let conflicted_files = if !self.conflicted_files.contains(&file) {
self.conflicted_files
.into_iter()
.chain(std::iter::once(file))
.collect()
} else {
self.conflicted_files
};
Self {
conflicted_files,
..self
}
}
#[must_use]
pub fn with_resolved_file(self, file: String) -> Self {
let resolved_files = if !self.resolved_files.contains(&file) {
self.resolved_files
.into_iter()
.chain(std::iter::once(file))
.collect()
} else {
self.resolved_files
};
Self {
resolved_files,
..self
}
}
#[must_use]
pub fn with_error(self, error: String) -> Self {
Self {
error_count: self.error_count.saturating_add(1),
phase_error_count: self.phase_error_count.saturating_add(1),
last_error: Some(error),
timestamp: chrono::Utc::now().to_rfc3339(),
..self
}
}
#[must_use]
pub fn all_conflicts_resolved(&self) -> bool {
self.conflicted_files
.iter()
.all(|f| self.resolved_files.contains(f))
}
#[must_use]
pub fn unresolved_conflict_count(&self) -> usize {
self.conflicted_files
.iter()
.filter(|f| !self.resolved_files.contains(f))
.count()
}
}
pub fn save_rebase_checkpoint(checkpoint: &RebaseCheckpoint) -> io::Result<()> {
let json = serde_json::to_string_pretty(checkpoint).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to serialize rebase checkpoint: {e}"),
)
})?;
fs::create_dir_all(AGENT_DIR)?;
let checkpoint_existed = Path::new(&rebase_checkpoint_path()).exists();
let _ = backup_checkpoint();
let checkpoint_path_str = rebase_checkpoint_path();
let temp_path = format!("{checkpoint_path_str}.tmp");
let write_result = fs::write(&temp_path, &json);
if write_result.is_err() {
let _ = fs::remove_file(&temp_path);
return write_result;
}
let rename_result = fs::rename(&temp_path, &checkpoint_path_str);
if rename_result.is_err() {
let _ = fs::remove_file(&temp_path);
return rename_result;
}
if !checkpoint_existed {
let _ = backup_checkpoint();
}
Ok(())
}
pub fn load_rebase_checkpoint() -> io::Result<Option<RebaseCheckpoint>> {
let checkpoint = rebase_checkpoint_path();
let path = Path::new(&checkpoint);
if !path.exists() {
return Ok(None);
}
let content = fs::read_to_string(path)?;
let loaded_checkpoint: RebaseCheckpoint = match serde_json::from_str(&content) {
Ok(cp) => cp,
Err(e) => {
let backup_result = restore_from_backup();
return match backup_result {
Err(err) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Checkpoint corrupted: {e}; backup restore failed: {err}",),
)),
Ok(success) => Ok(success),
};
}
};
if let Err(e) = validate_checkpoint(&loaded_checkpoint) {
let backup_result = restore_from_backup();
return match backup_result {
Err(err) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Checkpoint validation failed: {e}; backup restore failed: {err}",),
)),
Ok(success) => Ok(success),
};
}
Ok(Some(loaded_checkpoint))
}
pub fn clear_rebase_checkpoint() -> io::Result<()> {
let checkpoint = rebase_checkpoint_path();
let path = Path::new(&checkpoint);
if path.exists() {
fs::remove_file(path)?;
}
Ok(())
}
#[must_use]
pub fn rebase_checkpoint_exists() -> bool {
Path::new(&rebase_checkpoint_path()).exists()
}
#[cfg(any(test, feature = "test-utils"))]
pub fn validate_checkpoint(checkpoint: &RebaseCheckpoint) -> io::Result<()> {
validate_checkpoint_impl(checkpoint)
}
#[cfg(not(any(test, feature = "test-utils")))]
fn validate_checkpoint(checkpoint: &RebaseCheckpoint) -> io::Result<()> {
validate_checkpoint_impl(checkpoint)
}
fn validate_checkpoint_impl(checkpoint: &RebaseCheckpoint) -> io::Result<()> {
if checkpoint.phase != RebasePhase::NotStarted && checkpoint.upstream_branch.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Checkpoint has empty upstream branch",
));
}
if chrono::DateTime::parse_from_rfc3339(&checkpoint.timestamp).is_err() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Checkpoint has invalid timestamp format",
));
}
checkpoint.resolved_files.iter().try_for_each(|resolved| {
if checkpoint.conflicted_files.contains(resolved) {
return Ok(());
}
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Resolved file '{resolved}' not found in conflicted files list"),
))
})?;
Ok(())
}
fn backup_checkpoint() -> io::Result<()> {
let checkpoint_path = rebase_checkpoint_path();
let backup_path = rebase_checkpoint_backup_path();
let checkpoint = Path::new(&checkpoint_path);
let backup = Path::new(&backup_path);
if !checkpoint.exists() {
return Ok(());
}
if backup.exists() {
fs::remove_file(backup)?;
}
fs::copy(checkpoint, backup)?;
Ok(())
}
fn restore_from_backup() -> io::Result<Option<RebaseCheckpoint>> {
let backup_path = rebase_checkpoint_backup_path();
let backup = Path::new(&backup_path);
if !backup.exists() {
return Ok(None);
}
let content = fs::read_to_string(backup)?;
let checkpoint: RebaseCheckpoint = serde_json::from_str(&content).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse backup checkpoint: {e}"),
)
})?;
validate_checkpoint(&checkpoint)?;
let checkpoint_path = rebase_checkpoint_path();
fs::copy(backup, checkpoint_path)?;
Ok(Some(checkpoint))
}
pub fn save_rebase_checkpoint_with_workspace(
checkpoint: &RebaseCheckpoint,
workspace: &dyn Workspace,
) -> io::Result<()> {
let json = serde_json::to_string_pretty(checkpoint).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to serialize rebase checkpoint: {e}"),
)
})?;
let agent_dir = Path::new(AGENT_DIR);
let checkpoint_path = Path::new(AGENT_DIR).join(REBASE_CHECKPOINT_FILE);
let backup_path = Path::new(AGENT_DIR).join(format!("{REBASE_CHECKPOINT_FILE}.bak"));
workspace.create_dir_all(agent_dir)?;
let checkpoint_existed = workspace.exists(&checkpoint_path);
if checkpoint_existed {
let _ = backup_checkpoint_with_workspace(workspace);
}
workspace.write_atomic(&checkpoint_path, &json)?;
if !checkpoint_existed {
let _ = backup_checkpoint_with_workspace(workspace);
}
if workspace.exists(&backup_path) {
if let Ok(content) = workspace.read(&backup_path) {
if content.trim().is_empty() {
let _ = workspace.remove(&backup_path);
}
}
}
Ok(())
}
pub fn load_rebase_checkpoint_with_workspace(
workspace: &dyn Workspace,
) -> io::Result<Option<RebaseCheckpoint>> {
let checkpoint_path = Path::new(AGENT_DIR).join(REBASE_CHECKPOINT_FILE);
if !workspace.exists(&checkpoint_path) {
return Ok(None);
}
let content = workspace.read(&checkpoint_path)?;
let loaded_checkpoint: RebaseCheckpoint = match serde_json::from_str(&content) {
Ok(cp) => cp,
Err(e) => {
let backup_result = restore_from_backup_with_workspace(workspace);
return match backup_result {
Err(err) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Checkpoint corrupted: {e}; backup restore failed: {err}",),
)),
Ok(success) => Ok(success),
};
}
};
if let Err(e) = validate_checkpoint_impl(&loaded_checkpoint) {
let backup_result = restore_from_backup_with_workspace(workspace);
return match backup_result {
Err(err) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Checkpoint validation failed: {e}; backup restore failed: {err}",),
)),
Ok(success) => Ok(success),
};
}
Ok(Some(loaded_checkpoint))
}
pub fn clear_rebase_checkpoint_with_workspace(workspace: &dyn Workspace) -> io::Result<()> {
let checkpoint_path = Path::new(AGENT_DIR).join(REBASE_CHECKPOINT_FILE);
if workspace.exists(&checkpoint_path) {
workspace.remove(&checkpoint_path)?;
}
Ok(())
}
pub fn rebase_checkpoint_exists_with_workspace(workspace: &dyn Workspace) -> bool {
let checkpoint_path = Path::new(AGENT_DIR).join(REBASE_CHECKPOINT_FILE);
workspace.exists(&checkpoint_path)
}
fn backup_checkpoint_with_workspace(workspace: &dyn Workspace) -> io::Result<()> {
let checkpoint_path = Path::new(AGENT_DIR).join(REBASE_CHECKPOINT_FILE);
let backup_path = Path::new(AGENT_DIR).join(format!("{REBASE_CHECKPOINT_FILE}.bak"));
if !workspace.exists(&checkpoint_path) {
return Ok(());
}
if workspace.exists(&backup_path) {
workspace.remove(&backup_path)?;
}
let content = workspace.read(&checkpoint_path)?;
workspace.write(&backup_path, &content)?;
Ok(())
}
fn restore_from_backup_with_workspace(
workspace: &dyn Workspace,
) -> io::Result<Option<RebaseCheckpoint>> {
let checkpoint_path = Path::new(AGENT_DIR).join(REBASE_CHECKPOINT_FILE);
let backup_path = Path::new(AGENT_DIR).join(format!("{REBASE_CHECKPOINT_FILE}.bak"));
if !workspace.exists(&backup_path) {
return Ok(None);
}
let content = workspace.read(&backup_path)?;
let checkpoint: RebaseCheckpoint = serde_json::from_str(&content).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse backup checkpoint: {e}"),
)
})?;
validate_checkpoint_impl(&checkpoint)?;
workspace.write(&checkpoint_path, &content)?;
Ok(Some(checkpoint))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rebase_checkpoint_default() {
let checkpoint = RebaseCheckpoint::default();
assert_eq!(checkpoint.phase, RebasePhase::NotStarted);
assert!(checkpoint.upstream_branch.is_empty());
assert!(checkpoint.conflicted_files.is_empty());
assert!(checkpoint.resolved_files.is_empty());
assert_eq!(checkpoint.error_count, 0);
assert!(checkpoint.last_error.is_none());
}
#[test]
fn test_rebase_checkpoint_new() {
let checkpoint = RebaseCheckpoint::new("main".to_string());
assert_eq!(checkpoint.phase, RebasePhase::NotStarted);
assert_eq!(checkpoint.upstream_branch, "main");
}
#[test]
fn test_rebase_checkpoint_with_phase() {
let checkpoint =
RebaseCheckpoint::new("main".to_string()).with_phase(RebasePhase::RebaseInProgress);
assert_eq!(checkpoint.phase, RebasePhase::RebaseInProgress);
}
#[test]
fn test_rebase_checkpoint_with_conflicted_file() {
let checkpoint = RebaseCheckpoint::new("main".to_string())
.with_conflicted_file("file1.txt".to_string())
.with_conflicted_file("file2.txt".to_string());
assert_eq!(
checkpoint.conflicted_files.len(),
2,
"Should track both files"
);
assert!(
checkpoint
.conflicted_files
.contains(&"file1.txt".to_string()),
"Should contain file1.txt"
);
assert!(
checkpoint
.conflicted_files
.contains(&"file2.txt".to_string()),
"Should contain file2.txt"
);
let checkpoint = checkpoint.with_conflicted_file("file1.txt".to_string());
assert_eq!(
checkpoint.conflicted_files.len(),
2,
"Should not increase count for duplicate"
);
assert!(
checkpoint
.conflicted_files
.contains(&"file1.txt".to_string()),
"Should still contain file1.txt"
);
}
#[test]
fn test_rebase_checkpoint_with_resolved_file() {
let checkpoint = RebaseCheckpoint::new("main".to_string())
.with_conflicted_file("file1.txt".to_string())
.with_resolved_file("file1.txt".to_string());
assert!(checkpoint.resolved_files.contains(&"file1.txt".to_string()));
}
#[test]
fn test_rebase_checkpoint_with_error() {
let checkpoint =
RebaseCheckpoint::new("main".to_string()).with_error("Test error".to_string());
assert_eq!(checkpoint.error_count, 1);
assert_eq!(checkpoint.last_error, Some("Test error".to_string()));
}
#[test]
fn test_rebase_checkpoint_all_conflicts_resolved() {
let checkpoint = RebaseCheckpoint::new("main".to_string())
.with_conflicted_file("file1.txt".to_string())
.with_conflicted_file("file2.txt".to_string())
.with_resolved_file("file1.txt".to_string())
.with_resolved_file("file2.txt".to_string());
assert!(checkpoint.all_conflicts_resolved());
}
#[test]
fn test_rebase_checkpoint_unresolved_conflict_count() {
let checkpoint = RebaseCheckpoint::new("main".to_string())
.with_conflicted_file("file1.txt".to_string())
.with_conflicted_file("file2.txt".to_string())
.with_resolved_file("file1.txt".to_string());
assert_eq!(checkpoint.unresolved_conflict_count(), 1);
}
#[test]
fn test_rebase_phase_equality() {
assert_eq!(RebasePhase::NotStarted, RebasePhase::NotStarted);
assert_ne!(RebasePhase::NotStarted, RebasePhase::RebaseInProgress);
}
#[test]
fn test_rebase_checkpoint_path() {
let path = rebase_checkpoint_path();
assert!(path.contains(".agent"));
assert!(path.contains("rebase_checkpoint.json"));
}
#[test]
fn test_rebase_checkpoint_serialization() {
let checkpoint = RebaseCheckpoint::new("feature-branch".to_string())
.with_phase(RebasePhase::ConflictResolutionInProgress)
.with_conflicted_file("src/lib.rs".to_string())
.with_resolved_file("src/main.rs".to_string())
.with_error("Test error".to_string());
let json = serde_json::to_string(&checkpoint).unwrap();
assert!(json.contains("feature-branch"));
assert!(json.contains("src/lib.rs"));
let deserialized: RebaseCheckpoint = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.phase, checkpoint.phase);
assert_eq!(deserialized.upstream_branch, checkpoint.upstream_branch);
}
#[test]
fn test_validate_checkpoint_valid() {
let checkpoint = RebaseCheckpoint::new("main".to_string())
.with_phase(RebasePhase::RebaseInProgress)
.with_conflicted_file("file1.rs".to_string())
.with_resolved_file("file1.rs".to_string());
assert!(validate_checkpoint(&checkpoint).is_ok());
}
#[test]
fn test_validate_checkpoint_empty_upstream() {
let checkpoint = RebaseCheckpoint::new(String::new()).with_phase(RebasePhase::NotStarted);
assert!(validate_checkpoint(&checkpoint).is_ok());
let checkpoint =
RebaseCheckpoint::new(String::new()).with_phase(RebasePhase::RebaseInProgress);
assert!(validate_checkpoint(&checkpoint).is_err());
}
#[test]
fn test_validate_checkpoint_invalid_timestamp() {
let mut checkpoint = RebaseCheckpoint::new("main".to_string());
checkpoint.timestamp = "invalid-timestamp".to_string();
assert!(validate_checkpoint(&checkpoint).is_err());
}
#[test]
fn test_validate_checkpoint_resolved_without_conflicted() {
let checkpoint =
RebaseCheckpoint::new("main".to_string()).with_resolved_file("file1.rs".to_string());
assert!(validate_checkpoint(&checkpoint).is_err());
}
}
#[cfg(all(test, feature = "test-utils"))]
mod workspace_tests {
use super::*;
use crate::workspace::MemoryWorkspace;
#[test]
fn test_save_and_load_checkpoint_with_workspace() {
let workspace = MemoryWorkspace::new_test();
let checkpoint = RebaseCheckpoint::new("main".to_string())
.with_phase(RebasePhase::ConflictDetected)
.with_conflicted_file("file1.rs".to_string());
save_rebase_checkpoint_with_workspace(&checkpoint, &workspace).unwrap();
assert!(rebase_checkpoint_exists_with_workspace(&workspace));
let loaded = load_rebase_checkpoint_with_workspace(&workspace)
.unwrap()
.expect("checkpoint should exist after save");
assert_eq!(loaded.phase, RebasePhase::ConflictDetected);
assert_eq!(loaded.upstream_branch, "main");
assert_eq!(
loaded.conflicted_files.len(),
1,
"Should have one conflicted file"
);
assert!(
loaded.conflicted_files.contains(&"file1.rs".to_string()),
"Should contain file1.rs"
);
}
#[test]
fn test_clear_checkpoint_with_workspace() {
let workspace = MemoryWorkspace::new_test();
let checkpoint = RebaseCheckpoint::new("main".to_string());
save_rebase_checkpoint_with_workspace(&checkpoint, &workspace).unwrap();
assert!(rebase_checkpoint_exists_with_workspace(&workspace));
clear_rebase_checkpoint_with_workspace(&workspace).unwrap();
assert!(!rebase_checkpoint_exists_with_workspace(&workspace));
}
#[test]
fn test_load_nonexistent_checkpoint_with_workspace() {
let workspace = MemoryWorkspace::new_test();
let result = load_rebase_checkpoint_with_workspace(&workspace).unwrap();
assert!(result.is_none());
}
#[test]
fn test_checkpoint_backup_with_workspace() {
let workspace = MemoryWorkspace::new_test();
let checkpoint1 =
RebaseCheckpoint::new("main".to_string()).with_phase(RebasePhase::RebaseInProgress);
save_rebase_checkpoint_with_workspace(&checkpoint1, &workspace).unwrap();
let checkpoint2 =
RebaseCheckpoint::new("main".to_string()).with_phase(RebasePhase::RebaseComplete);
save_rebase_checkpoint_with_workspace(&checkpoint2, &workspace).unwrap();
let loaded = load_rebase_checkpoint_with_workspace(&workspace)
.unwrap()
.expect("checkpoint should exist");
assert_eq!(loaded.phase, RebasePhase::RebaseComplete);
}
#[test]
fn test_corrupted_checkpoint_restores_from_backup_with_workspace() {
let workspace = MemoryWorkspace::new_test();
let checkpoint_path = Path::new(AGENT_DIR).join(REBASE_CHECKPOINT_FILE);
let backup_path = Path::new(AGENT_DIR).join(format!("{REBASE_CHECKPOINT_FILE}.bak"));
let checkpoint = RebaseCheckpoint::new("main".to_string())
.with_phase(RebasePhase::ConflictDetected)
.with_conflicted_file("file.rs".to_string());
save_rebase_checkpoint_with_workspace(&checkpoint, &workspace).unwrap();
assert!(workspace.exists(&backup_path));
workspace
.write(&checkpoint_path, "corrupted data {{{")
.unwrap();
let loaded = load_rebase_checkpoint_with_workspace(&workspace)
.unwrap()
.expect("should restore from backup");
assert_eq!(loaded.phase, RebasePhase::ConflictDetected);
}
}