use crate::{
core::{Function, StabilityMap, Transformation},
domains::{AtomDomain, VectorDomain},
error::Fallible,
metrics::{AbsoluteDistance, IntDistance, SymmetricDistance},
traits::{
AlertingAbs, ExactIntCast, InfAdd, InfCast, InfMul, InfSub, ProductOrd, samplers::Shuffle,
},
};
use dashu::integer::IBig;
use num::{One, Zero};
use opendp_derive::bootstrap;
use super::{Float, Pairwise, Sequential, SumRelaxation};
#[cfg(feature = "ffi")]
mod ffi;
#[bootstrap(
features("contrib"),
arguments(bounds(rust_type = "(T, T)")),
generics(S(default = "Pairwise<T>")),
returns(c_type = "FfiResult<AnyTransformation *>"),
derived_types(T = "$get_atom_or_infer(S, get_first(bounds))")
)]
pub fn make_bounded_float_checked_sum<S>(
size_limit: usize,
bounds: (S::Item, S::Item),
) -> Fallible<
Transformation<
VectorDomain<AtomDomain<S::Item>>,
SymmetricDistance,
AtomDomain<S::Item>,
AbsoluteDistance<S::Item>,
>,
>
where
S: UncheckedSum,
S::Item: 'static + Float,
{
if S::can_float_sum_overflow(size_limit, bounds)? {
return fallible!(
MakeTransformation,
"potential for overflow when computing function. You could resolve this by choosing tighter clipping bounds."
);
}
let (lower, upper) = bounds;
let ideal_sensitivity = upper
.inf_sub(&lower)?
.total_max(lower.alerting_abs()?.total_max(upper)?)?;
let relaxation = S::relaxation(size_limit, lower, upper)?;
Transformation::new(
VectorDomain::new(AtomDomain::new_closed(bounds)?),
SymmetricDistance,
AtomDomain::new_non_nan(),
AbsoluteDistance::default(),
Function::new_fallible(move |arg: &Vec<S::Item>| {
let mut data = arg.clone();
if arg.len() > size_limit {
data.shuffle()?
}
Ok(S::unchecked_sum(&data[..size_limit.min(data.len())]))
}),
StabilityMap::new_fallible(move |d_in: &IntDistance| {
S::Item::inf_cast(*d_in)?
.inf_mul(&ideal_sensitivity)?
.inf_add(&relaxation)
}),
)
}
#[bootstrap(
features("contrib"),
arguments(bounds(rust_type = "(T, T)")),
generics(S(default = "Pairwise<T>")),
returns(c_type = "FfiResult<AnyTransformation *>"),
derived_types(T = "$get_atom_or_infer(S, get_first(bounds))")
)]
pub fn make_sized_bounded_float_checked_sum<S>(
size: usize,
bounds: (S::Item, S::Item),
) -> Fallible<
Transformation<
VectorDomain<AtomDomain<S::Item>>,
SymmetricDistance,
AtomDomain<S::Item>,
AbsoluteDistance<S::Item>,
>,
>
where
S: UncheckedSum,
S::Item: 'static + Float,
{
if S::can_float_sum_overflow(size, bounds)? {
return fallible!(
MakeTransformation,
"potential for overflow when computing function. You could resolve this by choosing tighter clipping bounds."
);
}
let (lower, upper) = bounds;
let ideal_sensitivity = upper.inf_sub(&lower)?;
let relaxation = S::relaxation(size, lower, upper)?;
Transformation::new(
VectorDomain::new(AtomDomain::new_closed(bounds)?).with_size(size),
SymmetricDistance,
AtomDomain::new_non_nan(),
AbsoluteDistance::default(),
Function::new(move |arg: &Vec<S::Item>| S::unchecked_sum(arg)),
StabilityMap::new_fallible(move |d_in: &IntDistance| {
S::Item::inf_cast(d_in / 2)?
.inf_mul(&ideal_sensitivity)?
.inf_add(&relaxation)
}),
)
}
#[doc(hidden)]
pub trait UncheckedSum: SumRelaxation + CanFloatSumOverflow {
fn unchecked_sum(arg: &[Self::Item]) -> Self::Item;
}
impl<T: Float> UncheckedSum for Sequential<T> {
fn unchecked_sum(arg: &[T]) -> T {
arg.iter().cloned().sum()
}
}
impl<T: Float> UncheckedSum for Pairwise<T> {
fn unchecked_sum(arg: &[T]) -> T {
match arg.len() {
0 => T::zero(),
1 => arg[0],
n => {
let m = n / 2;
Self::unchecked_sum(&arg[..m]) + Self::unchecked_sum(&arg[m..])
}
}
}
}
#[doc(hidden)]
pub trait CanFloatSumOverflow: SumRelaxation {
fn can_float_sum_overflow(size: usize, bounds: (Self::Item, Self::Item)) -> Fallible<bool>;
}
impl<T: Float> CanFloatSumOverflow for Sequential<T> {
fn can_float_sum_overflow(size: usize, (lower, upper): (T, T)) -> Fallible<bool> {
let _2 = T::one() + T::one();
let size_ = T::inf_cast(size)?;
let mag = lower.alerting_abs()?.total_max(upper)?;
let mag_limit = _2.powf(T::exact_int_cast(
T::EXPONENT_BIAS - T::MANTISSA_BITS - T::Bits::one(),
)?);
if mag < mag_limit {
return Ok(false);
}
Ok(round_up_to_nearest_power_of_two(mag)?
.inf_mul(&size_)
.is_err())
}
}
impl<T: Float> CanFloatSumOverflow for Pairwise<T> {
fn can_float_sum_overflow(size: usize, (lower, upper): (T, T)) -> Fallible<bool> {
let _2 = T::one() + T::one();
let size_ = T::inf_cast(size)?;
let mag = lower.alerting_abs()?.total_max(upper)?;
let max_ulp = _2.powf(T::exact_int_cast(T::EXPONENT_BIAS - T::MANTISSA_BITS)?);
let mag_limit = max_ulp.neg_inf_div(&size_)?;
if mag < mag_limit {
return Ok(false);
}
Ok(round_up_to_nearest_power_of_two(mag)?
.inf_mul(&size_)
.is_err())
}
}
fn round_up_to_nearest_power_of_two<T>(x: T) -> Fallible<T>
where
T: ExactIntCast<T::Bits> + Float,
{
if x.is_sign_negative() {
return fallible!(
FailedFunction,
"get_smallest_greater_or_equal_power_of_two must have a positive argument"
);
}
let exponent_bias: IBig = T::EXPONENT_BIAS.into();
let exponent: IBig = x.raw_exponent().into();
let exponent_unbiased = exponent - exponent_bias;
let pow = exponent_unbiased
+ if x.mantissa().is_zero() {
IBig::ZERO
} else {
IBig::ONE
};
let _2 = T::one() + T::one();
_2.inf_powi(pow)
}
#[cfg(test)]
mod test;