use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use dashu::float::FBig;
use dashu::float::round::mode::{Down, Up, Zero};
use num::{ToPrimitive, Zero as _};
use opendp_derive::bootstrap;
use crate::core::{Function, Measurement, MetricSpace, PrivacyMap};
use crate::domains::{AtomDomain, MapDomain};
use crate::error::Fallible;
use crate::interactive::Queryable;
use crate::measures::MaxDivergence;
use crate::metrics::{AbsoluteDistance, L01InfDistance};
use crate::traits::samplers::{fill_bytes, sample_bernoulli_float};
use crate::traits::{Hashable, InfCast, InfMul, Integer, ToFloatRounded};
use std::collections::hash_map::DefaultHasher;
#[cfg(test)]
mod test;
#[cfg(feature = "ffi")]
mod ffi;
const ALPHA_DEFAULT: u32 = 4;
const SIZE_FACTOR_DEFAULT: u32 = 50;
type SparseDomain<K, C> = MapDomain<AtomDomain<K>, AtomDomain<C>>;
type BitVector = Vec<bool>;
type HashFunction<K> = Arc<dyn Fn(&K) -> usize + Send + Sync>;
#[derive(Clone)]
#[doc(hidden)]
pub struct AlpState<K> {
alpha: f64,
scale: f64,
hashers: Vec<HashFunction<K>>,
z: BitVector,
}
fn pre_hash<K: Hash>(x: K) -> u64 {
let mut hasher = DefaultHasher::new();
x.hash(&mut hasher);
hasher.finish()
}
fn hash(x: u64, a: u64, b: u64, l: u32) -> usize {
(a.wrapping_mul(x).wrapping_add(b) >> (64 - l)) as usize
}
fn sample_hash_function<K: Hash>(l: u32) -> Fallible<Arc<dyn Fn(&K) -> usize + Send + Sync>> {
let mut buf = [0u8; 8];
fill_bytes(&mut buf)?;
let a = u64::from_ne_bytes(buf) | 1u64;
fill_bytes(&mut buf)?;
let b = u64::from_ne_bytes(buf);
Ok(Arc::new(move |x: &K| hash(pre_hash(x), a, b, l)))
}
fn exponent_next_power_of_two(x: u64) -> u32 {
let exp = 63 - x.leading_zeros().min(63);
if x > (1 << exp) { exp + 1 } else { exp }
}
fn scale_and_round<CI>(x: CI, alpha: f64, scale: f64) -> Fallible<usize>
where
CI: Integer + ToPrimitive,
{
let mut scale = FBig::<Down>::neg_inf_cast(scale)?;
scale /= FBig::<Down>::inf_cast(alpha)?;
scale = scale
.clone()
.with_precision(
(f64::MANTISSA_DIGITS as i32 - scale.exp())
.max(FBig::ONE)
.to_f64_rounded() as usize,
)
.value();
let r = FBig::from(x.max(CI::zero()).to_u64().unwrap_or_else(|| u64::MAX))
.with_precision(64)
.value()
* scale;
let floored = f64::inf_cast(r.clone().floor())? as usize;
match sample_bernoulli_float(f64::inf_cast(r.fract())?, false)? {
true => Ok(floored + 1),
false => Ok(floored),
}
}
fn compute_prob(alpha: f64) -> f64 {
let alpha: FBig<Down> = FBig::<Down>::neg_inf_cast(alpha).expect("impl is infallible") + 2;
let alpha = FBig::<Up>::ONE / alpha.with_rounding();
f64::inf_cast(alpha).expect("impl is infallible")
}
fn are_parameters_invalid(alpha: f64, scale: f64) -> bool {
let scale = FBig::<Zero>::inf_cast(scale).expect("impl is infallible");
let alpha = FBig::<Zero>::neg_inf_cast(alpha).expect("impl is infallible");
scale * (1i64 << 52) < alpha
}
fn compute_projection<K, CI>(
x: &HashMap<K, CI>,
hashers: &Vec<HashFunction<K>>, alpha: f64,
scale: f64,
projection_size: usize, ) -> Fallible<BitVector>
where
CI: Integer + ToPrimitive,
{
let mut z = vec![false; projection_size];
for (k, v) in x.iter() {
let round = scale_and_round(v.clone(), alpha.clone(), scale.clone())?;
(hashers.iter().take(round)).for_each(|h_i| z[h_i(k) % projection_size] = true);
}
let p = compute_prob(alpha);
z.iter()
.map(|b| sample_bernoulli_float(p, false).map(|flip| b ^ flip))
.collect()
}
fn estimate_unary(v: &Vec<bool>) -> f64 {
let mut prefix_sum = Vec::with_capacity(v.len() + 1usize);
prefix_sum.push(0);
v.iter()
.map(|b| if *b { 1 } else { -1 })
.for_each(|x| prefix_sum.push(prefix_sum.last().unwrap() + x));
let high = prefix_sum.iter().max().unwrap();
let peaks = prefix_sum
.iter()
.enumerate()
.filter_map(|(idx, height)| if high == height { Some(idx) } else { None })
.collect::<Vec<_>>();
peaks.iter().sum::<usize>() as f64 / peaks.len() as f64
}
fn compute_estimate<K>(state: &AlpState<K>, key: &K) -> f64 {
let v = (state.hashers.iter())
.map(|h_i| state.z[h_i(key) % state.z.len()])
.collect::<Vec<_>>();
estimate_unary(&v) * state.alpha as f64 / state.scale
}
pub fn make_alp_state_with_hashers<K, CI>(
input_domain: SparseDomain<K, CI>,
input_metric: L01InfDistance<AbsoluteDistance<CI>>,
scale: f64,
alpha: f64,
projection_size: usize,
hashers: Vec<HashFunction<K>>,
) -> Fallible<
Measurement<
SparseDomain<K, CI>,
L01InfDistance<AbsoluteDistance<CI>>,
MaxDivergence,
AlpState<K>,
>,
>
where
K: 'static + Hashable,
CI: 'static + Integer + ToPrimitive,
f64: InfCast<CI>,
(SparseDomain<K, CI>, L01InfDistance<AbsoluteDistance<CI>>): MetricSpace,
{
if input_domain.value_domain.nan() {
return fallible!(MakeMeasurement, "value domain must be non-nan");
}
if scale.is_sign_negative() || scale.is_zero() {
return fallible!(MakeMeasurement, "scale ({}) must be positive", scale);
}
if alpha.is_sign_negative() || alpha.is_zero() {
return fallible!(MakeMeasurement, "alpha ({}) must be positive", alpha);
}
if projection_size == 0 {
return fallible!(MakeMeasurement, "projection_size must be positive");
}
if are_parameters_invalid(alpha, scale) {
return fallible!(
MakeMeasurement,
"scale divided by alpha must be above 2^-52"
);
}
Measurement::new(
input_domain,
input_metric,
MaxDivergence,
Function::new_fallible(move |x: &HashMap<K, CI>| {
let z = compute_projection(x, &hashers, alpha, scale, projection_size)?;
Ok(AlpState {
alpha,
scale,
hashers: hashers.clone(),
z,
})
}),
PrivacyMap::new_fallible(move |(_l0, l1, _li)| f64::inf_cast(*l1)?.inf_mul(&scale)),
)
}
pub fn make_alp_state<K, CI>(
input_domain: SparseDomain<K, CI>,
input_metric: L01InfDistance<AbsoluteDistance<CI>>,
scale: f64,
total_limit: CI,
value_limit: Option<CI>,
size_factor: Option<u32>,
alpha: Option<u32>,
) -> Fallible<
Measurement<
SparseDomain<K, CI>,
L01InfDistance<AbsoluteDistance<CI>>,
MaxDivergence,
AlpState<K>,
>,
>
where
K: 'static + Hashable,
CI: 'static + Integer + InfCast<f64> + ToPrimitive,
f64: InfCast<CI> + InfCast<u32>,
(SparseDomain<K, CI>, L01InfDistance<AbsoluteDistance<CI>>): MetricSpace,
{
let value_limit: f64 = value_limit
.or_else(|| {
(input_domain.value_domain.bounds.as_ref())
.and_then(|b| b.upper())
.cloned()
})
.ok_or_else(|| {
err!(
MakeMeasurement,
"value_limit is required when data is unbounded"
)
})?
.to_f64()
.ok_or_else(|| err!(MakeMeasurement, "failed to parse value_limit"))?;
let total_limit: f64 = total_limit
.to_f64()
.ok_or_else(|| err!(MakeMeasurement, "failed to parse total_limit"))?;
let size_factor = size_factor.unwrap_or(SIZE_FACTOR_DEFAULT) as f64;
let alpha = alpha.unwrap_or(ALPHA_DEFAULT) as f64;
let quotient = (scale / alpha)
.to_f64()
.ok_or_else(|| err!(MakeTransformation, "failed to parse scale"))?;
let m = usize::inf_cast(value_limit * quotient)?;
let exp = exponent_next_power_of_two((size_factor * total_limit * quotient) as u64);
let hashers = (0..m)
.map(|_| sample_hash_function(exp))
.collect::<Fallible<Vec<HashFunction<K>>>>()?;
make_alp_state_with_hashers(input_domain, input_metric, scale, alpha, 1 << exp, hashers)
}
pub fn post_alp_state_to_queryable<K>() -> Function<AlpState<K>, Queryable<K, f64>>
where
K: 'static + Clone,
{
Function::new(move |state: &AlpState<K>| {
let state = state.clone();
Queryable::new_raw_external(move |key: &K| Ok(compute_estimate(&state, key)))
})
}
#[bootstrap(
features("contrib"),
arguments(
input_domain(c_type = "AnyDomain *"),
input_metric(c_type = "AnyMetric *"),
total_limit(c_type = "void *"),
value_limit(c_type = "void *", default = b"null"),
size_factor(default = 50),
alpha(default = 4),
),
generics(K(suppress), CI(suppress)),
derived_types(CI = "$get_value_type(get_carrier_type(input_domain))")
)]
pub fn make_alp_queryable<K, CI>(
input_domain: MapDomain<AtomDomain<K>, AtomDomain<CI>>,
input_metric: L01InfDistance<AbsoluteDistance<CI>>,
scale: f64,
total_limit: CI,
value_limit: Option<CI>,
size_factor: Option<u32>,
alpha: Option<u32>,
) -> Fallible<
Measurement<
MapDomain<AtomDomain<K>, AtomDomain<CI>>,
L01InfDistance<AbsoluteDistance<CI>>,
MaxDivergence,
Queryable<K, f64>,
>,
>
where
K: 'static + Hashable,
CI: 'static + Integer + InfCast<f64> + ToPrimitive,
f64: InfCast<CI>,
(
MapDomain<AtomDomain<K>, AtomDomain<CI>>,
L01InfDistance<AbsoluteDistance<CI>>,
): MetricSpace,
{
make_alp_state(
input_domain,
input_metric,
scale,
total_limit,
value_limit,
size_factor,
alpha,
)? >> post_alp_state_to_queryable()
}