use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use syn::visit::Visit;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum NormalizedToken {
Keyword(&'static str),
Operator(&'static str),
Ident(usize),
MethodCall(String),
FieldAccess(String),
IntLit,
FloatLit,
StrLit,
BoolLit(bool),
CharLit,
MacroCall(String),
Semi,
}
pub fn normalize_body(body: &syn::Block) -> Vec<NormalizedToken> {
let mut n = Normalizer {
tokens: Vec::new(),
ident_map: HashMap::new(),
next_ident_id: 0,
};
syn::visit::visit_block(&mut n, body);
n.tokens
}
pub fn normalize_stmts(stmts: &[syn::Stmt]) -> Vec<NormalizedToken> {
let mut n = Normalizer {
tokens: Vec::new(),
ident_map: HashMap::new(),
next_ident_id: 0,
};
stmts.iter().for_each(|stmt| n.visit_stmt(stmt));
n.tokens
}
pub fn structural_hash(tokens: &[NormalizedToken]) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
tokens.hash(&mut hasher);
hasher.finish()
}
pub fn jaccard_similarity(a: &[NormalizedToken], b: &[NormalizedToken]) -> f64 {
if a.is_empty() && b.is_empty() {
return 1.0;
}
if a.is_empty() || b.is_empty() {
return 0.0;
}
let mut counts_a: HashMap<&NormalizedToken, usize> = HashMap::new();
for t in a {
*counts_a.entry(t).or_insert(0) += 1;
}
let mut counts_b: HashMap<&NormalizedToken, usize> = HashMap::new();
for t in b {
*counts_b.entry(t).or_insert(0) += 1;
}
let all_keys: HashSet<&NormalizedToken> =
counts_a.keys().chain(counts_b.keys()).copied().collect();
let mut intersection = 0usize;
let mut union = 0usize;
for key in all_keys {
let ca = counts_a.get(key).copied().unwrap_or(0);
let cb = counts_b.get(key).copied().unwrap_or(0);
intersection += ca.min(cb);
union += ca.max(cb);
}
if union == 0 {
1.0
} else {
intersection as f64 / union as f64
}
}
struct Normalizer {
tokens: Vec<NormalizedToken>,
ident_map: HashMap<String, usize>,
next_ident_id: usize,
}
impl Normalizer {
fn resolve_ident(&mut self, name: &str) -> usize {
if let Some(&id) = self.ident_map.get(name) {
id
} else {
let id = self.next_ident_id;
self.next_ident_id += 1;
self.ident_map.insert(name.to_string(), id);
id
}
}
}
fn bin_op_str(op: &syn::BinOp) -> &'static str {
match op {
syn::BinOp::Add(_) => "+",
syn::BinOp::Sub(_) => "-",
syn::BinOp::Mul(_) => "*",
syn::BinOp::Div(_) => "/",
syn::BinOp::Rem(_) => "%",
syn::BinOp::And(_) => "&&",
syn::BinOp::Or(_) => "||",
syn::BinOp::BitXor(_) => "^",
syn::BinOp::BitAnd(_) => "&",
syn::BinOp::BitOr(_) => "|",
syn::BinOp::Shl(_) => "<<",
syn::BinOp::Shr(_) => ">>",
syn::BinOp::Eq(_) => "==",
syn::BinOp::Lt(_) => "<",
syn::BinOp::Le(_) => "<=",
syn::BinOp::Ne(_) => "!=",
syn::BinOp::Ge(_) => ">=",
syn::BinOp::Gt(_) => ">",
syn::BinOp::AddAssign(_) => "+=",
syn::BinOp::SubAssign(_) => "-=",
syn::BinOp::MulAssign(_) => "*=",
syn::BinOp::DivAssign(_) => "/=",
syn::BinOp::RemAssign(_) => "%=",
syn::BinOp::BitXorAssign(_) => "^=",
syn::BinOp::BitAndAssign(_) => "&=",
syn::BinOp::BitOrAssign(_) => "|=",
syn::BinOp::ShlAssign(_) => "<<=",
syn::BinOp::ShrAssign(_) => ">>=",
_ => "?op",
}
}
fn un_op_str(op: &syn::UnOp) -> &'static str {
match op {
syn::UnOp::Deref(_) => "*",
syn::UnOp::Not(_) => "!",
syn::UnOp::Neg(_) => "-",
_ => "?un",
}
}
impl<'ast> Visit<'ast> for Normalizer {
fn visit_stmt(&mut self, stmt: &'ast syn::Stmt) {
match stmt {
syn::Stmt::Local(local) => {
self.tokens.push(NormalizedToken::Keyword("let"));
self.visit_pat(&local.pat);
if let Some(init) = &local.init {
self.tokens.push(NormalizedToken::Operator("="));
self.visit_expr(&init.expr);
if let Some((_, diverge)) = &init.diverge {
self.tokens.push(NormalizedToken::Keyword("else"));
self.visit_expr(diverge);
}
}
self.tokens.push(NormalizedToken::Semi);
}
syn::Stmt::Expr(expr, semi) => {
self.visit_expr(expr);
if semi.is_some() {
self.tokens.push(NormalizedToken::Semi);
}
}
syn::Stmt::Macro(m) => {
let name = m
.mac
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
self.tokens.push(NormalizedToken::MacroCall(name));
self.tokens.push(NormalizedToken::Semi);
}
syn::Stmt::Item(_) => { }
}
}
fn visit_expr(&mut self, expr: &'ast syn::Expr) {
match expr {
syn::Expr::Lit(lit) => match &lit.lit {
syn::Lit::Int(_) => self.tokens.push(NormalizedToken::IntLit),
syn::Lit::Float(_) => self.tokens.push(NormalizedToken::FloatLit),
syn::Lit::Str(_) | syn::Lit::ByteStr(_) => {
self.tokens.push(NormalizedToken::StrLit);
}
syn::Lit::Bool(b) => self.tokens.push(NormalizedToken::BoolLit(b.value)),
syn::Lit::Char(_) | syn::Lit::Byte(_) => {
self.tokens.push(NormalizedToken::CharLit);
}
_ => {}
},
syn::Expr::Path(p) => {
if p.path.segments.len() == 1 {
let name = p.path.segments[0].ident.to_string();
let id = self.resolve_ident(&name);
self.tokens.push(NormalizedToken::Ident(id));
}
}
syn::Expr::Binary(e) => {
self.visit_expr(&e.left);
self.tokens
.push(NormalizedToken::Operator(bin_op_str(&e.op)));
self.visit_expr(&e.right);
}
syn::Expr::Unary(e) => {
self.tokens
.push(NormalizedToken::Operator(un_op_str(&e.op)));
self.visit_expr(&e.expr);
}
syn::Expr::Assign(e) => {
self.visit_expr(&e.left);
self.tokens.push(NormalizedToken::Operator("="));
self.visit_expr(&e.right);
}
syn::Expr::Call(e) => {
self.visit_expr(&e.func);
for arg in &e.args {
self.visit_expr(arg);
}
}
syn::Expr::MethodCall(e) => {
self.visit_expr(&e.receiver);
self.tokens
.push(NormalizedToken::MethodCall(e.method.to_string()));
for arg in &e.args {
self.visit_expr(arg);
}
}
syn::Expr::Field(e) => {
self.visit_expr(&e.base);
let field_name = match &e.member {
syn::Member::Named(ident) => ident.to_string(),
syn::Member::Unnamed(idx) => idx.index.to_string(),
};
self.tokens.push(NormalizedToken::FieldAccess(field_name));
}
syn::Expr::If(e) => {
self.tokens.push(NormalizedToken::Keyword("if"));
self.visit_expr(&e.cond);
for stmt in &e.then_branch.stmts {
self.visit_stmt(stmt);
}
if let Some((_, else_branch)) = &e.else_branch {
self.tokens.push(NormalizedToken::Keyword("else"));
self.visit_expr(else_branch);
}
}
syn::Expr::Match(e) => {
self.tokens.push(NormalizedToken::Keyword("match"));
self.visit_expr(&e.expr);
for arm in &e.arms {
self.visit_pat(&arm.pat);
if let Some((_, guard)) = &arm.guard {
self.tokens.push(NormalizedToken::Keyword("if"));
self.visit_expr(guard);
}
self.tokens.push(NormalizedToken::Operator("=>"));
self.visit_expr(&arm.body);
}
}
syn::Expr::ForLoop(e) => {
self.tokens.push(NormalizedToken::Keyword("for"));
self.visit_pat(&e.pat);
self.tokens.push(NormalizedToken::Keyword("in"));
self.visit_expr(&e.expr);
for stmt in &e.body.stmts {
self.visit_stmt(stmt);
}
}
syn::Expr::While(e) => {
self.tokens.push(NormalizedToken::Keyword("while"));
self.visit_expr(&e.cond);
for stmt in &e.body.stmts {
self.visit_stmt(stmt);
}
}
syn::Expr::Loop(e) => {
self.tokens.push(NormalizedToken::Keyword("loop"));
for stmt in &e.body.stmts {
self.visit_stmt(stmt);
}
}
syn::Expr::Block(e) => {
for stmt in &e.block.stmts {
self.visit_stmt(stmt);
}
}
syn::Expr::Return(e) => {
self.tokens.push(NormalizedToken::Keyword("return"));
if let Some(expr) = &e.expr {
self.visit_expr(expr);
}
}
syn::Expr::Break(e) => {
self.tokens.push(NormalizedToken::Keyword("break"));
if let Some(expr) = &e.expr {
self.visit_expr(expr);
}
}
syn::Expr::Continue(_) => {
self.tokens.push(NormalizedToken::Keyword("continue"));
}
syn::Expr::Reference(e) => {
self.tokens.push(NormalizedToken::Operator("&"));
if e.mutability.is_some() {
self.tokens.push(NormalizedToken::Keyword("mut"));
}
self.visit_expr(&e.expr);
}
syn::Expr::Index(e) => {
self.visit_expr(&e.expr);
self.tokens.push(NormalizedToken::Operator("[]"));
self.visit_expr(&e.index);
}
syn::Expr::Tuple(e) => {
self.tokens.push(NormalizedToken::Keyword("tuple"));
for elem in &e.elems {
self.visit_expr(elem);
}
}
syn::Expr::Array(e) => {
self.tokens.push(NormalizedToken::Keyword("array"));
for elem in &e.elems {
self.visit_expr(elem);
}
}
syn::Expr::Closure(e) => {
self.tokens.push(NormalizedToken::Keyword("closure"));
for input in &e.inputs {
self.visit_pat(input);
}
self.visit_expr(&e.body);
}
syn::Expr::Try(e) => {
self.visit_expr(&e.expr);
self.tokens.push(NormalizedToken::Operator("?"));
}
syn::Expr::Await(e) => {
self.visit_expr(&e.base);
self.tokens.push(NormalizedToken::Keyword("await"));
}
syn::Expr::Range(e) => {
if let Some(start) = &e.start {
self.visit_expr(start);
}
self.tokens.push(NormalizedToken::Operator(".."));
if let Some(end) = &e.end {
self.visit_expr(end);
}
}
syn::Expr::Cast(e) => {
self.visit_expr(&e.expr);
self.tokens.push(NormalizedToken::Keyword("as"));
}
syn::Expr::Paren(e) => {
self.visit_expr(&e.expr);
}
syn::Expr::Repeat(e) => {
self.tokens.push(NormalizedToken::Keyword("array"));
self.visit_expr(&e.expr);
self.visit_expr(&e.len);
}
syn::Expr::Let(e) => {
self.tokens.push(NormalizedToken::Keyword("let"));
self.visit_pat(&e.pat);
self.tokens.push(NormalizedToken::Operator("="));
self.visit_expr(&e.expr);
}
syn::Expr::Struct(e) => {
self.tokens.push(NormalizedToken::Keyword("struct"));
for field in &e.fields {
if let syn::Member::Named(ident) = &field.member {
self.tokens
.push(NormalizedToken::FieldAccess(ident.to_string()));
}
self.visit_expr(&field.expr);
}
if let Some(rest) = &e.rest {
self.tokens.push(NormalizedToken::Operator(".."));
self.visit_expr(rest);
}
}
syn::Expr::Yield(e) => {
self.tokens.push(NormalizedToken::Keyword("yield"));
if let Some(expr) = &e.expr {
self.visit_expr(expr);
}
}
syn::Expr::Macro(m) => {
let name = m
.mac
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
self.tokens.push(NormalizedToken::MacroCall(name));
}
_ => {
syn::visit::visit_expr(self, expr);
}
}
}
fn visit_pat(&mut self, pat: &'ast syn::Pat) {
match pat {
syn::Pat::Ident(p) => {
if p.mutability.is_some() {
self.tokens.push(NormalizedToken::Keyword("mut"));
}
let id = self.resolve_ident(&p.ident.to_string());
self.tokens.push(NormalizedToken::Ident(id));
if let Some((_, sub)) = &p.subpat {
self.tokens.push(NormalizedToken::Operator("@"));
self.visit_pat(sub);
}
}
syn::Pat::Wild(_) => {
self.tokens.push(NormalizedToken::Keyword("_"));
}
syn::Pat::Tuple(t) => {
self.tokens.push(NormalizedToken::Keyword("tuple"));
for elem in &t.elems {
self.visit_pat(elem);
}
}
syn::Pat::TupleStruct(ts) => {
self.tokens.push(NormalizedToken::Keyword("tuple"));
for elem in &ts.elems {
self.visit_pat(elem);
}
}
syn::Pat::Struct(s) => {
self.tokens.push(NormalizedToken::Keyword("struct"));
for field in &s.fields {
if let syn::Member::Named(ident) = &field.member {
self.tokens
.push(NormalizedToken::FieldAccess(ident.to_string()));
}
self.visit_pat(&field.pat);
}
}
syn::Pat::Lit(l) => {
match &l.lit {
syn::Lit::Int(_) => self.tokens.push(NormalizedToken::IntLit),
syn::Lit::Float(_) => self.tokens.push(NormalizedToken::FloatLit),
syn::Lit::Str(_) | syn::Lit::ByteStr(_) => {
self.tokens.push(NormalizedToken::StrLit);
}
syn::Lit::Bool(b) => {
self.tokens.push(NormalizedToken::BoolLit(b.value));
}
syn::Lit::Char(_) | syn::Lit::Byte(_) => {
self.tokens.push(NormalizedToken::CharLit);
}
_ => {}
}
}
syn::Pat::Reference(r) => {
self.tokens.push(NormalizedToken::Operator("&"));
if r.mutability.is_some() {
self.tokens.push(NormalizedToken::Keyword("mut"));
}
self.visit_pat(&r.pat);
}
syn::Pat::Or(o) => {
for (i, case) in o.cases.iter().enumerate() {
if i > 0 {
self.tokens.push(NormalizedToken::Operator("|"));
}
self.visit_pat(case);
}
}
syn::Pat::Slice(s) => {
self.tokens.push(NormalizedToken::Keyword("array"));
for elem in &s.elems {
self.visit_pat(elem);
}
}
syn::Pat::Rest(_) => {
self.tokens.push(NormalizedToken::Operator(".."));
}
syn::Pat::Range(r) => {
if let Some(start) = &r.start {
self.visit_expr(start);
}
self.tokens.push(NormalizedToken::Operator(".."));
if let Some(end) = &r.end {
self.visit_expr(end);
}
}
_ => {
syn::visit::visit_pat(self, pat);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_body(code: &str) -> syn::Block {
let wrapped = format!("fn test_fn() {{ {} }}", code);
let file = syn::parse_file(&wrapped).expect("parse failed");
let syn::Item::Fn(f) = &file.items[0] else {
unreachable!("wrapped code is always a function")
};
*f.block.clone()
}
#[test]
fn test_normalize_empty_body() {
let body = parse_body("");
let tokens = normalize_body(&body);
assert!(tokens.is_empty());
}
#[test]
fn test_normalize_let_binding() {
let body = parse_body("let x = 1;");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::Keyword("let")));
assert!(tokens.contains(&NormalizedToken::Ident(0)));
assert!(tokens.contains(&NormalizedToken::IntLit));
assert!(tokens.contains(&NormalizedToken::Semi));
}
#[test]
fn test_normalize_same_structure_different_names_same_hash() {
let body_a = parse_body("let x = a + b;");
let body_b = parse_body("let y = p + q;");
let hash_a = structural_hash(&normalize_body(&body_a));
let hash_b = structural_hash(&normalize_body(&body_b));
assert_eq!(hash_a, hash_b);
}
#[test]
fn test_normalize_different_structure_different_hash() {
let body_a = parse_body("let x = a + b;");
let body_b = parse_body("let x = a * b;");
let hash_a = structural_hash(&normalize_body(&body_a));
let hash_b = structural_hash(&normalize_body(&body_b));
assert_ne!(hash_a, hash_b);
}
#[test]
fn test_structural_hash_deterministic() {
let body = parse_body("let x = foo(a, b);");
let hash1 = structural_hash(&normalize_body(&body));
let hash2 = structural_hash(&normalize_body(&body));
assert_eq!(hash1, hash2);
}
#[test]
fn test_jaccard_identical() {
let body = parse_body("let x = 1;");
let tokens = normalize_body(&body);
let sim = jaccard_similarity(&tokens, &tokens);
assert!((sim - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_jaccard_disjoint() {
let a = vec![NormalizedToken::IntLit, NormalizedToken::Keyword("if")];
let b = vec![NormalizedToken::StrLit, NormalizedToken::Keyword("for")];
let sim = jaccard_similarity(&a, &b);
assert!((sim).abs() < f64::EPSILON);
}
#[test]
fn test_jaccard_partial_overlap() {
let a = vec![
NormalizedToken::Keyword("let"),
NormalizedToken::IntLit,
NormalizedToken::Semi,
];
let b = vec![
NormalizedToken::Keyword("let"),
NormalizedToken::StrLit,
NormalizedToken::Semi,
];
let sim = jaccard_similarity(&a, &b);
assert!((sim - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_jaccard_both_empty() {
let sim = jaccard_similarity(&[], &[]);
assert!((sim - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_jaccard_one_empty() {
let a = vec![NormalizedToken::IntLit];
let sim = jaccard_similarity(&a, &[]);
assert!((sim).abs() < f64::EPSILON);
}
#[test]
fn test_normalize_if_expression() {
let body = parse_body("if x > 0 { return true; }");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::Keyword("if")));
assert!(tokens.contains(&NormalizedToken::Operator(">")));
assert!(tokens.contains(&NormalizedToken::Keyword("return")));
assert!(tokens.contains(&NormalizedToken::BoolLit(true)));
}
#[test]
fn test_normalize_method_call_preserves_name() {
let body = parse_body("x.push(42);");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::MethodCall("push".to_string())));
assert!(tokens.contains(&NormalizedToken::IntLit));
}
#[test]
fn test_normalize_field_access_preserves_name() {
let body = parse_body("let v = self.name;");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::FieldAccess("name".to_string())));
}
#[test]
fn test_normalize_bool_values_distinct() {
let body_true = parse_body("return true;");
let body_false = parse_body("return false;");
let hash_true = structural_hash(&normalize_body(&body_true));
let hash_false = structural_hash(&normalize_body(&body_false));
assert_ne!(hash_true, hash_false);
}
#[test]
fn test_normalize_stmts_subset() {
let body = parse_body("let a = 1; let b = 2; let c = 3;");
let tokens_first_two = normalize_stmts(&body.stmts[..2]);
let tokens_all = normalize_body(&body);
assert!(tokens_first_two.len() < tokens_all.len());
assert_eq!(tokens_first_two[..4], tokens_all[..4]);
}
#[test]
fn test_normalize_for_loop() {
let body = parse_body("for item in list { process(item); }");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::Keyword("for")));
assert!(tokens.contains(&NormalizedToken::Keyword("in")));
}
#[test]
fn test_normalize_match_expression() {
let body = parse_body("match x { 0 => true, _ => false }");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::Keyword("match")));
assert!(tokens.contains(&NormalizedToken::Operator("=>")));
assert!(tokens.contains(&NormalizedToken::BoolLit(true)));
assert!(tokens.contains(&NormalizedToken::BoolLit(false)));
assert!(tokens.contains(&NormalizedToken::Keyword("_")));
}
#[test]
fn test_normalize_closure() {
let body = parse_body("let f = |x| x + 1;");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::Keyword("closure")));
assert!(tokens.contains(&NormalizedToken::Operator("+")));
}
#[test]
fn test_normalize_try_operator() {
let body = parse_body("let r = foo()?;");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::Operator("?")));
}
#[test]
fn test_normalize_reference() {
let body = parse_body("let r = &mut x;");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::Operator("&")));
assert!(tokens.contains(&NormalizedToken::Keyword("mut")));
}
#[test]
fn test_normalize_macro_call() {
let body = parse_body("println!(\"hello\");");
let tokens = normalize_body(&body);
assert!(tokens.contains(&NormalizedToken::MacroCall("println".to_string())));
}
#[test]
fn test_normalize_complex_same_structure() {
let body_a =
parse_body("for item in items { if item.is_valid() { results.push(item.name()); } }");
let body_b =
parse_body("for entry in data { if entry.is_valid() { output.push(entry.name()); } }");
let hash_a = structural_hash(&normalize_body(&body_a));
let hash_b = structural_hash(&normalize_body(&body_b));
assert_eq!(hash_a, hash_b);
}
}