use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum LiteralKind {
Int,
Float,
Str,
ByteStr,
CStr,
Byte,
Char,
Bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PlaceholderKind {
Variable,
Function,
Type,
Lifetime,
Label,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BinOpKind {
Add,
Sub,
Mul,
Div,
Rem,
And,
Or,
BitXor,
BitAnd,
BitOr,
Shl,
Shr,
Eq,
Lt,
Le,
Ne,
Ge,
Gt,
AddAssign,
SubAssign,
MulAssign,
DivAssign,
RemAssign,
BitXorAssign,
BitAndAssign,
BitOrAssign,
ShlAssign,
ShrAssign,
Other,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum UnOpKind {
Deref,
Not,
Neg,
Other,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum NodeKind {
Block,
LetBinding,
Semi,
Paren,
Literal(LiteralKind),
Placeholder(PlaceholderKind, usize),
BinaryOp(BinOpKind),
UnaryOp(UnOpKind),
Range,
Call,
MethodCall,
FieldAccess,
Index,
Path,
Closure,
FnSignature,
Return,
Break,
Continue,
Assign,
Reference {
mutable: bool,
},
Tuple,
Array,
Repeat,
Cast,
StructInit,
Await,
Try,
If,
Match,
MatchArm,
Loop,
While,
ForLoop,
LetExpr,
PatWild,
PatPlaceholder(PlaceholderKind, usize),
PatTuple,
PatStruct,
PatOr,
PatLiteral,
PatReference {
mutable: bool,
},
PatSlice,
PatRest,
PatRange,
TypePlaceholder(PlaceholderKind, usize),
TypeReference {
mutable: bool,
},
TypeTuple,
TypeSlice,
TypeArray,
TypePath,
TypeImplTrait,
TypeInfer,
TypeUnit,
TypeNever,
FieldValue,
MacroCall {
name: String,
},
Opaque,
None,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NormalizedNode {
pub kind: NodeKind,
pub children: Vec<Self>,
}
impl NormalizedNode {
#[must_use]
pub const fn leaf(kind: NodeKind) -> Self {
Self {
kind,
children: vec![],
}
}
#[must_use]
pub const fn with_children(kind: NodeKind, children: Vec<Self>) -> Self {
Self { kind, children }
}
#[must_use]
pub const fn none() -> Self {
Self::leaf(NodeKind::None)
}
pub fn opt(node: Option<Self>) -> Self {
node.unwrap_or_else(Self::none)
}
#[must_use]
pub const fn is_none(&self) -> bool {
matches!(self.kind, NodeKind::None)
}
}
pub struct NormalizationContext {
mappings: HashMap<(String, PlaceholderKind), usize>,
counters: HashMap<PlaceholderKind, usize>,
}
impl NormalizationContext {
#[must_use]
pub fn new() -> Self {
Self {
mappings: HashMap::new(),
counters: HashMap::new(),
}
}
pub fn placeholder(&mut self, name: &str, kind: PlaceholderKind) -> usize {
let key = (name.to_string(), kind);
if let Some(&idx) = self.mappings.get(&key) {
return idx;
}
let counter = self.counters.entry(kind).or_insert(0);
let idx = *counter;
*counter += 1;
self.mappings.insert(key, idx);
idx
}
}
impl Default for NormalizationContext {
fn default() -> Self {
Self::new()
}
}
fn collect_placeholder_order(
node: &NormalizedNode,
order: &mut Vec<(PlaceholderKind, usize)>,
seen: &mut std::collections::HashSet<(PlaceholderKind, usize)>,
) {
match &node.kind {
NodeKind::Placeholder(kind, idx)
| NodeKind::PatPlaceholder(kind, idx)
| NodeKind::TypePlaceholder(kind, idx) => {
if seen.insert((*kind, *idx)) {
order.push((*kind, *idx));
}
}
_ => {}
}
for child in &node.children {
collect_placeholder_order(child, order, seen);
}
}
fn apply_reindex(
node: &NormalizedNode,
mapping: &HashMap<(PlaceholderKind, usize), usize>,
) -> NormalizedNode {
let kind = match &node.kind {
NodeKind::Placeholder(kind, idx) => {
let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
NodeKind::Placeholder(*kind, new_idx)
}
NodeKind::PatPlaceholder(kind, idx) => {
let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
NodeKind::PatPlaceholder(*kind, new_idx)
}
NodeKind::TypePlaceholder(kind, idx) => {
let new_idx = mapping.get(&(*kind, *idx)).copied().unwrap_or(*idx);
NodeKind::TypePlaceholder(*kind, new_idx)
}
other => other.clone(),
};
let children = node
.children
.iter()
.map(|c| apply_reindex(c, mapping))
.collect();
NormalizedNode { kind, children }
}
#[must_use]
pub fn reindex_placeholders(node: &NormalizedNode) -> NormalizedNode {
let mut order = Vec::new();
let mut seen = std::collections::HashSet::new();
collect_placeholder_order(node, &mut order, &mut seen);
let mut counters: HashMap<PlaceholderKind, usize> = HashMap::new();
let mut mapping: HashMap<(PlaceholderKind, usize), usize> = HashMap::new();
for (kind, old_idx) in order {
let counter = counters.entry(kind).or_insert(0);
mapping.insert((kind, old_idx), *counter);
*counter += 1;
}
apply_reindex(node, &mapping)
}
pub fn count_nodes(node: &NormalizedNode) -> usize {
if node.is_none() {
return 0;
}
1 + node.children.iter().map(count_nodes).sum::<usize>()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reindex_remaps_from_zero() {
let node = NormalizedNode::with_children(
NodeKind::BinaryOp(BinOpKind::Add),
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 8)),
],
);
let reindexed = reindex_placeholders(&node);
let expected = NormalizedNode::with_children(
NodeKind::BinaryOp(BinOpKind::Add),
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 1)),
],
);
assert_eq!(reindexed, expected);
}
#[test]
fn reindex_preserves_same_placeholder_identity() {
let node = NormalizedNode::with_children(
NodeKind::BinaryOp(BinOpKind::Add),
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 3)),
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 3)),
],
);
let reindexed = reindex_placeholders(&node);
let expected = NormalizedNode::with_children(
NodeKind::BinaryOp(BinOpKind::Add),
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
],
);
assert_eq!(reindexed, expected);
}
#[test]
fn reindex_makes_equivalent_subtrees_equal() {
let subtree1 = NormalizedNode::with_children(
NodeKind::Block,
vec![
NormalizedNode::with_children(
NodeKind::LetBinding,
vec![
NormalizedNode::leaf(NodeKind::PatPlaceholder(
PlaceholderKind::Variable,
2,
)),
NormalizedNode::none(),
NormalizedNode::with_children(
NodeKind::BinaryOp(BinOpKind::Add),
vec![
NormalizedNode::leaf(NodeKind::Placeholder(
PlaceholderKind::Variable,
0,
)),
NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
],
),
NormalizedNode::none(),
],
),
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 2)),
],
);
let subtree2 = NormalizedNode::with_children(
NodeKind::Block,
vec![
NormalizedNode::with_children(
NodeKind::LetBinding,
vec![
NormalizedNode::leaf(NodeKind::PatPlaceholder(
PlaceholderKind::Variable,
7,
)),
NormalizedNode::none(),
NormalizedNode::with_children(
NodeKind::BinaryOp(BinOpKind::Add),
vec![
NormalizedNode::leaf(NodeKind::Placeholder(
PlaceholderKind::Variable,
5,
)),
NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
],
),
NormalizedNode::none(),
],
),
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 7)),
],
);
assert_ne!(subtree1, subtree2);
assert_eq!(
reindex_placeholders(&subtree1),
reindex_placeholders(&subtree2)
);
}
#[test]
fn reindex_handles_multiple_placeholder_kinds() {
let node = NormalizedNode::with_children(
NodeKind::Call,
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Function, 3)),
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
NormalizedNode::with_children(
NodeKind::Cast,
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 5)),
NormalizedNode::leaf(NodeKind::TypePlaceholder(PlaceholderKind::Type, 2)),
],
),
],
);
let reindexed = reindex_placeholders(&node);
let expected = NormalizedNode::with_children(
NodeKind::Call,
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Function, 0)),
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
NormalizedNode::with_children(
NodeKind::Cast,
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
NormalizedNode::leaf(NodeKind::TypePlaceholder(PlaceholderKind::Type, 0)),
],
),
],
);
assert_eq!(reindexed, expected);
}
#[test]
fn count_nodes_skips_none_sentinels() {
let node = NormalizedNode::with_children(
NodeKind::If,
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
NormalizedNode::with_children(NodeKind::Block, vec![]),
NormalizedNode::none(),
],
);
assert_eq!(count_nodes(&node), 3);
}
#[test]
fn context_assigns_sequential_indices() {
let mut ctx = NormalizationContext::new();
assert_eq!(ctx.placeholder("x", PlaceholderKind::Variable), 0);
assert_eq!(ctx.placeholder("y", PlaceholderKind::Variable), 1);
assert_eq!(ctx.placeholder("z", PlaceholderKind::Variable), 2);
}
#[test]
fn context_returns_same_index_for_same_name() {
let mut ctx = NormalizationContext::new();
let first = ctx.placeholder("x", PlaceholderKind::Variable);
let second = ctx.placeholder("x", PlaceholderKind::Variable);
assert_eq!(first, second);
assert_eq!(first, 0);
}
#[test]
fn context_per_kind_counters_are_independent() {
let mut ctx = NormalizationContext::new();
let var_idx = ctx.placeholder("foo", PlaceholderKind::Variable);
let fn_idx = ctx.placeholder("foo", PlaceholderKind::Function);
let type_idx = ctx.placeholder("foo", PlaceholderKind::Type);
assert_eq!(var_idx, 0);
assert_eq!(fn_idx, 0);
assert_eq!(type_idx, 0);
}
#[test]
fn context_same_name_different_kind_are_distinct() {
let mut ctx = NormalizationContext::new();
ctx.placeholder("x", PlaceholderKind::Variable);
ctx.placeholder("x", PlaceholderKind::Function);
let y_var = ctx.placeholder("y", PlaceholderKind::Variable);
assert_eq!(y_var, 1);
let y_fn = ctx.placeholder("y", PlaceholderKind::Function);
assert_eq!(y_fn, 1);
}
#[test]
fn count_nodes_basic() {
let node = NormalizedNode::with_children(
NodeKind::BinaryOp(BinOpKind::Add),
vec![
NormalizedNode::leaf(NodeKind::Placeholder(PlaceholderKind::Variable, 0)),
NormalizedNode::leaf(NodeKind::Literal(LiteralKind::Int)),
],
);
assert_eq!(count_nodes(&node), 3);
}
}