1use std::collections::{BTreeMap, BTreeSet};
8use std::path::Path;
9
10use dk_core::{Error, Result, Symbol};
11
12use crate::parser::ParserRegistry;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum MergeStatus {
17 Clean,
19 Conflict,
21}
22
23#[derive(Debug)]
25pub struct MergeResult {
26 pub status: MergeStatus,
27 pub merged_content: String,
28 pub conflicts: Vec<SymbolConflict>,
29}
30
31#[derive(Debug)]
33pub struct SymbolConflict {
34 pub qualified_name: String,
35 pub kind: String,
36 pub version_a: String,
37 pub version_b: String,
38 pub base: String,
39}
40
41#[derive(Debug, Clone)]
43struct SymbolSpan {
44 qualified_name: String,
45 kind: String,
46 text: String,
48 _order: usize,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54struct ImportLine {
55 text: String,
57}
58
59fn extract_spans(
62 source: &str,
63 symbols: &[Symbol],
64) -> (Vec<ImportLine>, Vec<SymbolSpan>) {
65 let bytes = source.as_bytes();
66 let mut import_lines = Vec::new();
67 let mut symbol_spans = Vec::new();
68
69 let mut symbol_ranges: Vec<(usize, usize)> = symbols
72 .iter()
73 .map(|s| (s.span.start_byte as usize, s.span.end_byte as usize))
74 .collect();
75 symbol_ranges.sort_by_key(|r| r.0);
76
77 for line in source.lines() {
80 let trimmed = line.trim();
81 if trimmed.starts_with("use ")
82 || trimmed.starts_with("import ")
83 || trimmed.starts_with("from ")
84 {
85 let line_start = line.as_ptr() as usize - bytes.as_ptr() as usize;
87 let inside_symbol = symbol_ranges
88 .iter()
89 .any(|(start, end)| line_start >= *start && line_start < *end);
90 if !inside_symbol {
91 import_lines.push(ImportLine {
92 text: line.to_string(),
93 });
94 }
95 }
96 }
97
98 for (order, sym) in symbols.iter().enumerate() {
103 let start = sym.span.start_byte as usize;
104 let end = sym.span.end_byte as usize;
105 if end <= bytes.len() {
106 let body = String::from_utf8_lossy(&bytes[start..end]).to_string();
107 let text = match &sym.doc_comment {
111 Some(doc) if !doc.is_empty() && !body.contains(doc.as_str()) => {
112 format!("{doc}\n{body}")
113 }
114 _ => body,
115 };
116 symbol_spans.push(SymbolSpan {
117 qualified_name: sym.qualified_name.clone(),
118 kind: sym.kind.to_string(),
119 text,
120 _order: order,
121 });
122 }
123 }
124
125 (import_lines, symbol_spans)
126}
127
128pub fn ast_merge(
137 registry: &ParserRegistry,
138 file_path: &str,
139 base: &str,
140 version_a: &str,
141 version_b: &str,
142) -> Result<MergeResult> {
143 let path = Path::new(file_path);
144
145 if !registry.supports_file(path) {
146 return Err(Error::UnsupportedLanguage(format!(
147 "AST merge not supported for file: {file_path}"
148 )));
149 }
150
151 let base_analysis = registry.parse_file(path, base.as_bytes())?;
153 let a_analysis = registry.parse_file(path, version_a.as_bytes())?;
154 let b_analysis = registry.parse_file(path, version_b.as_bytes())?;
155
156 let (base_imports, base_spans) = extract_spans(base, &base_analysis.symbols);
158 let (a_imports, a_spans) = extract_spans(version_a, &a_analysis.symbols);
159 let (b_imports, b_spans) = extract_spans(version_b, &b_analysis.symbols);
160
161 let base_map: BTreeMap<&str, &SymbolSpan> =
163 base_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
164 let a_map: BTreeMap<&str, &SymbolSpan> =
165 a_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
166 let b_map: BTreeMap<&str, &SymbolSpan> =
167 b_spans.iter().map(|s| (s.qualified_name.as_str(), s)).collect();
168
169 let all_names: BTreeSet<&str> = base_map
171 .keys()
172 .chain(a_map.keys())
173 .chain(b_map.keys())
174 .copied()
175 .collect();
176
177 let mut merged_symbols: Vec<SymbolSpan> = Vec::new();
178 let mut conflicts: Vec<SymbolConflict> = Vec::new();
179 let mut order_counter: usize = 0;
180
181 for name in &all_names {
182 let in_base = base_map.get(name);
183 let in_a = a_map.get(name);
184 let in_b = b_map.get(name);
185
186 let a_modified = match (in_base, in_a) {
187 (Some(base_s), Some(a_s)) => base_s.text != a_s.text,
188 (None, Some(_)) => true, (Some(_), None) => true, (None, None) => false,
191 };
192
193 let b_modified = match (in_base, in_b) {
194 (Some(base_s), Some(b_s)) => base_s.text != b_s.text,
195 (None, Some(_)) => true, (Some(_), None) => true, (None, None) => false,
198 };
199
200 match (a_modified, b_modified) {
201 (false, false) => {
202 if let Some(base_s) = in_base {
204 merged_symbols.push(SymbolSpan {
205 qualified_name: base_s.qualified_name.clone(),
206 kind: base_s.kind.clone(),
207 text: base_s.text.clone(),
208 _order: order_counter,
209 });
210 order_counter += 1;
211 }
212 }
213 (true, false) => {
214 if let Some(a_s) = in_a {
216 merged_symbols.push(SymbolSpan {
218 qualified_name: a_s.qualified_name.clone(),
219 kind: a_s.kind.clone(),
220 text: a_s.text.clone(),
221 _order: order_counter,
222 });
223 order_counter += 1;
224 }
225 }
227 (false, true) => {
228 if let Some(b_s) = in_b {
230 merged_symbols.push(SymbolSpan {
232 qualified_name: b_s.qualified_name.clone(),
233 kind: b_s.kind.clone(),
234 text: b_s.text.clone(),
235 _order: order_counter,
236 });
237 order_counter += 1;
238 }
239 }
241 (true, true) => {
242 match (in_base, in_a, in_b) {
244 (None, Some(a_s), Some(b_s)) => {
245 conflicts.push(SymbolConflict {
247 qualified_name: name.to_string(),
248 kind: a_s.kind.clone(),
249 version_a: a_s.text.clone(),
250 version_b: b_s.text.clone(),
251 base: String::new(),
252 });
253 merged_symbols.push(SymbolSpan {
255 qualified_name: a_s.qualified_name.clone(),
256 kind: a_s.kind.clone(),
257 text: a_s.text.clone(),
258 _order: order_counter,
259 });
260 order_counter += 1;
261 }
262 (Some(base_s), Some(a_s), Some(b_s)) => {
263 if a_s.text == b_s.text {
264 merged_symbols.push(SymbolSpan {
266 qualified_name: a_s.qualified_name.clone(),
267 kind: a_s.kind.clone(),
268 text: a_s.text.clone(),
269 _order: order_counter,
270 });
271 order_counter += 1;
272 } else {
273 conflicts.push(SymbolConflict {
275 qualified_name: name.to_string(),
276 kind: base_s.kind.clone(),
277 version_a: a_s.text.clone(),
278 version_b: b_s.text.clone(),
279 base: base_s.text.clone(),
280 });
281 merged_symbols.push(SymbolSpan {
283 qualified_name: a_s.qualified_name.clone(),
284 kind: a_s.kind.clone(),
285 text: a_s.text.clone(),
286 _order: order_counter,
287 });
288 order_counter += 1;
289 }
290 }
291 (Some(base_s), None, Some(b_s)) => {
292 conflicts.push(SymbolConflict {
294 qualified_name: name.to_string(),
295 kind: base_s.kind.clone(),
296 version_a: String::new(),
297 version_b: b_s.text.clone(),
298 base: base_s.text.clone(),
299 });
300 merged_symbols.push(SymbolSpan {
302 qualified_name: b_s.qualified_name.clone(),
303 kind: b_s.kind.clone(),
304 text: b_s.text.clone(),
305 _order: order_counter,
306 });
307 order_counter += 1;
308 }
309 (Some(base_s), Some(a_s), None) => {
310 conflicts.push(SymbolConflict {
312 qualified_name: name.to_string(),
313 kind: base_s.kind.clone(),
314 version_a: a_s.text.clone(),
315 version_b: String::new(),
316 base: base_s.text.clone(),
317 });
318 merged_symbols.push(SymbolSpan {
320 qualified_name: a_s.qualified_name.clone(),
321 kind: a_s.kind.clone(),
322 text: a_s.text.clone(),
323 _order: order_counter,
324 });
325 order_counter += 1;
326 }
327 (Some(_), None, None) => {
328 }
330 _ => {}
331 }
332 }
333 }
334 }
335
336 let mut merged_import_set: BTreeSet<String> = BTreeSet::new();
338 for imp in base_imports.iter().chain(a_imports.iter()).chain(b_imports.iter()) {
339 merged_import_set.insert(imp.text.clone());
340 }
341
342 let mut output = String::new();
344
345 if !merged_import_set.is_empty() {
347 for imp in &merged_import_set {
348 output.push_str(imp);
349 output.push('\n');
350 }
351 output.push('\n');
352 }
353
354 let symbol_texts: Vec<&str> = merged_symbols.iter().map(|s| s.text.as_str()).collect();
356 output.push_str(&symbol_texts.join("\n\n"));
357
358 if !output.ends_with('\n') {
360 output.push('\n');
361 }
362
363 let status = if conflicts.is_empty() {
364 MergeStatus::Clean
365 } else {
366 MergeStatus::Conflict
367 };
368
369 Ok(MergeResult {
370 status,
371 merged_content: output,
372 conflicts,
373 })
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_merge_status_eq() {
382 assert_eq!(MergeStatus::Clean, MergeStatus::Clean);
383 assert_eq!(MergeStatus::Conflict, MergeStatus::Conflict);
384 assert_ne!(MergeStatus::Clean, MergeStatus::Conflict);
385 }
386}