Skip to main content

dk_engine/conflict/
ast_merge.rs

1//! AST-aware smart merge: operates at the SYMBOL level, not the line level.
2//!
3//! If Agent A modifies `fn_a` and Agent B modifies `fn_b` in the same file,
4//! that is NOT a conflict even if line numbers shifted. Only same-symbol
5//! modifications across sessions are TRUE conflicts.
6
7use std::collections::{BTreeMap, HashSet};
8use std::path::Path;
9
10use dk_core::{Error, Result, Symbol};
11
12use crate::parser::ParserRegistry;
13
14/// The overall status of a merge.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum MergeStatus {
17    /// All symbols merged cleanly — no overlapping edits.
18    Clean,
19    /// At least one symbol was modified by both sides.
20    Conflict,
21}
22
23/// The result of a three-way AST-level merge.
24#[derive(Debug)]
25pub struct MergeResult {
26    pub status: MergeStatus,
27    pub merged_content: String,
28    pub conflicts: Vec<SymbolConflict>,
29}
30
31/// A single symbol-level conflict: both sides changed the same symbol.
32#[derive(Debug)]
33pub struct SymbolConflict {
34    pub qualified_name: String,
35    pub kind: String,
36    pub version_a: String,
37    pub version_b: String,
38    pub base: String,
39}
40
41/// A named span of source text representing a top-level symbol.
42#[derive(Debug, Clone)]
43struct SymbolSpan {
44    qualified_name: String,
45    kind: String,
46    /// The full source text of the symbol (including doc comments captured by the span).
47    text: String,
48    /// Original ordering index (reserved for future use).
49    _order: usize,
50}
51
52/// An import line extracted from source.
53#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54struct ImportLine {
55    /// The raw text of the import statement (e.g. `use std::io;`).
56    text: String,
57}
58
59/// Extract the import block (all contiguous use/import statements at the top)
60/// and the list of top-level symbol spans from parsed source.
61fn extract_spans(
62    source: &str,
63    symbols: &[Symbol],
64) -> (Vec<ImportLine>, Vec<SymbolSpan>) {
65    let bytes = source.as_bytes();
66    let mut import_lines = Vec::new();
67    let mut symbol_spans = Vec::new();
68
69    // Collect symbol byte ranges so we can identify import lines
70    // (lines that fall outside any symbol span).
71    let mut symbol_ranges: Vec<(usize, usize)> = symbols
72        .iter()
73        .map(|s| (s.span.start_byte as usize, s.span.end_byte as usize))
74        .collect();
75    symbol_ranges.sort_by_key(|r| r.0);
76
77    // Extract import lines: lines in the source that are NOT inside any symbol span
78    // and look like import/use statements.
79    for line in source.lines() {
80        let trimmed = line.trim();
81        if trimmed.starts_with("use ")
82            || trimmed.starts_with("import ")
83            || trimmed.starts_with("from ")
84        {
85            // Check this line isn't inside a symbol span
86            let line_start = line.as_ptr() as usize - bytes.as_ptr() as usize;
87            let inside_symbol = symbol_ranges
88                .iter()
89                .any(|(start, end)| line_start >= *start && line_start < *end);
90            if !inside_symbol {
91                import_lines.push(ImportLine {
92                    text: line.to_string(),
93                });
94            }
95        }
96    }
97
98    // Extract symbol spans. Prepend doc comments to the symbol text so
99    // that comment changes are tracked per-symbol (e.g. adding a doc
100    // comment to a route handler is attributed to that handler's symbol,
101    // not silently dropped during AST merge).
102    for (order, sym) in symbols.iter().enumerate() {
103        let start = sym.span.start_byte as usize;
104        let end = sym.span.end_byte as usize;
105        if end <= bytes.len() {
106            let body = String::from_utf8_lossy(&bytes[start..end]).to_string();
107            // Only prepend when the doc text is outside the symbol byte span.
108            // TypeScript stores the full "// …" text as a sibling; Rust strips
109            // the "///" prefix; Python embeds the docstring inside the body.
110            let text = match &sym.doc_comment {
111                Some(doc) if !doc.is_empty() && !body.contains(doc.as_str()) => {
112                    format!("{doc}\n{body}")
113                }
114                _ => body,
115            };
116            symbol_spans.push(SymbolSpan {
117                qualified_name: sym.qualified_name.clone(),
118                kind: sym.kind.to_string(),
119                text,
120                _order: order,
121            });
122        }
123    }
124
125    (import_lines, symbol_spans)
126}
127
128/// Perform a three-way AST-level merge.
129///
130/// - `file_path`: used to select the tree-sitter parser by extension.
131/// - `base`: the common ancestor content.
132/// - `version_a`: one agent's modified content.
133/// - `version_b`: another agent's modified content.
134///
135/// Returns an error if the file extension is not supported by any parser.
136pub fn ast_merge(
137    registry: &ParserRegistry,
138    file_path: &str,
139    base: &str,
140    version_a: &str,
141    version_b: &str,
142) -> Result<MergeResult> {
143    let path = Path::new(file_path);
144
145    if !registry.supports_file(path) {
146        return Err(Error::UnsupportedLanguage(format!(
147            "AST merge not supported for file: {file_path}"
148        )));
149    }
150
151    // Parse all three versions
152    let base_analysis = registry.parse_file(path, base.as_bytes())?;
153    let a_analysis = registry.parse_file(path, version_a.as_bytes())?;
154    let b_analysis = registry.parse_file(path, version_b.as_bytes())?;
155
156    // Extract spans
157    let (base_imports, base_spans) = extract_spans(base, &base_analysis.symbols);
158    let (a_imports, a_spans) = extract_spans(version_a, &a_analysis.symbols);
159    let (b_imports, b_spans) = extract_spans(version_b, &b_analysis.symbols);
160
161    // Build lookup maps: qualified_name -> SymbolSpan
162    let base_map: BTreeMap<&str, &SymbolSpan> =
163        base_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
164    let a_map: BTreeMap<&str, &SymbolSpan> =
165        a_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
166    let b_map: BTreeMap<&str, &SymbolSpan> =
167        b_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
168
169    // Build ordered name list: base order first, then new symbols from A and B.
170    // This preserves the original file layout instead of alphabetizing.
171    let mut all_names: Vec<&str> = Vec::new();
172    let mut seen: HashSet<&str> = HashSet::new();
173
174    // Base symbols in their original order
175    for span in &base_spans {
176        let name = span.qualified_name.as_str();
177        if seen.insert(name) {
178            all_names.push(name);
179        }
180    }
181    // New symbols from A (in their file order)
182    for span in &a_spans {
183        let name = span.qualified_name.as_str();
184        if seen.insert(name) {
185            all_names.push(name);
186        }
187    }
188    // New symbols from B (in their file order)
189    for span in &b_spans {
190        let name = span.qualified_name.as_str();
191        if seen.insert(name) {
192            all_names.push(name);
193        }
194    }
195
196    let mut merged_symbols: Vec<SymbolSpan> = Vec::new();
197    let mut conflicts: Vec<SymbolConflict> = Vec::new();
198    let mut order_counter: usize = 0;
199
200    for name in &all_names {
201        let in_base = base_map.get(name);
202        let in_a = a_map.get(name);
203        let in_b = b_map.get(name);
204
205        let a_modified = match (in_base, in_a) {
206            (Some(base_s), Some(a_s)) => base_s.text != a_s.text,
207            (None, Some(_)) => true,  // new in A
208            (Some(_), None) => true,  // deleted by A
209            (None, None) => false,
210        };
211
212        let b_modified = match (in_base, in_b) {
213            (Some(base_s), Some(b_s)) => base_s.text != b_s.text,
214            (None, Some(_)) => true,  // new in B
215            (Some(_), None) => true,  // deleted by B
216            (None, None) => false,
217        };
218
219        match (a_modified, b_modified) {
220            (false, false) => {
221                // Neither modified — take base version
222                if let Some(base_s) = in_base {
223                    merged_symbols.push(SymbolSpan {
224                        qualified_name: base_s.qualified_name.clone(),
225                        kind: base_s.kind.clone(),
226                        text: base_s.text.clone(),
227                        _order: order_counter,
228                    });
229                    order_counter += 1;
230                }
231            }
232            (true, false) => {
233                // Only A modified
234                if let Some(a_s) = in_a {
235                    // A modified or added
236                    merged_symbols.push(SymbolSpan {
237                        qualified_name: a_s.qualified_name.clone(),
238                        kind: a_s.kind.clone(),
239                        text: a_s.text.clone(),
240                        _order: order_counter,
241                    });
242                    order_counter += 1;
243                }
244                // else: A deleted — don't include
245            }
246            (false, true) => {
247                // Only B modified
248                if let Some(b_s) = in_b {
249                    // B modified or added
250                    merged_symbols.push(SymbolSpan {
251                        qualified_name: b_s.qualified_name.clone(),
252                        kind: b_s.kind.clone(),
253                        text: b_s.text.clone(),
254                        _order: order_counter,
255                    });
256                    order_counter += 1;
257                }
258                // else: B deleted — don't include
259            }
260            (true, true) => {
261                // Both modified — check specifics
262                match (in_base, in_a, in_b) {
263                    (None, Some(a_s), Some(b_s)) => {
264                        // Both added a symbol with the same name → CONFLICT
265                        conflicts.push(SymbolConflict {
266                            qualified_name: name.to_string(),
267                            kind: a_s.kind.clone(),
268                            version_a: a_s.text.clone(),
269                            version_b: b_s.text.clone(),
270                            base: String::new(),
271                        });
272                        // Include A's version as placeholder in merged output
273                        merged_symbols.push(SymbolSpan {
274                            qualified_name: a_s.qualified_name.clone(),
275                            kind: a_s.kind.clone(),
276                            text: a_s.text.clone(),
277                            _order: order_counter,
278                        });
279                        order_counter += 1;
280                    }
281                    (Some(base_s), Some(a_s), Some(b_s)) => {
282                        if a_s.text == b_s.text {
283                            // Both made the same change — no conflict
284                            merged_symbols.push(SymbolSpan {
285                                qualified_name: a_s.qualified_name.clone(),
286                                kind: a_s.kind.clone(),
287                                text: a_s.text.clone(),
288                                _order: order_counter,
289                            });
290                            order_counter += 1;
291                        } else {
292                            // Both modified differently → TRUE CONFLICT
293                            conflicts.push(SymbolConflict {
294                                qualified_name: name.to_string(),
295                                kind: base_s.kind.clone(),
296                                version_a: a_s.text.clone(),
297                                version_b: b_s.text.clone(),
298                                base: base_s.text.clone(),
299                            });
300                            // Include A's version as placeholder
301                            merged_symbols.push(SymbolSpan {
302                                qualified_name: a_s.qualified_name.clone(),
303                                kind: a_s.kind.clone(),
304                                text: a_s.text.clone(),
305                                _order: order_counter,
306                            });
307                            order_counter += 1;
308                        }
309                    }
310                    (Some(base_s), None, Some(b_s)) => {
311                        // A deleted, B modified → CONFLICT
312                        conflicts.push(SymbolConflict {
313                            qualified_name: name.to_string(),
314                            kind: base_s.kind.clone(),
315                            version_a: String::new(),
316                            version_b: b_s.text.clone(),
317                            base: base_s.text.clone(),
318                        });
319                        // Include B's version as placeholder
320                        merged_symbols.push(SymbolSpan {
321                            qualified_name: b_s.qualified_name.clone(),
322                            kind: b_s.kind.clone(),
323                            text: b_s.text.clone(),
324                            _order: order_counter,
325                        });
326                        order_counter += 1;
327                    }
328                    (Some(base_s), Some(a_s), None) => {
329                        // B deleted, A modified → CONFLICT
330                        conflicts.push(SymbolConflict {
331                            qualified_name: name.to_string(),
332                            kind: base_s.kind.clone(),
333                            version_a: a_s.text.clone(),
334                            version_b: String::new(),
335                            base: base_s.text.clone(),
336                        });
337                        // Include A's version as placeholder
338                        merged_symbols.push(SymbolSpan {
339                            qualified_name: a_s.qualified_name.clone(),
340                            kind: a_s.kind.clone(),
341                            text: a_s.text.clone(),
342                            _order: order_counter,
343                        });
344                        order_counter += 1;
345                    }
346                    (Some(_), None, None) => {
347                        // Both deleted — agree, don't include
348                    }
349                    _ => {}
350                }
351            }
352        }
353    }
354
355    // Merge imports additively (union, deduplicated, preserving base order)
356    let mut merged_imports: Vec<String> = Vec::new();
357    let mut import_seen: HashSet<String> = HashSet::new();
358    // Base imports first (original order), then new imports from A and B
359    for imp in base_imports.iter().chain(a_imports.iter()).chain(b_imports.iter()) {
360        if import_seen.insert(imp.text.clone()) {
361            merged_imports.push(imp.text.clone());
362        }
363    }
364
365    // Reconstruct the file
366    let mut output = String::new();
367
368    // Imports first (in preserved order)
369    if !merged_imports.is_empty() {
370        for imp in &merged_imports {
371            output.push_str(imp);
372            output.push('\n');
373        }
374        output.push('\n');
375    }
376
377    // Then symbols, joined with double newlines
378    let symbol_texts: Vec<&str> = merged_symbols.iter().map(|s| s.text.as_str()).collect();
379    output.push_str(&symbol_texts.join("\n\n"));
380
381    // Ensure trailing newline
382    if !output.ends_with('\n') {
383        output.push('\n');
384    }
385
386    let status = if conflicts.is_empty() {
387        MergeStatus::Clean
388    } else {
389        MergeStatus::Conflict
390    };
391
392    Ok(MergeResult {
393        status,
394        merged_content: output,
395        conflicts,
396    })
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_merge_status_eq() {
405        assert_eq!(MergeStatus::Clean, MergeStatus::Clean);
406        assert_eq!(MergeStatus::Conflict, MergeStatus::Conflict);
407        assert_ne!(MergeStatus::Clean, MergeStatus::Conflict);
408    }
409}