use std::any;
use std::any::Any;
use crate::core::{Domain, Function, Measure, Measurement, Metric, MetricSpace};
use crate::error::*;
impl<TI: 'static, TO: 'static> Function<TI, TO> {
pub fn into_poly(self) -> Function<TI, Box<dyn Any>> {
let function = move |arg: &TI| -> Fallible<Box<dyn Any>> {
let res = self.eval(arg);
res.map(|o| Box::new(o) as Box<dyn Any>)
};
Function::new_fallible(function)
}
}
impl<TI> Function<TI, Box<dyn Any>> {
pub fn eval_poly<TO: 'static>(&self, arg: &TI) -> Fallible<TO> {
self.eval(arg)?
.downcast()
.map_err(|_| {
err!(
FailedCast,
"Failed downcast of eval_poly result to {}",
any::type_name::<TO>()
)
})
.map(|res| *res)
}
}
impl<DI, MI, MO, TO> Measurement<DI, MI, MO, TO>
where
DI: 'static + Domain,
DI::Carrier: 'static,
MI: 'static + Metric,
MO: 'static + Measure,
TO: 'static,
(DI, MI): MetricSpace,
{
pub fn into_poly(self) -> Measurement<DI, MI, MO, Box<dyn Any>> {
Measurement::new(
self.input_domain.clone(),
self.input_metric.clone(),
self.output_measure.clone(),
self.function.clone().into_poly(),
self.privacy_map.clone(),
)
.expect("compatibility check already passed")
}
}
#[cfg(all(test, feature = "untrusted"))]
mod tests {
use crate::domains::AtomDomain;
use crate::error::*;
use crate::measurements;
use crate::measures::MaxDivergence;
use crate::metrics::AbsoluteDistance;
#[test]
fn test_poly_measurement() -> Fallible<()> {
let input_domain = AtomDomain::<f64>::new_non_nan();
let input_metric = AbsoluteDistance::<u32>::default();
let op_plain = measurements::make_laplace::<_, _, MaxDivergence>(
input_domain,
input_metric,
0.0,
None,
)?;
let arg = 100.;
let res_plain = op_plain.invoke(&arg)?;
assert_eq!(res_plain, arg);
let op_poly = op_plain.into_poly();
let res_poly = op_poly.function.eval_poly::<f64>(&arg)?;
assert_eq!(res_poly, arg);
let res_bogus = op_poly.function.eval_poly::<i32>(&arg);
assert_eq!(
res_bogus.err().unwrap_test().variant,
ErrorVariant::FailedCast
);
Ok(())
}
}