Skip to main content

dk_engine/conflict/
ast_merge.rs

1//! AST-aware smart merge: operates at the SYMBOL level, not the line level.
2//!
3//! If Agent A modifies `fn_a` and Agent B modifies `fn_b` in the same file,
4//! that is NOT a conflict even if line numbers shifted. Only same-symbol
5//! modifications across sessions are TRUE conflicts.
6
7use std::collections::{BTreeMap, BTreeSet};
8use std::path::Path;
9
10use dk_core::{Error, Result, Symbol};
11
12use crate::parser::ParserRegistry;
13
14/// The overall status of a merge.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum MergeStatus {
17    /// All symbols merged cleanly — no overlapping edits.
18    Clean,
19    /// At least one symbol was modified by both sides.
20    Conflict,
21}
22
23/// The result of a three-way AST-level merge.
24#[derive(Debug)]
25pub struct MergeResult {
26    pub status: MergeStatus,
27    pub merged_content: String,
28    pub conflicts: Vec<SymbolConflict>,
29}
30
31/// A single symbol-level conflict: both sides changed the same symbol.
32#[derive(Debug)]
33pub struct SymbolConflict {
34    pub qualified_name: String,
35    pub kind: String,
36    pub version_a: String,
37    pub version_b: String,
38    pub base: String,
39}
40
41/// A named span of source text representing a top-level symbol.
42#[derive(Debug, Clone)]
43struct SymbolSpan {
44    qualified_name: String,
45    kind: String,
46    /// The full source text of the symbol (including doc comments captured by the span).
47    text: String,
48    /// Original ordering index so we can reconstruct file order.
49    _order: usize,
50}
51
52/// An import line extracted from source.
53#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54struct ImportLine {
55    /// The raw text of the import statement (e.g. `use std::io;`).
56    text: String,
57}
58
59/// Extract the import block (all contiguous use/import statements at the top)
60/// and the list of top-level symbol spans from parsed source.
61fn extract_spans(
62    source: &str,
63    symbols: &[Symbol],
64) -> (Vec<ImportLine>, Vec<SymbolSpan>) {
65    let bytes = source.as_bytes();
66    let mut import_lines = Vec::new();
67    let mut symbol_spans = Vec::new();
68
69    // Collect symbol byte ranges so we can identify import lines
70    // (lines that fall outside any symbol span).
71    let mut symbol_ranges: Vec<(usize, usize)> = symbols
72        .iter()
73        .map(|s| (s.span.start_byte as usize, s.span.end_byte as usize))
74        .collect();
75    symbol_ranges.sort_by_key(|r| r.0);
76
77    // Extract import lines: lines in the source that are NOT inside any symbol span
78    // and look like import/use statements.
79    for line in source.lines() {
80        let trimmed = line.trim();
81        if trimmed.starts_with("use ")
82            || trimmed.starts_with("import ")
83            || trimmed.starts_with("from ")
84        {
85            // Check this line isn't inside a symbol span
86            let line_start = line.as_ptr() as usize - bytes.as_ptr() as usize;
87            let inside_symbol = symbol_ranges
88                .iter()
89                .any(|(start, end)| line_start >= *start && line_start < *end);
90            if !inside_symbol {
91                import_lines.push(ImportLine {
92                    text: line.to_string(),
93                });
94            }
95        }
96    }
97
98    // Extract symbol spans. Prepend doc comments to the symbol text so
99    // that comment changes are tracked per-symbol (e.g. adding a doc
100    // comment to a route handler is attributed to that handler's symbol,
101    // not silently dropped during AST merge).
102    for (order, sym) in symbols.iter().enumerate() {
103        let start = sym.span.start_byte as usize;
104        let end = sym.span.end_byte as usize;
105        if end <= bytes.len() {
106            let body = String::from_utf8_lossy(&bytes[start..end]).to_string();
107            // Only prepend when the doc text is outside the symbol byte span.
108            // TypeScript stores the full "// …" text as a sibling; Rust strips
109            // the "///" prefix; Python embeds the docstring inside the body.
110            let text = match &sym.doc_comment {
111                Some(doc) if !doc.is_empty() && !body.contains(doc.as_str()) => {
112                    format!("{doc}\n{body}")
113                }
114                _ => body,
115            };
116            symbol_spans.push(SymbolSpan {
117                qualified_name: sym.qualified_name.clone(),
118                kind: sym.kind.to_string(),
119                text,
120                _order: order,
121            });
122        }
123    }
124
125    (import_lines, symbol_spans)
126}
127
128/// Perform a three-way AST-level merge.
129///
130/// - `file_path`: used to select the tree-sitter parser by extension.
131/// - `base`: the common ancestor content.
132/// - `version_a`: one agent's modified content.
133/// - `version_b`: another agent's modified content.
134///
135/// Returns an error if the file extension is not supported by any parser.
136pub fn ast_merge(
137    registry: &ParserRegistry,
138    file_path: &str,
139    base: &str,
140    version_a: &str,
141    version_b: &str,
142) -> Result<MergeResult> {
143    let path = Path::new(file_path);
144
145    if !registry.supports_file(path) {
146        return Err(Error::UnsupportedLanguage(format!(
147            "AST merge not supported for file: {file_path}"
148        )));
149    }
150
151    // Parse all three versions
152    let base_analysis = registry.parse_file(path, base.as_bytes())?;
153    let a_analysis = registry.parse_file(path, version_a.as_bytes())?;
154    let b_analysis = registry.parse_file(path, version_b.as_bytes())?;
155
156    // Extract spans
157    let (base_imports, base_spans) = extract_spans(base, &base_analysis.symbols);
158    let (a_imports, a_spans) = extract_spans(version_a, &a_analysis.symbols);
159    let (b_imports, b_spans) = extract_spans(version_b, &b_analysis.symbols);
160
161    // Build lookup maps: qualified_name -> SymbolSpan
162    let base_map: BTreeMap<&str, &SymbolSpan> =
163        base_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
164    let a_map: BTreeMap<&str, &SymbolSpan> =
165        a_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
166    let b_map: BTreeMap<&str, &SymbolSpan> =
167        b_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
168
169    // Collect all unique symbol names across the three versions
170    let all_names: BTreeSet<&str> = base_map
171        .keys()
172        .chain(a_map.keys())
173        .chain(b_map.keys())
174        .copied()
175        .collect();
176
177    let mut merged_symbols: Vec<SymbolSpan> = Vec::new();
178    let mut conflicts: Vec<SymbolConflict> = Vec::new();
179    let mut order_counter: usize = 0;
180
181    for name in &all_names {
182        let in_base = base_map.get(name);
183        let in_a = a_map.get(name);
184        let in_b = b_map.get(name);
185
186        let a_modified = match (in_base, in_a) {
187            (Some(base_s), Some(a_s)) => base_s.text != a_s.text,
188            (None, Some(_)) => true,  // new in A
189            (Some(_), None) => true,  // deleted by A
190            (None, None) => false,
191        };
192
193        let b_modified = match (in_base, in_b) {
194            (Some(base_s), Some(b_s)) => base_s.text != b_s.text,
195            (None, Some(_)) => true,  // new in B
196            (Some(_), None) => true,  // deleted by B
197            (None, None) => false,
198        };
199
200        match (a_modified, b_modified) {
201            (false, false) => {
202                // Neither modified — take base version
203                if let Some(base_s) = in_base {
204                    merged_symbols.push(SymbolSpan {
205                        qualified_name: base_s.qualified_name.clone(),
206                        kind: base_s.kind.clone(),
207                        text: base_s.text.clone(),
208                        _order: order_counter,
209                    });
210                    order_counter += 1;
211                }
212            }
213            (true, false) => {
214                // Only A modified
215                if let Some(a_s) = in_a {
216                    // A modified or added
217                    merged_symbols.push(SymbolSpan {
218                        qualified_name: a_s.qualified_name.clone(),
219                        kind: a_s.kind.clone(),
220                        text: a_s.text.clone(),
221                        _order: order_counter,
222                    });
223                    order_counter += 1;
224                }
225                // else: A deleted — don't include
226            }
227            (false, true) => {
228                // Only B modified
229                if let Some(b_s) = in_b {
230                    // B modified or added
231                    merged_symbols.push(SymbolSpan {
232                        qualified_name: b_s.qualified_name.clone(),
233                        kind: b_s.kind.clone(),
234                        text: b_s.text.clone(),
235                        _order: order_counter,
236                    });
237                    order_counter += 1;
238                }
239                // else: B deleted — don't include
240            }
241            (true, true) => {
242                // Both modified — check specifics
243                match (in_base, in_a, in_b) {
244                    (None, Some(a_s), Some(b_s)) => {
245                        // Both added a symbol with the same name → CONFLICT
246                        conflicts.push(SymbolConflict {
247                            qualified_name: name.to_string(),
248                            kind: a_s.kind.clone(),
249                            version_a: a_s.text.clone(),
250                            version_b: b_s.text.clone(),
251                            base: String::new(),
252                        });
253                        // Include A's version as placeholder in merged output
254                        merged_symbols.push(SymbolSpan {
255                            qualified_name: a_s.qualified_name.clone(),
256                            kind: a_s.kind.clone(),
257                            text: a_s.text.clone(),
258                            _order: order_counter,
259                        });
260                        order_counter += 1;
261                    }
262                    (Some(base_s), Some(a_s), Some(b_s)) => {
263                        if a_s.text == b_s.text {
264                            // Both made the same change — no conflict
265                            merged_symbols.push(SymbolSpan {
266                                qualified_name: a_s.qualified_name.clone(),
267                                kind: a_s.kind.clone(),
268                                text: a_s.text.clone(),
269                                _order: order_counter,
270                            });
271                            order_counter += 1;
272                        } else {
273                            // Both modified differently → TRUE CONFLICT
274                            conflicts.push(SymbolConflict {
275                                qualified_name: name.to_string(),
276                                kind: base_s.kind.clone(),
277                                version_a: a_s.text.clone(),
278                                version_b: b_s.text.clone(),
279                                base: base_s.text.clone(),
280                            });
281                            // Include A's version as placeholder
282                            merged_symbols.push(SymbolSpan {
283                                qualified_name: a_s.qualified_name.clone(),
284                                kind: a_s.kind.clone(),
285                                text: a_s.text.clone(),
286                                _order: order_counter,
287                            });
288                            order_counter += 1;
289                        }
290                    }
291                    (Some(base_s), None, Some(b_s)) => {
292                        // A deleted, B modified → CONFLICT
293                        conflicts.push(SymbolConflict {
294                            qualified_name: name.to_string(),
295                            kind: base_s.kind.clone(),
296                            version_a: String::new(),
297                            version_b: b_s.text.clone(),
298                            base: base_s.text.clone(),
299                        });
300                        // Include B's version as placeholder
301                        merged_symbols.push(SymbolSpan {
302                            qualified_name: b_s.qualified_name.clone(),
303                            kind: b_s.kind.clone(),
304                            text: b_s.text.clone(),
305                            _order: order_counter,
306                        });
307                        order_counter += 1;
308                    }
309                    (Some(base_s), Some(a_s), None) => {
310                        // B deleted, A modified → CONFLICT
311                        conflicts.push(SymbolConflict {
312                            qualified_name: name.to_string(),
313                            kind: base_s.kind.clone(),
314                            version_a: a_s.text.clone(),
315                            version_b: String::new(),
316                            base: base_s.text.clone(),
317                        });
318                        // Include A's version as placeholder
319                        merged_symbols.push(SymbolSpan {
320                            qualified_name: a_s.qualified_name.clone(),
321                            kind: a_s.kind.clone(),
322                            text: a_s.text.clone(),
323                            _order: order_counter,
324                        });
325                        order_counter += 1;
326                    }
327                    (Some(_), None, None) => {
328                        // Both deleted — agree, don't include
329                    }
330                    _ => {}
331                }
332            }
333        }
334    }
335
336    // Merge imports additively (union, deduplicated)
337    let mut merged_import_set: BTreeSet<String> = BTreeSet::new();
338    for imp in base_imports.iter().chain(a_imports.iter()).chain(b_imports.iter()) {
339        merged_import_set.insert(imp.text.clone());
340    }
341
342    // Reconstruct the file
343    let mut output = String::new();
344
345    // Imports first
346    if !merged_import_set.is_empty() {
347        for imp in &merged_import_set {
348            output.push_str(imp);
349            output.push('\n');
350        }
351        output.push('\n');
352    }
353
354    // Then symbols, joined with double newlines
355    let symbol_texts: Vec<&str> = merged_symbols.iter().map(|s| s.text.as_str()).collect();
356    output.push_str(&symbol_texts.join("\n\n"));
357
358    // Ensure trailing newline
359    if !output.ends_with('\n') {
360        output.push('\n');
361    }
362
363    let status = if conflicts.is_empty() {
364        MergeStatus::Clean
365    } else {
366        MergeStatus::Conflict
367    };
368
369    Ok(MergeResult {
370        status,
371        merged_content: output,
372        conflicts,
373    })
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_merge_status_eq() {
382        assert_eq!(MergeStatus::Clean, MergeStatus::Clean);
383        assert_eq!(MergeStatus::Conflict, MergeStatus::Conflict);
384        assert_ne!(MergeStatus::Clean, MergeStatus::Conflict);
385    }
386}