use ff::Field;
pub struct MultisetHash<F>(pub F);
impl<F: Field> MultisetHash<F> {
pub fn new() -> Self {
MultisetHash(F::one())
}
pub fn add(&mut self, elem: F, count: u64) -> Self {
let term = elem.pow_vartime([count]);
MultisetHash(self.0 * term)
}
pub fn remove(&mut self, elem: F, count: u64) -> Self {
let inv_term = elem.pow_vartime([count]).invert();
if bool::from(inv_term.is_none()) {
panic!("elements must be nonzero");
}
MultisetHash(self.0 * inv_term.unwrap())
}
pub fn multiset_union(&self, other: &Self) -> Self {
MultisetHash(self.0 * other.0)
}
pub fn multiset_difference(&self, other: &Self) -> Self {
let inv = other.0.invert();
if bool::from(inv.is_none()) {
panic!("multiset hash of `other` must be nonzero");
}
MultisetHash(self.0 * inv.unwrap())
}
}
impl<F: Field> From<F> for MultisetHash<F> {
fn from(f: F) -> Self {
MultisetHash(f)
}
}
#[cfg(test)]
mod tests {
use bls12_381::Scalar;
use super::*;
#[test]
fn test_single_ops() {
let mut mh = MultisetHash::<Scalar>::new();
assert_eq!(mh.0, Scalar::one());
mh = mh.add(2.into(), 1);
mh = mh.remove(2.into(), 1);
assert_eq!(mh.0, Scalar::one());
mh = mh.add(5.into(), 4);
for _ in 0..4 {
mh = mh.remove(5.into(), 1);
}
assert_eq!(mh.0, Scalar::one());
for _ in 0..27 {
mh = mh.add(3.into(), 1);
}
mh = mh.remove(3.into(), 27);
assert_eq!(mh.0, Scalar::one());
}
#[test]
fn test_union() {
let a: Vec<(Scalar, u64)> = vec![(2.into(), 1), (10.into(), 4), (4.into(), 1), (7.into(), 3), (3.into(), 7)];
let b: Vec<(Scalar, u64)> = vec![(2.into(), 4), (6.into(), 1), (4.into(), 1), (7.into(), 7), (3.into(), 7)];
let mut left = MultisetHash::new();
for &(elem, count) in a.iter() {
left = left.add(elem, count);
}
let mut right = MultisetHash::new();
for &(elem, count) in b.iter() {
right = right.add(elem, count);
}
let u = left.multiset_union(&right);
let mut check = MultisetHash::new();
for &(elem, count) in a.iter() {
check = check.add(elem, count);
}
for &(elem, count) in b.iter() {
check = check.add(elem, count);
}
assert_eq!(u.0, check.0);
}
#[test]
fn test_difference() {
let a: Vec<(Scalar, u64)> = vec![(50.into(), 1), (10.into(), 4), (4.into(), 1), (7.into(), 3), (3.into(), 7)];
let b: Vec<(Scalar, u64)> = vec![(2.into(), 4), (6.into(), 1), (4.into(), 1), (7.into(), 7), (3.into(), 7)];
let mut left = MultisetHash::new();
let mut right = MultisetHash::new();
for &(elem, count) in a.iter() {
left = left.add(elem, count);
}
for &(elem, count) in b.iter() {
right = right.add(elem, count);
}
let intersection = left.multiset_difference(&right);
let mut check = MultisetHash::new();
check = check.add(50.into(), 1);
check = check.add(10.into(), 4);
check = check.remove(7.into(), 4);
check = check.remove(2.into(), 4);
check = check.remove(6.into(), 1);
assert_eq!(intersection.0, check.0);
}
}