1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use std::any;
use std::any::Any;
use crate::core::{Domain, Function, Measure, Measurement, Metric, Transformation};
use crate::error::*;
#[derive(PartialEq, Clone)]
pub struct PolyDomain {}
impl PolyDomain {
pub fn new() -> Self { PolyDomain {} }
}
impl Domain for PolyDomain {
type Carrier = Box<dyn Any>;
fn member(&self, _val: &Self::Carrier) -> Fallible<bool> { Ok(true) }
}
impl<DI, DO> Function<DI, DO>
where DI: 'static + Domain,
DI::Carrier: 'static,
DO: 'static + Domain,
DO::Carrier: 'static {
pub fn into_poly(self) -> Function<DI, PolyDomain> {
let function = move |arg: &DI::Carrier| -> Fallible<<PolyDomain as Domain>::Carrier> {
let res = self.eval(arg);
res.map(|o| Box::new(o) as Box<dyn Any>)
};
Function::new_fallible(function)
}
}
impl<DI: Domain> Function<DI, PolyDomain> {
pub fn eval_poly<T: 'static>(&self, arg: &DI::Carrier) -> Fallible<T> {
self.eval(arg)?.downcast().map_err(|_| err!(FailedCast, "Failed downcast of eval_poly result to {}", any::type_name::<T>())).map(|res| *res)
}
}
impl<DI, DO, MI, MO> Measurement<DI, DO, MI, MO>
where DI: 'static + Domain,
DI::Carrier: 'static,
DO: 'static + Domain,
DO::Carrier: 'static,
MI: 'static + Metric,
MO: 'static + Measure {
pub fn into_poly(self) -> Measurement<DI, PolyDomain, MI, MO> {
Measurement::new(
self.input_domain,
PolyDomain::new(),
self.function.into_poly(),
self.input_metric,
self.output_measure,
self.privacy_relation,
)
}
}
impl<DI, DO, MI, MO> Transformation<DI, DO, MI, MO>
where DI: 'static + Domain,
DI::Carrier: 'static,
DO: 'static + Domain,
DO::Carrier: 'static,
MI: 'static + Metric,
MO: 'static + Metric {
pub fn into_poly(self) -> Transformation<DI, PolyDomain, MI, MO> {
Transformation::new(
self.input_domain,
PolyDomain::new(),
self.function.into_poly(),
self.input_metric,
self.output_metric,
self.stability_relation,
)
}
}
#[cfg(test)]
mod tests {
use crate::dist::SubstituteDistance;
use crate::dom::AllDomain;
use crate::error::*;
use crate::meas;
use crate::trans;
#[test]
fn test_poly_measurement() -> Fallible<()> {
let op_plain = meas::make_base_laplace::<AllDomain<_>>(0.0)?;
let arg = 99.9;
let res_plain = op_plain.function.eval(&arg)?;
assert_eq!(res_plain, arg);
let op_poly = op_plain.into_poly();
let res_poly = op_poly.function.eval_poly::<f64>(&arg)?;
assert_eq!(res_poly, arg);
let res_bogus = op_poly.function.eval_poly::<i32>(&arg);
assert_eq!(res_bogus.err().unwrap_test().variant, ErrorVariant::FailedCast);
Ok(())
}
#[test]
fn test_poly_transformation() -> Fallible<()> {
let op_plain = trans::make_identity(AllDomain::new(), SubstituteDistance::default())?;
let arg = 99.9;
let res_plain = op_plain.function.eval(&arg)?;
assert_eq!(res_plain, arg);
let op_poly = op_plain.into_poly();
let res_poly = op_poly.function.eval_poly::<f64>(&arg)?;
assert_eq!(res_poly, arg);
let res_bogus = op_poly.function.eval_poly::<i32>(&arg);
assert_eq!(res_bogus.err().unwrap_test().variant, ErrorVariant::FailedCast);
Ok(())
}
}