#[allow(unused_imports)]
use log::{debug, trace};
use std::fmt::Debug;
use num;
use rand::distr::{Distribution, Uniform};
use rand::prelude::*;
use rand_xoshiro::Xoshiro256PlusPlus;
use sha2::{Digest, Sha512_256};
use indexmap::IndexMap;
use std::collections::HashMap;
use super::sig::Sig;
use crate::exp01::*;
use crate::maxvaluetrack::*;
pub struct ProbMinHash3aSha<D>
where
D: Clone + Eq + Debug + Sig,
{
m: usize,
maxvaluetracker: MaxValueTracker<f64>,
exp01: ExpRestricted01,
to_be_processed: Vec<(D, f64, Xoshiro256PlusPlus)>,
signature: Vec<D>,
}
impl<D> ProbMinHash3aSha<D>
where
D: Clone + Eq + Debug + Sig,
{
pub fn new(nbhash: usize, initobj: D) -> Self {
assert!(nbhash >= 2);
let lambda = ((nbhash as f64) / ((nbhash - 1) as f64)).ln();
let h_signature = (0..nbhash).map(|_| initobj.clone()).collect();
ProbMinHash3aSha {
m: nbhash,
maxvaluetracker: MaxValueTracker::new(nbhash),
exp01: ExpRestricted01::new(lambda),
to_be_processed: Vec::<(D, f64, Xoshiro256PlusPlus)>::new(),
signature: h_signature,
}
}
pub fn hash_weigthed_idxmap<Hidx, F>(&mut self, data: &IndexMap<D, F, Hidx>)
where
Hidx: std::hash::BuildHasher,
F: num::ToPrimitive + std::fmt::Display,
{
let unif0m = Uniform::<usize>::new(0, self.m).unwrap();
let mut qmax: f64 = self.maxvaluetracker.get_max_value();
let iter = data.iter();
for (key, weight_t) in iter {
trace!("hash_item : id {:?} weight {} ", key, weight_t);
let weight = weight_t.to_f64().unwrap();
assert!(
weight.is_finite() && weight >= 0.,
"conversion to f64 failed"
);
let winv = 1. / weight;
let mut hasher = Sha512_256::new();
hasher.update(key.get_sig());
let new_hash = hasher.finalize();
let hashed_slice = new_hash.as_slice();
assert_eq!(hashed_slice.len(), 32);
let mut seed: [u8; 32] = [0; 32];
seed.copy_from_slice(&hashed_slice[..32]);
let mut rng = Xoshiro256PlusPlus::from_seed(seed);
let h = winv * self.exp01.sample(&mut rng);
qmax = self.maxvaluetracker.get_max_value();
if h < qmax {
let k = unif0m.sample(&mut rng);
assert!(k < self.m);
if h < self.maxvaluetracker.get_value(k) {
self.signature[k] = key.clone();
self.maxvaluetracker.update(k, h);
qmax = self.maxvaluetracker.get_max_value();
}
if winv < qmax {
self.to_be_processed.push((key.clone(), winv, rng));
}
} } let mut i = 2; while !self.to_be_processed.is_empty() {
let mut insert_pos = 0;
trace!(
" i : {:?} , nb to process : {}",
i,
self.to_be_processed.len()
);
for j in 0..self.to_be_processed.len() {
let (key, winv, rng) = &mut self.to_be_processed[j];
let mut h = (*winv) * (i - 1) as f64;
if h < self.maxvaluetracker.get_max_value() {
h += (*winv) * self.exp01.sample(rng);
let k = unif0m.sample(rng);
if h < self.maxvaluetracker.get_value(k) {
self.signature[k] = key.clone();
self.maxvaluetracker.update(k, h);
qmax = self.maxvaluetracker.get_max_value();
}
if (*winv) * (i as f64) < qmax {
self.to_be_processed[insert_pos] = (key.clone(), *winv, rng.clone());
insert_pos += 1;
}
}
} self.to_be_processed.truncate(insert_pos);
i += 1;
} }
pub fn hash_weigthed_hashmap<Hidx, F>(&mut self, data: &HashMap<D, F, Hidx>)
where
Hidx: std::hash::BuildHasher,
F: num::ToPrimitive + std::fmt::Display,
{
let unif0m = Uniform::<usize>::new(0, self.m).unwrap();
let mut qmax: f64 = self.maxvaluetracker.get_max_value();
let iter = data.iter();
for (key, weight_t) in iter {
trace!("hash_item : id {:?} weight {} ", key, weight_t);
let weight = weight_t.to_f64().unwrap();
assert!(
weight.is_finite() && weight >= 0.,
"conversion to f64 failed"
);
let winv = 1. / weight;
let mut hasher = Sha512_256::new();
hasher.update(key.get_sig());
let new_hash = hasher.finalize();
let hashed_slice = new_hash.as_slice();
assert_eq!(hashed_slice.len(), 32);
let mut seed: [u8; 32] = [0; 32];
seed.copy_from_slice(&hashed_slice[..32]);
let mut rng = Xoshiro256PlusPlus::from_seed(seed);
let h = winv * self.exp01.sample(&mut rng);
qmax = self.maxvaluetracker.get_max_value();
if h < qmax {
let k = unif0m.sample(&mut rng);
assert!(k < self.m);
if h < self.maxvaluetracker.get_value(k) {
self.signature[k] = key.clone();
self.maxvaluetracker.update(k, h);
qmax = self.maxvaluetracker.get_max_value();
}
if winv < qmax {
self.to_be_processed.push((key.clone(), winv, rng));
}
} } let mut i = 2; while !self.to_be_processed.is_empty() {
let mut insert_pos = 0;
trace!(
" i : {:?} , nb to process : {}",
i,
self.to_be_processed.len()
);
for j in 0..self.to_be_processed.len() {
let (key, winv, rng) = &mut self.to_be_processed[j];
let mut h = (*winv) * (i - 1) as f64;
if h < self.maxvaluetracker.get_max_value() {
h += (*winv) * self.exp01.sample(rng);
let k = unif0m.sample(rng);
if h < self.maxvaluetracker.get_value(k) {
self.signature[k] = key.clone();
self.maxvaluetracker.update(k, h);
qmax = self.maxvaluetracker.get_max_value();
}
if (*winv) * (i as f64) < qmax {
self.to_be_processed[insert_pos] = (key.clone(), *winv, rng.clone());
insert_pos += 1;
}
}
} self.to_be_processed.truncate(insert_pos);
i += 1;
} }
pub fn get_signature(&self) -> &Vec<D> {
&self.signature
}
}
#[cfg(test)]
mod tests {
use log::*;
use fnv::FnvBuildHasher;
use indexmap::IndexMap;
type FnvIndexMap<K, V> = IndexMap<K, V, FnvBuildHasher>;
fn log_init_test() {
let _ = env_logger::builder().is_test(true).try_init();
}
use crate::jaccard::*;
use super::*;
fn generate_slices(nb_slices: usize, length: usize) -> Vec<Vec<u8>> {
let mut rng = Xoshiro256PlusPlus::seed_from_u64(237);
let unif = Uniform::<u8>::new_inclusive(0, 255).unwrap();
let mut slices = Vec::<Vec<u8>>::with_capacity(nb_slices);
for _ in 0..nb_slices {
let mut slice = Vec::<u8>::with_capacity(length);
for _ in 0..length {
slice.push(unif.sample(&mut rng));
}
slices.push(slice);
}
return slices;
}
#[test]
fn test_probminhash3asha_count_intersection_unequal_weights() {
log_init_test();
println!("test_probminhash3a_count_intersection_unequal_weights");
debug!("test_probminhash3a_count_intersection_unequal_weights");
let set_size = 100;
let nbhash = 2000;
let objects = generate_slices(set_size, 256);
let mut wa: FnvIndexMap<Vec<u8>, f64> =
FnvIndexMap::with_capacity_and_hasher(70, FnvBuildHasher::default());
for i in 0..set_size {
if i < 70 {
*wa.entry(objects[i].clone()).or_insert(0.) += 2. * i as f64;
}
}
let mut wb: FnvIndexMap<Vec<u8>, f64> =
FnvIndexMap::with_capacity_and_hasher(70, FnvBuildHasher::default());
for i in 0..set_size {
if i >= 50 {
wb.entry(objects[i].clone()).or_insert((i as f64).powi(4)); }
}
trace!("\n\n hashing wa");
let mut waprobhash = ProbMinHash3aSha::<Vec<u8>>::new(nbhash, [0u8; 256].to_vec());
waprobhash.hash_weigthed_idxmap(&wa);
trace!("\n\n hashing wb");
let mut wbprobhash = ProbMinHash3aSha::<Vec<u8>>::new(nbhash, [0u8; 256].to_vec());
wbprobhash.hash_weigthed_idxmap(&wb);
let siga = waprobhash.get_signature();
let sigb = wbprobhash.get_signature();
let jp_approx = compute_probminhash_jaccard(siga, sigb);
let mut jp = 0.;
for i in 0..set_size {
let wa_i = *wa.get(&objects[i]).unwrap_or(&0.);
let wb_i = *wb.get(&objects[i]).unwrap_or(&0.);
if wa_i > 0. && wb_i > 0. {
let mut den = 0.;
for j in 0..set_size {
let wa_j = *wa.get(&objects[j]).unwrap_or(&0.);
let wb_j = *wb.get(&objects[j]).unwrap_or(&0.);
den += (wa_j / wa_i).max(wb_j / wb_i);
}
jp += 1. / den;
}
}
debug!("Jp = {} ", jp);
info!(
"jp exact= {jptheo:.3} , jp estimate = {jp_est:.3} ",
jptheo = jp,
jp_est = jp_approx
);
assert!(jp_approx > 0.);
} }