use std::{cell::RefCell, rc::Rc};
use opendp_derive::{bootstrap, proven};
use crate::{
combinators::{Adaptivity, Composability, CompositionMeasure, assert_elements_match},
core::{
Domain, Function, Measurement, Metric, MetricSpace, Odometer, OdometerAnswer,
OdometerQuery, OdometerQueryable, PrivacyMap,
},
error::Fallible,
interactive::{Answer, Query, Queryable, Wrapper},
};
#[cfg(test)]
mod test;
#[cfg(feature = "ffi")]
mod ffi;
#[bootstrap(
features("contrib"),
arguments(output_measure(c_type = "AnyMeasure *", rust_type = b"null"),),
generics(DI(suppress), TO(suppress), MI(suppress), MO(suppress))
)]
pub fn make_fully_adaptive_composition<
DI: 'static + Domain,
MI: 'static + Metric,
MO: 'static + CompositionMeasure,
TO: 'static,
>(
input_domain: DI,
input_metric: MI,
output_measure: MO,
) -> Fallible<Odometer<DI, MI, MO, Measurement<DI, MI, MO, TO>, TO>>
where
DI::Carrier: Clone,
(DI, MI): MetricSpace,
{
output_measure.composability(Adaptivity::FullyAdaptive)?;
Odometer::new(
input_domain.clone(),
input_metric.clone(),
output_measure.clone(),
Function::new_fallible(move |arg: &DI::Carrier| {
new_fully_adaptive_composition_queryable(
input_domain.clone(),
input_metric.clone(),
output_measure.clone(),
arg.clone(),
)
}),
)
}
#[proven(
proof_path = "combinators/sequential_composition/fully_adaptive/new_fully_adaptive_composition_queryable.tex"
)]
fn new_fully_adaptive_composition_queryable<
DI: 'static + Domain,
TO: 'static,
MI: 'static + Metric,
MO: 'static + CompositionMeasure,
>(
input_domain: DI,
input_metric: MI,
output_measure: MO,
data: DI::Carrier,
) -> Fallible<OdometerQueryable<Measurement<DI, MI, MO, TO>, TO, MI::Distance, MO::Distance>>
where
(DI, MI): MetricSpace,
{
let require_sequentiality = matches!(
output_measure.composability(Adaptivity::FullyAdaptive)?,
Composability::Sequential
);
let mut privacy_maps: Vec<PrivacyMap<MI, MO>> = vec![];
Queryable::new(
move |self_: &OdometerQueryable<
Measurement<DI, MI, MO, TO>,
TO,
MI::Distance,
MO::Distance,
>,
query: Query<OdometerQuery<Measurement<DI, MI, MO, TO>, _>>| {
struct AskPermission(usize);
Ok(match query {
Query::External(OdometerQuery::Invoke(meas)) => {
assert_elements_match!(DomainMismatch, &input_domain, &meas.input_domain);
assert_elements_match!(MetricMismatch, &input_metric, &meas.input_metric);
assert_elements_match!(MeasureMismatch, &output_measure, &meas.output_measure);
let enforce_sequentiality = Rc::new(RefCell::new(false));
let seq_wrapper = require_sequentiality.then(|| {
let child_id = privacy_maps.len();
let mut self_ = self_.clone();
Wrapper::new_recursive_pre_hook(enclose!(
enforce_sequentiality,
move || {
if *enforce_sequentiality.borrow() {
self_.eval_internal(&AskPermission(child_id))?
} else {
Ok(())
}
}
))
});
let answer = meas.invoke_wrap(&data, seq_wrapper)?;
*enforce_sequentiality.borrow_mut() = true;
privacy_maps.push(meas.privacy_map.clone());
Answer::External(OdometerAnswer::Invoke(answer))
}
Query::External(OdometerQuery::PrivacyLoss(d_in)) => {
let d_mids = (privacy_maps.iter())
.map(|m| m.eval(d_in))
.collect::<Fallible<_>>()?;
let d_out = output_measure.compose(d_mids)?;
Answer::External(OdometerAnswer::PrivacyLoss(d_out))
}
Query::Internal(query) => {
if let Some(AskPermission(id)) = query.downcast_ref() {
if *id + 1 != privacy_maps.len() {
return fallible!(
FailedFunction,
"sequential odometer has received a new query"
);
}
return Ok(Answer::internal(()));
}
return fallible!(FailedFunction, "query not recognized");
}
})
},
)
}