use log::trace;
use std::fmt::Debug;
use rand::distr::Distribution;
use rand::prelude::*;
use rand_distr::Exp1;
use rand_xoshiro::Xoshiro256PlusPlus;
use std::collections::HashMap;
use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher};
use crate::fyshuffle::*;
use crate::maxvaluetrack::*;
use crate::weightedset::*;
pub struct ProbMinHash2<D, H>
where
D: Copy + Eq + Hash + Debug,
H: Hasher + Default,
{
m: usize,
initobj: D,
b_hasher: BuildHasherDefault<H>,
maxvaluetracker: MaxValueTracker<f64>,
permut_generator: FYshuffle,
betas: Vec<f64>,
signature: Vec<D>,
}
impl<D, H> ProbMinHash2<D, H>
where
D: Copy + Eq + Hash + Debug,
H: Hasher + Default,
{
pub fn new(nbhash: usize, initobj: D) -> Self {
let h_signature = (0..nbhash).map(|_| initobj).collect();
let betas: Vec<f64> = (0..nbhash)
.map(|x| (nbhash as f64) / (nbhash - x - 1) as f64)
.collect();
ProbMinHash2 {
m: nbhash,
initobj,
b_hasher: BuildHasherDefault::<H>::default(),
maxvaluetracker: MaxValueTracker::new(nbhash),
permut_generator: FYshuffle::new(nbhash),
betas,
signature: h_signature,
}
}
pub fn hash_item(&mut self, id: D, weight: f64) {
assert!(weight > 0.);
trace!("hash_item : id {:?} weight {} ", id, weight);
let winv: f64 = 1. / weight;
let id_hash: u64 = self.b_hasher.hash_one(&id);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(id_hash);
self.permut_generator.reset();
let mut i = 0;
let x: f64 = Exp1.sample(&mut rng);
let mut h: f64 = winv * x;
let mut qmax = self.maxvaluetracker.get_max_value();
while h < qmax {
let k = self.permut_generator.next(&mut rng);
if h < self.maxvaluetracker.get_value(k) {
self.signature[k] = id;
self.maxvaluetracker.update(k, h);
qmax = self.maxvaluetracker.get_max_value();
if h >= qmax {
break;
}
}
let x: f64 = Exp1.sample(&mut rng);
h += winv * self.betas[i] * x;
i += 1;
assert!(i < self.m);
}
}
pub fn hash_wset<T>(&mut self, data: &mut T)
where
T: WeightedSet<Object = D> + Iterator<Item = D>,
{
while let Some(obj) = &data.next() {
let weight = data.get_weight(obj);
self.hash_item(*obj, weight);
}
}
pub fn hash_weigthed_hashmap<Hidx>(&mut self, data: &HashMap<D, f64>) {
let iter = data.iter();
for (key, weight) in iter {
trace!(" retrieved key {:?} ", key);
self.hash_item(*key, *weight);
}
}
pub fn get_signature(&self) -> &Vec<D> {
&self.signature
}
pub fn reset(&mut self) {
self.signature.fill(self.initobj);
self.maxvaluetracker.reset();
self.permut_generator.reset();
} }
#[cfg(test)]
mod tests {
use log::*;
use fnv::FnvHasher;
use crate::jaccard::*;
#[allow(dead_code)]
fn log_init_test() {
let _ = env_logger::builder().is_test(true).try_init();
}
use super::*;
#[test]
fn test_probminhash2_count_intersection_unequal_weights() {
log_init_test();
println!("test_probminhash2_count_intersection_unequal_weights");
debug!("test_probminhash2_count_intersection_unequal_weights");
let set_size = 100;
let nbhash = 50;
let mut wa = Vec::<f64>::with_capacity(set_size);
let mut wb = Vec::<f64>::with_capacity(set_size);
for i in 0..set_size {
if i < 70 {
wa.push(2. * i as f64);
} else {
wa.push(0.);
}
}
for i in 0..set_size {
if i < 50 {
wb.push(0.);
} else {
wb.push((i as f64).powi(4));
}
}
let mut jp_exact = 0.;
for i in 0..set_size {
if wa[i] > 0. && wb[i] > 0. {
let mut den = 0.;
for j in 0..set_size {
den += (wa[j] / wa[i]).max(wb[j] / wb[i]);
}
jp_exact += 1. / den;
}
}
trace!("Jp = {} ", jp_exact);
trace!("\n\n hashing wa");
let mut waprobhash = ProbMinHash2::<usize, FnvHasher>::new(nbhash, 0);
for i in 0..set_size {
if wa[i] > 0. {
waprobhash.hash_item(i, wa[i]);
}
}
trace!("\n\n hashing wb");
let mut wbprobhash = ProbMinHash2::<usize, FnvHasher>::new(nbhash, 0);
for i in 0..set_size {
if wb[i] > 0. {
wbprobhash.hash_item(i, wb[i]);
}
}
let siga = waprobhash.get_signature();
let sigb = wbprobhash.get_signature();
let jp_estimate = compute_probminhash_jaccard(siga, sigb);
info!(
"jp exact = {jp_exact:.3} , jp estimate {jp_estimate:.3} ",
jp_exact = jp_exact,
jp_estimate = jp_estimate
);
assert!(jp_estimate > 0.);
}
#[test]
fn test_probminhash2_count_intersection_equal_weights() {
log_init_test();
println!("test_probminhash2_count_intersection_equal_weights");
debug!("test_probminhash2_count_intersection_equal_weights");
let set_size = 100;
let nbhash = 50;
let mut wa = Vec::<f64>::with_capacity(set_size);
let mut wb = Vec::<f64>::with_capacity(set_size);
for i in 0..set_size {
if i < 70 {
wa.push(1.);
} else {
wa.push(0.);
}
}
for i in 0..set_size {
if i < 50 {
wb.push(0.);
} else {
wb.push(1.);
}
}
let mut jp_exact = 0.;
for i in 0..set_size {
if wa[i] > 0. && wb[i] > 0. {
let mut den = 0.;
for j in 0..set_size {
den += (wa[j] / wa[i]).max(wb[j] / wb[i]);
}
jp_exact += 1. / den;
}
}
trace!("Jp = {} ", jp_exact);
trace!("\n\n hashing wa");
let mut waprobhash = ProbMinHash2::<usize, FnvHasher>::new(nbhash, 0);
for i in 0..set_size {
if wa[i] > 0. {
waprobhash.hash_item(i, wa[i]);
}
}
trace!("\n\n hashing wb");
let mut wbprobhash = ProbMinHash2::<usize, FnvHasher>::new(nbhash, 0);
for i in 0..set_size {
if wb[i] > 0. {
wbprobhash.hash_item(i, wb[i]);
}
}
let siga = waprobhash.get_signature();
let sigb = wbprobhash.get_signature();
let jp_estimate = compute_probminhash_jaccard(siga, sigb);
info!(
"jp exact = {jp_exact:.3} , jp estimate {jp_estimate:.3} ",
jp_exact = jp_exact,
jp_estimate = jp_estimate
);
assert!(jp_estimate > 0.);
} }