1use std::collections::HashMap;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5pub enum LiteralKind {
6 Int,
7 Float,
8 Str,
9 ByteStr,
10 CStr,
11 Byte,
12 Char,
13 Bool,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum PlaceholderKind {
19 Variable,
20 Function,
21 Type,
22 Lifetime,
23 Label,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub enum BinOpKind {
29 Add,
30 Sub,
31 Mul,
32 Div,
33 Rem,
34 And,
35 Or,
36 BitXor,
37 BitAnd,
38 BitOr,
39 Shl,
40 Shr,
41 Eq,
42 Lt,
43 Le,
44 Ne,
45 Ge,
46 Gt,
47 AddAssign,
48 SubAssign,
49 MulAssign,
50 DivAssign,
51 RemAssign,
52 BitXorAssign,
53 BitAndAssign,
54 BitOrAssign,
55 ShlAssign,
56 ShrAssign,
57 Other,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62pub enum UnOpKind {
63 Deref,
64 Not,
65 Neg,
66 Other,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
72pub enum NodeKind {
73 Block,
75 LetBinding,
76 Semi,
77 Paren,
78
79 Literal(LiteralKind),
81 Placeholder(PlaceholderKind, usize),
82
83 BinaryOp(BinOpKind),
85 UnaryOp(UnOpKind),
86 Range,
87
88 Call,
90 MethodCall,
91 FieldAccess,
92 Index,
93 Path,
94
95 Closure,
97 FnSignature,
98
99 Return,
101 Break,
102 Continue,
103 Assign,
104
105 Reference {
107 mutable: bool,
108 },
109
110 Tuple,
112 Array,
113 Repeat,
114
115 Cast,
117 StructInit,
118
119 Await,
121 Try,
122
123 If,
125 Match,
126 MatchArm,
127 Loop,
128 While,
129 ForLoop,
130 LetExpr,
131
132 PatWild,
134 PatPlaceholder(PlaceholderKind, usize),
135 PatTuple,
136 PatStruct,
137 PatOr,
138 PatLiteral,
139 PatReference {
140 mutable: bool,
141 },
142 PatSlice,
143 PatRest,
144 PatRange,
145
146 TypePlaceholder(PlaceholderKind, usize),
148 TypeReference {
149 mutable: bool,
150 },
151 TypeTuple,
152 TypeSlice,
153 TypeArray,
154 TypePath,
155 TypeImplTrait,
156 TypeInfer,
157 TypeUnit,
158 TypeNever,
159
160 FieldValue,
162
163 MacroCall {
165 name: String,
166 },
167
168 Opaque,
170
171 None,
174}
175
176#[derive(Debug, Clone, PartialEq, Eq, Hash)]
200pub struct NormalizedNode {
201 pub kind: NodeKind,
202 pub children: Vec<Self>,
203}
204
205impl NormalizedNode {
206 #[must_use]
208 pub const fn leaf(kind: NodeKind) -> Self {
209 Self {
210 kind,
211 children: vec![],
212 }
213 }
214
215 #[must_use]
217 pub const fn with_children(kind: NodeKind, children: Vec<Self>) -> Self {
218 Self { kind, children }
219 }
220
221 #[must_use]
223 pub const fn none() -> Self {
224 Self::leaf(NodeKind::None)
225 }
226
227 pub fn opt(node: Option<Self>) -> Self {
229 node.unwrap_or_else(Self::none)
230 }
231
232 #[must_use]
234 pub const fn is_none(&self) -> bool {
235 matches!(self.kind, NodeKind::None)
236 }
237}
238
239pub struct NormalizationContext {
241 mappings: HashMap<(String, PlaceholderKind), usize>,
243 counters: HashMap<PlaceholderKind, usize>,
245}
246
247impl NormalizationContext {
248 #[must_use]
249 pub fn new() -> Self {
250 Self {
251 mappings: HashMap::new(),
252 counters: HashMap::new(),
253 }
254 }
255
256 pub fn placeholder(&mut self, name: &str, kind: PlaceholderKind) -> usize {
258 let key = (name.to_string(), kind);
259 if let Some(&idx) = self.mappings.get(&key) {
260 return idx;
261 }
262 let counter = self.counters.entry(kind).or_insert(0);
263 let idx = *counter;
264 *counter += 1;
265 self.mappings.insert(key, idx);
266 idx
267 }
268}
269
270impl Default for NormalizationContext {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276fn collect_placeholder_order(
281 node: &NormalizedNode,
282 order: &mut Vec<(PlaceholderKind, usize)>,
283 seen: &mut std::collections::HashSet<(PlaceholderKind, usize)>,
284) {
285 match &node.kind {
286 NodeKind::Placeholder(kind, idx)
287 | NodeKind::PatPlaceholder(kind, idx)
288 | NodeKind::TypePlaceholder(kind, idx) => {
289 if seen.insert((*kind, *idx)) {
290 order.push((*kind, *idx));
291 }
292 }
293 _ => {}
294 }
295 for child in &node.children {
296 collect_placeholder_order(child, order, seen);
297 }
298}
299
300fn apply_reindex(
302 node: &NormalizedNode,
303 mapping: &HashMap<(PlaceholderKind, usize), usize>,
304) -> NormalizedNode {
305 let kind = match &node.kind {
306 NodeKind::Placeholder(kind, idx) => {
307 let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
308 NodeKind::Placeholder(*kind, new_idx)
309 }
310 NodeKind::PatPlaceholder(kind, idx) => {
311 let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
312 NodeKind::PatPlaceholder(*kind, new_idx)
313 }
314 NodeKind::TypePlaceholder(kind, idx) => {
315 let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
316 NodeKind::TypePlaceholder(*kind, new_idx)
317 }
318 other => other.clone(),
319 };
320 let children = node
321 .children
322 .iter()
323 .map(|c| apply_reindex(c, mapping))
324 .collect();
325 NormalizedNode { kind, children }
326}
327
328#[must_use]
332pub fn reindex_placeholders(node: &NormalizedNode) -> NormalizedNode {
333 let mut order = Vec::new();
334 let mut seen = std::collections::HashSet::new();
335 collect_placeholder_order(node, &mut order, &mut seen);
336
337 let mut counters: HashMap<PlaceholderKind, usize> = HashMap::new();
339 let mut mapping: HashMap<(PlaceholderKind, usize), usize> = HashMap::new();
340 for (kind, old_idx) in order {
341 let counter = counters.entry(kind).or_insert(0);
342 mapping.insert((kind, old_idx), *counter);
343 *counter += 1;
344 }
345
346 apply_reindex(node, &mapping)
347}
348
349pub fn count_nodes(node: &NormalizedNode) -> usize {
352 if node.is_none() {
353 return 0;
354 }
355 1 + node.children.iter().map(count_nodes).sum::<usize>()
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn reindex_remaps_from_zero() {
364 let node = NormalizedNode::with_children(
365 NodeKind::BinaryOp(BinOpKind::Add),
366 vec![
367 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
368 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 8)),
369 ],
370 );
371 let reindexed = reindex_placeholders(&node);
372 let expected = NormalizedNode::with_children(
373 NodeKind::BinaryOp(BinOpKind::Add),
374 vec![
375 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
376 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 1)),
377 ],
378 );
379 assert_eq!(reindexed, expected);
380 }
381
382 #[test]
383 fn reindex_preserves_same_placeholder_identity() {
384 let node = NormalizedNode::with_children(
385 NodeKind::BinaryOp(BinOpKind::Add),
386 vec![
387 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 3)),
388 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 3)),
389 ],
390 );
391 let reindexed = reindex_placeholders(&node);
392 let expected = NormalizedNode::with_children(
393 NodeKind::BinaryOp(BinOpKind::Add),
394 vec![
395 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
396 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
397 ],
398 );
399 assert_eq!(reindexed, expected);
400 }
401
402 #[test]
403 fn reindex_makes_equivalent_subtrees_equal() {
404 let subtree1 = NormalizedNode::with_children(
405 NodeKind::Block,
406 vec![
407 NormalizedNode::with_children(
408 NodeKind::LetBinding,
409 vec![
410 NormalizedNode::leaf(NodeKind::PatPlaceholder(
411 PlaceholderKind::Variable,
412 2,
413 )),
414 NormalizedNode::none(),
415 NormalizedNode::with_children(
416 NodeKind::BinaryOp(BinOpKind::Add),
417 vec![
418 NormalizedNode::leaf(NodeKind::Placeholder(
419 PlaceholderKind::Variable,
420 0,
421 )),
422 NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
423 ],
424 ),
425 NormalizedNode::none(),
426 ],
427 ),
428 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 2)),
429 ],
430 );
431 let subtree2 = NormalizedNode::with_children(
432 NodeKind::Block,
433 vec![
434 NormalizedNode::with_children(
435 NodeKind::LetBinding,
436 vec![
437 NormalizedNode::leaf(NodeKind::PatPlaceholder(
438 PlaceholderKind::Variable,
439 7,
440 )),
441 NormalizedNode::none(),
442 NormalizedNode::with_children(
443 NodeKind::BinaryOp(BinOpKind::Add),
444 vec![
445 NormalizedNode::leaf(NodeKind::Placeholder(
446 PlaceholderKind::Variable,
447 5,
448 )),
449 NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
450 ],
451 ),
452 NormalizedNode::none(),
453 ],
454 ),
455 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 7)),
456 ],
457 );
458
459 assert_ne!(subtree1, subtree2);
460 assert_eq!(
461 reindex_placeholders(&subtree1),
462 reindex_placeholders(&subtree2)
463 );
464 }
465
466 #[test]
467 fn reindex_handles_multiple_placeholder_kinds() {
468 let node = NormalizedNode::with_children(
469 NodeKind::Call,
470 vec![
471 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Function, 3)),
472 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
473 NormalizedNode::with_children(
474 NodeKind::Cast,
475 vec![
476 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
477 NormalizedNode::leaf(NodeKind::TypePlaceholder(PlaceholderKind::Type, 2)),
478 ],
479 ),
480 ],
481 );
482 let reindexed = reindex_placeholders(&node);
483 let expected = NormalizedNode::with_children(
484 NodeKind::Call,
485 vec![
486 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Function, 0)),
487 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
488 NormalizedNode::with_children(
489 NodeKind::Cast,
490 vec![
491 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
492 NormalizedNode::leaf(NodeKind::TypePlaceholder(PlaceholderKind::Type, 0)),
493 ],
494 ),
495 ],
496 );
497 assert_eq!(reindexed, expected);
498 }
499
500 #[test]
501 fn count_nodes_skips_none_sentinels() {
502 let node = NormalizedNode::with_children(
503 NodeKind::If,
504 vec![
505 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
506 NormalizedNode::with_children(NodeKind::Block, vec![]),
507 NormalizedNode::none(),
508 ],
509 );
510 assert_eq!(count_nodes(&node), 3);
512 }
513
514 #[test]
517 fn context_assigns_sequential_indices() {
518 let mut ctx = NormalizationContext::new();
519 assert_eq!(ctx.placeholder("x", PlaceholderKind::Variable), 0);
520 assert_eq!(ctx.placeholder("y", PlaceholderKind::Variable), 1);
521 assert_eq!(ctx.placeholder("z", PlaceholderKind::Variable), 2);
522 }
523
524 #[test]
525 fn context_returns_same_index_for_same_name() {
526 let mut ctx = NormalizationContext::new();
527 let first = ctx.placeholder("x", PlaceholderKind::Variable);
528 let second = ctx.placeholder("x", PlaceholderKind::Variable);
529 assert_eq!(first, second);
530 assert_eq!(first, 0);
531 }
532
533 #[test]
534 fn context_per_kind_counters_are_independent() {
535 let mut ctx = NormalizationContext::new();
536 let var_idx = ctx.placeholder("foo", PlaceholderKind::Variable);
537 let fn_idx = ctx.placeholder("foo", PlaceholderKind::Function);
538 let type_idx = ctx.placeholder("foo", PlaceholderKind::Type);
539 assert_eq!(var_idx, 0);
541 assert_eq!(fn_idx, 0);
542 assert_eq!(type_idx, 0);
543 }
544
545 #[test]
546 fn context_same_name_different_kind_are_distinct() {
547 let mut ctx = NormalizationContext::new();
548 ctx.placeholder("x", PlaceholderKind::Variable);
549 ctx.placeholder("x", PlaceholderKind::Function);
550 let y_var = ctx.placeholder("y", PlaceholderKind::Variable);
552 assert_eq!(y_var, 1);
553 let y_fn = ctx.placeholder("y", PlaceholderKind::Function);
554 assert_eq!(y_fn, 1);
555 }
556
557 #[test]
560 fn count_nodes_basic() {
561 let node = NormalizedNode::with_children(
562 NodeKind::BinaryOp(BinOpKind::Add),
563 vec![
564 NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
565 NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
566 ],
567 );
568 assert_eq!(count_nodes(&node), 3);
569 }
570}