Skip to main content

dk_engine/conflict/
ast_merge.rs

1//! AST-aware smart merge: operates at the SYMBOL level, not the line level.
2//!
3//! If Agent A modifies `fn_a` and Agent B modifies `fn_b` in the same file,
4//! that is NOT a conflict even if line numbers shifted. Only same-symbol
5//! modifications across sessions are TRUE conflicts.
6
7use std::collections::{BTreeMap, BTreeSet};
8use std::path::Path;
9
10use dk_core::{Error, Result, Symbol};
11
12use crate::parser::ParserRegistry;
13
14/// The overall status of a merge.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum MergeStatus {
17    /// All symbols merged cleanly — no overlapping edits.
18    Clean,
19    /// At least one symbol was modified by both sides.
20    Conflict,
21}
22
23/// The result of a three-way AST-level merge.
24#[derive(Debug)]
25pub struct MergeResult {
26    pub status: MergeStatus,
27    pub merged_content: String,
28    pub conflicts: Vec<SymbolConflict>,
29}
30
31/// A single symbol-level conflict: both sides changed the same symbol.
32#[derive(Debug)]
33pub struct SymbolConflict {
34    pub qualified_name: String,
35    pub kind: String,
36    pub version_a: String,
37    pub version_b: String,
38    pub base: String,
39}
40
41/// A named span of source text representing a top-level symbol.
42#[derive(Debug, Clone)]
43struct SymbolSpan {
44    qualified_name: String,
45    kind: String,
46    /// The full source text of the symbol (including doc comments captured by the span).
47    text: String,
48    /// Original ordering index so we can reconstruct file order.
49    _order: usize,
50}
51
52/// An import line extracted from source.
53#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54struct ImportLine {
55    /// The raw text of the import statement (e.g. `use std::io;`).
56    text: String,
57}
58
59/// Extract the import block (all contiguous use/import statements at the top)
60/// and the list of top-level symbol spans from parsed source.
61fn extract_spans(
62    source: &str,
63    symbols: &[Symbol],
64) -> (Vec<ImportLine>, Vec<SymbolSpan>) {
65    let bytes = source.as_bytes();
66    let mut import_lines = Vec::new();
67    let mut symbol_spans = Vec::new();
68
69    // Collect symbol byte ranges so we can identify import lines
70    // (lines that fall outside any symbol span).
71    let mut symbol_ranges: Vec<(usize, usize)> = symbols
72        .iter()
73        .map(|s| (s.span.start_byte as usize, s.span.end_byte as usize))
74        .collect();
75    symbol_ranges.sort_by_key(|r| r.0);
76
77    // Extract import lines: lines in the source that are NOT inside any symbol span
78    // and look like import/use statements.
79    for line in source.lines() {
80        let trimmed = line.trim();
81        if trimmed.starts_with("use ")
82            || trimmed.starts_with("import ")
83            || trimmed.starts_with("from ")
84        {
85            // Check this line isn't inside a symbol span
86            let line_start = line.as_ptr() as usize - bytes.as_ptr() as usize;
87            let inside_symbol = symbol_ranges
88                .iter()
89                .any(|(start, end)| line_start >= *start && line_start < *end);
90            if !inside_symbol {
91                import_lines.push(ImportLine {
92                    text: line.to_string(),
93                });
94            }
95        }
96    }
97
98    // Extract symbol spans
99    for (order, sym) in symbols.iter().enumerate() {
100        let start = sym.span.start_byte as usize;
101        let end = sym.span.end_byte as usize;
102        if end <= bytes.len() {
103            let text = String::from_utf8_lossy(&bytes[start..end]).to_string();
104            symbol_spans.push(SymbolSpan {
105                qualified_name: sym.qualified_name.clone(),
106                kind: sym.kind.to_string(),
107                text,
108                _order: order,
109            });
110        }
111    }
112
113    (import_lines, symbol_spans)
114}
115
116/// Perform a three-way AST-level merge.
117///
118/// - `file_path`: used to select the tree-sitter parser by extension.
119/// - `base`: the common ancestor content.
120/// - `version_a`: one agent's modified content.
121/// - `version_b`: another agent's modified content.
122///
123/// Returns an error if the file extension is not supported by any parser.
124pub fn ast_merge(
125    registry: &ParserRegistry,
126    file_path: &str,
127    base: &str,
128    version_a: &str,
129    version_b: &str,
130) -> Result<MergeResult> {
131    let path = Path::new(file_path);
132
133    if !registry.supports_file(path) {
134        return Err(Error::UnsupportedLanguage(format!(
135            "AST merge not supported for file: {file_path}"
136        )));
137    }
138
139    // Parse all three versions
140    let base_analysis = registry.parse_file(path, base.as_bytes())?;
141    let a_analysis = registry.parse_file(path, version_a.as_bytes())?;
142    let b_analysis = registry.parse_file(path, version_b.as_bytes())?;
143
144    // Extract spans
145    let (base_imports, base_spans) = extract_spans(base, &base_analysis.symbols);
146    let (a_imports, a_spans) = extract_spans(version_a, &a_analysis.symbols);
147    let (b_imports, b_spans) = extract_spans(version_b, &b_analysis.symbols);
148
149    // Build lookup maps: qualified_name -> SymbolSpan
150    let base_map: BTreeMap<&str, &SymbolSpan> =
151        base_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
152    let a_map: BTreeMap<&str, &SymbolSpan> =
153        a_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
154    let b_map: BTreeMap<&str, &SymbolSpan> =
155        b_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
156
157    // Collect all unique symbol names across the three versions
158    let all_names: BTreeSet<&str> = base_map
159        .keys()
160        .chain(a_map.keys())
161        .chain(b_map.keys())
162        .copied()
163        .collect();
164
165    let mut merged_symbols: Vec<SymbolSpan> = Vec::new();
166    let mut conflicts: Vec<SymbolConflict> = Vec::new();
167    let mut order_counter: usize = 0;
168
169    for name in &all_names {
170        let in_base = base_map.get(name);
171        let in_a = a_map.get(name);
172        let in_b = b_map.get(name);
173
174        let a_modified = match (in_base, in_a) {
175            (Some(base_s), Some(a_s)) => base_s.text != a_s.text,
176            (None, Some(_)) => true,  // new in A
177            (Some(_), None) => true,  // deleted by A
178            (None, None) => false,
179        };
180
181        let b_modified = match (in_base, in_b) {
182            (Some(base_s), Some(b_s)) => base_s.text != b_s.text,
183            (None, Some(_)) => true,  // new in B
184            (Some(_), None) => true,  // deleted by B
185            (None, None) => false,
186        };
187
188        match (a_modified, b_modified) {
189            (false, false) => {
190                // Neither modified — take base version
191                if let Some(base_s) = in_base {
192                    merged_symbols.push(SymbolSpan {
193                        qualified_name: base_s.qualified_name.clone(),
194                        kind: base_s.kind.clone(),
195                        text: base_s.text.clone(),
196                        _order: order_counter,
197                    });
198                    order_counter += 1;
199                }
200            }
201            (true, false) => {
202                // Only A modified
203                if let Some(a_s) = in_a {
204                    // A modified or added
205                    merged_symbols.push(SymbolSpan {
206                        qualified_name: a_s.qualified_name.clone(),
207                        kind: a_s.kind.clone(),
208                        text: a_s.text.clone(),
209                        _order: order_counter,
210                    });
211                    order_counter += 1;
212                }
213                // else: A deleted — don't include
214            }
215            (false, true) => {
216                // Only B modified
217                if let Some(b_s) = in_b {
218                    // B modified or added
219                    merged_symbols.push(SymbolSpan {
220                        qualified_name: b_s.qualified_name.clone(),
221                        kind: b_s.kind.clone(),
222                        text: b_s.text.clone(),
223                        _order: order_counter,
224                    });
225                    order_counter += 1;
226                }
227                // else: B deleted — don't include
228            }
229            (true, true) => {
230                // Both modified — check specifics
231                match (in_base, in_a, in_b) {
232                    (None, Some(a_s), Some(b_s)) => {
233                        // Both added a symbol with the same name → CONFLICT
234                        conflicts.push(SymbolConflict {
235                            qualified_name: name.to_string(),
236                            kind: a_s.kind.clone(),
237                            version_a: a_s.text.clone(),
238                            version_b: b_s.text.clone(),
239                            base: String::new(),
240                        });
241                        // Include A's version as placeholder in merged output
242                        merged_symbols.push(SymbolSpan {
243                            qualified_name: a_s.qualified_name.clone(),
244                            kind: a_s.kind.clone(),
245                            text: a_s.text.clone(),
246                            _order: order_counter,
247                        });
248                        order_counter += 1;
249                    }
250                    (Some(base_s), Some(a_s), Some(b_s)) => {
251                        if a_s.text == b_s.text {
252                            // Both made the same change — no conflict
253                            merged_symbols.push(SymbolSpan {
254                                qualified_name: a_s.qualified_name.clone(),
255                                kind: a_s.kind.clone(),
256                                text: a_s.text.clone(),
257                                _order: order_counter,
258                            });
259                            order_counter += 1;
260                        } else {
261                            // Both modified differently → TRUE CONFLICT
262                            conflicts.push(SymbolConflict {
263                                qualified_name: name.to_string(),
264                                kind: base_s.kind.clone(),
265                                version_a: a_s.text.clone(),
266                                version_b: b_s.text.clone(),
267                                base: base_s.text.clone(),
268                            });
269                            // Include A's version as placeholder
270                            merged_symbols.push(SymbolSpan {
271                                qualified_name: a_s.qualified_name.clone(),
272                                kind: a_s.kind.clone(),
273                                text: a_s.text.clone(),
274                                _order: order_counter,
275                            });
276                            order_counter += 1;
277                        }
278                    }
279                    (Some(base_s), None, Some(b_s)) => {
280                        // A deleted, B modified → CONFLICT
281                        conflicts.push(SymbolConflict {
282                            qualified_name: name.to_string(),
283                            kind: base_s.kind.clone(),
284                            version_a: String::new(),
285                            version_b: b_s.text.clone(),
286                            base: base_s.text.clone(),
287                        });
288                        // Include B's version as placeholder
289                        merged_symbols.push(SymbolSpan {
290                            qualified_name: b_s.qualified_name.clone(),
291                            kind: b_s.kind.clone(),
292                            text: b_s.text.clone(),
293                            _order: order_counter,
294                        });
295                        order_counter += 1;
296                    }
297                    (Some(base_s), Some(a_s), None) => {
298                        // B deleted, A modified → CONFLICT
299                        conflicts.push(SymbolConflict {
300                            qualified_name: name.to_string(),
301                            kind: base_s.kind.clone(),
302                            version_a: a_s.text.clone(),
303                            version_b: String::new(),
304                            base: base_s.text.clone(),
305                        });
306                        // Include A's version as placeholder
307                        merged_symbols.push(SymbolSpan {
308                            qualified_name: a_s.qualified_name.clone(),
309                            kind: a_s.kind.clone(),
310                            text: a_s.text.clone(),
311                            _order: order_counter,
312                        });
313                        order_counter += 1;
314                    }
315                    (Some(_), None, None) => {
316                        // Both deleted — agree, don't include
317                    }
318                    _ => {}
319                }
320            }
321        }
322    }
323
324    // Merge imports additively (union, deduplicated)
325    let mut merged_import_set: BTreeSet<String> = BTreeSet::new();
326    for imp in base_imports.iter().chain(a_imports.iter()).chain(b_imports.iter()) {
327        merged_import_set.insert(imp.text.clone());
328    }
329
330    // Reconstruct the file
331    let mut output = String::new();
332
333    // Imports first
334    if !merged_import_set.is_empty() {
335        for imp in &merged_import_set {
336            output.push_str(imp);
337            output.push('\n');
338        }
339        output.push('\n');
340    }
341
342    // Then symbols, joined with double newlines
343    let symbol_texts: Vec<&str> = merged_symbols.iter().map(|s| s.text.as_str()).collect();
344    output.push_str(&symbol_texts.join("\n\n"));
345
346    // Ensure trailing newline
347    if !output.ends_with('\n') {
348        output.push('\n');
349    }
350
351    let status = if conflicts.is_empty() {
352        MergeStatus::Clean
353    } else {
354        MergeStatus::Conflict
355    };
356
357    Ok(MergeResult {
358        status,
359        merged_content: output,
360        conflicts,
361    })
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_merge_status_eq() {
370        assert_eq!(MergeStatus::Clean, MergeStatus::Clean);
371        assert_eq!(MergeStatus::Conflict, MergeStatus::Conflict);
372        assert_ne!(MergeStatus::Clean, MergeStatus::Conflict);
373    }
374}