use std::{cmp::Ordering, iter::zip};
use dashu::{integer::UBig, rational::RBig};
use opendp_derive::{bootstrap, proven};
use crate::{
core::{Function, MetricSpace, StabilityMap, Transformation},
domains::{AtomDomain, VectorDomain},
error::Fallible,
metrics::{IntDistance, LInfDistance},
traits::{AlertingMul, ExactIntCast, InfDiv, Integer, Number, RoundCast},
};
use super::traits::UnboundedMetric;
#[cfg(feature = "ffi")]
mod ffi;
#[cfg(test)]
mod test;
#[bootstrap(
features("contrib"),
generics(MI(suppress), TIA(suppress)),
derived_types(TIA = "$get_atom(get_type(input_domain))")
)]
pub fn make_quantile_score_candidates<MI: UnboundedMetric, TIA: Number>(
input_domain: VectorDomain<AtomDomain<TIA>>,
input_metric: MI,
candidates: Vec<TIA>,
alpha: f64,
) -> Fallible<
Transformation<
VectorDomain<AtomDomain<TIA>>,
MI,
VectorDomain<AtomDomain<u64>>,
LInfDistance<u64>,
>,
>
where
(VectorDomain<AtomDomain<TIA>>, MI): MetricSpace,
{
if input_domain.element_domain.nan() {
return fallible!(
MakeTransformation,
"input_domain members must have non-nan elements"
);
}
check_candidates(&candidates)?;
let (alpha_num, alpha_den, size_limit) = score_candidates_constants(
input_domain.size.map(u64::exact_int_cast).transpose()?,
alpha,
)?;
Transformation::<_, _, VectorDomain<AtomDomain<u64>>, _>::new(
input_domain.clone(),
input_metric,
VectorDomain::default().with_size(candidates.len()),
LInfDistance::default(),
Function::new(move |arg: &Vec<TIA>| {
Vec::from_iter(score_candidates(
arg.iter().cloned(),
candidates.clone(),
alpha_num,
alpha_den,
size_limit,
))
}),
StabilityMap::new_fallible(score_candidates_map(
alpha_num,
alpha_den,
input_domain.size.is_some(),
)),
)
}
#[proven(proof_path = "transformations/quantile_score_candidates/check_candidates.tex")]
pub(crate) fn check_candidates<T: Number>(candidates: &Vec<T>) -> Fallible<()> {
if candidates.is_empty() {
return fallible!(MakeTransformation, "candidates must be non-empty");
}
if candidates.windows(2).any(|w| {
w[0].partial_cmp(&w[1])
.map(|c| c != Ordering::Less)
.unwrap_or(true)
}) {
return fallible!(
MakeTransformation,
"candidates must be non-null and strictly increasing"
);
}
Ok(())
}
#[proven(proof_path = "transformations/quantile_score_candidates/score_candidates_constants.tex")]
pub(crate) fn score_candidates_constants(
size: Option<u64>,
alpha: f64,
) -> Fallible<(u64, u64, u64)> {
if !(0.0..=1.0).contains(&alpha) {
return fallible!(MakeTransformation, "alpha must be within [0, 1]");
}
let (alpha_num_exact, alpha_den_exact) = RBig::try_from(alpha)?.into_parts();
let alpha_den_approx = if let Some(size) = size {
u64::MAX.neg_inf_div(&size)?
} else {
u64::exact_int_cast(10_000)?
};
let (alpha_num, alpha_den) = if alpha_den_exact < UBig::from(alpha_den_approx) {
(
u64::try_from(alpha_num_exact.into_parts().1)?,
u64::try_from(alpha_den_exact)?,
)
} else {
let alpha_num_approx = u64::round_cast(alpha * f64::round_cast(alpha_den_approx.clone())?)?;
(alpha_num_approx, alpha_den_approx)
};
let size_limit = if let Some(size_limit) = size {
size_limit
} else {
u64::MAX.neg_inf_div(&alpha_den)?
};
assert!(alpha_num <= alpha_den);
size_limit.alerting_mul(&alpha_den)?;
Ok((alpha_num, alpha_den, size_limit))
}
#[proven(proof_path = "transformations/quantile_score_candidates/score_candidates_map.tex")]
pub(crate) fn score_candidates_map<T: Integer + ExactIntCast<IntDistance>>(
alpha_num: T,
alpha_den: T,
known_size: bool,
) -> impl Fn(&IntDistance) -> Fallible<T> {
move |d_in| {
if known_size {
T::exact_int_cast(d_in / 2)? .alerting_mul(&alpha_den)
} else {
let abs_dist_const = alpha_num.max(alpha_den - alpha_num);
T::exact_int_cast(*d_in)?.alerting_mul(&abs_dist_const)
}
}
}
#[proven(proof_path = "transformations/quantile_score_candidates/score_candidates.tex")]
pub(crate) fn score_candidates<TIA: PartialOrd>(
x: impl Iterator<Item = TIA>,
candidates: Vec<TIA>,
alpha_num: u64,
alpha_den: u64,
size_limit: u64,
) -> impl Iterator<Item = u64> {
let mut hist_ro = vec![0u64; candidates.len() + 1]; let mut hist_lo = vec![0u64; candidates.len() + 1];
x.for_each(|x_i| {
let idx_lt = candidates.partition_point(|c| *c < x_i);
hist_lo[idx_lt] += 1;
let idx_eq = idx_lt + candidates[idx_lt..].partition_point(|c| *c == x_i);
hist_ro[idx_eq] += 1;
});
let n: u64 = hist_lo.iter().sum();
hist_ro.pop();
hist_lo.pop();
let (mut lt, mut le) = (0u64, 0u64);
zip(hist_ro, hist_lo).map(move |(ro, lo)| {
lt += ro;
le += lo;
let gt = n - le;
let (lt_lim, gt_lim) = (lt.min(size_limit), gt.min(size_limit));
((alpha_den - alpha_num) * lt_lim).abs_diff(alpha_num * gt_lim)
})
}