ralph_workflow/git_helpers/
rebase.rs1#![deny(unsafe_code)]
14
15use std::io;
16use std::path::Path;
17
18fn git2_to_io_error(err: &git2::Error) -> io::Error {
20 io::Error::other(err.to_string())
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum RebaseResult {
26 Success,
28 Conflicts(Vec<String>),
30 NoOp,
32}
33
34pub fn rebase_onto(upstream_branch: &str) -> io::Result<RebaseResult> {
64 use std::process::Command;
65
66 let repo = git2::Repository::discover(".").map_err(|e| git2_to_io_error(&e))?;
68
69 match repo.head() {
70 Ok(_) => {}
71 Err(ref e) if e.code() == git2::ErrorCode::UnbornBranch => {
72 return Ok(RebaseResult::NoOp);
74 }
75 Err(e) => return Err(git2_to_io_error(&e)),
76 }
77
78 let upstream_object = repo.revparse_single(upstream_branch).map_err(|_| {
80 io::Error::new(
81 io::ErrorKind::NotFound,
82 format!("Upstream branch '{upstream_branch}' not found"),
83 )
84 })?;
85
86 let upstream_commit = upstream_object
87 .peel_to_commit()
88 .map_err(|e| git2_to_io_error(&e))?;
89
90 let head = repo.head().map_err(|e| git2_to_io_error(&e))?;
92 let head_commit = head.peel_to_commit().map_err(|e| git2_to_io_error(&e))?;
93
94 if repo
96 .graph_descendant_of(head_commit.id(), upstream_commit.id())
97 .map_err(|e| git2_to_io_error(&e))?
98 {
99 return Ok(RebaseResult::NoOp);
101 }
102
103 match repo.merge_base(head_commit.id(), upstream_commit.id()) {
106 Err(e)
107 if e.class() == git2::ErrorClass::Reference
108 && e.code() == git2::ErrorCode::NotFound =>
109 {
110 return Ok(RebaseResult::NoOp);
112 }
113 Err(e) => return Err(git2_to_io_error(&e)),
114 Ok(_) => {}
115 }
116
117 let branch_name = head.shorthand().ok_or_else(|| {
119 io::Error::new(
120 io::ErrorKind::NotFound,
121 "Could not determine branch name from HEAD",
122 )
123 })?;
124
125 if branch_name == "main" || branch_name == "master" {
126 return Ok(RebaseResult::NoOp);
127 }
128
129 let output = Command::new("git")
131 .args(["rebase", upstream_branch])
132 .output();
133
134 match output {
135 Ok(result) => {
136 if result.status.success() {
137 Ok(RebaseResult::Success)
138 } else {
139 let stderr = String::from_utf8_lossy(&result.stderr);
140 if stderr.contains("Conflict")
142 || stderr.contains("conflict")
143 || stderr.contains("Resolve")
144 {
145 Ok(RebaseResult::Conflicts(vec![]))
147 } else if stderr.contains("up to date") {
148 Ok(RebaseResult::NoOp)
149 } else {
150 Err(io::Error::other(format!("Rebase failed: {stderr}")))
151 }
152 }
153 }
154 Err(e) => Err(io::Error::other(format!(
155 "Failed to execute git rebase: {e}"
156 ))),
157 }
158}
159
160pub fn abort_rebase() -> io::Result<()> {
171 use std::process::Command;
172
173 let repo = git2::Repository::discover(".").map_err(|e| git2_to_io_error(&e))?;
174
175 let state = repo.state();
177 if state != git2::RepositoryState::Rebase
178 && state != git2::RepositoryState::RebaseMerge
179 && state != git2::RepositoryState::RebaseInteractive
180 {
181 return Err(io::Error::new(
182 io::ErrorKind::InvalidInput,
183 "No rebase in progress",
184 ));
185 }
186
187 let output = Command::new("git").args(["rebase", "--abort"]).output();
189
190 match output {
191 Ok(result) => {
192 if result.status.success() {
193 Ok(())
194 } else {
195 let stderr = String::from_utf8_lossy(&result.stderr);
196 Err(io::Error::other(format!(
197 "Failed to abort rebase: {stderr}"
198 )))
199 }
200 }
201 Err(e) => Err(io::Error::other(format!(
202 "Failed to execute git rebase --abort: {e}"
203 ))),
204 }
205}
206
207pub fn get_conflicted_files() -> io::Result<Vec<String>> {
217 let repo = git2::Repository::discover(".").map_err(|e| git2_to_io_error(&e))?;
218 let index = repo.index().map_err(|e| git2_to_io_error(&e))?;
219
220 let mut conflicted_files = Vec::new();
221
222 if !index.has_conflicts() {
224 return Ok(conflicted_files);
225 }
226
227 let conflicts = index.conflicts().map_err(|e| git2_to_io_error(&e))?;
229
230 for conflict in conflicts {
231 let conflict = conflict.map_err(|e| git2_to_io_error(&e))?;
232 if let Some(our_entry) = conflict.our {
234 if let Ok(path) = std::str::from_utf8(&our_entry.path) {
235 let path_str = path.to_string();
236 if !conflicted_files.contains(&path_str) {
237 conflicted_files.push(path_str);
238 }
239 }
240 }
241 }
242
243 Ok(conflicted_files)
244}
245
246pub fn get_conflict_markers_for_file(path: &Path) -> io::Result<String> {
260 use std::fs;
261 use std::io::Read;
262
263 let mut file = fs::File::open(path)?;
264 let mut content = String::new();
265 file.read_to_string(&mut content)?;
266
267 let mut conflict_sections = Vec::new();
269 let lines: Vec<&str> = content.lines().collect();
270 let mut i = 0;
271
272 while i < lines.len() {
273 if lines[i].trim_start().starts_with("<<<<<<<") {
274 let mut section = Vec::new();
276 section.push(lines[i]);
277
278 i += 1;
279 while i < lines.len() && !lines[i].trim_start().starts_with("=======") {
281 section.push(lines[i]);
282 i += 1;
283 }
284
285 if i < lines.len() {
286 section.push(lines[i]); i += 1;
288 }
289
290 while i < lines.len() && !lines[i].trim_start().starts_with(">>>>>>>") {
292 section.push(lines[i]);
293 i += 1;
294 }
295
296 if i < lines.len() {
297 section.push(lines[i]); i += 1;
299 }
300
301 conflict_sections.push(section.join("\n"));
302 } else {
303 i += 1;
304 }
305 }
306
307 if conflict_sections.is_empty() {
308 Ok(String::new())
310 } else {
311 Ok(conflict_sections.join("\n\n"))
312 }
313}
314
315pub fn continue_rebase() -> io::Result<()> {
328 use std::process::Command;
329
330 let repo = git2::Repository::discover(".").map_err(|e| git2_to_io_error(&e))?;
331
332 let state = repo.state();
334 if state != git2::RepositoryState::Rebase
335 && state != git2::RepositoryState::RebaseMerge
336 && state != git2::RepositoryState::RebaseInteractive
337 {
338 return Err(io::Error::new(
339 io::ErrorKind::InvalidInput,
340 "No rebase in progress",
341 ));
342 }
343
344 let conflicted = get_conflicted_files()?;
346 if !conflicted.is_empty() {
347 return Err(io::Error::new(
348 io::ErrorKind::InvalidInput,
349 format!(
350 "Cannot continue rebase: {} file(s) still have conflicts",
351 conflicted.len()
352 ),
353 ));
354 }
355
356 let output = Command::new("git").args(["rebase", "--continue"]).output();
358
359 match output {
360 Ok(result) => {
361 if result.status.success() {
362 Ok(())
363 } else {
364 let stderr = String::from_utf8_lossy(&result.stderr);
365 Err(io::Error::other(format!(
366 "Failed to continue rebase: {stderr}"
367 )))
368 }
369 }
370 Err(e) => Err(io::Error::other(format!(
371 "Failed to execute git rebase --continue: {e}"
372 ))),
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_rebase_result_variants_exist() {
382 let _ = RebaseResult::Success;
384 let _ = RebaseResult::NoOp;
385 let _ = RebaseResult::Conflicts(vec![]);
386 }
387
388 #[test]
389 fn test_rebase_onto_returns_result() {
390 let result = rebase_onto("nonexistent_branch_that_does_not_exist");
393 assert!(result.is_err());
395 }
396
397 #[test]
398 fn test_get_conflicted_files_returns_result() {
399 let result = get_conflicted_files();
401 assert!(result.is_ok());
403 }
404}