Skip to main content

mars_agents/merge/
mod.rs

1//! Three-way merge using `git merge-file` CLI.
2//!
3//! Wraps `git merge-file -p` to produce git-standard conflict markers
4//! that IDEs (VS Code, JetBrains) recognize and provide "Accept Current/
5//! Incoming/Both" UI for.
6//!
7//! Uses `git merge-file` via subprocess for consistent merge behavior.
8//! Since mars is inherently a git-based tool, `git` being in PATH is a safe
9//! assumption.
10
11use std::io::Write;
12use std::process::Command;
13
14use crate::error::MarsError;
15
16/// Result of a three-way merge via `git merge-file`.
17#[derive(Debug, Clone)]
18pub struct MergeResult {
19    /// The merged content (may contain conflict markers).
20    pub content: Vec<u8>,
21    /// Whether the merge produced conflict markers.
22    pub has_conflicts: bool,
23    /// Number of conflict regions (approximate — counts `<<<<<<<` markers).
24    pub conflict_count: usize,
25}
26
27/// Labels for the three sides of a merge.
28#[derive(Debug, Clone)]
29pub struct MergeLabels {
30    /// e.g., "base (mars installed)"
31    pub base: String,
32    /// e.g., "local"
33    pub local: String,
34    /// e.g., "meridian-base@v0.6.0"
35    pub theirs: String,
36}
37
38/// Perform three-way merge using `git merge-file`.
39///
40/// Inputs:
41/// - `base`: what mars installed last time (from cache)
42/// - `local`: current file on disk (user's copy)
43/// - `theirs`: new source content (upstream update)
44///
45/// Output: merged content, possibly with git conflict markers.
46///
47/// `git merge-file` exit codes:
48/// - 0 = clean merge
49/// - positive = number of conflicts
50/// - negative = error
51pub fn merge_content(
52    base: &[u8],
53    local: &[u8],
54    theirs: &[u8],
55    labels: &MergeLabels,
56) -> Result<MergeResult, MarsError> {
57    let dir = tempfile::TempDir::new()?;
58
59    let base_path = dir.path().join("base");
60    let local_path = dir.path().join("local");
61    let theirs_path = dir.path().join("theirs");
62
63    write_file(&base_path, base)?;
64    write_file(&local_path, local)?;
65    write_file(&theirs_path, theirs)?;
66
67    // git merge-file -p -L <local-label> -L <base-label> -L <theirs-label>
68    //   <local-file> <base-file> <theirs-file>
69    //
70    // Note: label order is local, base, theirs (matching file order).
71    // The -p flag writes merged output to stdout instead of modifying the file.
72    let output = Command::new("git")
73        .arg("merge-file")
74        .arg("-p")
75        .arg("-L")
76        .arg(&labels.local)
77        .arg("-L")
78        .arg(&labels.base)
79        .arg("-L")
80        .arg(&labels.theirs)
81        .arg(&local_path)
82        .arg(&base_path)
83        .arg(&theirs_path)
84        .output()
85        .map_err(|e| MarsError::Source {
86            source_name: "merge".to_string(),
87            message: format!("failed to run `git merge-file`: {e} — is git installed and in PATH?"),
88        })?;
89
90    let exit_code = output.status.code().unwrap_or(-1);
91
92    // Negative exit code = error (not a conflict)
93    if exit_code < 0 {
94        return Err(MarsError::Source {
95            source_name: "merge".to_string(),
96            message: format!(
97                "git merge-file failed (exit {}): {}",
98                exit_code,
99                String::from_utf8_lossy(&output.stderr)
100            ),
101        });
102    }
103
104    let content = output.stdout;
105    let has_conflicts = exit_code > 0;
106    let conflict_count = count_conflict_markers(&content);
107
108    Ok(MergeResult {
109        content,
110        has_conflicts,
111        conflict_count,
112    })
113}
114
115/// Check if file content contains unresolved conflict markers.
116///
117/// Scans for `<<<<<<<` markers that indicate an unresolved merge conflict.
118pub fn has_conflict_markers(content: &[u8]) -> bool {
119    // Look for "<<<<<<< " at the start of a line
120    if content.starts_with(b"<<<<<<<") {
121        return true;
122    }
123    content
124        .windows(8)
125        .any(|w| w[0] == b'\n' && &w[1..] == b"<<<<<<<")
126}
127
128/// Count conflict marker regions in content.
129fn count_conflict_markers(content: &[u8]) -> usize {
130    let mut count = 0;
131
132    // Check if content starts with a marker
133    if content.len() >= 7 && &content[..7] == b"<<<<<<<" {
134        count += 1;
135    }
136
137    // Count occurrences of "\n<<<<<<<" (marker at start of line)
138    for window in content.windows(8) {
139        if window[0] == b'\n' && &window[1..] == b"<<<<<<<" {
140            count += 1;
141        }
142    }
143
144    count
145}
146
147/// Helper to write bytes to a file.
148fn write_file(path: &std::path::Path, content: &[u8]) -> Result<(), MarsError> {
149    let mut file = std::fs::File::create(path)?;
150    file.write_all(content)?;
151    Ok(())
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    fn labels() -> MergeLabels {
159        MergeLabels {
160            base: "base (last sync)".to_string(),
161            local: "local".to_string(),
162            theirs: "meridian-base@v0.6.0".to_string(),
163        }
164    }
165
166    // === Clean merge tests ===
167
168    #[test]
169    fn all_three_identical() {
170        let content = b"line 1\nline 2\nline 3\n";
171        let result = merge_content(content, content, content, &labels()).unwrap();
172        assert!(!result.has_conflicts);
173        assert_eq!(result.conflict_count, 0);
174        assert_eq!(result.content, content);
175    }
176
177    #[test]
178    fn theirs_changed_local_same_as_base() {
179        let base = b"line 1\nline 2\nline 3\n";
180        let local = b"line 1\nline 2\nline 3\n";
181        let theirs = b"line 1\nline 2 modified\nline 3\n";
182
183        let result = merge_content(base, local, theirs, &labels()).unwrap();
184        assert!(!result.has_conflicts);
185        assert_eq!(result.content, theirs);
186    }
187
188    #[test]
189    fn local_changed_theirs_same_as_base() {
190        let base = b"line 1\nline 2\nline 3\n";
191        let local = b"line 1\nline 2 local edit\nline 3\n";
192        let theirs = b"line 1\nline 2\nline 3\n";
193
194        let result = merge_content(base, local, theirs, &labels()).unwrap();
195        assert!(!result.has_conflicts);
196        assert_eq!(result.content, local);
197    }
198
199    #[test]
200    fn non_overlapping_changes_merge_cleanly() {
201        let base = b"line 1\nline 2\nline 3\nline 4\nline 5\n";
202        let local = b"line 1 local\nline 2\nline 3\nline 4\nline 5\n";
203        let theirs = b"line 1\nline 2\nline 3\nline 4\nline 5 theirs\n";
204
205        let result = merge_content(base, local, theirs, &labels()).unwrap();
206        assert!(!result.has_conflicts);
207        let merged = String::from_utf8(result.content).unwrap();
208        assert!(merged.contains("line 1 local"));
209        assert!(merged.contains("line 5 theirs"));
210    }
211
212    // === Conflict tests ===
213
214    #[test]
215    fn overlapping_changes_produce_conflict() {
216        let base = b"line 1\nline 2\nline 3\n";
217        let local = b"line 1\nlocal change\nline 3\n";
218        let theirs = b"line 1\ntheirs change\nline 3\n";
219
220        let result = merge_content(base, local, theirs, &labels()).unwrap();
221        assert!(result.has_conflicts);
222        assert!(result.conflict_count >= 1);
223    }
224
225    #[test]
226    fn conflict_markers_match_git_format() {
227        let base = b"same\nconflict line\nsame\n";
228        let local = b"same\nlocal version\nsame\n";
229        let theirs = b"same\ntheirs version\nsame\n";
230
231        let result = merge_content(base, local, theirs, &labels()).unwrap();
232        assert!(result.has_conflicts);
233
234        let merged = String::from_utf8(result.content).unwrap();
235        assert!(merged.contains("<<<<<<<"), "should have opening marker");
236        assert!(merged.contains("======="), "should have separator");
237        assert!(merged.contains(">>>>>>>"), "should have closing marker");
238    }
239
240    #[test]
241    fn labels_appear_in_conflict_markers() {
242        let base = b"conflict\n";
243        let local = b"local version\n";
244        let theirs = b"theirs version\n";
245
246        let result = merge_content(base, local, theirs, &labels()).unwrap();
247        let merged = String::from_utf8(result.content).unwrap();
248        assert!(
249            merged.contains("local"),
250            "local label should appear: {merged}"
251        );
252        assert!(
253            merged.contains("meridian-base@v0.6.0"),
254            "theirs label should appear: {merged}"
255        );
256    }
257
258    #[test]
259    fn multiple_conflict_regions() {
260        // Use more spacing between conflicting regions so git treats them separately
261        let base = b"a\nb\nc\nd\ne\nf\ng\nh\ni\nj\n";
262        let local = b"a-local\nb\nc\nd\ne\nf\ng\nh\ni-local\nj\n";
263        let theirs = b"a-theirs\nb\nc\nd\ne\nf\ng\nh\ni-theirs\nj\n";
264
265        let result = merge_content(base, local, theirs, &labels()).unwrap();
266        assert!(result.has_conflicts);
267        assert!(
268            result.conflict_count >= 2,
269            "should have at least 2 conflicts, got {}",
270            result.conflict_count
271        );
272    }
273
274    // === Edge cases ===
275
276    #[test]
277    fn empty_base_with_different_content() {
278        let base = b"";
279        let local = b"local content\n";
280        let theirs = b"theirs content\n";
281
282        // Empty base with both sides adding content → conflict
283        let result = merge_content(base, local, theirs, &labels()).unwrap();
284        // Both added content from empty base — this is a conflict
285        assert!(result.has_conflicts);
286    }
287
288    #[test]
289    fn empty_base_same_additions() {
290        let base = b"";
291        let local = b"same content\n";
292        let theirs = b"same content\n";
293
294        let result = merge_content(base, local, theirs, &labels()).unwrap();
295        assert!(!result.has_conflicts);
296        assert_eq!(result.content, b"same content\n");
297    }
298
299    #[test]
300    fn all_empty() {
301        let result = merge_content(b"", b"", b"", &labels()).unwrap();
302        assert!(!result.has_conflicts);
303        assert!(result.content.is_empty());
304    }
305
306    // === has_conflict_markers tests ===
307
308    #[test]
309    fn has_conflict_markers_detects_markers() {
310        let content = b"before\n<<<<<<< local\nlocal\n=======\ntheirs\n>>>>>>> theirs\nafter\n";
311        assert!(has_conflict_markers(content));
312    }
313
314    #[test]
315    fn has_conflict_markers_at_start_of_file() {
316        let content = b"<<<<<<< local\nlocal\n=======\ntheirs\n>>>>>>> theirs\n";
317        assert!(has_conflict_markers(content));
318    }
319
320    #[test]
321    fn has_conflict_markers_no_markers() {
322        let content = b"normal content\nno conflicts here\n";
323        assert!(!has_conflict_markers(content));
324    }
325
326    #[test]
327    fn has_conflict_markers_partial_marker_not_detected() {
328        // "<<<<<<" (6 chars) shouldn't be detected — needs 7 (`<<<<<<<`)
329        let content = b"some <<<<<< stuff\n";
330        assert!(!has_conflict_markers(content));
331    }
332
333    #[test]
334    fn has_conflict_markers_in_middle_of_line_not_detected() {
335        // Marker must be at start of line
336        let content = b"text <<<<<<< not a real marker\n";
337        assert!(!has_conflict_markers(content));
338    }
339
340    // === count_conflict_markers tests ===
341
342    #[test]
343    fn count_zero_conflicts() {
344        assert_eq!(count_conflict_markers(b"no conflicts"), 0);
345    }
346
347    #[test]
348    fn count_one_conflict() {
349        let content = b"before\n<<<<<<< local\nlocal\n=======\ntheirs\n>>>>>>> theirs\nafter\n";
350        assert_eq!(count_conflict_markers(content), 1);
351    }
352
353    #[test]
354    fn count_multiple_conflicts() {
355        let content =
356            b"<<<<<<< a\nx\n=======\ny\n>>>>>>> b\nok\n<<<<<<< a\np\n=======\nq\n>>>>>>> b\n";
357        assert_eq!(count_conflict_markers(content), 2);
358    }
359}