Skip to main content

engram/sync/conflict/
merge.rs

1//! Three-way merge implementation
2
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5
6/// Result of a three-way merge
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MergeResult {
9    /// The merged content
10    pub content: String,
11    /// Whether the merge was successful (no conflicts)
12    pub success: bool,
13    /// Conflict markers in the output (if any)
14    pub has_conflict_markers: bool,
15    /// Lines that had conflicts
16    pub conflict_lines: Vec<usize>,
17    /// Statistics about the merge
18    pub stats: MergeStats,
19}
20
21/// Statistics about a merge operation
22#[derive(Debug, Clone, Default, Serialize, Deserialize)]
23pub struct MergeStats {
24    /// Lines from base that were kept
25    pub base_kept: usize,
26    /// Lines added by local
27    pub local_added: usize,
28    /// Lines added by remote
29    pub remote_added: usize,
30    /// Lines deleted by local
31    pub local_deleted: usize,
32    /// Lines deleted by remote
33    pub remote_deleted: usize,
34    /// Lines with conflicts
35    pub conflicts: usize,
36}
37
38/// Three-way merge implementation
39pub struct ThreeWayMerge {
40    /// Marker for local changes in conflict
41    local_marker: String,
42    /// Marker for remote changes in conflict
43    remote_marker: String,
44    /// Separator between versions
45    separator: String,
46}
47
48impl Default for ThreeWayMerge {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl ThreeWayMerge {
55    /// Create a new three-way merger
56    pub fn new() -> Self {
57        Self {
58            local_marker: "<<<<<<< LOCAL".to_string(),
59            remote_marker: ">>>>>>> REMOTE".to_string(),
60            separator: "=======".to_string(),
61        }
62    }
63
64    /// Set custom conflict markers
65    pub fn with_markers(
66        mut self,
67        local: impl Into<String>,
68        separator: impl Into<String>,
69        remote: impl Into<String>,
70    ) -> Self {
71        self.local_marker = local.into();
72        self.separator = separator.into();
73        self.remote_marker = remote.into();
74        self
75    }
76
77    /// Perform three-way merge
78    pub fn merge(&self, base: &str, local: &str, remote: &str) -> MergeResult {
79        let base_lines: Vec<&str> = base.lines().collect();
80        let local_lines: Vec<&str> = local.lines().collect();
81        let remote_lines: Vec<&str> = remote.lines().collect();
82
83        let mut result = Vec::new();
84        let mut stats = MergeStats::default();
85        let mut conflict_lines = Vec::new();
86        let mut has_conflicts = false;
87
88        let max_len = base_lines
89            .len()
90            .max(local_lines.len())
91            .max(remote_lines.len());
92        let mut base_idx = 0;
93        let mut local_idx = 0;
94        let mut remote_idx = 0;
95
96        while base_idx < base_lines.len()
97            || local_idx < local_lines.len()
98            || remote_idx < remote_lines.len()
99        {
100            let base_line = base_lines.get(base_idx);
101            let local_line = local_lines.get(local_idx);
102            let remote_line = remote_lines.get(remote_idx);
103
104            match (base_line, local_line, remote_line) {
105                // All three match - keep it
106                (Some(b), Some(l), Some(r)) if b == l && l == r => {
107                    result.push((*l).to_string());
108                    stats.base_kept += 1;
109                    base_idx += 1;
110                    local_idx += 1;
111                    remote_idx += 1;
112                }
113
114                // Local changed, remote unchanged - take local
115                (Some(b), Some(l), Some(r)) if b == r && b != l => {
116                    result.push((*l).to_string());
117                    stats.local_added += 1;
118                    base_idx += 1;
119                    local_idx += 1;
120                    remote_idx += 1;
121                }
122
123                // Remote changed, local unchanged - take remote
124                (Some(b), Some(l), Some(r)) if b == l && b != r => {
125                    result.push((*r).to_string());
126                    stats.remote_added += 1;
127                    base_idx += 1;
128                    local_idx += 1;
129                    remote_idx += 1;
130                }
131
132                // Both changed differently - conflict!
133                (Some(_), Some(l), Some(r)) if l != r => {
134                    has_conflicts = true;
135                    stats.conflicts += 1;
136                    conflict_lines.push(result.len());
137
138                    result.push(self.local_marker.clone());
139                    result.push((*l).to_string());
140                    result.push(self.separator.clone());
141                    result.push((*r).to_string());
142                    result.push(self.remote_marker.clone());
143
144                    base_idx += 1;
145                    local_idx += 1;
146                    remote_idx += 1;
147                }
148
149                // Both changed to same - take it
150                (Some(_), Some(l), Some(r)) if l == r => {
151                    result.push((*l).to_string());
152                    stats.local_added += 1;
153                    base_idx += 1;
154                    local_idx += 1;
155                    remote_idx += 1;
156                }
157
158                // Local added line (past base)
159                (None, Some(l), None) => {
160                    result.push((*l).to_string());
161                    stats.local_added += 1;
162                    local_idx += 1;
163                }
164
165                // Remote added line (past base)
166                (None, None, Some(r)) => {
167                    result.push((*r).to_string());
168                    stats.remote_added += 1;
169                    remote_idx += 1;
170                }
171
172                // Both added different lines - conflict
173                (None, Some(l), Some(r)) if l != r => {
174                    has_conflicts = true;
175                    stats.conflicts += 1;
176                    conflict_lines.push(result.len());
177
178                    result.push(self.local_marker.clone());
179                    result.push((*l).to_string());
180                    result.push(self.separator.clone());
181                    result.push((*r).to_string());
182                    result.push(self.remote_marker.clone());
183
184                    local_idx += 1;
185                    remote_idx += 1;
186                }
187
188                // Both added same line
189                (None, Some(l), Some(r)) if l == r => {
190                    result.push((*l).to_string());
191                    stats.local_added += 1;
192                    local_idx += 1;
193                    remote_idx += 1;
194                }
195
196                // Local deleted (remote unchanged)
197                (Some(_), None, Some(r)) if base_lines.get(base_idx) == Some(r) => {
198                    stats.local_deleted += 1;
199                    base_idx += 1;
200                    remote_idx += 1;
201                }
202
203                // Remote deleted (local unchanged)
204                (Some(_), Some(l), None) if base_lines.get(base_idx) == Some(l) => {
205                    stats.remote_deleted += 1;
206                    base_idx += 1;
207                    local_idx += 1;
208                }
209
210                // Fallback: take what we have
211                _ => {
212                    if let Some(l) = local_line {
213                        result.push((*l).to_string());
214                    }
215                    if let Some(r) = remote_line {
216                        if local_line.map(|l| l != r).unwrap_or(true) {
217                            result.push((*r).to_string());
218                        }
219                    }
220                    break; // Prevent infinite loop
221                }
222            }
223
224            // Safety: prevent infinite loops
225            if base_idx + local_idx + remote_idx > max_len * 3 + 10 {
226                break;
227            }
228        }
229
230        MergeResult {
231            content: result.join("\n"),
232            success: !has_conflicts,
233            has_conflict_markers: has_conflicts,
234            conflict_lines,
235            stats,
236        }
237    }
238
239    /// Merge tags by taking union
240    pub fn merge_tags(&self, base: &[String], local: &[String], remote: &[String]) -> Vec<String> {
241        let mut result: HashSet<String> = HashSet::new();
242
243        // Start with base
244        result.extend(base.iter().cloned());
245
246        // Add local additions
247        for tag in local {
248            result.insert(tag.clone());
249        }
250
251        // Add remote additions
252        for tag in remote {
253            result.insert(tag.clone());
254        }
255
256        // Remove tags deleted by both
257        let local_set: HashSet<_> = local.iter().collect();
258        let remote_set: HashSet<_> = remote.iter().collect();
259
260        result
261            .into_iter()
262            .filter(|tag| {
263                // Keep if: in local OR in remote OR (in base AND (in local OR in remote))
264                local_set.contains(tag) || remote_set.contains(tag)
265            })
266            .collect()
267    }
268
269    /// Merge metadata by combining with local preference for conflicts
270    /// Now works with HashMap directly
271    pub fn merge_metadata_map(
272        &self,
273        base: Option<&HashMap<String, serde_json::Value>>,
274        local: &HashMap<String, serde_json::Value>,
275        remote: &HashMap<String, serde_json::Value>,
276    ) -> HashMap<String, serde_json::Value> {
277        let mut result = HashMap::new();
278
279        // If local and remote are the same, return local
280        if local == remote {
281            return local.clone();
282        }
283
284        // Start with base if present, otherwise start fresh
285        if let Some(base_map) = base {
286            result = base_map.clone();
287        }
288
289        // Apply local changes
290        for (k, v) in local {
291            let base_value = base.and_then(|b| b.get(k));
292            if base_value != Some(v) {
293                result.insert(k.clone(), v.clone());
294            }
295        }
296
297        // Apply remote changes (if not conflicting with local)
298        for (k, v) in remote {
299            let base_value = base.and_then(|b| b.get(k));
300            let local_value = local.get(k);
301            // Only apply remote change if local didn't change this key from base
302            if base_value != Some(v) && local_value == base_value {
303                result.insert(k.clone(), v.clone());
304            }
305        }
306
307        // Add any keys that are only in local or remote (not in base)
308        for (k, v) in local {
309            if !result.contains_key(k) {
310                result.insert(k.clone(), v.clone());
311            }
312        }
313        for (k, v) in remote {
314            if !result.contains_key(k) {
315                result.insert(k.clone(), v.clone());
316            }
317        }
318
319        result
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_merge_no_conflict() {
329        let merger = ThreeWayMerge::new();
330
331        let base = "Line 1\nLine 2\nLine 3";
332        let local = "Line 1 modified\nLine 2\nLine 3";
333        let remote = "Line 1\nLine 2\nLine 3 modified";
334
335        let result = merger.merge(base, local, remote);
336        assert!(result.success);
337        assert!(!result.has_conflict_markers);
338        assert!(result.content.contains("Line 1 modified"));
339        assert!(result.content.contains("Line 3 modified"));
340    }
341
342    #[test]
343    fn test_merge_with_conflict() {
344        let merger = ThreeWayMerge::new();
345
346        let base = "Line 1\nLine 2";
347        let local = "Local change\nLine 2";
348        let remote = "Remote change\nLine 2";
349
350        let result = merger.merge(base, local, remote);
351        assert!(!result.success);
352        assert!(result.has_conflict_markers);
353        assert!(result.content.contains("<<<<<<< LOCAL"));
354        assert!(result.content.contains("Local change"));
355        assert!(result.content.contains("======="));
356        assert!(result.content.contains("Remote change"));
357        assert!(result.content.contains(">>>>>>> REMOTE"));
358    }
359
360    #[test]
361    fn test_merge_both_same_change() {
362        let merger = ThreeWayMerge::new();
363
364        let base = "Original";
365        let local = "Same change";
366        let remote = "Same change";
367
368        let result = merger.merge(base, local, remote);
369        assert!(result.success);
370        assert_eq!(result.content, "Same change");
371    }
372
373    #[test]
374    fn test_merge_tags() {
375        let merger = ThreeWayMerge::new();
376
377        let base = vec!["tag1".to_string(), "tag2".to_string()];
378        let local = vec![
379            "tag1".to_string(),
380            "tag2".to_string(),
381            "local_tag".to_string(),
382        ];
383        let remote = vec![
384            "tag1".to_string(),
385            "tag2".to_string(),
386            "remote_tag".to_string(),
387        ];
388
389        let result = merger.merge_tags(&base, &local, &remote);
390        assert!(result.contains(&"tag1".to_string()));
391        assert!(result.contains(&"tag2".to_string()));
392        assert!(result.contains(&"local_tag".to_string()));
393        assert!(result.contains(&"remote_tag".to_string()));
394    }
395
396    #[test]
397    fn test_merge_metadata_map() {
398        let merger = ThreeWayMerge::new();
399
400        let mut base = HashMap::new();
401        base.insert("a".to_string(), serde_json::json!(1));
402        base.insert("b".to_string(), serde_json::json!(2));
403
404        let mut local = HashMap::new();
405        local.insert("a".to_string(), serde_json::json!(10)); // Changed a
406        local.insert("b".to_string(), serde_json::json!(2));
407
408        let mut remote = HashMap::new();
409        remote.insert("a".to_string(), serde_json::json!(1));
410        remote.insert("b".to_string(), serde_json::json!(20)); // Changed b
411
412        let result = merger.merge_metadata_map(Some(&base), &local, &remote);
413        assert_eq!(result.get("a"), Some(&serde_json::json!(10))); // Local's change
414        assert_eq!(result.get("b"), Some(&serde_json::json!(20))); // Remote's change
415    }
416}