#[cfg(feature = "ffi")]
mod ffi;
mod shr;
use crate::core::{
Domain, Function, Measure, Measurement, Metric, MetricSpace, PrivacyMap, StabilityMap,
Transformation,
};
use crate::error::{Error, ErrorVariant, Fallible};
use std::fmt::Debug;
const ERROR_URL: &str = "https://github.com/opendp/opendp/discussions/297";
macro_rules! assert_elements_match {
($variant:ident, $v1:expr, $v2:expr) => {
if &$v1 != &$v2 {
return Err($crate::combinators::mismatch_error(
$crate::error::ErrorVariant::$variant,
&$v1,
&$v2,
));
}
};
}
pub(crate) use assert_elements_match;
pub(crate) fn mismatch_error<T: Debug>(variant: ErrorVariant, struct1: &T, struct2: &T) -> Error {
let str1 = format!("{:?}", struct1);
let str2 = format!("{:?}", struct2);
let mode = match &variant {
ErrorVariant::DomainMismatch => "domain",
ErrorVariant::MetricMismatch => "metric",
ErrorVariant::MeasureMismatch => "measure",
_ => unimplemented!("unrecognized error variant"),
};
let explanation = if str1 == str2 {
format!(
"\n The structure of the intermediate {mode}s are the same, but the parameters differ.\n shared_{mode}: {str1}\n",
mode = mode,
str1 = str1
)
} else {
format!(
"\n output_{mode}: {struct1}\n input_{mode}: {struct2}\n",
mode = mode,
struct1 = str1,
struct2 = str2
)
};
Error {
variant,
message: Some(format!(
"Intermediate {}s don't match. See {}{}",
mode, ERROR_URL, explanation
)),
backtrace: err!(@backtrace),
}
}
pub fn make_chain_mt<DI, DX, TO, MI, MX, MO>(
measurement1: &Measurement<DX, MX, MO, TO>,
transformation0: &Transformation<DI, MI, DX, MX>,
) -> Fallible<Measurement<DI, MI, MO, TO>>
where
DI: 'static + Domain,
DX: 'static + Domain,
TO: 'static,
MI: 'static + Metric,
MX: 'static + Metric,
MO: 'static + Measure,
(DI, MI): MetricSpace,
(DX, MX): MetricSpace,
{
assert_elements_match!(
DomainMismatch,
transformation0.output_domain,
measurement1.input_domain
);
assert_elements_match!(
MetricMismatch,
transformation0.output_metric,
measurement1.input_metric
);
Measurement::new(
transformation0.input_domain.clone(),
transformation0.input_metric.clone(),
measurement1.output_measure.clone(),
Function::make_chain(&measurement1.function, &transformation0.function),
PrivacyMap::make_chain(&measurement1.privacy_map, &transformation0.stability_map),
)
}
pub fn make_chain_tt<DI, DX, DO, MI, MX, MO>(
transformation1: &Transformation<DX, MX, DO, MO>,
transformation0: &Transformation<DI, MI, DX, MX>,
) -> Fallible<Transformation<DI, MI, DO, MO>>
where
DI: 'static + Domain,
DX: 'static + Domain,
DO: 'static + Domain,
MI: 'static + Metric,
MX: 'static + Metric,
MO: 'static + Metric,
(DI, MI): MetricSpace,
(DX, MX): MetricSpace,
(DO, MO): MetricSpace,
{
assert_elements_match!(
DomainMismatch,
transformation0.output_domain,
transformation1.input_domain
);
assert_elements_match!(
MetricMismatch,
transformation0.output_metric,
transformation1.input_metric
);
Transformation::new(
transformation0.input_domain.clone(),
transformation0.input_metric.clone(),
transformation1.output_domain.clone(),
transformation1.output_metric.clone(),
Function::make_chain(&transformation1.function, &transformation0.function),
StabilityMap::make_chain(
&transformation1.stability_map,
&transformation0.stability_map,
),
)
}
pub fn make_chain_pm<DI, TX, TO, MI, MO>(
postprocess1: &Function<TX, TO>,
measurement0: &Measurement<DI, MI, MO, TX>,
) -> Fallible<Measurement<DI, MI, MO, TO>>
where
DI: 'static + Domain,
TX: 'static,
TO: 'static,
MI: 'static + Metric,
MO: 'static + Measure,
(DI, MI): MetricSpace,
{
Measurement::new(
measurement0.input_domain.clone(),
measurement0.input_metric.clone(),
measurement0.output_measure.clone(),
Function::make_chain(postprocess1, &measurement0.function),
measurement0.privacy_map.clone(),
)
}
#[cfg(test)]
mod test;