use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::rc::Rc;
use num::{Integer, ToPrimitive};
use rug::{Float, float::Round, ops::AddAssignRound, ops::DivAssignRound};
use crate::core::{Measurement, Function, PrivacyRelation};
use crate::dist::{L1Distance, MaxDivergence};
use crate::dom::{AllDomain, MapDomain};
use crate::error::Fallible;
use crate::interactive::Queryable;
use crate::traits::{DistanceConstant, CheckNull, InfCast, CastInternalReal};
use crate::samplers::{fill_bytes, SampleBernoulli};
use std::collections::hash_map::DefaultHasher;
const ALPHA_DEFAULT : u32 = 4;
const SIZE_FACTOR_DEFAULT : u32 = 50;
type SparseDomain<K, C> = MapDomain<AllDomain<K>, AllDomain<C>>;
type BitVector = Vec<bool>;
type HashFunctions<K> = Vec<Rc<dyn Fn(&K) -> usize>>;
#[derive(Clone)]
pub struct AlpState<K, T>{
alpha: T,
scale: T,
h: HashFunctions<K>,
z: BitVector
}
impl <K,T> CheckNull for AlpState<K,T> {
fn is_null(&self) -> bool { false }
}
type AlpDomain<K, T> = AllDomain<AlpState<K, T>>;
fn hash(x: u64, a: u64, b:u64, l: u32) -> usize {
(a.wrapping_mul(x).wrapping_add(b) >> (64 - l)) as usize
}
fn pre_hash<K: Hash>(x: K) -> u64 {
let mut hasher = DefaultHasher::new();
x.hash(&mut hasher);
hasher.finish()
}
fn sample_hash_function<K>(l: u32) -> Fallible<Rc<dyn Fn(&K) -> usize>>
where K: Clone + Hash {
let mut buf = [0u8; 8];
fill_bytes(&mut buf)?;
let a = u64::from_ne_bytes(buf) | 1u64;
fill_bytes(&mut buf)?;
let b = u64::from_ne_bytes(buf);
Ok(Rc::new(move |x: &K| hash(pre_hash(x), a, b, l)))
}
fn exponent_next_power_of_two(x: u64) -> u32 {
let exp = 63 - x.leading_zeros();
if x > (1 << exp) { exp + 1 } else { exp }
}
fn scale_and_round<C, T>(x : C, alpha: T, scale: T) -> Fallible<usize>
where C: Integer + ToPrimitive,
T: CastInternalReal {
let mut scalar = scale.into_internal();
scalar.div_assign_round(alpha.into_internal(), Round::Down);
scalar.set_prec_round((f64::MANTISSA_DIGITS as i32 - scalar.get_exp().unwrap()).max(1) as u32, Round::Down);
let r = Float::with_val(f64::MANTISSA_DIGITS * 2, x.max(C::zero()).to_u64().unwrap_or_default()) * scalar;
let floored = f64::from_internal(r.clone().floor()) as usize;
match bool::sample_bernoulli(f64::from_internal(r.fract()), false)? {
true => Ok(floored + 1),
false => Ok(floored)
}
}
fn compute_prob<T: CastInternalReal>(alpha: T) -> f64 {
let mut a = alpha.into_internal();
a.add_assign_round(2, Round::Down);
let mut p = 1f64.into_internal();
p.div_assign_round( a, Round::Up); f64::from_internal(p)
}
fn check_parameters<T : CastInternalReal>(alpha: T, scale: T) -> bool {
scale.into_internal() * Float::with_val(53, 52).exp2() < alpha.into_internal()
}
fn compute_projection<K, C, T>(x: &HashMap<K, C>, h: &HashFunctions<K>, alpha: T, scale: T, s: usize) -> Fallible<BitVector>
where C: Clone + Integer + ToPrimitive,
T: Clone + CastInternalReal {
let mut z = vec![false; s];
for (k, v) in x.iter() {
let round = scale_and_round(v.clone(), alpha.clone(), scale.clone())?;
h.iter().take(round).for_each(|f| z[f(k) % s] = true); }
let p = compute_prob(alpha);
z.iter().map(|b| bool::sample_bernoulli(p , false).map(|flip| b ^ flip)).collect()
}
fn estimate_unary<T>(v: &Vec<bool>) -> T
where T : num::Float {
let mut prefix_sum = Vec::with_capacity(v.len() + 1usize);
prefix_sum.push(0);
v.iter().map(|b| if *b {1} else {-1}).for_each(|x| prefix_sum.push(prefix_sum.last().unwrap() + x));
let high = prefix_sum.iter().max().unwrap();
let peaks = prefix_sum.iter().enumerate()
.filter_map(|(idx, height)| if high == height {Some(idx)} else {None}).collect::<Vec<_>>();
T::from(peaks.iter().sum::<usize>()).unwrap() / T::from(peaks.len()).unwrap()
}
fn compute_estimate<K, T>(state: &AlpState<K, T>, key: &K) -> T
where T: num::Float {
let v = state.h.iter().map(|f| state.z[f(key) % state.z.len()]).collect::<Vec<_>>();
estimate_unary::<T>(&v) * T::from(state.alpha).unwrap() / state.scale
}
pub fn make_base_alp_with_hashers<K, C, T>(alpha: T, scale: T, s: usize, h: HashFunctions<K>)
-> Fallible<Measurement<SparseDomain<K, C>,
AlpDomain<K, T>,
L1Distance<C>, MaxDivergence<T>>>
where K: 'static + Eq + Hash + CheckNull,
C: 'static + Clone + Integer + CheckNull + DistanceConstant<C> + InfCast<T> + ToPrimitive,
T: 'static + num::Float + DistanceConstant<T> + CastInternalReal + InfCast<C>,
AlpState<K,T> : CheckNull {
if alpha.is_sign_negative() || alpha.is_zero() {
return fallible!(MakeMeasurement, "alpha must be positive")
}
if scale.is_sign_negative() || scale.is_zero() {
return fallible!(MakeMeasurement, "scale must be positive")
}
if s == 0 {
return fallible!(MakeMeasurement, "s can not be zero")
}
if check_parameters(alpha, scale) {
return fallible!(MakeMeasurement, "scale divided by alpha must be above 2^-52")
}
Ok(Measurement::new(
MapDomain { key_domain: AllDomain::new(), value_domain: AllDomain::new()},
AllDomain::new(),
Function::new_fallible(move |x: &HashMap<K, C>| {
let z = compute_projection(x, &h, alpha, scale, s)?;
Ok(AlpState { alpha, scale, h:h.clone(), z })
}),
L1Distance::default(),
MaxDivergence::default(),
PrivacyRelation::new_from_constant(scale)
))
}
pub fn make_base_alp<K, C, T>(total: usize, size_factor: Option<u32>, alpha: Option<T>, scale: T, beta: C)
-> Fallible<Measurement<SparseDomain<K, C>,
AlpDomain<K, T>,
L1Distance<C>, MaxDivergence<T>>>
where K: 'static + Eq + Hash + Clone + CheckNull,
C: 'static + Clone + Integer + CheckNull + DistanceConstant<C> + InfCast<T> + ToPrimitive,
T: 'static + num::Float + DistanceConstant<T> + CastInternalReal + InfCast<C>,
AlpState<K,T> : CheckNull {
let factor = size_factor.unwrap_or(SIZE_FACTOR_DEFAULT) as f64;
let alpha = alpha.unwrap_or(T::from(ALPHA_DEFAULT).unwrap());
let beta: f64 = T::inf_cast(beta)?.to_f64()
.ok_or_else(|| err!(MakeTransformation, "failed to parse beta"))?;
let quotient = (scale / alpha).to_f64()
.ok_or_else(|| err!(MakeTransformation, "failed to parse scale/alpha"))?;
let m = (beta * quotient).ceil() as usize;
let exp = exponent_next_power_of_two((factor * total as f64 * quotient) as u64);
let h = (0..m).map(|_| sample_hash_function(exp)).collect::<Fallible<HashFunctions<K>>>()?;
make_base_alp_with_hashers(alpha, scale, 1 << exp, h)
}
pub fn post_process<K, T>(state: AlpState<K, T>) -> Queryable<AlpState<K, T>, K, T>
where T: num::Float {
Queryable::new(
state,
move |state: AlpState<K, T>, key: &K| {
let estimate = compute_estimate(&state, key);
Ok((state, estimate))
})
}
pub fn make_alp_histogram_post_process<K, C, T>(
m: &Measurement<SparseDomain<K, C>, AlpDomain<K, T>, L1Distance<C>, MaxDivergence<T>>
) -> Fallible<Measurement<SparseDomain<K, C>, AllDomain<Queryable<AlpState<K, T>, K, T>>, L1Distance<C>, MaxDivergence<T>>>
where K: 'static + Eq + Hash + CheckNull,
C: 'static + Clone + CheckNull,
T: 'static + num::Float,
HashMap<K,C>: Clone,
AlpState<K,T>: Clone {
let function = m.function.clone();
Ok(Measurement::new(
m.input_domain.clone(),
AllDomain::new(),
Function::new_fallible(move |x| function.eval(x).map(post_process)),
m.input_metric.clone(),
m.output_measure.clone(),
m.privacy_relation.clone()))
}
#[cfg(test)]
mod tests {
use super::*;
fn idx<T>(i: usize) -> Rc<dyn Fn(&T) -> usize> {
Rc::new(move |_| i)
}
fn index_identify_functions<T> (n: usize) -> HashFunctions<T> {
(0..n).map(|i| {
idx(i)
}).collect::<HashFunctions<T>>()
}
#[test]
fn test_exponent_next_power_of_two() -> Fallible<()> {
assert_eq!(exponent_next_power_of_two(1 as u64), 0);
assert_eq!(exponent_next_power_of_two(2 as u64), 1);
assert_eq!(exponent_next_power_of_two(3 as u64), 2);
assert_eq!(exponent_next_power_of_two(7 as u64), 3);
Ok(())
}
#[test]
fn test_hash() -> Fallible<()> {
assert_eq!(hash(3, 4, 5, 64), 17);
assert_eq!(hash(3, 4, 5, 63), 8);
assert_eq!(hash(1, u64::MAX, 0, 2), 3);
assert_eq!(hash(1, u64::MAX, 0, 3), 7);
assert_eq!(hash(4, u64::MAX, 0, 16), (1 << 16) - 1);
Ok(())
}
#[test]
fn test_sample_hash() -> Fallible<()> {
let h = sample_hash_function(5)?;
for i in 0u64..20u64 {
assert!(h(&i) < (1 << 5));
}
Ok(())
}
#[test]
fn test_alp_construction() -> Fallible<()> {
let beta = 10;
let alp = make_base_alp_with_hashers::<u32, u32, f64>(1., 1.0, beta, index_identify_functions(beta))?;
assert!(alp.privacy_relation.eval(&1, &1.)?);
assert!(!alp.privacy_relation.eval(&1, &0.999)?);
let mut x = HashMap::new();
x.insert(42, 10);
alp.function.eval(&x.clone())?;
x.insert(42, 10000);
alp.function.eval(&x.clone())?;
Ok(())
}
#[test]
fn test_alp_construction_out_of_range() -> Fallible<()> {
let s = 5;
let h = index_identify_functions(20);
let alp = make_base_alp_with_hashers::<u32, u32, f64>( 1., 1.0, s, h)?;
let mut x = HashMap::new();
x.insert(42, 3);
alp.function.eval(&x.clone())?;
Ok(())
}
#[test]
fn test_estimate_unary() -> Fallible<()> {
let z1 = vec![true, true, true, false, true, false, false, true];
assert!(estimate_unary::<f64>(&z1) == 4.0);
let z2 = vec![true, false, false, false, true, false, false, true];
assert!(estimate_unary::<f64>(&z2) == 1.0);
let z3 = vec![false, true, true, false, false, true, false, true];
assert!(estimate_unary::<f64>(&z3) == 3.0);
Ok(())
}
#[test]
fn test_compute_estimate() -> Fallible<()> {
let z1 = vec![true, true, true, false, true, false, false, true];
assert!(compute_estimate(&AlpState {alpha:3., scale:1.0, h:index_identify_functions(8), z:z1}, &0) == 12.0);
let z2 = vec![true, false, false, false, true, false, false, true];
assert!(compute_estimate(&AlpState {alpha:1., scale:2.0, h:index_identify_functions(8), z:z2}, &0) == 0.5);
let z3 = vec![false, true, true, false, false, true, false, true];
assert!(compute_estimate(&AlpState {alpha:1., scale:0.5, h:index_identify_functions(8), z:z3}, &0) == 6.0);
Ok(())
}
#[test]
fn test_construct_and_post_process() -> Fallible<()> {
let mut x = HashMap::new();
x.insert(0, 7);
x.insert(42, 12);
x.insert(100, 5);
let alp = make_base_alp::<i32,i32,f64>(24, None, None, 2., 24)?;
let state = alp.function.eval(&x)?;
let mut query = post_process(state);
query.eval(&0)?;
query.eval(&42)?;
query.eval(&100)?;
query.eval(&1000)?;
Ok(())
}
#[test]
fn test_post_process_measurement() -> Fallible<()> {
let mut x = HashMap::new();
x.insert(0, 7);
x.insert(42, 12);
x.insert(100, 5);
let alp = make_base_alp::<i32,i32,f64>(24, None, None, 2., 24)?;
let wrapped = make_alp_histogram_post_process(&alp)?;
assert!(wrapped.privacy_relation.eval(&1, &2.)?);
assert!(!wrapped.privacy_relation.eval(&1, &1.999)?);
let mut query = wrapped.function.eval(&x)?;
query.eval(&0)?;
query.eval(&42)?;
query.eval(&100)?;
query.eval(&1000)?;
Ok(())
}
}