use std::collections::HashMap;
use std::process::Command;
use anyhow::{Context, Result};
use git2::{Oid, Repository};
use tracing::debug;
use crate::data::amendments::{Amendment, AmendmentFile};
use crate::git::SHORT_HASH_LEN;
pub struct AmendmentHandler {
repo: Repository,
}
impl AmendmentHandler {
pub fn new() -> Result<Self> {
let repo = Repository::open(".").context("Failed to open git repository")?;
Ok(Self { repo })
}
pub fn apply_amendments(&self, yaml_file: &str) -> Result<()> {
let amendment_file = AmendmentFile::load_from_file(yaml_file)?;
self.perform_safety_checks(&amendment_file)?;
let amendments = self.organize_amendments(&amendment_file.amendments)?;
if amendments.is_empty() {
println!("No valid amendments found to apply.");
return Ok(());
}
if amendments.len() == 1 && self.is_head_commit(&amendments[0].0)? {
println!(
"Amending HEAD commit: {}",
&amendments[0].0[..SHORT_HASH_LEN]
);
self.amend_head_commit(&amendments[0].1)?;
} else {
println!(
"Amending {} commits using interactive rebase",
amendments.len()
);
self.amend_via_rebase(amendments)?;
}
println!("✅ Amendment operations completed successfully");
Ok(())
}
fn perform_safety_checks(&self, amendment_file: &AmendmentFile) -> Result<()> {
crate::utils::preflight::check_working_directory_clean()
.context("Cannot amend commits with uncommitted changes")?;
for amendment in &amendment_file.amendments {
self.validate_commit_amendable(&amendment.commit)?;
}
Ok(())
}
fn validate_commit_amendable(&self, commit_hash: &str) -> Result<()> {
let oid = Oid::from_str(commit_hash)
.with_context(|| format!("Invalid commit hash: {commit_hash}"))?;
let _commit = self
.repo
.find_commit(oid)
.with_context(|| format!("Commit not found: {commit_hash}"))?;
Ok(())
}
fn organize_amendments(&self, amendments: &[Amendment]) -> Result<Vec<(String, String)>> {
let mut valid_amendments = Vec::new();
let mut commit_depths = HashMap::new();
for amendment in amendments {
if let Ok(depth) = self.get_commit_depth_from_head(&amendment.commit) {
commit_depths.insert(amendment.commit.clone(), depth);
valid_amendments.push((amendment.commit.clone(), amendment.message.clone()));
} else {
println!(
"Warning: Skipping invalid commit {}",
&amendment.commit[..SHORT_HASH_LEN]
);
}
}
valid_amendments.sort_by_key(|(commit, _)| commit_depths.get(commit).copied().unwrap_or(0));
valid_amendments.reverse();
Ok(valid_amendments)
}
fn get_commit_depth_from_head(&self, commit_hash: &str) -> Result<usize> {
let target_oid = Oid::from_str(commit_hash)?;
let mut revwalk = self.repo.revwalk()?;
revwalk.push_head()?;
for (depth, oid_result) in revwalk.enumerate() {
let oid = oid_result?;
if oid == target_oid {
return Ok(depth);
}
}
anyhow::bail!("Commit {commit_hash} not found in current branch history");
}
fn is_head_commit(&self, commit_hash: &str) -> Result<bool> {
let head_oid = self.repo.head()?.target().context("HEAD has no target")?;
let target_oid = Oid::from_str(commit_hash)?;
Ok(head_oid == target_oid)
}
fn amend_head_commit(&self, new_message: &str) -> Result<()> {
let head_commit = self.repo.head()?.peel_to_commit()?;
let output = Command::new("git")
.args(["commit", "--amend", "--message", new_message])
.output()
.context("Failed to execute git commit --amend")?;
if !output.status.success() {
let error_msg = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("Failed to amend HEAD commit: {error_msg}");
}
let new_head = self.repo.head()?.peel_to_commit()?;
println!(
"✅ Amended HEAD commit {} -> {}",
&head_commit.id().to_string()[..SHORT_HASH_LEN],
&new_head.id().to_string()[..SHORT_HASH_LEN]
);
Ok(())
}
fn amend_via_rebase(&self, amendments: Vec<(String, String)>) -> Result<()> {
if amendments.is_empty() {
return Ok(());
}
println!("Amending commits individually in reverse order (newest to oldest)");
let mut sorted_amendments = amendments;
sorted_amendments
.sort_by_key(|(hash, _)| self.get_commit_depth_from_head(hash).unwrap_or(usize::MAX));
for (commit_hash, new_message) in sorted_amendments {
let depth = self.get_commit_depth_from_head(&commit_hash)?;
if depth == 0 {
println!("Amending HEAD commit: {}", &commit_hash[..SHORT_HASH_LEN]);
self.amend_head_commit(&new_message)?;
} else {
println!(
"Amending commit at depth {}: {}",
depth,
&commit_hash[..SHORT_HASH_LEN]
);
self.amend_single_commit_via_rebase(&commit_hash, &new_message)?;
}
}
Ok(())
}
fn amend_single_commit_via_rebase(&self, commit_hash: &str, new_message: &str) -> Result<()> {
let base_commit = format!("{commit_hash}^");
let temp_dir = tempfile::tempdir()?;
let sequence_file = temp_dir.path().join("rebase-sequence");
let mut sequence_content = String::new();
let commit_list_output = Command::new("git")
.args(["rev-list", "--reverse", &format!("{base_commit}..HEAD")])
.output()
.context("Failed to get commit list for rebase")?;
if !commit_list_output.status.success() {
anyhow::bail!("Failed to generate commit list for rebase");
}
let commit_list = String::from_utf8_lossy(&commit_list_output.stdout);
for line in commit_list.lines() {
let commit = line.trim();
if commit.is_empty() {
continue;
}
let subject_output = Command::new("git")
.args(["log", "--format=%s", "-n", "1", commit])
.output()
.context("Failed to get commit subject")?;
let subject = String::from_utf8_lossy(&subject_output.stdout)
.trim()
.to_string();
if commit.starts_with(&commit_hash[..commit.len().min(commit_hash.len())]) {
sequence_content.push_str(&format!("edit {commit} {subject}\n"));
} else {
sequence_content.push_str(&format!("pick {commit} {subject}\n"));
}
}
std::fs::write(&sequence_file, sequence_content)?;
println!(
"Starting interactive rebase to amend commit: {}",
&commit_hash[..SHORT_HASH_LEN]
);
let rebase_result = Command::new("git")
.args(["rebase", "-i", &base_commit])
.env(
"GIT_SEQUENCE_EDITOR",
format!("cp {}", sequence_file.display()),
)
.env("GIT_EDITOR", "true") .output()
.context("Failed to start interactive rebase")?;
if !rebase_result.status.success() {
let error_msg = String::from_utf8_lossy(&rebase_result.stderr);
if let Err(e) = Command::new("git").args(["rebase", "--abort"]).output() {
debug!("Rebase abort during cleanup failed: {e}");
}
anyhow::bail!("Interactive rebase failed: {error_msg}");
}
let repo_state = self.repo.state();
if repo_state == git2::RepositoryState::RebaseInteractive {
let current_commit_output = Command::new("git")
.args(["rev-parse", "HEAD"])
.output()
.context("Failed to get current commit during rebase")?;
let current_commit = String::from_utf8_lossy(¤t_commit_output.stdout)
.trim()
.to_string();
if current_commit
.starts_with(&commit_hash[..current_commit.len().min(commit_hash.len())])
{
let amend_result = Command::new("git")
.args(["commit", "--amend", "-m", new_message])
.output()
.context("Failed to amend commit during rebase")?;
if !amend_result.status.success() {
let error_msg = String::from_utf8_lossy(&amend_result.stderr);
if let Err(e) = Command::new("git").args(["rebase", "--abort"]).output() {
debug!("Rebase abort during cleanup failed: {e}");
}
anyhow::bail!("Failed to amend commit: {error_msg}");
}
println!("✅ Amended commit: {}", &commit_hash[..SHORT_HASH_LEN]);
let continue_result = Command::new("git")
.args(["rebase", "--continue"])
.output()
.context("Failed to continue rebase")?;
if !continue_result.status.success() {
let error_msg = String::from_utf8_lossy(&continue_result.stderr);
if let Err(e) = Command::new("git").args(["rebase", "--abort"]).output() {
debug!("Rebase abort during cleanup failed: {e}");
}
anyhow::bail!("Failed to continue rebase: {error_msg}");
}
println!("✅ Rebase completed successfully");
} else {
if let Err(e) = Command::new("git").args(["rebase", "--abort"]).output() {
debug!("Rebase abort during cleanup failed: {e}");
}
anyhow::bail!(
"Unexpected commit during rebase. Expected {}, got {}",
&commit_hash[..SHORT_HASH_LEN],
¤t_commit[..SHORT_HASH_LEN]
);
}
} else if repo_state != git2::RepositoryState::Clean {
anyhow::bail!("Repository in unexpected state after rebase: {repo_state:?}");
}
Ok(())
}
}