use std::cmp::Ordering;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum HashKey {
Const {
type_name: String,
repr: String,
},
VarVN {
vn: usize,
},
Name {
name: String,
},
BinOp {
op: String,
left: Box<HashKey>,
right: Box<HashKey>,
commutative: bool,
},
UnaryOp {
op: String,
operand: Box<HashKey>,
},
BoolOp {
op: String,
operands: Vec<HashKey>,
},
Compare {
parts: Vec<String>,
},
Call {
unique_id: usize,
},
Attribute {
value: Box<HashKey>,
attr: String,
},
Subscript {
value: Box<HashKey>,
slice: Box<HashKey>,
},
Unique {
id: usize,
},
}
pub fn is_commutative(op: &str) -> bool {
matches!(
op,
"Add" | "Mult" | "BitOr" | "BitAnd" | "BitXor" |
"+" | "*" | "|" | "&" | "^" |
"==" | "!=" |
"and" | "or" |
"&&" | "||"
)
}
pub fn normalize_binop(op: &str, left: HashKey, right: HashKey) -> HashKey {
let commutative = is_commutative(op);
let (normalized_left, normalized_right) = if commutative {
match compare_hash_keys(&left, &right) {
Ordering::Greater => (right, left),
_ => (left, right),
}
} else {
(left, right)
};
HashKey::BinOp {
op: op.to_string(),
left: Box::new(normalized_left),
right: Box::new(normalized_right),
commutative,
}
}
fn compare_hash_keys(a: &HashKey, b: &HashKey) -> Ordering {
match (a, b) {
(
HashKey::Const {
type_name: t1,
repr: r1,
},
HashKey::Const {
type_name: t2,
repr: r2,
},
) => t1.cmp(t2).then_with(|| r1.cmp(r2)),
(HashKey::VarVN { vn: v1 }, HashKey::VarVN { vn: v2 }) => v1.cmp(v2),
(HashKey::Name { name: n1 }, HashKey::Name { name: n2 }) => n1.cmp(n2),
(
HashKey::BinOp {
op: o1,
left: l1,
right: r1,
..
},
HashKey::BinOp {
op: o2,
left: l2,
right: r2,
..
},
) => o1
.cmp(o2)
.then_with(|| compare_hash_keys(l1, l2))
.then_with(|| compare_hash_keys(r1, r2)),
(
HashKey::UnaryOp {
op: o1,
operand: op1,
},
HashKey::UnaryOp {
op: o2,
operand: op2,
},
) => o1.cmp(o2).then_with(|| compare_hash_keys(op1, op2)),
(HashKey::Call { unique_id: u1 }, HashKey::Call { unique_id: u2 }) => u1.cmp(u2),
(HashKey::Unique { id: u1 }, HashKey::Unique { id: u2 }) => u1.cmp(u2),
_ => discriminant_order(a).cmp(&discriminant_order(b)),
}
}
fn discriminant_order(key: &HashKey) -> u8 {
match key {
HashKey::Const { .. } => 0,
HashKey::VarVN { .. } => 1,
HashKey::Name { .. } => 2,
HashKey::BinOp { .. } => 3,
HashKey::UnaryOp { .. } => 4,
HashKey::BoolOp { .. } => 5,
HashKey::Compare { .. } => 6,
HashKey::Call { .. } => 7,
HashKey::Attribute { .. } => 8,
HashKey::Subscript { .. } => 9,
HashKey::Unique { .. } => 10,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_commutative_add() {
assert!(is_commutative("Add"));
}
#[test]
fn test_is_commutative_mult() {
assert!(is_commutative("Mult"));
}
#[test]
fn test_is_commutative_bitor() {
assert!(is_commutative("BitOr"));
}
#[test]
fn test_is_commutative_bitand() {
assert!(is_commutative("BitAnd"));
}
#[test]
fn test_is_commutative_bitxor() {
assert!(is_commutative("BitXor"));
}
#[test]
fn test_is_not_commutative_sub() {
assert!(!is_commutative("Sub"));
}
#[test]
fn test_is_not_commutative_div() {
assert!(!is_commutative("Div"));
}
#[test]
fn test_is_not_commutative_mod() {
assert!(!is_commutative("Mod"));
}
#[test]
fn test_is_not_commutative_pow() {
assert!(!is_commutative("Pow"));
}
#[test]
fn test_is_not_commutative_lshift() {
assert!(!is_commutative("LShift"));
}
#[test]
fn test_is_not_commutative_rshift() {
assert!(!is_commutative("RShift"));
}
#[test]
fn test_is_commutative_raw_plus() {
assert!(is_commutative("+"));
}
#[test]
fn test_is_commutative_raw_star() {
assert!(is_commutative("*"));
}
#[test]
fn test_is_commutative_raw_eq() {
assert!(is_commutative("=="));
}
#[test]
fn test_is_commutative_raw_neq() {
assert!(is_commutative("!="));
}
#[test]
fn test_is_commutative_raw_pipe() {
assert!(is_commutative("|"));
}
#[test]
fn test_is_commutative_raw_ampersand() {
assert!(is_commutative("&"));
}
#[test]
fn test_is_commutative_raw_caret() {
assert!(is_commutative("^"));
}
#[test]
fn test_is_commutative_c_style_and() {
assert!(is_commutative("&&"));
}
#[test]
fn test_is_commutative_c_style_or() {
assert!(is_commutative("||"));
}
#[test]
fn test_is_not_commutative_raw_minus() {
assert!(!is_commutative("-"));
}
#[test]
fn test_is_not_commutative_raw_slash() {
assert!(!is_commutative("/"));
}
#[test]
fn test_is_not_commutative_raw_percent() {
assert!(!is_commutative("%"));
}
#[test]
fn test_hash_key_commutative_normalization_add() {
let x = HashKey::Name {
name: "x".to_string(),
};
let y = HashKey::Name {
name: "y".to_string(),
};
let xy = normalize_binop("Add", x.clone(), y.clone());
let yx = normalize_binop("Add", y.clone(), x.clone());
assert_eq!(xy, yx, "x + y and y + x should be equal");
}
#[test]
fn test_hash_key_commutative_normalization_mult() {
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let ab = normalize_binop("Mult", a.clone(), b.clone());
let ba = normalize_binop("Mult", b.clone(), a.clone());
assert_eq!(ab, ba, "a * b and b * a should be equal");
}
#[test]
fn test_hash_key_commutative_normalization_bitor() {
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let ab = normalize_binop("BitOr", a.clone(), b.clone());
let ba = normalize_binop("BitOr", b.clone(), a.clone());
assert_eq!(ab, ba, "a | b and b | a should be equal");
}
#[test]
fn test_hash_key_commutative_normalization_bitand() {
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let ab = normalize_binop("BitAnd", a.clone(), b.clone());
let ba = normalize_binop("BitAnd", b.clone(), a.clone());
assert_eq!(ab, ba, "a & b and b & a should be equal");
}
#[test]
fn test_hash_key_commutative_normalization_bitxor() {
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let ab = normalize_binop("BitXor", a.clone(), b.clone());
let ba = normalize_binop("BitXor", b.clone(), a.clone());
assert_eq!(ab, ba, "a ^ b and b ^ a should be equal");
}
#[test]
fn test_hash_key_non_commutative_order_preserved_sub() {
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let ab = normalize_binop("Sub", a.clone(), b.clone());
let ba = normalize_binop("Sub", b.clone(), a.clone());
assert_ne!(ab, ba, "a - b and b - a should be different");
}
#[test]
fn test_hash_key_non_commutative_order_preserved_div() {
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let ab = normalize_binop("Div", a.clone(), b.clone());
let ba = normalize_binop("Div", b.clone(), a.clone());
assert_ne!(ab, ba, "a / b and b / a should be different");
}
#[test]
fn test_hash_key_const_equality() {
let c1 = HashKey::Const {
type_name: "int".to_string(),
repr: "42".to_string(),
};
let c2 = HashKey::Const {
type_name: "int".to_string(),
repr: "42".to_string(),
};
let c3 = HashKey::Const {
type_name: "int".to_string(),
repr: "43".to_string(),
};
assert_eq!(c1, c2);
assert_ne!(c1, c3);
}
#[test]
fn test_hash_key_var_vn_equality() {
let v1 = HashKey::VarVN { vn: 1 };
let v2 = HashKey::VarVN { vn: 1 };
let v3 = HashKey::VarVN { vn: 2 };
assert_eq!(v1, v2);
assert_ne!(v1, v3);
}
#[test]
fn test_hash_key_name_equality() {
let n1 = HashKey::Name {
name: "x".to_string(),
};
let n2 = HashKey::Name {
name: "x".to_string(),
};
let n3 = HashKey::Name {
name: "y".to_string(),
};
assert_eq!(n1, n2);
assert_ne!(n1, n3);
}
#[test]
fn test_hash_key_call_always_different() {
let c1 = HashKey::Call { unique_id: 1 };
let c2 = HashKey::Call { unique_id: 2 };
let c3 = HashKey::Call { unique_id: 1 };
assert_ne!(c1, c2);
assert_eq!(c1, c3);
}
#[test]
fn test_hash_key_unique_always_different() {
let u1 = HashKey::Unique { id: 1 };
let u2 = HashKey::Unique { id: 2 };
assert_ne!(u1, u2);
}
#[test]
fn test_nested_commutative_normalization() {
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let c = HashKey::Name {
name: "c".to_string(),
};
let ab = normalize_binop("Add", a.clone(), b.clone());
let ba = normalize_binop("Add", b.clone(), a.clone());
let abc = normalize_binop("Add", ab.clone(), c.clone());
let bac = normalize_binop("Add", ba.clone(), c.clone());
assert_eq!(abc, bac, "Nested commutative expressions should normalize");
}
#[test]
fn test_different_operators_not_equal() {
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let add = normalize_binop("Add", a.clone(), b.clone());
let sub = normalize_binop("Sub", a.clone(), b.clone());
let mult = normalize_binop("Mult", a.clone(), b.clone());
assert_ne!(add, sub);
assert_ne!(add, mult);
assert_ne!(sub, mult);
}
#[test]
fn test_hash_key_as_hashmap_key() {
use std::collections::HashMap;
let mut map: HashMap<HashKey, usize> = HashMap::new();
let key1 = HashKey::Name {
name: "x".to_string(),
};
let key2 = HashKey::Name {
name: "x".to_string(),
};
map.insert(key1, 1);
assert_eq!(map.get(&key2), Some(&1));
}
#[test]
fn test_commutative_keys_hash_same() {
use std::collections::HashMap;
let mut map: HashMap<HashKey, usize> = HashMap::new();
let a = HashKey::Name {
name: "a".to_string(),
};
let b = HashKey::Name {
name: "b".to_string(),
};
let ab = normalize_binop("Add", a.clone(), b.clone());
let ba = normalize_binop("Add", b.clone(), a.clone());
map.insert(ab, 1);
assert_eq!(map.get(&ba), Some(&1));
}
}