1use std::collections::{BTreeMap, HashSet};
8use std::path::Path;
9
10use dk_core::{Error, Result, Symbol};
11
12use crate::parser::ParserRegistry;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum MergeStatus {
17 Clean,
19 Conflict,
21}
22
23#[derive(Debug)]
25pub struct MergeResult {
26 pub status: MergeStatus,
27 pub merged_content: String,
28 pub conflicts: Vec<SymbolConflict>,
29}
30
31#[derive(Debug)]
33pub struct SymbolConflict {
34 pub qualified_name: String,
35 pub kind: String,
36 pub version_a: String,
37 pub version_b: String,
38 pub base: String,
39}
40
41#[derive(Debug, Clone)]
43struct SymbolSpan {
44 qualified_name: String,
45 kind: String,
46 text: String,
48 _order: usize,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54struct ImportLine {
55 text: String,
57}
58
59fn extract_spans(
62 source: &str,
63 symbols: &[Symbol],
64) -> (Vec<ImportLine>, Vec<SymbolSpan>) {
65 let bytes = source.as_bytes();
66 let mut import_lines = Vec::new();
67 let mut symbol_spans = Vec::new();
68
69 let mut symbol_ranges: Vec<(usize, usize)> = symbols
72 .iter()
73 .map(|s| (s.span.start_byte as usize, s.span.end_byte as usize))
74 .collect();
75 symbol_ranges.sort_by_key(|r| r.0);
76
77 for line in source.lines() {
80 let trimmed = line.trim();
81 if trimmed.starts_with("use ")
82 || trimmed.starts_with("import ")
83 || trimmed.starts_with("from ")
84 {
85 let line_start = line.as_ptr() as usize - bytes.as_ptr() as usize;
87 let inside_symbol = symbol_ranges
88 .iter()
89 .any(|(start, end)| line_start >= *start && line_start < *end);
90 if !inside_symbol {
91 import_lines.push(ImportLine {
92 text: line.to_string(),
93 });
94 }
95 }
96 }
97
98 for (order, sym) in symbols.iter().enumerate() {
103 let start = sym.span.start_byte as usize;
104 let end = sym.span.end_byte as usize;
105 if end <= bytes.len() {
106 let extended_end = if end < bytes.len() {
110 let rest = &source[end..];
111 if let Some(nl) = rest.find('\n') {
112 let trailing = rest[..nl].trim();
113 if trailing.starts_with('#') || trailing.is_empty() {
114 end + nl
116 } else {
117 end
118 }
119 } else {
120 source.len() }
122 } else {
123 end
124 };
125 let body = String::from_utf8_lossy(&bytes[start..extended_end]).to_string();
126 let text = match &sym.doc_comment {
130 Some(doc) if !doc.is_empty() && !body.contains(doc.as_str()) => {
131 format!("{doc}\n{body}")
132 }
133 _ => body,
134 };
135 symbol_spans.push(SymbolSpan {
136 qualified_name: sym.qualified_name.clone(),
137 kind: sym.kind.to_string(),
138 text,
139 _order: order,
140 });
141 }
142 }
143
144 (import_lines, symbol_spans)
145}
146
147pub fn ast_merge(
156 registry: &ParserRegistry,
157 file_path: &str,
158 base: &str,
159 version_a: &str,
160 version_b: &str,
161) -> Result<MergeResult> {
162 let path = Path::new(file_path);
163
164 if !registry.supports_file(path) {
165 return Err(Error::UnsupportedLanguage(format!(
166 "AST merge not supported for file: {file_path}"
167 )));
168 }
169
170 let base_analysis = registry.parse_file(path, base.as_bytes())?;
172 let a_analysis = registry.parse_file(path, version_a.as_bytes())?;
173 let b_analysis = registry.parse_file(path, version_b.as_bytes())?;
174
175 let (base_imports, base_spans) = extract_spans(base, &base_analysis.symbols);
177 let (a_imports, a_spans) = extract_spans(version_a, &a_analysis.symbols);
178 let (b_imports, b_spans) = extract_spans(version_b, &b_analysis.symbols);
179
180 let base_map: BTreeMap<&str, &SymbolSpan> =
182 base_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
183 let a_map: BTreeMap<&str, &SymbolSpan> =
184 a_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
185 let b_map: BTreeMap<&str, &SymbolSpan> =
186 b_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
187
188 let mut all_names: Vec<&str> = Vec::new();
191 let mut seen: HashSet<&str> = HashSet::new();
192
193 for span in &base_spans {
195 let name = span.qualified_name.as_str();
196 if seen.insert(name) {
197 all_names.push(name);
198 }
199 }
200 for span in &a_spans {
202 let name = span.qualified_name.as_str();
203 if seen.insert(name) {
204 all_names.push(name);
205 }
206 }
207 for span in &b_spans {
209 let name = span.qualified_name.as_str();
210 if seen.insert(name) {
211 all_names.push(name);
212 }
213 }
214
215 let mut merged_symbols: Vec<SymbolSpan> = Vec::new();
216 let mut conflicts: Vec<SymbolConflict> = Vec::new();
217 let mut order_counter: usize = 0;
218
219 for name in &all_names {
220 let in_base = base_map.get(name);
221 let in_a = a_map.get(name);
222 let in_b = b_map.get(name);
223
224 let a_modified = match (in_base, in_a) {
225 (Some(base_s), Some(a_s)) => base_s.text != a_s.text,
226 (None, Some(_)) => true, (Some(_), None) => true, (None, None) => false,
229 };
230
231 let b_modified = match (in_base, in_b) {
232 (Some(base_s), Some(b_s)) => base_s.text != b_s.text,
233 (None, Some(_)) => true, (Some(_), None) => true, (None, None) => false,
236 };
237
238 match (a_modified, b_modified) {
239 (false, false) => {
240 if let Some(base_s) = in_base {
242 merged_symbols.push(SymbolSpan {
243 qualified_name: base_s.qualified_name.clone(),
244 kind: base_s.kind.clone(),
245 text: base_s.text.clone(),
246 _order: order_counter,
247 });
248 order_counter += 1;
249 }
250 }
251 (true, false) => {
252 if let Some(a_s) = in_a {
254 merged_symbols.push(SymbolSpan {
256 qualified_name: a_s.qualified_name.clone(),
257 kind: a_s.kind.clone(),
258 text: a_s.text.clone(),
259 _order: order_counter,
260 });
261 order_counter += 1;
262 }
263 }
265 (false, true) => {
266 if let Some(b_s) = in_b {
268 merged_symbols.push(SymbolSpan {
270 qualified_name: b_s.qualified_name.clone(),
271 kind: b_s.kind.clone(),
272 text: b_s.text.clone(),
273 _order: order_counter,
274 });
275 order_counter += 1;
276 }
277 }
279 (true, true) => {
280 match (in_base, in_a, in_b) {
282 (None, Some(a_s), Some(b_s)) => {
283 conflicts.push(SymbolConflict {
285 qualified_name: name.to_string(),
286 kind: a_s.kind.clone(),
287 version_a: a_s.text.clone(),
288 version_b: b_s.text.clone(),
289 base: String::new(),
290 });
291 merged_symbols.push(SymbolSpan {
293 qualified_name: a_s.qualified_name.clone(),
294 kind: a_s.kind.clone(),
295 text: a_s.text.clone(),
296 _order: order_counter,
297 });
298 order_counter += 1;
299 }
300 (Some(base_s), Some(a_s), Some(b_s)) => {
301 if a_s.text == b_s.text {
302 merged_symbols.push(SymbolSpan {
304 qualified_name: a_s.qualified_name.clone(),
305 kind: a_s.kind.clone(),
306 text: a_s.text.clone(),
307 _order: order_counter,
308 });
309 order_counter += 1;
310 } else {
311 conflicts.push(SymbolConflict {
313 qualified_name: name.to_string(),
314 kind: base_s.kind.clone(),
315 version_a: a_s.text.clone(),
316 version_b: b_s.text.clone(),
317 base: base_s.text.clone(),
318 });
319 merged_symbols.push(SymbolSpan {
321 qualified_name: a_s.qualified_name.clone(),
322 kind: a_s.kind.clone(),
323 text: a_s.text.clone(),
324 _order: order_counter,
325 });
326 order_counter += 1;
327 }
328 }
329 (Some(base_s), None, Some(b_s)) => {
330 conflicts.push(SymbolConflict {
332 qualified_name: name.to_string(),
333 kind: base_s.kind.clone(),
334 version_a: String::new(),
335 version_b: b_s.text.clone(),
336 base: base_s.text.clone(),
337 });
338 merged_symbols.push(SymbolSpan {
340 qualified_name: b_s.qualified_name.clone(),
341 kind: b_s.kind.clone(),
342 text: b_s.text.clone(),
343 _order: order_counter,
344 });
345 order_counter += 1;
346 }
347 (Some(base_s), Some(a_s), None) => {
348 conflicts.push(SymbolConflict {
350 qualified_name: name.to_string(),
351 kind: base_s.kind.clone(),
352 version_a: a_s.text.clone(),
353 version_b: String::new(),
354 base: base_s.text.clone(),
355 });
356 merged_symbols.push(SymbolSpan {
358 qualified_name: a_s.qualified_name.clone(),
359 kind: a_s.kind.clone(),
360 text: a_s.text.clone(),
361 _order: order_counter,
362 });
363 order_counter += 1;
364 }
365 (Some(_), None, None) => {
366 }
368 _ => {}
369 }
370 }
371 }
372 }
373
374 let mut merged_imports: Vec<String> = Vec::new();
376 let mut import_seen: HashSet<String> = HashSet::new();
377 for imp in base_imports.iter().chain(a_imports.iter()).chain(b_imports.iter()) {
379 if import_seen.insert(imp.text.clone()) {
380 merged_imports.push(imp.text.clone());
381 }
382 }
383
384 let mut output = String::new();
386
387 if !merged_imports.is_empty() {
389 for imp in &merged_imports {
390 output.push_str(imp);
391 output.push('\n');
392 }
393 output.push('\n');
394 }
395
396 for (i, sym) in merged_symbols.iter().enumerate() {
400 if i > 0 {
401 let prev_is_var = merged_symbols[i - 1].kind == "variable";
402 let curr_is_var = sym.kind == "variable";
403 if prev_is_var && curr_is_var {
404 output.push('\n');
405 } else {
406 output.push_str("\n\n");
407 }
408 }
409 output.push_str(&sym.text);
410 }
411
412 if !output.ends_with('\n') {
414 output.push('\n');
415 }
416
417 let status = if conflicts.is_empty() {
418 MergeStatus::Clean
419 } else {
420 MergeStatus::Conflict
421 };
422
423 Ok(MergeResult {
424 status,
425 merged_content: output,
426 conflicts,
427 })
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_merge_status_eq() {
436 assert_eq!(MergeStatus::Clean, MergeStatus::Clean);
437 assert_eq!(MergeStatus::Conflict, MergeStatus::Conflict);
438 assert_ne!(MergeStatus::Clean, MergeStatus::Conflict);
439 }
440}