use fnv;
use super::ast::ExprKind::*;
use super::ast::LiteralKind::*;
use super::ast::*;
use crate::error::*;
use std::collections::hash_map::Entry;
use std::hash::{Hash, Hasher};
use std::fmt;
pub trait HashIgnoringSymbols {
fn hash_ignoring_symbols(&self) -> WeldResult<u64>;
}
impl HashIgnoringSymbols for Expr {
fn hash_ignoring_symbols(&self) -> WeldResult<u64> {
Ok(ExprHash::from(self)?.value())
}
}
struct ExprHash {
hasher: fnv::FnvHasher,
}
impl fmt::Debug for ExprHash {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ExprHash({})", self.hasher.finish())
}
}
impl ExprHash {
fn hash_expr<'a>(
&mut self,
expr: &'a Expr,
symbol_positions: &mut fnv::FnvHashMap<&'a Symbol, Vec<i32>>,
max_id: &mut i32,
) -> WeldResult<()> {
let mut finished_subexpressions = false;
expr.ty.hash(&mut self.hasher);
expr.kind.name().hash(&mut self.hasher);
match expr.kind {
Literal(ref kind) => match *kind {
BoolLiteral(v) => v.hash(&mut self.hasher),
I8Literal(v) => v.hash(&mut self.hasher),
I16Literal(v) => v.hash(&mut self.hasher),
I32Literal(v) => v.hash(&mut self.hasher),
I64Literal(v) => v.hash(&mut self.hasher),
U8Literal(v) => v.hash(&mut self.hasher),
U16Literal(v) => v.hash(&mut self.hasher),
U32Literal(v) => v.hash(&mut self.hasher),
U64Literal(v) => v.hash(&mut self.hasher),
F32Literal(v) => v.hash(&mut self.hasher),
F64Literal(v) => v.hash(&mut self.hasher),
StringLiteral(ref v) => v.hash(&mut self.hasher),
},
Ident(ref sym) => {
match symbol_positions.entry(sym) {
Entry::Occupied(ref ent) => {
ent.get().hash(&mut self.hasher);
}
_ => {
return compile_err!("Undefined symbol {}", sym);
}
}
}
BinOp { ref kind, .. } => {
kind.hash(&mut self.hasher);
}
UnaryOp { ref kind, .. } => {
kind.hash(&mut self.hasher);
}
Cast { ref kind, .. } => {
kind.hash(&mut self.hasher);
}
GetField { ref index, .. } => {
index.hash(&mut self.hasher);
}
Let {
ref name,
ref value,
ref body,
} => {
self.hash_expr(value, symbol_positions, max_id)?;
{
let entry = symbol_positions.entry(name).or_insert_with(Vec::new);
entry.push(*max_id);
*max_id += 1;
} self.hash_expr(body, symbol_positions, max_id)?;
let entry = symbol_positions.entry(name).or_insert_with(Vec::new);
let _ = entry.pop();
finished_subexpressions = true;
}
Lambda {
ref params,
ref body,
} => {
for param in params.iter() {
let entry = symbol_positions.entry(¶m.name).or_insert_with(Vec::new);
entry.push(*max_id);
*max_id += 1;
}
self.hash_expr(body, symbol_positions, max_id)?;
for param in params.iter() {
let entry = symbol_positions.entry(¶m.name).or_insert_with(Vec::new);
entry.pop();
}
finished_subexpressions = true;
}
CUDF {
ref sym_name,
ref return_ty,
..
} => {
sym_name.hash(&mut self.hasher);
return_ty.hash(&mut self.hasher);
}
Deserialize { ref value_ty, .. } => {
value_ty.hash(&mut self.hasher);
}
For { ref iters, .. } => {
for iter in iters.iter() {
iter.kind.hash(&mut self.hasher);
}
}
Negate(_)
| Not(_)
| Assert(_)
| Broadcast(_)
| Serialize(_)
| ToVec { .. }
| MakeStruct { .. }
| MakeVector { .. }
| Zip { .. }
| Length { .. }
| Lookup { .. }
| OptLookup { .. }
| KeyExists { .. }
| Slice { .. }
| Sort { .. }
| If { .. }
| Iterate { .. }
| Select { .. }
| Apply { .. }
| NewBuilder(_)
| Merge { .. }
| Res { .. } => {}
}
if !finished_subexpressions {
for child in expr.children() {
self.hash_expr(child, symbol_positions, max_id)?;
}
}
Ok(())
}
pub fn value(&self) -> u64 {
self.hasher.finish()
}
pub fn from(expr: &Expr) -> WeldResult<ExprHash> {
let mut sig = ExprHash {
hasher: fnv::FnvHasher::default(),
};
let mut symbol_positions = fnv::FnvHashMap::default();
let mut max_id = 0;
sig.hash_expr(expr, &mut symbol_positions, &mut max_id)?;
Ok(sig)
}
}
impl PartialEq for ExprHash {
fn eq(&self, other: &ExprHash) -> bool {
self.value() == other.value()
}
}
#[cfg(test)]
use crate::syntax::parser::*;
#[test]
fn test_compare_same() {
let a = &parse_expr("|| let a = 1; let b = 1; a").unwrap();
let b = &parse_expr("|| let a = 1; let b = 1; a").unwrap();
assert_eq!(ExprHash::from(a).unwrap(), ExprHash::from(b).unwrap());
}
#[test]
fn test_compare_different_symbols() {
let a = &parse_expr("|| let a = 1; let b = 1; a").unwrap();
let b = &parse_expr("|| let c = 1; let d = 1; c").unwrap();
assert_eq!(ExprHash::from(a).unwrap(), ExprHash::from(b).unwrap());
}
#[test]
fn test_compare_different_symbols_ne() {
let a = &parse_expr("|| let a = 1; let b = 1; a").unwrap();
let b = &parse_expr("|| let c = 1; let d = 1; d").unwrap();
assert!(ExprHash::from(a).unwrap() != ExprHash::from(b).unwrap());
}
#[test]
fn test_lambda() {
let a = &parse_expr("|a: i32| let a = 1; let b = 1; a").unwrap();
let b = &parse_expr("|a: i32| let a = 1; let c = 1; a").unwrap();
assert_eq!(ExprHash::from(a).unwrap(), ExprHash::from(b).unwrap());
}