Skip to main content

dupes_core/
node.rs

1use std::collections::HashMap;
2
3/// Kinds of literals — preserves type but erases value.
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5pub enum LiteralKind {
6    Int,
7    Float,
8    Str,
9    ByteStr,
10    CStr,
11    Byte,
12    Char,
13    Bool,
14}
15
16/// Kinds of placeholders — what the original identifier referred to.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum PlaceholderKind {
19    Variable,
20    Function,
21    Type,
22    Lifetime,
23    Label,
24}
25
26/// Binary operators.
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub enum BinOpKind {
29    Add,
30    Sub,
31    Mul,
32    Div,
33    Rem,
34    And,
35    Or,
36    BitXor,
37    BitAnd,
38    BitOr,
39    Shl,
40    Shr,
41    Eq,
42    Lt,
43    Le,
44    Ne,
45    Ge,
46    Gt,
47    AddAssign,
48    SubAssign,
49    MulAssign,
50    DivAssign,
51    RemAssign,
52    BitXorAssign,
53    BitAndAssign,
54    BitOrAssign,
55    ShlAssign,
56    ShrAssign,
57    Other,
58}
59
60/// Unary operators.
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62pub enum UnOpKind {
63    Deref,
64    Not,
65    Neg,
66    Other,
67}
68
69/// The kind of a normalized AST node. Carries only non-child data
70/// (operator kinds, literal kinds, placeholder indices, mutability flags, macro names).
71#[derive(Debug, Clone, PartialEq, Eq, Hash)]
72pub enum NodeKind {
73    // Blocks and statements
74    Block,
75    LetBinding,
76    Semi,
77    Paren,
78
79    // Literals and identifiers
80    Literal(LiteralKind),
81    Placeholder(PlaceholderKind, usize),
82
83    // Operations
84    BinaryOp(BinOpKind),
85    UnaryOp(UnOpKind),
86    Range,
87
88    // Calls and access
89    Call,
90    MethodCall,
91    FieldAccess,
92    Index,
93    Path,
94
95    // Closures and functions
96    Closure,
97    FnSignature,
98
99    // Control flow
100    Return,
101    Break,
102    Continue,
103    Assign,
104
105    // References and pointers
106    Reference {
107        mutable: bool,
108    },
109
110    // Compound types
111    Tuple,
112    Array,
113    Repeat,
114
115    // Type operations
116    Cast,
117    StructInit,
118
119    // Async/error
120    Await,
121    Try,
122
123    // Control flow structures
124    If,
125    Match,
126    MatchArm,
127    Loop,
128    While,
129    ForLoop,
130    LetExpr,
131
132    // Patterns
133    PatWild,
134    PatPlaceholder(PlaceholderKind, usize),
135    PatTuple,
136    PatStruct,
137    PatOr,
138    PatLiteral,
139    PatReference {
140        mutable: bool,
141    },
142    PatSlice,
143    PatRest,
144    PatRange,
145
146    // Types
147    TypePlaceholder(PlaceholderKind, usize),
148    TypeReference {
149        mutable: bool,
150    },
151    TypeTuple,
152    TypeSlice,
153    TypeArray,
154    TypePath,
155    TypeImplTrait,
156    TypeInfer,
157    TypeUnit,
158    TypeNever,
159
160    // Field initializer (name = value)
161    FieldValue,
162
163    // Macro invocations
164    MacroCall {
165        name: String,
166    },
167
168    // Opaque — unsupported constructs
169    Opaque,
170
171    /// Sentinel for absent optional children, ensuring fixed child positions
172    /// for correct zip alignment in similarity comparison.
173    None,
174}
175
176/// A normalized AST node. Uses a data-driven `{ kind, children }` representation
177/// instead of a large enum with differently-shaped variants. This allows generic
178/// traversal algorithms (count_nodes, reindex, count_matching, extract) to work
179/// without exhaustive matching on every variant.
180///
181/// ## Child ordering conventions
182///
183/// - **Fixed with None sentinels** (always same child count):
184///   - `If` -> [condition, then_branch, else_or_None]
185///   - `LetBinding` -> [pattern, type_or_None, init_or_None, diverge_or_None]
186///   - `Range` / `PatRange` -> [from_or_None, to_or_None]
187///   - `MatchArm` -> [pattern, guard_or_None, body]
188/// - **Fixed children first, variable after** (for zip alignment):
189///   - `Call` -> [func, arg0, arg1, ...]
190///   - `MethodCall` -> [receiver, method, arg0, ...]
191///   - `Closure` -> [body, param0, ...]
192///   - `FnSignature` -> [return_type_or_None, param0, ...]
193///   - `Match` -> [expr, arm0, arm1, ...]
194///   - `StructInit` -> [rest_or_None, field0, field1, ...]
195///   - `MacroCall` -> [arg0, arg1, ...]
196/// - **Variable-length (0 or 1)**: `Return`, `Break` -> [] or [value]
197/// - **Homogeneous**: `Block`, `Tuple`, `Array`, `Path`, `PatTuple`, etc. -> [elem0, ...]
198/// - **All other fixed**: e.g. `BinaryOp` -> [left, right], `ForLoop` -> [pat, iter, body]
199#[derive(Debug, Clone, PartialEq, Eq, Hash)]
200pub struct NormalizedNode {
201    pub kind: NodeKind,
202    pub children: Vec<Self>,
203}
204
205impl NormalizedNode {
206    /// Create a leaf node (no children).
207    #[must_use]
208    pub const fn leaf(kind: NodeKind) -> Self {
209        Self {
210            kind,
211            children: vec![],
212        }
213    }
214
215    /// Create a node with children.
216    #[must_use]
217    pub const fn with_children(kind: NodeKind, children: Vec<Self>) -> Self {
218        Self { kind, children }
219    }
220
221    /// Create a None sentinel node.
222    #[must_use]
223    pub const fn none() -> Self {
224        Self::leaf(NodeKind::None)
225    }
226
227    /// Convert an Option<NormalizedNode> to a node, using None sentinel for absent values.
228    pub fn opt(node: Option<Self>) -> Self {
229        node.unwrap_or_else(Self::none)
230    }
231
232    /// Check if this is a None sentinel node.
233    #[must_use]
234    pub const fn is_none(&self) -> bool {
235        matches!(self.kind, NodeKind::None)
236    }
237}
238
239/// Tracks identifier-to-placeholder mappings during normalization.
240pub struct NormalizationContext {
241    /// Maps (identifier_string, kind) -> placeholder index
242    mappings: HashMap<(String, PlaceholderKind), usize>,
243    /// Per-kind counters
244    counters: HashMap<PlaceholderKind, usize>,
245}
246
247impl NormalizationContext {
248    #[must_use]
249    pub fn new() -> Self {
250        Self {
251            mappings: HashMap::new(),
252            counters: HashMap::new(),
253        }
254    }
255
256    /// Get or assign a placeholder index for the given identifier and kind.
257    pub fn placeholder(&mut self, name: &str, kind: PlaceholderKind) -> usize {
258        let key = (name.to_string(), kind);
259        if let Some(&idx) = self.mappings.get(&key) {
260            return idx;
261        }
262        let counter = self.counters.entry(kind).or_insert(0);
263        let idx = *counter;
264        *counter += 1;
265        self.mappings.insert(key, idx);
266        idx
267    }
268}
269
270impl Default for NormalizationContext {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276// -- Placeholder re-indexing --------------------------------------------------
277
278/// Collects all placeholder occurrences in depth-first order, building
279/// a mapping from (kind, old_index) -> new_sequential_index.
280fn collect_placeholder_order(
281    node: &NormalizedNode,
282    order: &mut Vec<(PlaceholderKind, usize)>,
283    seen: &mut std::collections::HashSet<(PlaceholderKind, usize)>,
284) {
285    match &node.kind {
286        NodeKind::Placeholder(kind, idx)
287        | NodeKind::PatPlaceholder(kind, idx)
288        | NodeKind::TypePlaceholder(kind, idx) => {
289            if seen.insert((*kind, *idx)) {
290                order.push((*kind, *idx));
291            }
292        }
293        _ => {}
294    }
295    for child in &node.children {
296        collect_placeholder_order(child, order, seen);
297    }
298}
299
300/// Applies the reindex mapping to a node, returning a new node with remapped indices.
301fn apply_reindex(
302    node: &NormalizedNode,
303    mapping: &HashMap<(PlaceholderKind, usize), usize>,
304) -> NormalizedNode {
305    let kind = match &node.kind {
306        NodeKind::Placeholder(kind, idx) => {
307            let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
308            NodeKind::Placeholder(*kind, new_idx)
309        }
310        NodeKind::PatPlaceholder(kind, idx) => {
311            let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
312            NodeKind::PatPlaceholder(*kind, new_idx)
313        }
314        NodeKind::TypePlaceholder(kind, idx) => {
315            let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
316            NodeKind::TypePlaceholder(*kind, new_idx)
317        }
318        other => other.clone(),
319    };
320    let children = node
321        .children
322        .iter()
323        .map(|c| apply_reindex(c, mapping))
324        .collect();
325    NormalizedNode { kind, children }
326}
327
328/// Re-index all placeholders in a sub-tree so that indices start from 0
329/// per kind, assigned by first-occurrence depth-first order.
330/// This allows comparing sub-trees extracted from different function contexts.
331#[must_use]
332pub fn reindex_placeholders(node: &NormalizedNode) -> NormalizedNode {
333    let mut order = Vec::new();
334    let mut seen = std::collections::HashSet::new();
335    collect_placeholder_order(node, &mut order, &mut seen);
336
337    // Build mapping: (kind, old_index) -> new sequential index per kind
338    let mut counters: HashMap<PlaceholderKind, usize> = HashMap::new();
339    let mut mapping: HashMap<(PlaceholderKind, usize), usize> = HashMap::new();
340    for (kind, old_idx) in order {
341        let counter = counters.entry(kind).or_insert(0);
342        mapping.insert((kind, old_idx), *counter);
343        *counter += 1;
344    }
345
346    apply_reindex(node, &mapping)
347}
348
349/// Count the number of nodes in a normalized tree.
350/// None sentinel nodes are not counted.
351pub fn count_nodes(node: &NormalizedNode) -> usize {
352    if node.is_none() {
353        return 0;
354    }
355    1 + node.children.iter().map(count_nodes).sum::<usize>()
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn reindex_remaps_from_zero() {
364        let node = NormalizedNode::with_children(
365            NodeKind::BinaryOp(BinOpKind::Add),
366            vec![
367                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
368                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 8)),
369            ],
370        );
371        let reindexed = reindex_placeholders(&node);
372        let expected = NormalizedNode::with_children(
373            NodeKind::BinaryOp(BinOpKind::Add),
374            vec![
375                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
376                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 1)),
377            ],
378        );
379        assert_eq!(reindexed, expected);
380    }
381
382    #[test]
383    fn reindex_preserves_same_placeholder_identity() {
384        let node = NormalizedNode::with_children(
385            NodeKind::BinaryOp(BinOpKind::Add),
386            vec![
387                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 3)),
388                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 3)),
389            ],
390        );
391        let reindexed = reindex_placeholders(&node);
392        let expected = NormalizedNode::with_children(
393            NodeKind::BinaryOp(BinOpKind::Add),
394            vec![
395                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
396                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
397            ],
398        );
399        assert_eq!(reindexed, expected);
400    }
401
402    #[test]
403    fn reindex_makes_equivalent_subtrees_equal() {
404        let subtree1 = NormalizedNode::with_children(
405            NodeKind::Block,
406            vec![
407                NormalizedNode::with_children(
408                    NodeKind::LetBinding,
409                    vec![
410                        NormalizedNode::leaf(NodeKind::PatPlaceholder(
411                            PlaceholderKind::Variable,
412                            2,
413                        )),
414                        NormalizedNode::none(),
415                        NormalizedNode::with_children(
416                            NodeKind::BinaryOp(BinOpKind::Add),
417                            vec![
418                                NormalizedNode::leaf(NodeKind::Placeholder(
419                                    PlaceholderKind::Variable,
420                                    0,
421                                )),
422                                NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
423                            ],
424                        ),
425                        NormalizedNode::none(),
426                    ],
427                ),
428                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 2)),
429            ],
430        );
431        let subtree2 = NormalizedNode::with_children(
432            NodeKind::Block,
433            vec![
434                NormalizedNode::with_children(
435                    NodeKind::LetBinding,
436                    vec![
437                        NormalizedNode::leaf(NodeKind::PatPlaceholder(
438                            PlaceholderKind::Variable,
439                            7,
440                        )),
441                        NormalizedNode::none(),
442                        NormalizedNode::with_children(
443                            NodeKind::BinaryOp(BinOpKind::Add),
444                            vec![
445                                NormalizedNode::leaf(NodeKind::Placeholder(
446                                    PlaceholderKind::Variable,
447                                    5,
448                                )),
449                                NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
450                            ],
451                        ),
452                        NormalizedNode::none(),
453                    ],
454                ),
455                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 7)),
456            ],
457        );
458
459        assert_ne!(subtree1, subtree2);
460        assert_eq!(
461            reindex_placeholders(&subtree1),
462            reindex_placeholders(&subtree2)
463        );
464    }
465
466    #[test]
467    fn reindex_handles_multiple_placeholder_kinds() {
468        let node = NormalizedNode::with_children(
469            NodeKind::Call,
470            vec![
471                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Function, 3)),
472                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
473                NormalizedNode::with_children(
474                    NodeKind::Cast,
475                    vec![
476                        NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
477                        NormalizedNode::leaf(NodeKind::TypePlaceholder(PlaceholderKind::Type, 2)),
478                    ],
479                ),
480            ],
481        );
482        let reindexed = reindex_placeholders(&node);
483        let expected = NormalizedNode::with_children(
484            NodeKind::Call,
485            vec![
486                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Function, 0)),
487                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
488                NormalizedNode::with_children(
489                    NodeKind::Cast,
490                    vec![
491                        NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
492                        NormalizedNode::leaf(NodeKind::TypePlaceholder(PlaceholderKind::Type, 0)),
493                    ],
494                ),
495            ],
496        );
497        assert_eq!(reindexed, expected);
498    }
499
500    #[test]
501    fn count_nodes_skips_none_sentinels() {
502        let node = NormalizedNode::with_children(
503            NodeKind::If,
504            vec![
505                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
506                NormalizedNode::with_children(NodeKind::Block, vec![]),
507                NormalizedNode::none(),
508            ],
509        );
510        // If(1) + Placeholder(1) + Block(1) = 3 (None is not counted)
511        assert_eq!(count_nodes(&node), 3);
512    }
513
514    // -- NormalizationContext tests --
515
516    #[test]
517    fn context_assigns_sequential_indices() {
518        let mut ctx = NormalizationContext::new();
519        assert_eq!(ctx.placeholder("x", PlaceholderKind::Variable), 0);
520        assert_eq!(ctx.placeholder("y", PlaceholderKind::Variable), 1);
521        assert_eq!(ctx.placeholder("z", PlaceholderKind::Variable), 2);
522    }
523
524    #[test]
525    fn context_returns_same_index_for_same_name() {
526        let mut ctx = NormalizationContext::new();
527        let first = ctx.placeholder("x", PlaceholderKind::Variable);
528        let second = ctx.placeholder("x", PlaceholderKind::Variable);
529        assert_eq!(first, second);
530        assert_eq!(first, 0);
531    }
532
533    #[test]
534    fn context_per_kind_counters_are_independent() {
535        let mut ctx = NormalizationContext::new();
536        let var_idx = ctx.placeholder("foo", PlaceholderKind::Variable);
537        let fn_idx = ctx.placeholder("foo", PlaceholderKind::Function);
538        let type_idx = ctx.placeholder("foo", PlaceholderKind::Type);
539        // Each kind starts from 0 independently
540        assert_eq!(var_idx, 0);
541        assert_eq!(fn_idx, 0);
542        assert_eq!(type_idx, 0);
543    }
544
545    #[test]
546    fn context_same_name_different_kind_are_distinct() {
547        let mut ctx = NormalizationContext::new();
548        ctx.placeholder("x", PlaceholderKind::Variable);
549        ctx.placeholder("x", PlaceholderKind::Function);
550        // Second variable should get index 1, not 0
551        let y_var = ctx.placeholder("y", PlaceholderKind::Variable);
552        assert_eq!(y_var, 1);
553        let y_fn = ctx.placeholder("y", PlaceholderKind::Function);
554        assert_eq!(y_fn, 1);
555    }
556
557    // -- count_nodes tests --
558
559    #[test]
560    fn count_nodes_basic() {
561        let node = NormalizedNode::with_children(
562            NodeKind::BinaryOp(BinOpKind::Add),
563            vec![
564                NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
565                NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
566            ],
567        );
568        assert_eq!(count_nodes(&node), 3);
569    }
570}