Skip to main content

mars_agents/merge/
mod.rs

1//! Three-way merge using `git merge-file` CLI.
2//!
3//! Wraps `git merge-file -p` to produce git-standard conflict markers
4//! that IDEs (VS Code, JetBrains) recognize and provide "Accept Current/
5//! Incoming/Both" UI for.
6//!
7//! Uses `git merge-file` via subprocess for consistent merge behavior.
8//! Since mars is inherently a git-based tool, `git` being in PATH is a safe
9//! assumption.
10
11use std::io::Write;
12use std::path::Path;
13use std::process::Command;
14
15use crate::error::MarsError;
16
17/// Result of a three-way merge via `git merge-file`.
18#[derive(Debug, Clone)]
19pub struct MergeResult {
20    /// The merged content (may contain conflict markers).
21    pub content: Vec<u8>,
22    /// Whether the merge produced conflict markers.
23    pub has_conflicts: bool,
24    /// Number of conflict regions (approximate — counts `<<<<<<<` markers).
25    pub conflict_count: usize,
26}
27
28/// Labels for the three sides of a merge.
29#[derive(Debug, Clone)]
30pub struct MergeLabels {
31    /// e.g., "base (mars installed)"
32    pub base: String,
33    /// e.g., "local"
34    pub local: String,
35    /// e.g., "meridian-base@v0.6.0"
36    pub theirs: String,
37}
38
39/// Perform three-way merge using `git merge-file`.
40///
41/// Inputs:
42/// - `base`: what mars installed last time (from cache)
43/// - `local`: current file on disk (user's copy)
44/// - `theirs`: new source content (upstream update)
45///
46/// Output: merged content, possibly with git conflict markers.
47///
48/// `git merge-file` exit codes:
49/// - 0 = clean merge
50/// - positive = number of conflicts
51/// - negative = error
52pub fn merge_content(
53    base: &[u8],
54    local: &[u8],
55    theirs: &[u8],
56    labels: &MergeLabels,
57) -> Result<MergeResult, MarsError> {
58    let dir = tempfile::TempDir::new()?;
59
60    let base_path = dir.path().join("base");
61    let local_path = dir.path().join("local");
62    let theirs_path = dir.path().join("theirs");
63
64    write_file(&base_path, base)?;
65    write_file(&local_path, local)?;
66    write_file(&theirs_path, theirs)?;
67
68    // git merge-file -p -L <local-label> -L <base-label> -L <theirs-label>
69    //   <local-file> <base-file> <theirs-file>
70    //
71    // Note: label order is local, base, theirs (matching file order).
72    // The -p flag writes merged output to stdout instead of modifying the file.
73    let output = Command::new("git")
74        .arg("merge-file")
75        .arg("-p")
76        .arg("-L")
77        .arg(&labels.local)
78        .arg("-L")
79        .arg(&labels.base)
80        .arg("-L")
81        .arg(&labels.theirs)
82        .arg(&local_path)
83        .arg(&base_path)
84        .arg(&theirs_path)
85        .output()
86        .map_err(|e| MarsError::Source {
87            source_name: "merge".to_string(),
88            message: format!("failed to run `git merge-file`: {e} — is git installed and in PATH?"),
89        })?;
90
91    let exit_code = output.status.code().unwrap_or(-1);
92
93    // Negative exit code = error (not a conflict)
94    if exit_code < 0 {
95        return Err(MarsError::Source {
96            source_name: "merge".to_string(),
97            message: format!(
98                "git merge-file failed (exit {}): {}",
99                exit_code,
100                String::from_utf8_lossy(&output.stderr)
101            ),
102        });
103    }
104
105    let content = output.stdout;
106    let has_conflicts = exit_code > 0;
107    let conflict_count = count_conflict_markers(&content);
108
109    Ok(MergeResult {
110        content,
111        has_conflicts,
112        conflict_count,
113    })
114}
115
116/// Check if file content contains unresolved conflict markers.
117///
118/// Scans for `<<<<<<<` markers that indicate an unresolved merge conflict.
119pub fn has_conflict_markers(content: &[u8]) -> bool {
120    // Look for "<<<<<<< " at the start of a line
121    if content.starts_with(b"<<<<<<<") {
122        return true;
123    }
124    content
125        .windows(8)
126        .any(|w| w[0] == b'\n' && &w[1..] == b"<<<<<<<")
127}
128
129/// Check whether a file on disk contains unresolved conflict markers.
130pub fn file_has_conflict_markers(path: &Path) -> bool {
131    // Intentionally treat unreadable files as "no markers" so list/resolve views stay conservative.
132    std::fs::read(path)
133        .map(|content| has_conflict_markers(&content))
134        .unwrap_or(false)
135}
136
137/// Count conflict marker regions in content.
138fn count_conflict_markers(content: &[u8]) -> usize {
139    let mut count = 0;
140
141    // Check if content starts with a marker
142    if content.len() >= 7 && &content[..7] == b"<<<<<<<" {
143        count += 1;
144    }
145
146    // Count occurrences of "\n<<<<<<<" (marker at start of line)
147    for window in content.windows(8) {
148        if window[0] == b'\n' && &window[1..] == b"<<<<<<<" {
149            count += 1;
150        }
151    }
152
153    count
154}
155
156/// Helper to write bytes to a file.
157fn write_file(path: &std::path::Path, content: &[u8]) -> Result<(), MarsError> {
158    let mut file = std::fs::File::create(path)?;
159    file.write_all(content)?;
160    Ok(())
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn labels() -> MergeLabels {
168        MergeLabels {
169            base: "base (last sync)".to_string(),
170            local: "local".to_string(),
171            theirs: "meridian-base@v0.6.0".to_string(),
172        }
173    }
174
175    // === Clean merge tests ===
176
177    #[test]
178    fn all_three_identical() {
179        let content = b"line 1\nline 2\nline 3\n";
180        let result = merge_content(content, content, content, &labels()).unwrap();
181        assert!(!result.has_conflicts);
182        assert_eq!(result.conflict_count, 0);
183        assert_eq!(result.content, content);
184    }
185
186    #[test]
187    fn theirs_changed_local_same_as_base() {
188        let base = b"line 1\nline 2\nline 3\n";
189        let local = b"line 1\nline 2\nline 3\n";
190        let theirs = b"line 1\nline 2 modified\nline 3\n";
191
192        let result = merge_content(base, local, theirs, &labels()).unwrap();
193        assert!(!result.has_conflicts);
194        assert_eq!(result.content, theirs);
195    }
196
197    #[test]
198    fn local_changed_theirs_same_as_base() {
199        let base = b"line 1\nline 2\nline 3\n";
200        let local = b"line 1\nline 2 local edit\nline 3\n";
201        let theirs = b"line 1\nline 2\nline 3\n";
202
203        let result = merge_content(base, local, theirs, &labels()).unwrap();
204        assert!(!result.has_conflicts);
205        assert_eq!(result.content, local);
206    }
207
208    #[test]
209    fn non_overlapping_changes_merge_cleanly() {
210        let base = b"line 1\nline 2\nline 3\nline 4\nline 5\n";
211        let local = b"line 1 local\nline 2\nline 3\nline 4\nline 5\n";
212        let theirs = b"line 1\nline 2\nline 3\nline 4\nline 5 theirs\n";
213
214        let result = merge_content(base, local, theirs, &labels()).unwrap();
215        assert!(!result.has_conflicts);
216        let merged = String::from_utf8(result.content).unwrap();
217        assert!(merged.contains("line 1 local"));
218        assert!(merged.contains("line 5 theirs"));
219    }
220
221    // === Conflict tests ===
222
223    #[test]
224    fn overlapping_changes_produce_conflict() {
225        let base = b"line 1\nline 2\nline 3\n";
226        let local = b"line 1\nlocal change\nline 3\n";
227        let theirs = b"line 1\ntheirs change\nline 3\n";
228
229        let result = merge_content(base, local, theirs, &labels()).unwrap();
230        assert!(result.has_conflicts);
231        assert!(result.conflict_count >= 1);
232    }
233
234    #[test]
235    fn conflict_markers_match_git_format() {
236        let base = b"same\nconflict line\nsame\n";
237        let local = b"same\nlocal version\nsame\n";
238        let theirs = b"same\ntheirs version\nsame\n";
239
240        let result = merge_content(base, local, theirs, &labels()).unwrap();
241        assert!(result.has_conflicts);
242
243        let merged = String::from_utf8(result.content).unwrap();
244        assert!(merged.contains("<<<<<<<"), "should have opening marker");
245        assert!(merged.contains("======="), "should have separator");
246        assert!(merged.contains(">>>>>>>"), "should have closing marker");
247    }
248
249    #[test]
250    fn labels_appear_in_conflict_markers() {
251        let base = b"conflict\n";
252        let local = b"local version\n";
253        let theirs = b"theirs version\n";
254
255        let result = merge_content(base, local, theirs, &labels()).unwrap();
256        let merged = String::from_utf8(result.content).unwrap();
257        assert!(
258            merged.contains("local"),
259            "local label should appear: {merged}"
260        );
261        assert!(
262            merged.contains("meridian-base@v0.6.0"),
263            "theirs label should appear: {merged}"
264        );
265    }
266
267    #[test]
268    fn multiple_conflict_regions() {
269        // Use more spacing between conflicting regions so git treats them separately
270        let base = b"a\nb\nc\nd\ne\nf\ng\nh\ni\nj\n";
271        let local = b"a-local\nb\nc\nd\ne\nf\ng\nh\ni-local\nj\n";
272        let theirs = b"a-theirs\nb\nc\nd\ne\nf\ng\nh\ni-theirs\nj\n";
273
274        let result = merge_content(base, local, theirs, &labels()).unwrap();
275        assert!(result.has_conflicts);
276        assert!(
277            result.conflict_count >= 2,
278            "should have at least 2 conflicts, got {}",
279            result.conflict_count
280        );
281    }
282
283    // === Edge cases ===
284
285    #[test]
286    fn empty_base_with_different_content() {
287        let base = b"";
288        let local = b"local content\n";
289        let theirs = b"theirs content\n";
290
291        // Empty base with both sides adding content → conflict
292        let result = merge_content(base, local, theirs, &labels()).unwrap();
293        // Both added content from empty base — this is a conflict
294        assert!(result.has_conflicts);
295    }
296
297    #[test]
298    fn empty_base_same_additions() {
299        let base = b"";
300        let local = b"same content\n";
301        let theirs = b"same content\n";
302
303        let result = merge_content(base, local, theirs, &labels()).unwrap();
304        assert!(!result.has_conflicts);
305        assert_eq!(result.content, b"same content\n");
306    }
307
308    #[test]
309    fn all_empty() {
310        let result = merge_content(b"", b"", b"", &labels()).unwrap();
311        assert!(!result.has_conflicts);
312        assert!(result.content.is_empty());
313    }
314
315    // === has_conflict_markers tests ===
316
317    #[test]
318    fn has_conflict_markers_detects_markers() {
319        let content = b"before\n<<<<<<< local\nlocal\n=======\ntheirs\n>>>>>>> theirs\nafter\n";
320        assert!(has_conflict_markers(content));
321    }
322
323    #[test]
324    fn has_conflict_markers_at_start_of_file() {
325        let content = b"<<<<<<< local\nlocal\n=======\ntheirs\n>>>>>>> theirs\n";
326        assert!(has_conflict_markers(content));
327    }
328
329    #[test]
330    fn has_conflict_markers_no_markers() {
331        let content = b"normal content\nno conflicts here\n";
332        assert!(!has_conflict_markers(content));
333    }
334
335    #[test]
336    fn has_conflict_markers_partial_marker_not_detected() {
337        // "<<<<<<" (6 chars) shouldn't be detected — needs 7 (`<<<<<<<`)
338        let content = b"some <<<<<< stuff\n";
339        assert!(!has_conflict_markers(content));
340    }
341
342    #[test]
343    fn has_conflict_markers_in_middle_of_line_not_detected() {
344        // Marker must be at start of line
345        let content = b"text <<<<<<< not a real marker\n";
346        assert!(!has_conflict_markers(content));
347    }
348
349    // === count_conflict_markers tests ===
350
351    #[test]
352    fn count_zero_conflicts() {
353        assert_eq!(count_conflict_markers(b"no conflicts"), 0);
354    }
355
356    #[test]
357    fn count_one_conflict() {
358        let content = b"before\n<<<<<<< local\nlocal\n=======\ntheirs\n>>>>>>> theirs\nafter\n";
359        assert_eq!(count_conflict_markers(content), 1);
360    }
361
362    #[test]
363    fn count_multiple_conflicts() {
364        let content =
365            b"<<<<<<< a\nx\n=======\ny\n>>>>>>> b\nok\n<<<<<<< a\np\n=======\nq\n>>>>>>> b\n";
366        assert_eq!(count_conflict_markers(content), 2);
367    }
368}