1use std::collections::HashMap;
16
17use tree_sitter::{Node, Parser, Tree};
18
19use super::{Pattern, PatternVar};
20use crate::error::CodemodError;
21use crate::language::LanguageAdapter;
22
23#[derive(Debug, Clone)]
29struct NodeSnapshot {
30 kind: String,
32 text: String,
34 #[allow(dead_code)]
36 is_named: bool,
37 children: Vec<usize>,
39 #[allow(dead_code)]
41 depth: usize,
42}
43
44#[derive(Debug)]
46enum DiffKind {
47 Same,
49 Changed {
51 before_text: String,
52 after_text: String,
53 node_kind: String,
54 },
55 Structural,
58}
59
60#[derive(Debug, Clone, Copy)]
62enum TemplateSource {
63 Before,
64 After,
65}
66
67pub struct PatternInferrer {
76 language: Box<dyn LanguageAdapter>,
77}
78
79impl PatternInferrer {
80 pub fn new(language: Box<dyn LanguageAdapter>) -> Self {
82 Self { language }
83 }
84
85 pub fn infer_from_example(&self, before: &str, after: &str) -> crate::Result<Pattern> {
96 let before_tree = self.parse(before)?;
97 let after_tree = self.parse(after)?;
98
99 let before_snaps = Self::flatten_tree(&before_tree, before);
101 let after_snaps = Self::flatten_tree(&after_tree, after);
102
103 let mut var_counter: usize = 0;
105 let mut variables: Vec<PatternVar> = Vec::new();
106 let mut var_map: HashMap<(String, String), String> = HashMap::new();
109
110 let before_template = self.build_template(
111 &before_snaps,
112 &after_snaps,
113 0,
114 0,
115 before,
116 TemplateSource::Before,
117 &mut var_counter,
118 &mut variables,
119 &mut var_map,
120 );
121
122 let after_template = self.build_template(
123 &before_snaps,
124 &after_snaps,
125 0,
126 0,
127 after,
128 TemplateSource::After,
129 &mut var_counter,
130 &mut variables,
131 &mut var_map,
132 );
133
134 let confidence = Self::compute_confidence(&variables, &before_template, &after_template);
135
136 let pattern = Pattern::new(
137 before_template,
138 after_template,
139 variables,
140 self.language.name().to_string(),
141 confidence,
142 );
143
144 Ok(pattern)
145 }
146
147 pub fn infer_from_examples(&self, examples: &[(String, String)]) -> crate::Result<Pattern> {
157 if examples.is_empty() {
158 return Err(CodemodError::PatternInference(
159 "At least one example pair is required".into(),
160 ));
161 }
162
163 let mut pattern = self.infer_from_example(&examples[0].0, &examples[0].1)?;
165
166 if examples.len() == 1 {
167 return Ok(pattern);
168 }
169
170 let mut confirmed: usize = 1;
172 for (before, after) in &examples[1..] {
173 match self.infer_from_example(before, after) {
174 Ok(other) => {
175 if Self::patterns_compatible(&pattern, &other) {
176 confirmed += 1;
177 } else {
178 log::warn!("Example pair produced an incompatible pattern — skipping");
179 }
180 }
181 Err(e) => {
182 log::warn!("Failed to infer from example pair: {e}");
183 }
184 }
185 }
186
187 let cross_factor = confirmed as f64 / examples.len() as f64;
190 pattern.confidence = (pattern.confidence * 0.6 + cross_factor * 0.4).min(1.0);
191
192 Ok(pattern)
193 }
194
195 fn parse(&self, source: &str) -> crate::Result<Tree> {
201 let mut parser = Parser::new();
202 parser
203 .set_language(&self.language.language())
204 .map_err(|e| CodemodError::Parse(format!("Failed to set language: {e}")))?;
205 parser
206 .parse(source, None)
207 .ok_or_else(|| CodemodError::Parse("tree-sitter returned no tree".into()))
208 }
209
210 fn flatten_tree(tree: &Tree, source: &str) -> Vec<NodeSnapshot> {
217 let mut snaps = Vec::new();
218 Self::flatten_node(tree.root_node(), source, &mut snaps, 0);
219 snaps
220 }
221
222 fn flatten_node(
224 node: Node,
225 source: &str,
226 snaps: &mut Vec<NodeSnapshot>,
227 depth: usize,
228 ) -> usize {
229 let idx = snaps.len();
230 snaps.push(NodeSnapshot {
232 kind: node.kind().to_string(),
233 text: source[node.byte_range()].to_string(),
234 is_named: node.is_named(),
235 children: Vec::new(),
236 depth,
237 });
238
239 let mut child_indices = Vec::new();
240 let child_count = node.named_child_count();
241 for i in 0..child_count {
242 if let Some(child) = node.named_child(i) {
243 let child_idx = Self::flatten_node(child, source, snaps, depth + 1);
244 child_indices.push(child_idx);
245 }
246 }
247
248 snaps[idx].children = child_indices;
249 idx
250 }
251
252 #[allow(clippy::too_many_arguments)]
262 fn build_template(
263 &self,
264 before_snaps: &[NodeSnapshot],
265 after_snaps: &[NodeSnapshot],
266 before_idx: usize,
267 after_idx: usize,
268 source: &str,
269 side: TemplateSource,
270 var_counter: &mut usize,
271 variables: &mut Vec<PatternVar>,
272 var_map: &mut HashMap<(String, String), String>,
273 ) -> String {
274 if before_idx >= before_snaps.len() || after_idx >= after_snaps.len() {
276 return source.to_string();
277 }
278
279 let b_snap = &before_snaps[before_idx];
280 let a_snap = &after_snaps[after_idx];
281
282 match self.diff_nodes(b_snap, a_snap) {
283 DiffKind::Same => {
284 match side {
286 TemplateSource::Before => b_snap.text.clone(),
287 TemplateSource::After => a_snap.text.clone(),
288 }
289 }
290 DiffKind::Changed {
291 before_text,
292 after_text,
293 node_kind,
294 } => {
295 let key = (before_text.clone(), after_text.clone());
297 let var_name = if let Some(name) = var_map.get(&key) {
298 name.clone()
299 } else {
300 *var_counter += 1;
301 let name = format!("$var{}", *var_counter);
302 var_map.insert(key, name.clone());
303 variables.push(PatternVar {
304 name: name.clone(),
305 node_type: Some(node_kind),
306 });
307 name
308 };
309 var_name
310 }
311 DiffKind::Structural => {
312 if b_snap.children.is_empty() && a_snap.children.is_empty() {
316 let key = (b_snap.text.clone(), a_snap.text.clone());
317 let var_name = if let Some(name) = var_map.get(&key) {
318 name.clone()
319 } else {
320 *var_counter += 1;
321 let name = format!("$var{}", *var_counter);
322 var_map.insert(key, name.clone());
323 variables.push(PatternVar {
324 name: name.clone(),
325 node_type: Some(b_snap.kind.clone()),
326 });
327 name
328 };
329 return var_name;
330 }
331
332 self.build_template_from_children(
335 before_snaps,
336 after_snaps,
337 b_snap,
338 a_snap,
339 source,
340 side,
341 var_counter,
342 variables,
343 var_map,
344 )
345 }
346 }
347 }
348
349 #[allow(clippy::too_many_arguments)]
353 fn build_template_from_children(
354 &self,
355 before_snaps: &[NodeSnapshot],
356 after_snaps: &[NodeSnapshot],
357 b_snap: &NodeSnapshot,
358 a_snap: &NodeSnapshot,
359 _source: &str,
360 side: TemplateSource,
361 var_counter: &mut usize,
362 variables: &mut Vec<PatternVar>,
363 var_map: &mut HashMap<(String, String), String>,
364 ) -> String {
365 let base_snap = match side {
366 TemplateSource::Before => b_snap,
367 TemplateSource::After => a_snap,
368 };
369 let base_text = &base_snap.text;
370
371 let min_children = b_snap.children.len().min(a_snap.children.len());
373 if min_children == 0 {
374 return base_text.clone();
375 }
376
377 let mut result = base_text.clone();
378 let mut replacements: Vec<(String, String)> = Vec::new();
381
382 for i in 0..min_children {
383 let b_child_idx = b_snap.children[i];
384 let a_child_idx = a_snap.children[i];
385
386 let child_template = self.build_template(
387 before_snaps,
388 after_snaps,
389 b_child_idx,
390 a_child_idx,
391 match side {
392 TemplateSource::Before => &before_snaps[b_child_idx].text,
393 TemplateSource::After => &after_snaps[a_child_idx].text,
394 },
395 side,
396 var_counter,
397 variables,
398 var_map,
399 );
400
401 let original_child_text = match side {
402 TemplateSource::Before => &before_snaps[b_child_idx].text,
403 TemplateSource::After => &after_snaps[a_child_idx].text,
404 };
405
406 if child_template != *original_child_text {
407 replacements.push((original_child_text.clone(), child_template));
408 }
409 }
410
411 for (old, new) in replacements.iter().rev() {
414 if let Some(pos) = result.rfind(old.as_str()) {
415 result.replace_range(pos..pos + old.len(), new);
416 }
417 }
418
419 result
420 }
421
422 fn diff_nodes(&self, before: &NodeSnapshot, after: &NodeSnapshot) -> DiffKind {
424 if before.text == after.text {
426 return DiffKind::Same;
427 }
428
429 if before.children.is_empty() && after.children.is_empty() && before.kind == after.kind {
431 return DiffKind::Changed {
432 before_text: before.text.clone(),
433 after_text: after.text.clone(),
434 node_kind: before.kind.clone(),
435 };
436 }
437
438 if before.kind == after.kind {
440 return DiffKind::Structural;
441 }
442
443 DiffKind::Structural
445 }
446
447 fn compute_confidence(
457 variables: &[PatternVar],
458 before_template: &str,
459 _after_template: &str,
460 ) -> f64 {
461 if before_template.is_empty() {
462 return 0.0;
463 }
464
465 let total_len = before_template.len() as f64;
466 let var_len: f64 = variables.iter().map(|v| v.name.len() as f64).sum();
467
468 let fixed_ratio = 1.0 - (var_len / total_len).min(1.0);
470
471 let var_penalty = 1.0 / (1.0 + variables.len() as f64 * 0.15);
473
474 (fixed_ratio * 0.7 + var_penalty * 0.3).clamp(0.0, 1.0)
475 }
476
477 fn patterns_compatible(a: &Pattern, b: &Pattern) -> bool {
485 if a.variables.len() != b.variables.len() {
487 return false;
488 }
489
490 let skeleton_a = Self::strip_variables(&a.before_template);
492 let skeleton_b = Self::strip_variables(&b.before_template);
493
494 skeleton_a == skeleton_b
495 }
496
497 fn strip_variables(template: &str) -> String {
500 let mut result = String::with_capacity(template.len());
501 let mut chars = template.chars().peekable();
502 while let Some(ch) = chars.next() {
503 if ch == '$' {
504 result.push_str("$$");
506 while let Some(&next) = chars.peek() {
507 if next.is_alphanumeric() || next == '_' {
508 chars.next();
509 } else {
510 break;
511 }
512 }
513 } else {
514 result.push(ch);
515 }
516 }
517 result
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_strip_variables() {
527 let input = "foo($var1, $var2)";
528 let stripped = PatternInferrer::strip_variables(input);
529 assert_eq!(stripped, "foo($$, $$)");
530 }
531
532 #[test]
533 fn test_compute_confidence_no_variables() {
534 let vars: Vec<PatternVar> = vec![];
535 let conf = PatternInferrer::compute_confidence(&vars, "println!(\"hello\")", "");
536 assert!(conf > 0.9, "expected high confidence, got {conf}");
538 }
539
540 #[test]
541 fn test_compute_confidence_with_variables() {
542 let vars = vec![
543 PatternVar {
544 name: "$var1".into(),
545 node_type: Some("identifier".into()),
546 },
547 PatternVar {
548 name: "$var2".into(),
549 node_type: Some("identifier".into()),
550 },
551 ];
552 let conf = PatternInferrer::compute_confidence(&vars, "foo($var1, $var2)", "");
553 assert!(
554 conf > 0.0 && conf < 1.0,
555 "expected moderate confidence, got {conf}"
556 );
557 }
558
559 #[test]
560 fn test_patterns_compatible_same() {
561 let a = Pattern::new(
562 "foo($var1)".into(),
563 "bar($var1)".into(),
564 vec![PatternVar {
565 name: "$var1".into(),
566 node_type: None,
567 }],
568 "stub".into(),
569 0.9,
570 );
571 let b = Pattern::new(
572 "foo($var1)".into(),
573 "bar($var1)".into(),
574 vec![PatternVar {
575 name: "$var1".into(),
576 node_type: None,
577 }],
578 "stub".into(),
579 0.8,
580 );
581 assert!(PatternInferrer::patterns_compatible(&a, &b));
582 }
583
584 #[test]
585 fn test_patterns_incompatible_different_var_count() {
586 let a = Pattern::new(
587 "foo($var1)".into(),
588 "bar($var1)".into(),
589 vec![PatternVar {
590 name: "$var1".into(),
591 node_type: None,
592 }],
593 "stub".into(),
594 0.9,
595 );
596 let b = Pattern::new(
597 "foo($var1, $var2)".into(),
598 "bar($var1, $var2)".into(),
599 vec![
600 PatternVar {
601 name: "$var1".into(),
602 node_type: None,
603 },
604 PatternVar {
605 name: "$var2".into(),
606 node_type: None,
607 },
608 ],
609 "stub".into(),
610 0.8,
611 );
612 assert!(!PatternInferrer::patterns_compatible(&a, &b));
613 }
614}