use dashu::{integer::IBig, rational::RBig};
use opendp_derive::proven;
use crate::{
core::{Domain, Function, Measure, Measurement, Metric, MetricSpace, PrivacyMap},
domains::{AtomDomain, VectorDomain},
error::Fallible,
traits::samplers::{sample_discrete_gaussian, sample_discrete_laplace},
};
#[cfg(test)]
mod test;
pub(crate) mod nature;
mod distribution;
pub use distribution::*;
pub trait MakeNoise<DI: Domain, MI: Metric, MO: Measure>
where
(DI, MI): MetricSpace,
{
fn make_noise(self, input_space: (DI, MI)) -> Fallible<Measurement<DI, MI, MO, DI::Carrier>>;
}
pub trait NoisePrivacyMap<MI: Metric, MO: Measure>: Sample {
fn noise_privacy_map(
&self,
input_metric: &MI,
output_measure: &MO,
) -> Fallible<PrivacyMap<MI, MO>>;
}
#[derive(Clone)]
pub struct ZExpFamily<const P: usize> {
pub scale: RBig,
}
pub trait Sample: 'static + Clone + Send + Sync {
fn sample(&self, shift: &IBig) -> Fallible<IBig>;
}
#[proven(proof_path = "measurements/noise/Sample_for_ZExpFamily1.tex")]
impl Sample for ZExpFamily<1> {
fn sample(&self, shift: &IBig) -> Fallible<IBig> {
Ok(shift + sample_discrete_laplace(self.scale.clone())?)
}
}
#[proven(proof_path = "measurements/noise/Sample_for_ZExpFamily2.tex")]
impl Sample for ZExpFamily<2> {
fn sample(&self, shift: &IBig) -> Fallible<IBig> {
Ok(shift + sample_discrete_gaussian(self.scale.clone())?)
}
}
#[proven(proof_path = "measurements/noise/MakeNoise_IBig_for_RV.tex")]
impl<MI: Metric, MO: 'static + Measure, RV: Sample>
MakeNoise<VectorDomain<AtomDomain<IBig>>, MI, MO> for RV
where
(VectorDomain<AtomDomain<IBig>>, MI): MetricSpace,
RV: NoisePrivacyMap<MI, MO>,
{
fn make_noise(
self,
(input_domain, input_metric): (VectorDomain<AtomDomain<IBig>>, MI),
) -> Fallible<Measurement<VectorDomain<AtomDomain<IBig>>, MI, MO, Vec<IBig>>> {
let distribution = self.clone();
let output_measure = MO::default();
let privacy_map = self.noise_privacy_map(&input_metric, &output_measure)?;
Measurement::new(
input_domain,
input_metric,
output_measure,
Function::new_fallible(move |x: &Vec<IBig>| {
x.into_iter().map(|x_i| distribution.sample(x_i)).collect()
}),
privacy_map,
)
}
}