use opendp_derive::bootstrap;
use std::{cell::RefCell, fmt::Debug, rc::Rc};
use crate::{
combinators::{Adaptivity, Composability, assert_elements_match},
core::{Domain, Function, Measurement, Metric, MetricSpace, PrivacyMap},
error::Fallible,
interactive::{Answer, Query, Queryable, Wrapper},
traits::ProductOrd,
};
#[cfg(test)]
mod test;
#[cfg(feature = "ffi")]
mod ffi;
use super::CompositionMeasure;
#[bootstrap(
features("contrib"),
arguments(
d_in(rust_type = "$get_distance_type(input_metric)", c_type = "AnyObject *"),
d_mids(rust_type = "Vec<QO>", c_type = "AnyObject *"),
output_measure(c_type = "AnyMeasure *", rust_type = b"null")
),
generics(DI(suppress), TO(suppress), MI(suppress), MO(suppress)),
derived_types(QO = "$get_distance_type(output_measure)")
)]
pub fn make_adaptive_composition<
DI: Domain + 'static,
MI: Metric + 'static,
MO: CompositionMeasure + 'static,
TO: 'static,
>(
input_domain: DI,
input_metric: MI,
output_measure: MO,
d_in: MI::Distance,
mut d_mids: Vec<MO::Distance>,
) -> Fallible<Measurement<DI, MI, MO, Queryable<Measurement<DI, MI, MO, TO>, TO>>>
where
DI::Carrier: 'static + Clone,
MI::Distance: 'static + ProductOrd + Clone,
MO::Distance: 'static + ProductOrd + Clone + Debug,
(DI, MI): MetricSpace,
{
if d_mids.len() == 0 {
return fallible!(MakeMeasurement, "d_mids must have at least one element");
}
d_mids.reverse();
let d_out = output_measure.compose(d_mids.clone())?;
let require_sequentiality = matches!(
output_measure.composability(Adaptivity::Adaptive)?,
Composability::Sequential
);
let d_in_constructor = d_in.clone();
Measurement::new(
input_domain.clone(),
input_metric.clone(),
output_measure.clone(),
Function::new_fallible(move |data: &DI::Carrier| {
let input_domain = input_domain.clone();
let input_metric = input_metric.clone();
let output_measure = output_measure.clone();
let d_in = d_in.clone();
let data = data.clone();
let mut d_mids = d_mids.clone();
Queryable::new(move |self_, query: Query<Measurement<DI, MI, MO, TO>>| {
struct AskPermission(usize);
if let Query::External(meas) = query {
assert_elements_match!(DomainMismatch, input_domain, meas.input_domain);
assert_elements_match!(MetricMismatch, input_metric, meas.input_metric);
assert_elements_match!(MeasureMismatch, output_measure, meas.output_measure);
let d_mid =
(d_mids.last()).ok_or_else(|| err!(FailedFunction, "out of queries"))?;
if !meas.check(&d_in, d_mid)? {
return fallible!(
FailedFunction,
"insufficient budget for query: {:?} > {:?}",
meas.map(&d_in)?,
d_mid
);
}
let enforce_sequentiality = Rc::new(RefCell::new(false));
let seq_wrapper = require_sequentiality.then(|| {
let child_id = d_mids.len() - 1;
let mut self_ = self_.clone();
Wrapper::new_recursive_pre_hook(enclose!(
enforce_sequentiality,
move || {
if *enforce_sequentiality.borrow() {
self_.eval_internal(&AskPermission(child_id))
} else {
Ok(())
}
}
))
});
let answer = meas.invoke_wrap(&data, seq_wrapper)?;
*enforce_sequentiality.borrow_mut() = true;
d_mids.pop();
return Ok(Answer::External(answer));
}
if let Query::Internal(query) = query {
if let Some(AskPermission(id)) = query.downcast_ref() {
if *id != d_mids.len() {
return fallible!(
FailedFunction,
"Adaptive compositor has received a new query. To satisfy the sequentiality constraint of adaptive composition, only the most recent release from the parent compositor may be interacted with."
);
}
return Ok(Answer::internal(()));
}
}
fallible!(FailedFunction, "unrecognized query: {:?}", query)
})
}),
PrivacyMap::new_fallible(move |d_in_map: &MI::Distance| {
if d_in_map.total_gt(&d_in_constructor)? {
fallible!(
RelationDebug,
"d_in from the privacy map must be no greater than the d_in passed into the constructor"
)
} else {
Ok(d_out.clone())
}
}),
)
}
#[bootstrap(
features("contrib"),
arguments(
d_in(rust_type = "$get_distance_type(input_metric)", c_type = "AnyObject *"),
d_mids(rust_type = "Vec<QO>", c_type = "AnyObject *"),
output_measure(c_type = "AnyMeasure *", rust_type = b"null")
),
generics(DI(suppress), TO(suppress), MI(suppress), MO(suppress)),
derived_types(QO = "$get_distance_type(output_measure)")
)]
#[deprecated(
since = "0.14.0",
note = "This function has been renamed: use :py:func:`~opendp.combinators.make_adaptive_composition` instead."
)]
pub fn make_sequential_composition<
DI: Domain + 'static,
MI: Metric + 'static,
MO: CompositionMeasure + 'static,
TO: 'static,
>(
input_domain: DI,
input_metric: MI,
output_measure: MO,
d_in: MI::Distance,
d_mids: Vec<MO::Distance>,
) -> Fallible<Measurement<DI, MI, MO, Queryable<Measurement<DI, MI, MO, TO>, TO>>>
where
DI::Carrier: 'static + Clone,
MI::Distance: 'static + ProductOrd + Clone + Send + Sync,
MO::Distance: 'static + ProductOrd + Clone + Send + Sync + Debug,
(DI, MI): MetricSpace,
{
make_adaptive_composition(input_domain, input_metric, output_measure, d_in, d_mids)
}