1use std::path::{Path, PathBuf};
4
5use ainl_contracts::FeatureSnapshot;
6use chrono::Utc;
7use thiserror::Error;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct ShellOutput {
12 pub exit_code: i32,
13 pub stdout: String,
14 pub stderr: String,
15}
16
17pub trait ShellRunner {
19 fn run(&self, cwd: &Path, program: &str, args: &[&str]) -> Result<ShellOutput, GitSnapshotError>;
20}
21
22#[derive(Debug, Error)]
24pub enum GitSnapshotError {
25 #[error("shell: {0}")]
26 Shell(String),
27 #[error("not a git repository at {0}")]
28 NotARepo(PathBuf),
29 #[error("git command failed (exit {exit_code}): {stderr}")]
30 CommandFailed {
31 exit_code: i32,
32 stderr: String,
33 },
34 #[error("missing stash sha in output")]
35 MissingStashSha,
36}
37
38pub fn resolve_repo_toplevel(
40 shell: &dyn ShellRunner,
41 path: &Path,
42) -> Result<PathBuf, GitSnapshotError> {
43 let out = shell.run(path, "git", &["rev-parse", "--show-toplevel"])?;
44 if out.exit_code != 0 {
45 return Err(GitSnapshotError::NotARepo(path.to_path_buf()));
46 }
47 let top = out.stdout.trim().to_string();
48 if top.is_empty() {
49 return Err(GitSnapshotError::NotARepo(path.to_path_buf()));
50 }
51 Ok(PathBuf::from(top))
52}
53
54pub fn create_snapshot(
56 shell: &dyn ShellRunner,
57 path: &Path,
58) -> Result<FeatureSnapshot, GitSnapshotError> {
59 let repo_toplevel = resolve_repo_toplevel(shell, path)?;
60 let stash_out = shell.run(&repo_toplevel, "git", &["stash", "create"])?;
61 if stash_out.exit_code != 0 {
62 return Err(GitSnapshotError::CommandFailed {
63 exit_code: stash_out.exit_code,
64 stderr: stash_out.stderr,
65 });
66 }
67 let stash_sha = stash_out.stdout.trim().to_string();
68 if stash_sha.is_empty() {
69 return Err(GitSnapshotError::MissingStashSha);
70 }
71 let head_out = shell.run(&repo_toplevel, "git", &["rev-parse", "HEAD"])?;
72 if head_out.exit_code != 0 {
73 return Err(GitSnapshotError::CommandFailed {
74 exit_code: head_out.exit_code,
75 stderr: head_out.stderr,
76 });
77 }
78 let head_sha = head_out.stdout.trim().to_string();
79 Ok(FeatureSnapshot {
80 repo_toplevel,
81 stash_sha,
82 head_sha,
83 taken_at: Utc::now(),
84 })
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct SnapshotApplyResult {
90 pub applied: bool,
91 pub conflicts: Vec<String>,
92}
93
94fn parse_apply_conflicts(stdout: &str, stderr: &str) -> Vec<String> {
95 let mut conflicts = Vec::new();
96 for line in stdout.lines().chain(stderr.lines()) {
97 let trimmed = line.trim();
98 if trimmed.is_empty() {
99 continue;
100 }
101 if trimmed.contains("CONFLICT")
102 || trimmed.contains("conflict")
103 || trimmed.contains("Merge conflict")
104 {
105 conflicts.push(trimmed.to_string());
106 }
107 }
108 if conflicts.is_empty() {
109 let combined = format!("{stdout}{stderr}").trim().to_string();
110 if !combined.is_empty() {
111 conflicts.push(combined);
112 }
113 }
114 conflicts
115}
116
117pub fn apply_snapshot(
122 shell: &dyn ShellRunner,
123 snapshot: &FeatureSnapshot,
124) -> Result<SnapshotApplyResult, GitSnapshotError> {
125 let out = shell.run(
126 &snapshot.repo_toplevel,
127 "git",
128 &["stash", "apply", snapshot.stash_sha.as_str()],
129 )?;
130 if out.exit_code != 0 {
131 return Ok(SnapshotApplyResult {
132 applied: false,
133 conflicts: parse_apply_conflicts(&out.stdout, &out.stderr),
134 });
135 }
136 Ok(SnapshotApplyResult {
137 applied: true,
138 conflicts: Vec::new(),
139 })
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use std::collections::HashMap;
146 use std::sync::Mutex;
147
148 struct MockShell {
149 responses: Mutex<HashMap<String, ShellOutput>>,
150 }
151
152 impl MockShell {
153 fn new() -> Self {
154 Self {
155 responses: Mutex::new(HashMap::new()),
156 }
157 }
158
159 fn when(mut self, key: &str, out: ShellOutput) -> Self {
160 self.responses.get_mut().unwrap().insert(key.into(), out);
161 self
162 }
163 }
164
165 impl ShellRunner for MockShell {
166 fn run(
167 &self,
168 _cwd: &Path,
169 program: &str,
170 args: &[&str],
171 ) -> Result<ShellOutput, GitSnapshotError> {
172 let key = format!("{program} {}", args.join(" "));
173 self.responses
174 .lock()
175 .unwrap()
176 .get(&key)
177 .cloned()
178 .ok_or_else(|| GitSnapshotError::Shell(format!("no mock for {key}")))
179 }
180 }
181
182 #[test]
183 fn resolve_toplevel() {
184 let shell = MockShell::new().when(
185 "git rev-parse --show-toplevel",
186 ShellOutput {
187 exit_code: 0,
188 stdout: "/repo\n".into(),
189 stderr: String::new(),
190 },
191 );
192 let top = resolve_repo_toplevel(&shell, Path::new("/repo/src")).unwrap();
193 assert_eq!(top, PathBuf::from("/repo"));
194 }
195
196 #[test]
197 fn apply_snapshot_reports_conflicts_without_error() {
198 let shell = MockShell::new().when(
199 "git stash apply deadbeef",
200 ShellOutput {
201 exit_code: 1,
202 stdout: String::new(),
203 stderr: "error: patch failed: CONFLICT (content): file.txt\n".into(),
204 },
205 );
206 let snapshot = FeatureSnapshot {
207 repo_toplevel: PathBuf::from("/repo"),
208 stash_sha: "deadbeef".into(),
209 head_sha: "abc".into(),
210 taken_at: chrono::Utc::now(),
211 };
212 let result = apply_snapshot(&shell, &snapshot).unwrap();
213 assert!(!result.applied);
214 assert!(!result.conflicts.is_empty());
215 }
216
217 #[test]
218 fn create_snapshot_roundtrip_fields() {
219 let shell = MockShell::new()
220 .when(
221 "git rev-parse --show-toplevel",
222 ShellOutput {
223 exit_code: 0,
224 stdout: "/repo\n".into(),
225 stderr: String::new(),
226 },
227 )
228 .when(
229 "git stash create",
230 ShellOutput {
231 exit_code: 0,
232 stdout: "deadbeef\n".into(),
233 stderr: String::new(),
234 },
235 )
236 .when(
237 "git rev-parse HEAD",
238 ShellOutput {
239 exit_code: 0,
240 stdout: "cafebabe\n".into(),
241 stderr: String::new(),
242 },
243 );
244 let snap = create_snapshot(&shell, Path::new("/repo")).unwrap();
245 assert_eq!(snap.stash_sha, "deadbeef");
246 assert_eq!(snap.head_sha, "cafebabe");
247 }
248}