Skip to main content

dk_engine/conflict/
ast_merge.rs

1//! AST-aware smart merge: operates at the SYMBOL level, not the line level.
2//!
3//! If Agent A modifies `fn_a` and Agent B modifies `fn_b` in the same file,
4//! that is NOT a conflict even if line numbers shifted. Only same-symbol
5//! modifications across sessions are TRUE conflicts.
6
7use std::collections::{BTreeMap, HashSet};
8use std::path::Path;
9
10use dk_core::{Error, Result, Symbol};
11
12use crate::parser::ParserRegistry;
13
14/// The overall status of a merge.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum MergeStatus {
17    /// All symbols merged cleanly — no overlapping edits.
18    Clean,
19    /// At least one symbol was modified by both sides.
20    Conflict,
21}
22
23/// The result of a three-way AST-level merge.
24#[derive(Debug)]
25pub struct MergeResult {
26    pub status: MergeStatus,
27    pub merged_content: String,
28    pub conflicts: Vec<SymbolConflict>,
29}
30
31/// A single symbol-level conflict: both sides changed the same symbol.
32#[derive(Debug)]
33pub struct SymbolConflict {
34    pub qualified_name: String,
35    pub kind: String,
36    pub version_a: String,
37    pub version_b: String,
38    pub base: String,
39}
40
41/// A named span of source text representing a top-level symbol.
42#[derive(Debug, Clone)]
43struct SymbolSpan {
44    qualified_name: String,
45    kind: String,
46    /// The full source text of the symbol (including doc comments captured by the span).
47    text: String,
48    /// Original ordering index (reserved for future use).
49    _order: usize,
50}
51
52/// An import line extracted from source.
53#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54struct ImportLine {
55    /// The raw text of the import statement (e.g. `use std::io;`).
56    text: String,
57}
58
59/// Extract the import block (all contiguous use/import statements at the top)
60/// and the list of top-level symbol spans from parsed source.
61fn extract_spans(
62    source: &str,
63    symbols: &[Symbol],
64) -> (Vec<ImportLine>, Vec<SymbolSpan>) {
65    let bytes = source.as_bytes();
66    let mut import_lines = Vec::new();
67    let mut symbol_spans = Vec::new();
68
69    // Collect symbol byte ranges so we can identify import lines
70    // (lines that fall outside any symbol span).
71    let mut symbol_ranges: Vec<(usize, usize)> = symbols
72        .iter()
73        .map(|s| (s.span.start_byte as usize, s.span.end_byte as usize))
74        .collect();
75    symbol_ranges.sort_by_key(|r| r.0);
76
77    // Extract import lines: lines in the source that are NOT inside any symbol span
78    // and look like import/use statements.
79    for line in source.lines() {
80        let trimmed = line.trim();
81        if trimmed.starts_with("use ")
82            || trimmed.starts_with("import ")
83            || trimmed.starts_with("from ")
84        {
85            // Check this line isn't inside a symbol span
86            let line_start = line.as_ptr() as usize - bytes.as_ptr() as usize;
87            let inside_symbol = symbol_ranges
88                .iter()
89                .any(|(start, end)| line_start >= *start && line_start < *end);
90            if !inside_symbol {
91                import_lines.push(ImportLine {
92                    text: line.to_string(),
93                });
94            }
95        }
96    }
97
98    // Extract symbol spans. Prepend doc comments to the symbol text so
99    // that comment changes are tracked per-symbol (e.g. adding a doc
100    // comment to a route handler is attributed to that handler's symbol,
101    // not silently dropped during AST merge).
102    for (order, sym) in symbols.iter().enumerate() {
103        let start = sym.span.start_byte as usize;
104        let end = sym.span.end_byte as usize;
105        if end <= bytes.len() {
106            // Extend the body to include any trailing inline comment on the
107            // same line (tree-sitter Python places `# comment` as a sibling
108            // node after the expression_statement, so end_byte stops before it).
109            let extended_end = if end < bytes.len() {
110                let rest = &source[end..];
111                if let Some(nl) = rest.find('\n') {
112                    let trailing = rest[..nl].trim();
113                    if trailing.starts_with('#') || trailing.is_empty() {
114                        // Include inline comment or whitespace up to newline
115                        end + nl
116                    } else {
117                        end
118                    }
119                } else {
120                    source.len() // last line, no trailing newline
121                }
122            } else {
123                end
124            };
125            let body = String::from_utf8_lossy(&bytes[start..extended_end]).to_string();
126            // Only prepend when the doc text is outside the symbol byte span.
127            // TypeScript stores the full "// …" text as a sibling; Rust strips
128            // the "///" prefix; Python embeds the docstring inside the body.
129            let text = match &sym.doc_comment {
130                Some(doc) if !doc.is_empty() && !body.contains(doc.as_str()) => {
131                    format!("{doc}\n{body}")
132                }
133                _ => body,
134            };
135            symbol_spans.push(SymbolSpan {
136                qualified_name: sym.qualified_name.clone(),
137                kind: sym.kind.to_string(),
138                text,
139                _order: order,
140            });
141        }
142    }
143
144    (import_lines, symbol_spans)
145}
146
147/// Perform a three-way AST-level merge.
148///
149/// - `file_path`: used to select the tree-sitter parser by extension.
150/// - `base`: the common ancestor content.
151/// - `version_a`: one agent's modified content.
152/// - `version_b`: another agent's modified content.
153///
154/// Returns an error if the file extension is not supported by any parser.
155pub fn ast_merge(
156    registry: &ParserRegistry,
157    file_path: &str,
158    base: &str,
159    version_a: &str,
160    version_b: &str,
161) -> Result<MergeResult> {
162    let path = Path::new(file_path);
163
164    if !registry.supports_file(path) {
165        return Err(Error::UnsupportedLanguage(format!(
166            "AST merge not supported for file: {file_path}"
167        )));
168    }
169
170    // Parse all three versions
171    let base_analysis = registry.parse_file(path, base.as_bytes())?;
172    let a_analysis = registry.parse_file(path, version_a.as_bytes())?;
173    let b_analysis = registry.parse_file(path, version_b.as_bytes())?;
174
175    // Extract spans
176    let (base_imports, base_spans) = extract_spans(base, &base_analysis.symbols);
177    let (a_imports, a_spans) = extract_spans(version_a, &a_analysis.symbols);
178    let (b_imports, b_spans) = extract_spans(version_b, &b_analysis.symbols);
179
180    // Build lookup maps: qualified_name -> SymbolSpan
181    let base_map: BTreeMap<&str, &SymbolSpan> =
182        base_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
183    let a_map: BTreeMap<&str, &SymbolSpan> =
184        a_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
185    let b_map: BTreeMap<&str, &SymbolSpan> =
186        b_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
187
188    // Build ordered name list: base order first, then new symbols from A and B.
189    // This preserves the original file layout instead of alphabetizing.
190    let mut all_names: Vec<&str> = Vec::new();
191    let mut seen: HashSet<&str> = HashSet::new();
192
193    // Base symbols in their original order
194    for span in &base_spans {
195        let name = span.qualified_name.as_str();
196        if seen.insert(name) {
197            all_names.push(name);
198        }
199    }
200    // New symbols from A (in their file order)
201    for span in &a_spans {
202        let name = span.qualified_name.as_str();
203        if seen.insert(name) {
204            all_names.push(name);
205        }
206    }
207    // New symbols from B (in their file order)
208    for span in &b_spans {
209        let name = span.qualified_name.as_str();
210        if seen.insert(name) {
211            all_names.push(name);
212        }
213    }
214
215    let mut merged_symbols: Vec<SymbolSpan> = Vec::new();
216    let mut conflicts: Vec<SymbolConflict> = Vec::new();
217    let mut order_counter: usize = 0;
218
219    for name in &all_names {
220        let in_base = base_map.get(name);
221        let in_a = a_map.get(name);
222        let in_b = b_map.get(name);
223
224        let a_modified = match (in_base, in_a) {
225            (Some(base_s), Some(a_s)) => base_s.text != a_s.text,
226            (None, Some(_)) => true,  // new in A
227            (Some(_), None) => true,  // deleted by A
228            (None, None) => false,
229        };
230
231        let b_modified = match (in_base, in_b) {
232            (Some(base_s), Some(b_s)) => base_s.text != b_s.text,
233            (None, Some(_)) => true,  // new in B
234            (Some(_), None) => true,  // deleted by B
235            (None, None) => false,
236        };
237
238        match (a_modified, b_modified) {
239            (false, false) => {
240                // Neither modified — take base version
241                if let Some(base_s) = in_base {
242                    merged_symbols.push(SymbolSpan {
243                        qualified_name: base_s.qualified_name.clone(),
244                        kind: base_s.kind.clone(),
245                        text: base_s.text.clone(),
246                        _order: order_counter,
247                    });
248                    order_counter += 1;
249                }
250            }
251            (true, false) => {
252                // Only A modified
253                if let Some(a_s) = in_a {
254                    // A modified or added
255                    merged_symbols.push(SymbolSpan {
256                        qualified_name: a_s.qualified_name.clone(),
257                        kind: a_s.kind.clone(),
258                        text: a_s.text.clone(),
259                        _order: order_counter,
260                    });
261                    order_counter += 1;
262                }
263                // else: A deleted — don't include
264            }
265            (false, true) => {
266                // Only B modified
267                if let Some(b_s) = in_b {
268                    // B modified or added
269                    merged_symbols.push(SymbolSpan {
270                        qualified_name: b_s.qualified_name.clone(),
271                        kind: b_s.kind.clone(),
272                        text: b_s.text.clone(),
273                        _order: order_counter,
274                    });
275                    order_counter += 1;
276                }
277                // else: B deleted — don't include
278            }
279            (true, true) => {
280                // Both modified — check specifics
281                match (in_base, in_a, in_b) {
282                    (None, Some(a_s), Some(b_s)) => {
283                        // Both added a symbol with the same name → CONFLICT
284                        conflicts.push(SymbolConflict {
285                            qualified_name: name.to_string(),
286                            kind: a_s.kind.clone(),
287                            version_a: a_s.text.clone(),
288                            version_b: b_s.text.clone(),
289                            base: String::new(),
290                        });
291                        // Include A's version as placeholder in merged output
292                        merged_symbols.push(SymbolSpan {
293                            qualified_name: a_s.qualified_name.clone(),
294                            kind: a_s.kind.clone(),
295                            text: a_s.text.clone(),
296                            _order: order_counter,
297                        });
298                        order_counter += 1;
299                    }
300                    (Some(base_s), Some(a_s), Some(b_s)) => {
301                        if a_s.text == b_s.text {
302                            // Both made the same change — no conflict
303                            merged_symbols.push(SymbolSpan {
304                                qualified_name: a_s.qualified_name.clone(),
305                                kind: a_s.kind.clone(),
306                                text: a_s.text.clone(),
307                                _order: order_counter,
308                            });
309                            order_counter += 1;
310                        } else {
311                            // Both modified differently → TRUE CONFLICT
312                            conflicts.push(SymbolConflict {
313                                qualified_name: name.to_string(),
314                                kind: base_s.kind.clone(),
315                                version_a: a_s.text.clone(),
316                                version_b: b_s.text.clone(),
317                                base: base_s.text.clone(),
318                            });
319                            // Include A's version as placeholder
320                            merged_symbols.push(SymbolSpan {
321                                qualified_name: a_s.qualified_name.clone(),
322                                kind: a_s.kind.clone(),
323                                text: a_s.text.clone(),
324                                _order: order_counter,
325                            });
326                            order_counter += 1;
327                        }
328                    }
329                    (Some(base_s), None, Some(b_s)) => {
330                        // A deleted, B modified → CONFLICT
331                        conflicts.push(SymbolConflict {
332                            qualified_name: name.to_string(),
333                            kind: base_s.kind.clone(),
334                            version_a: String::new(),
335                            version_b: b_s.text.clone(),
336                            base: base_s.text.clone(),
337                        });
338                        // Include B's version as placeholder
339                        merged_symbols.push(SymbolSpan {
340                            qualified_name: b_s.qualified_name.clone(),
341                            kind: b_s.kind.clone(),
342                            text: b_s.text.clone(),
343                            _order: order_counter,
344                        });
345                        order_counter += 1;
346                    }
347                    (Some(base_s), Some(a_s), None) => {
348                        // B deleted, A modified → CONFLICT
349                        conflicts.push(SymbolConflict {
350                            qualified_name: name.to_string(),
351                            kind: base_s.kind.clone(),
352                            version_a: a_s.text.clone(),
353                            version_b: String::new(),
354                            base: base_s.text.clone(),
355                        });
356                        // Include A's version as placeholder
357                        merged_symbols.push(SymbolSpan {
358                            qualified_name: a_s.qualified_name.clone(),
359                            kind: a_s.kind.clone(),
360                            text: a_s.text.clone(),
361                            _order: order_counter,
362                        });
363                        order_counter += 1;
364                    }
365                    (Some(_), None, None) => {
366                        // Both deleted — agree, don't include
367                    }
368                    _ => {}
369                }
370            }
371        }
372    }
373
374    // Merge imports additively (union, deduplicated, preserving base order)
375    let mut merged_imports: Vec<String> = Vec::new();
376    let mut import_seen: HashSet<String> = HashSet::new();
377    // Base imports first (original order), then new imports from A and B
378    for imp in base_imports.iter().chain(a_imports.iter()).chain(b_imports.iter()) {
379        if import_seen.insert(imp.text.clone()) {
380            merged_imports.push(imp.text.clone());
381        }
382    }
383
384    // Reconstruct the file
385    let mut output = String::new();
386
387    // Imports first (in preserved order)
388    if !merged_imports.is_empty() {
389        for imp in &merged_imports {
390            output.push_str(imp);
391            output.push('\n');
392        }
393        output.push('\n');
394    }
395
396    // Join symbols with context-aware spacing:
397    // - Double newline (\n\n) before/after functions and classes
398    // - Single newline (\n) between consecutive variables
399    for (i, sym) in merged_symbols.iter().enumerate() {
400        if i > 0 {
401            let prev_is_var = merged_symbols[i - 1].kind == "variable";
402            let curr_is_var = sym.kind == "variable";
403            if prev_is_var && curr_is_var {
404                output.push('\n');
405            } else {
406                output.push_str("\n\n");
407            }
408        }
409        output.push_str(&sym.text);
410    }
411
412    // Ensure trailing newline
413    if !output.ends_with('\n') {
414        output.push('\n');
415    }
416
417    let status = if conflicts.is_empty() {
418        MergeStatus::Clean
419    } else {
420        MergeStatus::Conflict
421    };
422
423    Ok(MergeResult {
424        status,
425        merged_content: output,
426        conflicts,
427    })
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_merge_status_eq() {
436        assert_eq!(MergeStatus::Clean, MergeStatus::Clean);
437        assert_eq!(MergeStatus::Conflict, MergeStatus::Conflict);
438        assert_ne!(MergeStatus::Clean, MergeStatus::Conflict);
439    }
440}