use std::path::{Path, PathBuf};
use ainl_contracts::FeatureSnapshot;
use chrono::Utc;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShellOutput {
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
}
pub trait ShellRunner {
fn run(&self, cwd: &Path, program: &str, args: &[&str]) -> Result<ShellOutput, GitSnapshotError>;
}
#[derive(Debug, Error)]
pub enum GitSnapshotError {
#[error("shell: {0}")]
Shell(String),
#[error("not a git repository at {0}")]
NotARepo(PathBuf),
#[error("git command failed (exit {exit_code}): {stderr}")]
CommandFailed {
exit_code: i32,
stderr: String,
},
#[error("missing stash sha in output")]
MissingStashSha,
}
pub fn resolve_repo_toplevel(
shell: &dyn ShellRunner,
path: &Path,
) -> Result<PathBuf, GitSnapshotError> {
let out = shell.run(path, "git", &["rev-parse", "--show-toplevel"])?;
if out.exit_code != 0 {
return Err(GitSnapshotError::NotARepo(path.to_path_buf()));
}
let top = out.stdout.trim().to_string();
if top.is_empty() {
return Err(GitSnapshotError::NotARepo(path.to_path_buf()));
}
Ok(PathBuf::from(top))
}
pub fn create_snapshot(
shell: &dyn ShellRunner,
path: &Path,
) -> Result<FeatureSnapshot, GitSnapshotError> {
let repo_toplevel = resolve_repo_toplevel(shell, path)?;
let stash_out = shell.run(&repo_toplevel, "git", &["stash", "create"])?;
if stash_out.exit_code != 0 {
return Err(GitSnapshotError::CommandFailed {
exit_code: stash_out.exit_code,
stderr: stash_out.stderr,
});
}
let stash_sha = stash_out.stdout.trim().to_string();
if stash_sha.is_empty() {
return Err(GitSnapshotError::MissingStashSha);
}
let head_out = shell.run(&repo_toplevel, "git", &["rev-parse", "HEAD"])?;
if head_out.exit_code != 0 {
return Err(GitSnapshotError::CommandFailed {
exit_code: head_out.exit_code,
stderr: head_out.stderr,
});
}
let head_sha = head_out.stdout.trim().to_string();
Ok(FeatureSnapshot {
repo_toplevel,
stash_sha,
head_sha,
taken_at: Utc::now(),
})
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SnapshotApplyResult {
pub applied: bool,
pub conflicts: Vec<String>,
}
fn parse_apply_conflicts(stdout: &str, stderr: &str) -> Vec<String> {
let mut conflicts = Vec::new();
for line in stdout.lines().chain(stderr.lines()) {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if trimmed.contains("CONFLICT")
|| trimmed.contains("conflict")
|| trimmed.contains("Merge conflict")
{
conflicts.push(trimmed.to_string());
}
}
if conflicts.is_empty() {
let combined = format!("{stdout}{stderr}").trim().to_string();
if !combined.is_empty() {
conflicts.push(combined);
}
}
conflicts
}
pub fn apply_snapshot(
shell: &dyn ShellRunner,
snapshot: &FeatureSnapshot,
) -> Result<SnapshotApplyResult, GitSnapshotError> {
let out = shell.run(
&snapshot.repo_toplevel,
"git",
&["stash", "apply", snapshot.stash_sha.as_str()],
)?;
if out.exit_code != 0 {
return Ok(SnapshotApplyResult {
applied: false,
conflicts: parse_apply_conflicts(&out.stdout, &out.stderr),
});
}
Ok(SnapshotApplyResult {
applied: true,
conflicts: Vec::new(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::Mutex;
struct MockShell {
responses: Mutex<HashMap<String, ShellOutput>>,
}
impl MockShell {
fn new() -> Self {
Self {
responses: Mutex::new(HashMap::new()),
}
}
fn when(mut self, key: &str, out: ShellOutput) -> Self {
self.responses.get_mut().unwrap().insert(key.into(), out);
self
}
}
impl ShellRunner for MockShell {
fn run(
&self,
_cwd: &Path,
program: &str,
args: &[&str],
) -> Result<ShellOutput, GitSnapshotError> {
let key = format!("{program} {}", args.join(" "));
self.responses
.lock()
.unwrap()
.get(&key)
.cloned()
.ok_or_else(|| GitSnapshotError::Shell(format!("no mock for {key}")))
}
}
#[test]
fn resolve_toplevel() {
let shell = MockShell::new().when(
"git rev-parse --show-toplevel",
ShellOutput {
exit_code: 0,
stdout: "/repo\n".into(),
stderr: String::new(),
},
);
let top = resolve_repo_toplevel(&shell, Path::new("/repo/src")).unwrap();
assert_eq!(top, PathBuf::from("/repo"));
}
#[test]
fn apply_snapshot_reports_conflicts_without_error() {
let shell = MockShell::new().when(
"git stash apply deadbeef",
ShellOutput {
exit_code: 1,
stdout: String::new(),
stderr: "error: patch failed: CONFLICT (content): file.txt\n".into(),
},
);
let snapshot = FeatureSnapshot {
repo_toplevel: PathBuf::from("/repo"),
stash_sha: "deadbeef".into(),
head_sha: "abc".into(),
taken_at: chrono::Utc::now(),
};
let result = apply_snapshot(&shell, &snapshot).unwrap();
assert!(!result.applied);
assert!(!result.conflicts.is_empty());
}
#[test]
fn create_snapshot_roundtrip_fields() {
let shell = MockShell::new()
.when(
"git rev-parse --show-toplevel",
ShellOutput {
exit_code: 0,
stdout: "/repo\n".into(),
stderr: String::new(),
},
)
.when(
"git stash create",
ShellOutput {
exit_code: 0,
stdout: "deadbeef\n".into(),
stderr: String::new(),
},
)
.when(
"git rev-parse HEAD",
ShellOutput {
exit_code: 0,
stdout: "cafebabe\n".into(),
stderr: String::new(),
},
);
let snap = create_snapshot(&shell, Path::new("/repo")).unwrap();
assert_eq!(snap.stash_sha, "deadbeef");
assert_eq!(snap.head_sha, "cafebabe");
}
}