use crate::{
core::{Domain, Function, Measurement, Metric, MetricSpace, PrivacyMap},
error::Fallible,
measures::MaxDivergence,
traits::{CastInternalRational, InfLn, InfMul, InfSub, samplers::sample_geometric_exp_fast},
};
use dashu::integer::UBig;
use opendp_derive::bootstrap;
use std::{fmt::Debug, ops::Neg};
#[cfg(feature = "ffi")]
mod ffi;
#[cfg(test)]
mod test;
#[bootstrap(
features("contrib"),
arguments(measurement(rust_type = "AnyMeasurement"),),
generics(DI(suppress), MI(suppress), TO(suppress))
)]
pub fn make_select_private_candidate<
DI: 'static + Domain,
MI: 'static + Metric,
TO: 'static + Debug,
>(
measurement: Measurement<DI, MI, MaxDivergence, (f64, TO)>,
stop_probability: f64,
threshold: f64,
) -> Fallible<Measurement<DI, MI, MaxDivergence, Option<(f64, TO)>>>
where
(DI, MI): MetricSpace,
{
if !(0f64..1f64).contains(&stop_probability) {
return fallible!(MakeMeasurement, "stop_probability must be in [0, 1)");
}
if !threshold.is_finite() {
return fallible!(MakeMeasurement, "threshold must be finite");
}
let scale = if stop_probability > 0.0 {
let ln_cp = 1.0.neg_inf_sub(&stop_probability)?.inf_ln()?;
Some(ln_cp.recip().neg().into_rational()?)
} else {
None
};
let function = measurement.function.clone();
let privacy_map = measurement.privacy_map.clone();
Measurement::new(
measurement.input_domain.clone(),
measurement.input_metric.clone(),
measurement.output_measure.clone(),
Function::new_fallible(move |arg| {
let mut remaining_iterations = (scale.clone())
.map(|s| sample_geometric_exp_fast(s).map(|v| v + UBig::ONE))
.transpose()?;
loop {
let (score, output) = function.eval(arg)?;
if score >= threshold {
return Ok(Some((score, output)));
}
if let Some(i) = remaining_iterations.as_mut() {
*i -= UBig::ONE;
if i == &UBig::ZERO {
return Ok(None);
}
}
}
}),
PrivacyMap::new_fallible(move |d_in| privacy_map.eval(d_in)?.inf_mul(&2.0)),
)
}