1use std::collections::{BTreeMap, BTreeSet};
8use std::path::Path;
9
10use dk_core::{Error, Result, Symbol};
11
12use crate::parser::ParserRegistry;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum MergeStatus {
17 Clean,
19 Conflict,
21}
22
23#[derive(Debug)]
25pub struct MergeResult {
26 pub status: MergeStatus,
27 pub merged_content: String,
28 pub conflicts: Vec<SymbolConflict>,
29}
30
31#[derive(Debug)]
33pub struct SymbolConflict {
34 pub qualified_name: String,
35 pub kind: String,
36 pub version_a: String,
37 pub version_b: String,
38 pub base: String,
39}
40
41#[derive(Debug, Clone)]
43struct SymbolSpan {
44 qualified_name: String,
45 kind: String,
46 text: String,
48 _order: usize,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54struct ImportLine {
55 text: String,
57}
58
59fn extract_spans(
62 source: &str,
63 symbols: &[Symbol],
64) -> (Vec<ImportLine>, Vec<SymbolSpan>) {
65 let bytes = source.as_bytes();
66 let mut import_lines = Vec::new();
67 let mut symbol_spans = Vec::new();
68
69 let mut symbol_ranges: Vec<(usize, usize)> = symbols
72 .iter()
73 .map(|s| (s.span.start_byte as usize, s.span.end_byte as usize))
74 .collect();
75 symbol_ranges.sort_by_key(|r| r.0);
76
77 for line in source.lines() {
80 let trimmed = line.trim();
81 if trimmed.starts_with("use ")
82 || trimmed.starts_with("import ")
83 || trimmed.starts_with("from ")
84 {
85 let line_start = line.as_ptr() as usize - bytes.as_ptr() as usize;
87 let inside_symbol = symbol_ranges
88 .iter()
89 .any(|(start, end)| line_start >= *start && line_start < *end);
90 if !inside_symbol {
91 import_lines.push(ImportLine {
92 text: line.to_string(),
93 });
94 }
95 }
96 }
97
98 for (order, sym) in symbols.iter().enumerate() {
100 let start = sym.span.start_byte as usize;
101 let end = sym.span.end_byte as usize;
102 if end <= bytes.len() {
103 let text = String::from_utf8_lossy(&bytes[start..end]).to_string();
104 symbol_spans.push(SymbolSpan {
105 qualified_name: sym.qualified_name.clone(),
106 kind: sym.kind.to_string(),
107 text,
108 _order: order,
109 });
110 }
111 }
112
113 (import_lines, symbol_spans)
114}
115
116pub fn ast_merge(
125 registry: &ParserRegistry,
126 file_path: &str,
127 base: &str,
128 version_a: &str,
129 version_b: &str,
130) -> Result<MergeResult> {
131 let path = Path::new(file_path);
132
133 if !registry.supports_file(path) {
134 return Err(Error::UnsupportedLanguage(format!(
135 "AST merge not supported for file: {file_path}"
136 )));
137 }
138
139 let base_analysis = registry.parse_file(path, base.as_bytes())?;
141 let a_analysis = registry.parse_file(path, version_a.as_bytes())?;
142 let b_analysis = registry.parse_file(path, version_b.as_bytes())?;
143
144 let (base_imports, base_spans) = extract_spans(base, &base_analysis.symbols);
146 let (a_imports, a_spans) = extract_spans(version_a, &a_analysis.symbols);
147 let (b_imports, b_spans) = extract_spans(version_b, &b_analysis.symbols);
148
149 let base_map: BTreeMap<&str, &SymbolSpan> =
151 base_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
152 let a_map: BTreeMap<&str, &SymbolSpan> =
153 a_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
154 let b_map: BTreeMap<&str, &SymbolSpan> =
155 b_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
156
157 let all_names: BTreeSet<&str> = base_map
159 .keys()
160 .chain(a_map.keys())
161 .chain(b_map.keys())
162 .copied()
163 .collect();
164
165 let mut merged_symbols: Vec<SymbolSpan> = Vec::new();
166 let mut conflicts: Vec<SymbolConflict> = Vec::new();
167 let mut order_counter: usize = 0;
168
169 for name in &all_names {
170 let in_base = base_map.get(name);
171 let in_a = a_map.get(name);
172 let in_b = b_map.get(name);
173
174 let a_modified = match (in_base, in_a) {
175 (Some(base_s), Some(a_s)) => base_s.text != a_s.text,
176 (None, Some(_)) => true, (Some(_), None) => true, (None, None) => false,
179 };
180
181 let b_modified = match (in_base, in_b) {
182 (Some(base_s), Some(b_s)) => base_s.text != b_s.text,
183 (None, Some(_)) => true, (Some(_), None) => true, (None, None) => false,
186 };
187
188 match (a_modified, b_modified) {
189 (false, false) => {
190 if let Some(base_s) = in_base {
192 merged_symbols.push(SymbolSpan {
193 qualified_name: base_s.qualified_name.clone(),
194 kind: base_s.kind.clone(),
195 text: base_s.text.clone(),
196 _order: order_counter,
197 });
198 order_counter += 1;
199 }
200 }
201 (true, false) => {
202 if let Some(a_s) = in_a {
204 merged_symbols.push(SymbolSpan {
206 qualified_name: a_s.qualified_name.clone(),
207 kind: a_s.kind.clone(),
208 text: a_s.text.clone(),
209 _order: order_counter,
210 });
211 order_counter += 1;
212 }
213 }
215 (false, true) => {
216 if let Some(b_s) = in_b {
218 merged_symbols.push(SymbolSpan {
220 qualified_name: b_s.qualified_name.clone(),
221 kind: b_s.kind.clone(),
222 text: b_s.text.clone(),
223 _order: order_counter,
224 });
225 order_counter += 1;
226 }
227 }
229 (true, true) => {
230 match (in_base, in_a, in_b) {
232 (None, Some(a_s), Some(b_s)) => {
233 conflicts.push(SymbolConflict {
235 qualified_name: name.to_string(),
236 kind: a_s.kind.clone(),
237 version_a: a_s.text.clone(),
238 version_b: b_s.text.clone(),
239 base: String::new(),
240 });
241 merged_symbols.push(SymbolSpan {
243 qualified_name: a_s.qualified_name.clone(),
244 kind: a_s.kind.clone(),
245 text: a_s.text.clone(),
246 _order: order_counter,
247 });
248 order_counter += 1;
249 }
250 (Some(base_s), Some(a_s), Some(b_s)) => {
251 if a_s.text == b_s.text {
252 merged_symbols.push(SymbolSpan {
254 qualified_name: a_s.qualified_name.clone(),
255 kind: a_s.kind.clone(),
256 text: a_s.text.clone(),
257 _order: order_counter,
258 });
259 order_counter += 1;
260 } else {
261 conflicts.push(SymbolConflict {
263 qualified_name: name.to_string(),
264 kind: base_s.kind.clone(),
265 version_a: a_s.text.clone(),
266 version_b: b_s.text.clone(),
267 base: base_s.text.clone(),
268 });
269 merged_symbols.push(SymbolSpan {
271 qualified_name: a_s.qualified_name.clone(),
272 kind: a_s.kind.clone(),
273 text: a_s.text.clone(),
274 _order: order_counter,
275 });
276 order_counter += 1;
277 }
278 }
279 (Some(base_s), None, Some(b_s)) => {
280 conflicts.push(SymbolConflict {
282 qualified_name: name.to_string(),
283 kind: base_s.kind.clone(),
284 version_a: String::new(),
285 version_b: b_s.text.clone(),
286 base: base_s.text.clone(),
287 });
288 merged_symbols.push(SymbolSpan {
290 qualified_name: b_s.qualified_name.clone(),
291 kind: b_s.kind.clone(),
292 text: b_s.text.clone(),
293 _order: order_counter,
294 });
295 order_counter += 1;
296 }
297 (Some(base_s), Some(a_s), None) => {
298 conflicts.push(SymbolConflict {
300 qualified_name: name.to_string(),
301 kind: base_s.kind.clone(),
302 version_a: a_s.text.clone(),
303 version_b: String::new(),
304 base: base_s.text.clone(),
305 });
306 merged_symbols.push(SymbolSpan {
308 qualified_name: a_s.qualified_name.clone(),
309 kind: a_s.kind.clone(),
310 text: a_s.text.clone(),
311 _order: order_counter,
312 });
313 order_counter += 1;
314 }
315 (Some(_), None, None) => {
316 }
318 _ => {}
319 }
320 }
321 }
322 }
323
324 let mut merged_import_set: BTreeSet<String> = BTreeSet::new();
326 for imp in base_imports.iter().chain(a_imports.iter()).chain(b_imports.iter()) {
327 merged_import_set.insert(imp.text.clone());
328 }
329
330 let mut output = String::new();
332
333 if !merged_import_set.is_empty() {
335 for imp in &merged_import_set {
336 output.push_str(imp);
337 output.push('\n');
338 }
339 output.push('\n');
340 }
341
342 let symbol_texts: Vec<&str> = merged_symbols.iter().map(|s| s.text.as_str()).collect();
344 output.push_str(&symbol_texts.join("\n\n"));
345
346 if !output.ends_with('\n') {
348 output.push('\n');
349 }
350
351 let status = if conflicts.is_empty() {
352 MergeStatus::Clean
353 } else {
354 MergeStatus::Conflict
355 };
356
357 Ok(MergeResult {
358 status,
359 merged_content: output,
360 conflicts,
361 })
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_merge_status_eq() {
370 assert_eq!(MergeStatus::Clean, MergeStatus::Clean);
371 assert_eq!(MergeStatus::Conflict, MergeStatus::Conflict);
372 assert_ne!(MergeStatus::Clean, MergeStatus::Conflict);
373 }
374}