1use std::collections::{BTreeMap, HashSet};
8use std::path::Path;
9
10use dk_core::{Error, Result, Symbol};
11
12use crate::parser::ParserRegistry;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum MergeStatus {
17 Clean,
19 Conflict,
21}
22
23#[derive(Debug)]
25pub struct MergeResult {
26 pub status: MergeStatus,
27 pub merged_content: String,
28 pub conflicts: Vec<SymbolConflict>,
29}
30
31#[derive(Debug)]
33pub struct SymbolConflict {
34 pub qualified_name: String,
35 pub kind: String,
36 pub version_a: String,
37 pub version_b: String,
38 pub base: String,
39}
40
41#[derive(Debug, Clone)]
43struct SymbolSpan {
44 qualified_name: String,
45 kind: String,
46 text: String,
48 _order: usize,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54struct ImportLine {
55 text: String,
57}
58
59fn extract_spans(
62 source: &str,
63 symbols: &[Symbol],
64) -> (Vec<ImportLine>, Vec<SymbolSpan>) {
65 let bytes = source.as_bytes();
66 let mut import_lines = Vec::new();
67 let mut symbol_spans = Vec::new();
68
69 let mut symbol_ranges: Vec<(usize, usize)> = symbols
72 .iter()
73 .map(|s| (s.span.start_byte as usize, s.span.end_byte as usize))
74 .collect();
75 symbol_ranges.sort_by_key(|r| r.0);
76
77 for line in source.lines() {
80 let trimmed = line.trim();
81 if trimmed.starts_with("use ")
82 || trimmed.starts_with("import ")
83 || trimmed.starts_with("from ")
84 {
85 let line_start = line.as_ptr() as usize - bytes.as_ptr() as usize;
87 let inside_symbol = symbol_ranges
88 .iter()
89 .any(|(start, end)| line_start >= *start && line_start < *end);
90 if !inside_symbol {
91 import_lines.push(ImportLine {
92 text: line.to_string(),
93 });
94 }
95 }
96 }
97
98 for (order, sym) in symbols.iter().enumerate() {
103 let start = sym.span.start_byte as usize;
104 let end = sym.span.end_byte as usize;
105 if end <= bytes.len() {
106 let body = String::from_utf8_lossy(&bytes[start..end]).to_string();
107 let text = match &sym.doc_comment {
111 Some(doc) if !doc.is_empty() && !body.contains(doc.as_str()) => {
112 format!("{doc}\n{body}")
113 }
114 _ => body,
115 };
116 symbol_spans.push(SymbolSpan {
117 qualified_name: sym.qualified_name.clone(),
118 kind: sym.kind.to_string(),
119 text,
120 _order: order,
121 });
122 }
123 }
124
125 (import_lines, symbol_spans)
126}
127
128pub fn ast_merge(
137 registry: &ParserRegistry,
138 file_path: &str,
139 base: &str,
140 version_a: &str,
141 version_b: &str,
142) -> Result<MergeResult> {
143 let path = Path::new(file_path);
144
145 if !registry.supports_file(path) {
146 return Err(Error::UnsupportedLanguage(format!(
147 "AST merge not supported for file: {file_path}"
148 )));
149 }
150
151 let base_analysis = registry.parse_file(path, base.as_bytes())?;
153 let a_analysis = registry.parse_file(path, version_a.as_bytes())?;
154 let b_analysis = registry.parse_file(path, version_b.as_bytes())?;
155
156 let (base_imports, base_spans) = extract_spans(base, &base_analysis.symbols);
158 let (a_imports, a_spans) = extract_spans(version_a, &a_analysis.symbols);
159 let (b_imports, b_spans) = extract_spans(version_b, &b_analysis.symbols);
160
161 let base_map: BTreeMap<&str, &SymbolSpan> =
163 base_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
164 let a_map: BTreeMap<&str, &SymbolSpan> =
165 a_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
166 let b_map: BTreeMap<&str, &SymbolSpan> =
167 b_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
168
169 let mut all_names: Vec<&str> = Vec::new();
172 let mut seen: HashSet<&str> = HashSet::new();
173
174 for span in &base_spans {
176 let name = span.qualified_name.as_str();
177 if seen.insert(name) {
178 all_names.push(name);
179 }
180 }
181 for span in &a_spans {
183 let name = span.qualified_name.as_str();
184 if seen.insert(name) {
185 all_names.push(name);
186 }
187 }
188 for span in &b_spans {
190 let name = span.qualified_name.as_str();
191 if seen.insert(name) {
192 all_names.push(name);
193 }
194 }
195
196 let mut merged_symbols: Vec<SymbolSpan> = Vec::new();
197 let mut conflicts: Vec<SymbolConflict> = Vec::new();
198 let mut order_counter: usize = 0;
199
200 for name in &all_names {
201 let in_base = base_map.get(name);
202 let in_a = a_map.get(name);
203 let in_b = b_map.get(name);
204
205 let a_modified = match (in_base, in_a) {
206 (Some(base_s), Some(a_s)) => base_s.text != a_s.text,
207 (None, Some(_)) => true, (Some(_), None) => true, (None, None) => false,
210 };
211
212 let b_modified = match (in_base, in_b) {
213 (Some(base_s), Some(b_s)) => base_s.text != b_s.text,
214 (None, Some(_)) => true, (Some(_), None) => true, (None, None) => false,
217 };
218
219 match (a_modified, b_modified) {
220 (false, false) => {
221 if let Some(base_s) = in_base {
223 merged_symbols.push(SymbolSpan {
224 qualified_name: base_s.qualified_name.clone(),
225 kind: base_s.kind.clone(),
226 text: base_s.text.clone(),
227 _order: order_counter,
228 });
229 order_counter += 1;
230 }
231 }
232 (true, false) => {
233 if let Some(a_s) = in_a {
235 merged_symbols.push(SymbolSpan {
237 qualified_name: a_s.qualified_name.clone(),
238 kind: a_s.kind.clone(),
239 text: a_s.text.clone(),
240 _order: order_counter,
241 });
242 order_counter += 1;
243 }
244 }
246 (false, true) => {
247 if let Some(b_s) = in_b {
249 merged_symbols.push(SymbolSpan {
251 qualified_name: b_s.qualified_name.clone(),
252 kind: b_s.kind.clone(),
253 text: b_s.text.clone(),
254 _order: order_counter,
255 });
256 order_counter += 1;
257 }
258 }
260 (true, true) => {
261 match (in_base, in_a, in_b) {
263 (None, Some(a_s), Some(b_s)) => {
264 conflicts.push(SymbolConflict {
266 qualified_name: name.to_string(),
267 kind: a_s.kind.clone(),
268 version_a: a_s.text.clone(),
269 version_b: b_s.text.clone(),
270 base: String::new(),
271 });
272 merged_symbols.push(SymbolSpan {
274 qualified_name: a_s.qualified_name.clone(),
275 kind: a_s.kind.clone(),
276 text: a_s.text.clone(),
277 _order: order_counter,
278 });
279 order_counter += 1;
280 }
281 (Some(base_s), Some(a_s), Some(b_s)) => {
282 if a_s.text == b_s.text {
283 merged_symbols.push(SymbolSpan {
285 qualified_name: a_s.qualified_name.clone(),
286 kind: a_s.kind.clone(),
287 text: a_s.text.clone(),
288 _order: order_counter,
289 });
290 order_counter += 1;
291 } else {
292 conflicts.push(SymbolConflict {
294 qualified_name: name.to_string(),
295 kind: base_s.kind.clone(),
296 version_a: a_s.text.clone(),
297 version_b: b_s.text.clone(),
298 base: base_s.text.clone(),
299 });
300 merged_symbols.push(SymbolSpan {
302 qualified_name: a_s.qualified_name.clone(),
303 kind: a_s.kind.clone(),
304 text: a_s.text.clone(),
305 _order: order_counter,
306 });
307 order_counter += 1;
308 }
309 }
310 (Some(base_s), None, Some(b_s)) => {
311 conflicts.push(SymbolConflict {
313 qualified_name: name.to_string(),
314 kind: base_s.kind.clone(),
315 version_a: String::new(),
316 version_b: b_s.text.clone(),
317 base: base_s.text.clone(),
318 });
319 merged_symbols.push(SymbolSpan {
321 qualified_name: b_s.qualified_name.clone(),
322 kind: b_s.kind.clone(),
323 text: b_s.text.clone(),
324 _order: order_counter,
325 });
326 order_counter += 1;
327 }
328 (Some(base_s), Some(a_s), None) => {
329 conflicts.push(SymbolConflict {
331 qualified_name: name.to_string(),
332 kind: base_s.kind.clone(),
333 version_a: a_s.text.clone(),
334 version_b: String::new(),
335 base: base_s.text.clone(),
336 });
337 merged_symbols.push(SymbolSpan {
339 qualified_name: a_s.qualified_name.clone(),
340 kind: a_s.kind.clone(),
341 text: a_s.text.clone(),
342 _order: order_counter,
343 });
344 order_counter += 1;
345 }
346 (Some(_), None, None) => {
347 }
349 _ => {}
350 }
351 }
352 }
353 }
354
355 let mut merged_imports: Vec<String> = Vec::new();
357 let mut import_seen: HashSet<String> = HashSet::new();
358 for imp in base_imports.iter().chain(a_imports.iter()).chain(b_imports.iter()) {
360 if import_seen.insert(imp.text.clone()) {
361 merged_imports.push(imp.text.clone());
362 }
363 }
364
365 let mut output = String::new();
367
368 if !merged_imports.is_empty() {
370 for imp in &merged_imports {
371 output.push_str(imp);
372 output.push('\n');
373 }
374 output.push('\n');
375 }
376
377 let symbol_texts: Vec<&str> = merged_symbols.iter().map(|s| s.text.as_str()).collect();
379 output.push_str(&symbol_texts.join("\n\n"));
380
381 if !output.ends_with('\n') {
383 output.push('\n');
384 }
385
386 let status = if conflicts.is_empty() {
387 MergeStatus::Clean
388 } else {
389 MergeStatus::Conflict
390 };
391
392 Ok(MergeResult {
393 status,
394 merged_content: output,
395 conflicts,
396 })
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_merge_status_eq() {
405 assert_eq!(MergeStatus::Clean, MergeStatus::Clean);
406 assert_eq!(MergeStatus::Conflict, MergeStatus::Conflict);
407 assert_ne!(MergeStatus::Clean, MergeStatus::Conflict);
408 }
409}