use std::ops::Shr;
use crate::core::{Domain, Function, HintMt, HintTt, Measure, Measurement, Metric, PrivacyRelation, StabilityRelation, Transformation};
use crate::dom::PairDomain;
use crate::error::Fallible;
use std::fmt::Debug;
const ERROR_URL: &str = "https://github.com/opendp/opendp/discussions/297";
fn mismatch_message<T1: Debug, T2: Debug>(mode: &str, struct1: &T1, struct2: &T2) -> String {
let str1 = format!("{:?}", struct1);
let str2 = format!("{:?}", struct2);
let explanation = if str1 == str2 {
format!("\n The structure of the intermediate {mode}s are the same, but the types or parameters differ.\n shared_{mode}: {str1}\n", mode=mode, str1=str1)
} else {
format!("\n output_{mode}: {struct1}\n input_{mode}: {struct2}\n", mode=mode, struct1=str1, struct2=str2)
};
return format!("Intermediate {}s don't match. See {}{}", mode, ERROR_URL, explanation)
}
pub fn make_chain_mt<DI, DX, DO, MI, MX, MO>(
measurement1: &Measurement<DX, DO, MX, MO>,
transformation0: &Transformation<DI, DX, MI, MX>,
hint: Option<&HintMt<MI, MO, MX>>,
) -> Fallible<Measurement<DI, DO, MI, MO>>
where DI: 'static + Domain,
DX: 'static + Domain,
DO: 'static + Domain,
MI: 'static + Metric,
MX: 'static + Metric,
MO: 'static + Measure {
if transformation0.output_domain != measurement1.input_domain {
return fallible!(DomainMismatch, mismatch_message("domain", &transformation0.output_domain, &measurement1.input_domain))
}
if transformation0.output_metric != measurement1.input_metric {
return fallible!(MetricMismatch, mismatch_message("metric", &transformation0.output_metric, &measurement1.input_metric))
}
Ok(Measurement::new(
transformation0.input_domain.clone(),
measurement1.output_domain.clone(),
Function::make_chain(&measurement1.function, &transformation0.function),
transformation0.input_metric.clone(),
measurement1.output_measure.clone(),
PrivacyRelation::make_chain(&measurement1.privacy_relation,&transformation0.stability_relation, hint)
))
}
pub fn make_chain_tt<DI, DX, DO, MI, MX, MO>(
transformation1: &Transformation<DX, DO, MX, MO>,
transformation0: &Transformation<DI, DX, MI, MX>,
hint: Option<&HintTt<MI, MO, MX>>,
) -> Fallible<Transformation<DI, DO, MI, MO>>
where DI: 'static + Domain,
DX: 'static + Domain,
DO: 'static + Domain,
MI: 'static + Metric,
MX: 'static + Metric,
MO: 'static + Metric {
if transformation0.output_domain != transformation1.input_domain {
return fallible!(DomainMismatch, mismatch_message("domain", &transformation0.output_domain, &transformation1.input_domain))
}
if transformation0.output_metric != transformation1.input_metric {
return fallible!(MetricMismatch, mismatch_message("metric", &transformation0.output_metric, &transformation1.input_metric))
}
Ok(Transformation::new(
transformation0.input_domain.clone(),
transformation1.output_domain.clone(),
Function::make_chain(&transformation1.function, &transformation0.function),
transformation0.input_metric.clone(),
transformation1.output_metric.clone(),
StabilityRelation::make_chain(&transformation1.stability_relation,&transformation0.stability_relation, hint)
))
}
pub fn make_basic_composition<DI, DO0, DO1, MI, MO>(measurement0: &Measurement<DI, DO0, MI, MO>, measurement1: &Measurement<DI, DO1, MI, MO>) -> Fallible<Measurement<DI, PairDomain<DO0, DO1>, MI, MO>>
where DI: 'static + Domain,
DO0: 'static + Domain,
DO1: 'static + Domain,
MI: 'static + Metric,
MO: 'static + Measure {
if measurement0.input_domain != measurement1.input_domain {
return fallible!(DomainMismatch, "Input domain mismatch");
} else if measurement0.input_metric != measurement1.input_metric {
return fallible!(MetricMismatch, "Input metric mismatch");
} else if measurement0.output_measure != measurement1.output_measure {
return fallible!(MeasureMismatch, "Output measure mismatch");
}
Ok(Measurement::new(
measurement0.input_domain.clone(),
PairDomain::new(measurement0.output_domain.clone(), measurement1.output_domain.clone()),
Function::make_basic_composition(&measurement0.function, &measurement1.function),
measurement0.input_metric.clone(),
measurement0.output_measure.clone(),
PrivacyRelation::new(|_i, _o| false),
))
}
#[cfg(test)]
mod tests {
use crate::core::*;
use crate::dist::{L1Distance, MaxDivergence};
use crate::dom::AllDomain;
use crate::error::ExplainUnwrap;
use super::*;
#[test]
fn test_make_chain_mt() {
let input_domain0 = AllDomain::<u8>::new();
let output_domain0 = AllDomain::<i32>::new();
let function0 = Function::new(|a: &u8| (a + 1) as i32);
let input_metric0 = L1Distance::<i32>::default();
let output_metric0 = L1Distance::<i32>::default();
let stability_relation0 = StabilityRelation::new_from_constant(1);
let transformation0 = Transformation::new(input_domain0, output_domain0, function0, input_metric0, output_metric0, stability_relation0);
let input_domain1 = AllDomain::<i32>::new();
let output_domain1 = AllDomain::<f64>::new();
let function1 = Function::new(|a: &i32| (a + 1) as f64);
let input_metric1 = L1Distance::<i32>::default();
let output_measure1 = MaxDivergence::default();
let privacy_relation1 = PrivacyRelation::new(|_d_in: &i32, _d_out: &f64| true);
let measurement1 = Measurement::new(input_domain1, output_domain1, function1, input_metric1, output_measure1, privacy_relation1);
let chain = make_chain_mt(&measurement1, &transformation0, None).unwrap_test();
let arg = 99_u8;
let ret = chain.invoke(&arg).unwrap_test();
assert_eq!(ret, 101.0);
}
#[test]
fn test_make_chain_tt() {
let input_domain0 = AllDomain::<u8>::new();
let output_domain0 = AllDomain::<i32>::new();
let function0 = Function::new(|a: &u8| (a + 1) as i32);
let input_metric0 = L1Distance::<i32>::default();
let output_metric0 = L1Distance::<i32>::default();
let stability_relation0 = StabilityRelation::new_from_constant(1);
let transformation0 = Transformation::new(input_domain0, output_domain0, function0, input_metric0, output_metric0, stability_relation0);
let input_domain1 = AllDomain::<i32>::new();
let output_domain1 = AllDomain::<f64>::new();
let function1 = Function::new(|a: &i32| (a + 1) as f64);
let input_metric1 = L1Distance::<i32>::default();
let output_metric1 = L1Distance::<i32>::default();
let stability_relation1 = StabilityRelation::new_from_constant(1);
let transformation1 = Transformation::new(input_domain1, output_domain1, function1, input_metric1, output_metric1, stability_relation1);
let chain = make_chain_tt(&transformation1, &transformation0, None).unwrap_test();
let arg = 99_u8;
let ret = chain.invoke(&arg).unwrap_test();
assert_eq!(ret, 101.0);
}
#[test]
fn test_make_basic_composition() {
let input_domain0 = AllDomain::<i32>::new();
let output_domain0 = AllDomain::<f32>::new();
let function0 = Function::new(|arg: &i32| (arg + 1) as f32);
let input_metric0 = L1Distance::<i32>::default();
let output_measure0 = MaxDivergence::default();
let privacy_relation0 = PrivacyRelation::new(|_d_in: &i32, _d_out: &f64| true);
let measurement0 = Measurement::new(input_domain0, output_domain0, function0, input_metric0, output_measure0, privacy_relation0);
let input_domain1 = AllDomain::<i32>::new();
let output_domain1 = AllDomain::<f64>::new();
let function1 = Function::new(|arg: &i32| (arg - 1) as f64);
let input_metric1 = L1Distance::<i32>::default();
let output_measure1 = MaxDivergence::default();
let privacy_relation1 = PrivacyRelation::new(|_d_in: &i32, _d_out: &f64| true);
let measurement1 = Measurement::new(input_domain1, output_domain1, function1, input_metric1, output_measure1, privacy_relation1);
let composition = make_basic_composition(&measurement0, &measurement1).unwrap_test();
let arg = 99;
let ret = composition.invoke(&arg).unwrap_test();
assert_eq!(ret, (100_f32, 98_f64));
}
}
impl<DI, DX, DO, MI, MX, MO> Shr<Measurement<DX, DO, MX, MO>> for Transformation<DI, DX, MI, MX>
where DI: 'static + Domain,
DX: 'static + Domain,
DO: 'static + Domain,
MI: 'static + Metric,
MX: 'static + Metric,
MO: 'static + Measure {
type Output = Fallible<Measurement<DI, DO, MI, MO>>;
fn shr(self, rhs: Measurement<DX, DO, MX, MO>) -> Self::Output {
make_chain_mt(&rhs, &self, None)
}
}
impl<DI, DX, DO, MI, MX, MO> Shr<Measurement<DX, DO, MX, MO>> for Fallible<Transformation<DI, DX, MI, MX>>
where DI: 'static + Domain,
DX: 'static + Domain,
DO: 'static + Domain,
MI: 'static + Metric,
MX: 'static + Metric,
MO: 'static + Measure {
type Output = Fallible<Measurement<DI, DO, MI, MO>>;
fn shr(self, rhs: Measurement<DX, DO, MX, MO>) -> Self::Output {
make_chain_mt(&rhs, &self?, None)
}
}
impl<DI, DX, DO, MI, MX, MO> Shr<Transformation<DX, DO, MX, MO>> for Transformation<DI, DX, MI, MX>
where DI: 'static + Domain,
DX: 'static + Domain,
DO: 'static + Domain,
MI: 'static + Metric,
MX: 'static + Metric,
MO: 'static + Metric {
type Output = Fallible<Transformation<DI, DO, MI, MO>>;
fn shr(self, rhs: Transformation<DX, DO, MX, MO>) -> Self::Output {
make_chain_tt(&rhs, &self, None)
}
}
impl<DI, DX, DO, MI, MX, MO> Shr<Transformation<DX, DO, MX, MO>> for Fallible<Transformation<DI, DX, MI, MX>>
where DI: 'static + Domain,
DX: 'static + Domain,
DO: 'static + Domain,
MI: 'static + Metric,
MX: 'static + Metric,
MO: 'static + Metric {
type Output = Fallible<Transformation<DI, DO, MI, MO>>;
fn shr(self, rhs: Transformation<DX, DO, MX, MO>) -> Self::Output {
make_chain_tt(&rhs, &self?, None)
}
}
#[cfg(test)]
mod tests_shr {
use crate::meas::geometric::make_base_geometric;
use crate::trans::{make_bounded_sum, make_cast_default, make_clamp, make_split_lines};
use super::*;
#[test]
fn test_shr() -> Fallible<()> {
(
make_split_lines()? >>
make_cast_default()? >>
make_clamp((0, 1))? >>
make_bounded_sum((0, 1))? >>
make_base_geometric(1., Some((0, 10)))?
).map(|_| ())
}
}