Skip to main content

codemod_core/pattern/
inferrer.rs

1//! Pattern inference engine.
2//!
3//! This module implements the core algorithm that derives a transformation
4//! pattern from one or more before/after example pairs. The process works as
5//! follows:
6//!
7//! 1. Parse both the *before* and *after* code into tree-sitter ASTs.
8//! 2. Walk the two trees in parallel, performing a **structural diff**.
9//! 3. Where leaf nodes differ (identifiers, literals, etc.), extract them as
10//!    [`PatternVar`](super::PatternVar)s.
11//! 4. Where the structure is identical, preserve it verbatim in the template.
12//! 5. Assemble the `before_template` and `after_template` and compute a
13//!    confidence score.
14
15use std::collections::HashMap;
16
17use tree_sitter::{Node, Parser, Tree};
18
19use super::{Pattern, PatternVar};
20use crate::error::CodemodError;
21use crate::language::LanguageAdapter;
22
23// ---------------------------------------------------------------------------
24// Internal helper: a flattened representation of AST nodes for diffing
25// ---------------------------------------------------------------------------
26
27/// A lightweight, owned snapshot of a tree-sitter node used during diffing.
28#[derive(Debug, Clone)]
29struct NodeSnapshot {
30    /// tree-sitter node kind (e.g. `"identifier"`, `"call_expression"`).
31    kind: String,
32    /// The source text spanned by this node.
33    text: String,
34    /// Whether this is a named node.
35    #[allow(dead_code)]
36    is_named: bool,
37    /// Indices of child snapshots in the owning `Vec<NodeSnapshot>`.
38    children: Vec<usize>,
39    /// Depth in the tree (root = 0).
40    #[allow(dead_code)]
41    depth: usize,
42}
43
44/// Describes how two nodes relate during the structural diff.
45#[derive(Debug)]
46enum DiffKind {
47    /// Nodes are structurally and textually identical.
48    Same,
49    /// Nodes have the same kind but different text — a variable candidate.
50    Changed {
51        before_text: String,
52        after_text: String,
53        node_kind: String,
54    },
55    /// The tree structures diverge in a way that cannot be captured by a
56    /// simple variable substitution.
57    Structural,
58}
59
60/// Indicates which side of the diff we are generating a template for.
61#[derive(Debug, Clone, Copy)]
62enum TemplateSource {
63    Before,
64    After,
65}
66
67// ---------------------------------------------------------------------------
68// Public API
69// ---------------------------------------------------------------------------
70
71/// The pattern inference engine.
72///
73/// Given one or more before/after example pairs, the inferrer derives a
74/// [`Pattern`] that captures the transformation.
75pub struct PatternInferrer {
76    language: Box<dyn LanguageAdapter>,
77}
78
79impl PatternInferrer {
80    /// Creates a new inferrer backed by the given language adapter.
81    pub fn new(language: Box<dyn LanguageAdapter>) -> Self {
82        Self { language }
83    }
84
85    // -----------------------------------------------------------------
86    // Public entry points
87    // -----------------------------------------------------------------
88
89    /// Infer a pattern from a single before/after example pair.
90    ///
91    /// # Errors
92    ///
93    /// Returns [`CodemodError::PatternInference`] if the ASTs cannot be
94    /// compared or no meaningful pattern can be derived.
95    pub fn infer_from_example(&self, before: &str, after: &str) -> crate::Result<Pattern> {
96        let before_tree = self.parse(before)?;
97        let after_tree = self.parse(after)?;
98
99        // Flatten trees into snapshot vectors for easier traversal.
100        let before_snaps = Self::flatten_tree(&before_tree, before);
101        let after_snaps = Self::flatten_tree(&after_tree, after);
102
103        // Structural diff starting from the root nodes.
104        let mut var_counter: usize = 0;
105        let mut variables: Vec<PatternVar> = Vec::new();
106        // Map from (before_text, after_text) -> variable name, used to
107        // ensure the *same* textual change gets the *same* variable.
108        let mut var_map: HashMap<(String, String), String> = HashMap::new();
109
110        let before_template = self.build_template(
111            &before_snaps,
112            &after_snaps,
113            0,
114            0,
115            before,
116            TemplateSource::Before,
117            &mut var_counter,
118            &mut variables,
119            &mut var_map,
120        );
121
122        let after_template = self.build_template(
123            &before_snaps,
124            &after_snaps,
125            0,
126            0,
127            after,
128            TemplateSource::After,
129            &mut var_counter,
130            &mut variables,
131            &mut var_map,
132        );
133
134        let confidence = Self::compute_confidence(&variables, &before_template, &after_template);
135
136        let pattern = Pattern::new(
137            before_template,
138            after_template,
139            variables,
140            self.language.name().to_string(),
141            confidence,
142        );
143
144        Ok(pattern)
145    }
146
147    /// Infer a pattern from multiple before/after example pairs.
148    ///
149    /// The algorithm infers a pattern from the first pair and then validates
150    /// it against subsequent pairs, refining the confidence score.
151    ///
152    /// # Errors
153    ///
154    /// Returns [`CodemodError::PatternInference`] if no consistent pattern
155    /// can be derived across the supplied examples.
156    pub fn infer_from_examples(&self, examples: &[(String, String)]) -> crate::Result<Pattern> {
157        if examples.is_empty() {
158            return Err(CodemodError::PatternInference(
159                "At least one example pair is required".into(),
160            ));
161        }
162
163        // Infer from the first pair.
164        let mut pattern = self.infer_from_example(&examples[0].0, &examples[0].1)?;
165
166        if examples.len() == 1 {
167            return Ok(pattern);
168        }
169
170        // Cross-validate against subsequent pairs and adjust confidence.
171        let mut confirmed: usize = 1;
172        for (before, after) in &examples[1..] {
173            match self.infer_from_example(before, after) {
174                Ok(other) => {
175                    if Self::patterns_compatible(&pattern, &other) {
176                        confirmed += 1;
177                    } else {
178                        log::warn!("Example pair produced an incompatible pattern — skipping");
179                    }
180                }
181                Err(e) => {
182                    log::warn!("Failed to infer from example pair: {e}");
183                }
184            }
185        }
186
187        // Confidence is boosted proportionally to the number of confirmed
188        // examples.
189        let cross_factor = confirmed as f64 / examples.len() as f64;
190        pattern.confidence = (pattern.confidence * 0.6 + cross_factor * 0.4).min(1.0);
191
192        Ok(pattern)
193    }
194
195    // -----------------------------------------------------------------
196    // Parsing helpers
197    // -----------------------------------------------------------------
198
199    /// Parse the given source code into a tree-sitter [`Tree`].
200    fn parse(&self, source: &str) -> crate::Result<Tree> {
201        let mut parser = Parser::new();
202        parser
203            .set_language(&self.language.language())
204            .map_err(|e| CodemodError::Parse(format!("Failed to set language: {e}")))?;
205        parser
206            .parse(source, None)
207            .ok_or_else(|| CodemodError::Parse("tree-sitter returned no tree".into()))
208    }
209
210    // -----------------------------------------------------------------
211    // Tree flattening
212    // -----------------------------------------------------------------
213
214    /// Flatten a tree-sitter tree into a `Vec<NodeSnapshot>`. Index `0` is
215    /// always the root node.
216    fn flatten_tree(tree: &Tree, source: &str) -> Vec<NodeSnapshot> {
217        let mut snaps = Vec::new();
218        Self::flatten_node(tree.root_node(), source, &mut snaps, 0);
219        snaps
220    }
221
222    /// Recursively flatten a node and its children.
223    fn flatten_node(
224        node: Node,
225        source: &str,
226        snaps: &mut Vec<NodeSnapshot>,
227        depth: usize,
228    ) -> usize {
229        let idx = snaps.len();
230        // Push a placeholder that we will fill in with child indices.
231        snaps.push(NodeSnapshot {
232            kind: node.kind().to_string(),
233            text: source[node.byte_range()].to_string(),
234            is_named: node.is_named(),
235            children: Vec::new(),
236            depth,
237        });
238
239        let mut child_indices = Vec::new();
240        let child_count = node.named_child_count();
241        for i in 0..child_count {
242            if let Some(child) = node.named_child(i) {
243                let child_idx = Self::flatten_node(child, source, snaps, depth + 1);
244                child_indices.push(child_idx);
245            }
246        }
247
248        snaps[idx].children = child_indices;
249        idx
250    }
251
252    // -----------------------------------------------------------------
253    // Structural diff / template building
254    // -----------------------------------------------------------------
255
256    /// Build a template string for one side of the diff.
257    ///
258    /// Walks both snapshot trees in parallel. Where the trees agree, the
259    /// original text is emitted verbatim. Where they disagree at a leaf
260    /// level, a `$variable` placeholder is emitted.
261    #[allow(clippy::too_many_arguments)]
262    fn build_template(
263        &self,
264        before_snaps: &[NodeSnapshot],
265        after_snaps: &[NodeSnapshot],
266        before_idx: usize,
267        after_idx: usize,
268        source: &str,
269        side: TemplateSource,
270        var_counter: &mut usize,
271        variables: &mut Vec<PatternVar>,
272        var_map: &mut HashMap<(String, String), String>,
273    ) -> String {
274        // Guard against out-of-bounds.
275        if before_idx >= before_snaps.len() || after_idx >= after_snaps.len() {
276            return source.to_string();
277        }
278
279        let b_snap = &before_snaps[before_idx];
280        let a_snap = &after_snaps[after_idx];
281
282        match self.diff_nodes(b_snap, a_snap) {
283            DiffKind::Same => {
284                // Trees agree — return original text from the requested side.
285                match side {
286                    TemplateSource::Before => b_snap.text.clone(),
287                    TemplateSource::After => a_snap.text.clone(),
288                }
289            }
290            DiffKind::Changed {
291                before_text,
292                after_text,
293                node_kind,
294            } => {
295                // Leaf-level change — introduce or reuse a variable.
296                let key = (before_text.clone(), after_text.clone());
297                let var_name = if let Some(name) = var_map.get(&key) {
298                    name.clone()
299                } else {
300                    *var_counter += 1;
301                    let name = format!("$var{}", *var_counter);
302                    var_map.insert(key, name.clone());
303                    variables.push(PatternVar {
304                        name: name.clone(),
305                        node_type: Some(node_kind),
306                    });
307                    name
308                };
309                var_name
310            }
311            DiffKind::Structural => {
312                // The structures diverge. If there are children we can still
313                // try to walk them in parallel; otherwise fall back to the
314                // entire text as a variable.
315                if b_snap.children.is_empty() && a_snap.children.is_empty() {
316                    let key = (b_snap.text.clone(), a_snap.text.clone());
317                    let var_name = if let Some(name) = var_map.get(&key) {
318                        name.clone()
319                    } else {
320                        *var_counter += 1;
321                        let name = format!("$var{}", *var_counter);
322                        var_map.insert(key, name.clone());
323                        variables.push(PatternVar {
324                            name: name.clone(),
325                            node_type: Some(b_snap.kind.clone()),
326                        });
327                        name
328                    };
329                    return var_name;
330                }
331
332                // Attempt parallel walk of children and reconstruct the
333                // template using the source text as a scaffold.
334                self.build_template_from_children(
335                    before_snaps,
336                    after_snaps,
337                    b_snap,
338                    a_snap,
339                    source,
340                    side,
341                    var_counter,
342                    variables,
343                    var_map,
344                )
345            }
346        }
347    }
348
349    /// Reconstruct a template by walking the children of two differing
350    /// structural nodes in parallel and stitching the results into the
351    /// original source text.
352    #[allow(clippy::too_many_arguments)]
353    fn build_template_from_children(
354        &self,
355        before_snaps: &[NodeSnapshot],
356        after_snaps: &[NodeSnapshot],
357        b_snap: &NodeSnapshot,
358        a_snap: &NodeSnapshot,
359        _source: &str,
360        side: TemplateSource,
361        var_counter: &mut usize,
362        variables: &mut Vec<PatternVar>,
363        var_map: &mut HashMap<(String, String), String>,
364    ) -> String {
365        let base_snap = match side {
366            TemplateSource::Before => b_snap,
367            TemplateSource::After => a_snap,
368        };
369        let base_text = &base_snap.text;
370
371        // Walk the minimum number of children present on both sides.
372        let min_children = b_snap.children.len().min(a_snap.children.len());
373        if min_children == 0 {
374            return base_text.clone();
375        }
376
377        let mut result = base_text.clone();
378        // We replace child texts from last to first to preserve byte offsets
379        // within `result`.
380        let mut replacements: Vec<(String, String)> = Vec::new();
381
382        for i in 0..min_children {
383            let b_child_idx = b_snap.children[i];
384            let a_child_idx = a_snap.children[i];
385
386            let child_template = self.build_template(
387                before_snaps,
388                after_snaps,
389                b_child_idx,
390                a_child_idx,
391                match side {
392                    TemplateSource::Before => &before_snaps[b_child_idx].text,
393                    TemplateSource::After => &after_snaps[a_child_idx].text,
394                },
395                side,
396                var_counter,
397                variables,
398                var_map,
399            );
400
401            let original_child_text = match side {
402                TemplateSource::Before => &before_snaps[b_child_idx].text,
403                TemplateSource::After => &after_snaps[a_child_idx].text,
404            };
405
406            if child_template != *original_child_text {
407                replacements.push((original_child_text.clone(), child_template));
408            }
409        }
410
411        // Apply replacements. We do a simple first-occurrence replacement for
412        // each pair. For identical child texts this is a best-effort heuristic.
413        for (old, new) in replacements.iter().rev() {
414            if let Some(pos) = result.rfind(old.as_str()) {
415                result.replace_range(pos..pos + old.len(), new);
416            }
417        }
418
419        result
420    }
421
422    /// Compare two node snapshots and classify the relationship.
423    fn diff_nodes(&self, before: &NodeSnapshot, after: &NodeSnapshot) -> DiffKind {
424        // Exact text match — trivially the same.
425        if before.text == after.text {
426            return DiffKind::Same;
427        }
428
429        // Both are leaf nodes of the same kind — treat as a variable change.
430        if before.children.is_empty() && after.children.is_empty() && before.kind == after.kind {
431            return DiffKind::Changed {
432                before_text: before.text.clone(),
433                after_text: after.text.clone(),
434                node_kind: before.kind.clone(),
435            };
436        }
437
438        // Same node kind with children — structural diff is needed.
439        if before.kind == after.kind {
440            return DiffKind::Structural;
441        }
442
443        // Completely different kinds — treat as a structural change.
444        DiffKind::Structural
445    }
446
447    // -----------------------------------------------------------------
448    // Confidence computation
449    // -----------------------------------------------------------------
450
451    /// Compute a heuristic confidence score for the inferred pattern.
452    ///
453    /// Factors considered:
454    /// - Number of variables (fewer -> higher confidence).
455    /// - Ratio of template text that is *fixed* vs. *variable*.
456    fn compute_confidence(
457        variables: &[PatternVar],
458        before_template: &str,
459        _after_template: &str,
460    ) -> f64 {
461        if before_template.is_empty() {
462            return 0.0;
463        }
464
465        let total_len = before_template.len() as f64;
466        let var_len: f64 = variables.iter().map(|v| v.name.len() as f64).sum();
467
468        // Fixed ratio: how much of the template is literal code.
469        let fixed_ratio = 1.0 - (var_len / total_len).min(1.0);
470
471        // Penalty for too many variables.
472        let var_penalty = 1.0 / (1.0 + variables.len() as f64 * 0.15);
473
474        (fixed_ratio * 0.7 + var_penalty * 0.3).clamp(0.0, 1.0)
475    }
476
477    // -----------------------------------------------------------------
478    // Cross-example compatibility check
479    // -----------------------------------------------------------------
480
481    /// Check whether two independently-inferred patterns are "compatible",
482    /// meaning they have the same variable count and the same fixed template
483    /// skeleton.
484    fn patterns_compatible(a: &Pattern, b: &Pattern) -> bool {
485        // Same number of variables is a strong signal.
486        if a.variables.len() != b.variables.len() {
487            return false;
488        }
489
490        // Strip variable placeholders and compare the skeletons.
491        let skeleton_a = Self::strip_variables(&a.before_template);
492        let skeleton_b = Self::strip_variables(&b.before_template);
493
494        skeleton_a == skeleton_b
495    }
496
497    /// Replace all `$varN` placeholders with a fixed sentinel so that two
498    /// templates can be compared structurally.
499    fn strip_variables(template: &str) -> String {
500        let mut result = String::with_capacity(template.len());
501        let mut chars = template.chars().peekable();
502        while let Some(ch) = chars.next() {
503            if ch == '$' {
504                // Skip the variable name.
505                result.push_str("$$");
506                while let Some(&next) = chars.peek() {
507                    if next.is_alphanumeric() || next == '_' {
508                        chars.next();
509                    } else {
510                        break;
511                    }
512                }
513            } else {
514                result.push(ch);
515            }
516        }
517        result
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_strip_variables() {
527        let input = "foo($var1, $var2)";
528        let stripped = PatternInferrer::strip_variables(input);
529        assert_eq!(stripped, "foo($$, $$)");
530    }
531
532    #[test]
533    fn test_compute_confidence_no_variables() {
534        let vars: Vec<PatternVar> = vec![];
535        let conf = PatternInferrer::compute_confidence(&vars, "println!(\"hello\")", "");
536        // No variables -> high confidence.
537        assert!(conf > 0.9, "expected high confidence, got {conf}");
538    }
539
540    #[test]
541    fn test_compute_confidence_with_variables() {
542        let vars = vec![
543            PatternVar {
544                name: "$var1".into(),
545                node_type: Some("identifier".into()),
546            },
547            PatternVar {
548                name: "$var2".into(),
549                node_type: Some("identifier".into()),
550            },
551        ];
552        let conf = PatternInferrer::compute_confidence(&vars, "foo($var1, $var2)", "");
553        assert!(
554            conf > 0.0 && conf < 1.0,
555            "expected moderate confidence, got {conf}"
556        );
557    }
558
559    #[test]
560    fn test_patterns_compatible_same() {
561        let a = Pattern::new(
562            "foo($var1)".into(),
563            "bar($var1)".into(),
564            vec![PatternVar {
565                name: "$var1".into(),
566                node_type: None,
567            }],
568            "stub".into(),
569            0.9,
570        );
571        let b = Pattern::new(
572            "foo($var1)".into(),
573            "bar($var1)".into(),
574            vec![PatternVar {
575                name: "$var1".into(),
576                node_type: None,
577            }],
578            "stub".into(),
579            0.8,
580        );
581        assert!(PatternInferrer::patterns_compatible(&a, &b));
582    }
583
584    #[test]
585    fn test_patterns_incompatible_different_var_count() {
586        let a = Pattern::new(
587            "foo($var1)".into(),
588            "bar($var1)".into(),
589            vec![PatternVar {
590                name: "$var1".into(),
591                node_type: None,
592            }],
593            "stub".into(),
594            0.9,
595        );
596        let b = Pattern::new(
597            "foo($var1, $var2)".into(),
598            "bar($var1, $var2)".into(),
599            vec![
600                PatternVar {
601                    name: "$var1".into(),
602                    node_type: None,
603                },
604                PatternVar {
605                    name: "$var2".into(),
606                    node_type: None,
607                },
608            ],
609            "stub".into(),
610            0.8,
611        );
612        assert!(!PatternInferrer::patterns_compatible(&a, &b));
613    }
614}