Skip to main content

opendp/transformations/make_stable_expr/
mod.rs

1use opendp_derive::bootstrap;
2use polars_plan::dsl::{AggExpr, Expr, FunctionExpr};
3
4use crate::{
5    core::{Metric, MetricSpace, Transformation},
6    domains::{ExprDomain, OuterMetric, WildExprDomain},
7    error::Fallible,
8    metrics::{L0InfDistance, L01InfDistance, LInfDistance, LpDistance},
9    polars::get_disabled_features_message,
10};
11
12use super::traits::UnboundedMetric;
13
14#[cfg(feature = "ffi")]
15mod ffi;
16
17#[cfg(feature = "contrib")]
18mod expr_alias;
19
20#[cfg(feature = "contrib")]
21mod expr_binary;
22
23#[cfg(feature = "contrib")]
24mod expr_boolean_function;
25
26#[cfg(feature = "contrib")]
27mod expr_cast;
28
29#[cfg(feature = "contrib")]
30mod expr_clip;
31
32#[cfg(feature = "contrib")]
33mod expr_col;
34
35#[cfg(feature = "contrib")]
36mod expr_count;
37
38#[cfg(feature = "contrib")]
39mod expr_cut;
40
41#[cfg(feature = "contrib")]
42pub(crate) mod expr_discrete_quantile_score;
43
44#[cfg(feature = "contrib")]
45pub(crate) mod expr_drop_nan_or_null;
46
47#[cfg(feature = "contrib")]
48mod expr_fill_nan;
49
50#[cfg(feature = "contrib")]
51mod expr_fill_null;
52
53#[cfg(feature = "contrib")]
54mod expr_filter;
55
56#[cfg(feature = "contrib")]
57mod expr_len;
58
59#[cfg(feature = "contrib")]
60mod expr_lit;
61
62#[cfg(feature = "contrib")]
63pub(crate) mod expr_replace;
64
65#[cfg(feature = "contrib")]
66mod expr_replace_strict;
67
68#[cfg(feature = "contrib")]
69mod expr_sum;
70
71#[cfg(feature = "contrib")]
72mod expr_to_physical;
73
74#[cfg(feature = "contrib")]
75mod namespace_arr;
76
77#[cfg(feature = "contrib")]
78mod namespace_dt;
79
80#[cfg(feature = "contrib")]
81mod namespace_str;
82
83#[bootstrap(
84    features("contrib"),
85    arguments(output_metric(c_type = "AnyMetric *", rust_type = b"null")),
86    generics(MI(suppress), MO(suppress))
87)]
88/// Create a stable transformation from an [`Expr`].
89///
90/// # Arguments
91/// * `input_domain` - The domain of the input data.
92/// * `input_metric` - How to measure distances between neighboring input data sets.
93/// * `expr` - The expression to be analyzed for stability.
94pub fn make_stable_expr<MI: 'static + Metric, MO: 'static + Metric>(
95    input_domain: WildExprDomain,
96    input_metric: MI,
97    expr: Expr,
98) -> Fallible<Transformation<WildExprDomain, MI, ExprDomain, MO>>
99where
100    Expr: StableExpr<MI, MO>,
101    (WildExprDomain, MI): MetricSpace,
102    (ExprDomain, MO): MetricSpace,
103{
104    expr.make_stable(input_domain, input_metric)
105}
106
107pub trait StableExpr<MI: Metric, MO: Metric> {
108    fn make_stable(
109        self,
110        input_domain: WildExprDomain,
111        input_metric: MI,
112    ) -> Fallible<Transformation<WildExprDomain, MI, ExprDomain, MO>>;
113}
114
115impl<M: OuterMetric> StableExpr<M, M> for Expr
116where
117    M::InnerMetric: UnboundedMetric,
118    M::Distance: Clone,
119    (WildExprDomain, M): MetricSpace,
120    (ExprDomain, M): MetricSpace,
121{
122    fn make_stable(
123        self,
124        input_domain: WildExprDomain,
125        input_metric: M,
126    ) -> Fallible<Transformation<WildExprDomain, M, ExprDomain, M>> {
127        if expr_fill_nan::match_fill_nan(&self).is_some() {
128            return expr_fill_nan::make_expr_fill_nan(input_domain, input_metric, self);
129        }
130
131        use Expr::*;
132        use FunctionExpr::*;
133        match self {
134            #[cfg(feature = "contrib")]
135            Alias(_, _) => expr_alias::make_expr_alias(input_domain, input_metric, self),
136
137            #[cfg(feature = "contrib")]
138            Expr::BinaryExpr { .. } => {
139                expr_binary::make_expr_binary(input_domain, input_metric, self)
140            }
141
142            #[cfg(feature = "contrib")]
143            Function {
144                function: Boolean(_),
145                ..
146            } => {
147                return expr_boolean_function::make_expr_boolean_function(
148                    input_domain,
149                    input_metric,
150                    self,
151                );
152            }
153
154            #[cfg(feature = "contrib")]
155            Cast { .. } => expr_cast::make_expr_cast(input_domain, input_metric, self),
156
157            #[cfg(feature = "contrib")]
158            Function {
159                function: Clip { .. },
160                ..
161            } => expr_clip::make_expr_clip(input_domain, input_metric, self),
162
163            #[cfg(feature = "contrib")]
164            Function {
165                function: DropNans | DropNulls,
166                ..
167            } => {
168                expr_drop_nan_or_null::make_expr_drop_nan_or_null(input_domain, input_metric, self)
169            }
170
171            #[cfg(feature = "contrib")]
172            Function {
173                function: FillNull { .. },
174                ..
175            } => expr_fill_null::make_expr_fill_null(input_domain, input_metric, self),
176
177            #[cfg(feature = "contrib")]
178            Filter { .. } => expr_filter::make_expr_filter(input_domain, input_metric, self),
179
180            #[cfg(feature = "contrib")]
181            Column(_) => expr_col::make_expr_col(input_domain, input_metric, self),
182
183            #[cfg(feature = "contrib")]
184            Function {
185                function: Cut { .. },
186                ..
187            } => expr_cut::make_expr_cut(input_domain, input_metric, self),
188
189            #[cfg(feature = "contrib")]
190            Literal(_) => expr_lit::make_expr_lit(input_domain, input_metric, self),
191
192            #[cfg(feature = "contrib")]
193            Function {
194                function: ToPhysical,
195                ..
196            } => expr_to_physical::make_expr_to_physical(input_domain, input_metric, self),
197
198            #[cfg(feature = "contrib")]
199            Function {
200                function: Replace, ..
201            } => expr_replace::make_expr_replace(input_domain, input_metric, self),
202
203            #[cfg(feature = "contrib")]
204            Function {
205                function: ReplaceStrict { .. },
206                ..
207            } => expr_replace_strict::make_expr_replace_strict(input_domain, input_metric, self),
208
209            #[cfg(feature = "contrib")]
210            Function {
211                function: FunctionExpr::ArrayExpr(_),
212                ..
213            } => namespace_arr::make_namespace_arr(input_domain, input_metric, self),
214
215            #[cfg(feature = "contrib")]
216            Function {
217                function: FunctionExpr::TemporalExpr(_),
218                ..
219            } => namespace_dt::make_namespace_dt(input_domain, input_metric, self),
220
221            #[cfg(feature = "contrib")]
222            Function {
223                function: FunctionExpr::StringExpr(_),
224                ..
225            } => namespace_str::make_namespace_str(input_domain, input_metric, self),
226
227            expr => fallible!(
228                MakeTransformation,
229                "Expr is not recognized at this time: {:?}. {}If you would like to see this supported, please file an issue.",
230                expr,
231                get_disabled_features_message()
232            ),
233        }
234    }
235}
236
237impl<MI, const P: usize> StableExpr<L01InfDistance<MI>, LpDistance<P, f64>> for Expr
238where
239    MI: 'static + UnboundedMetric,
240{
241    fn make_stable(
242        self,
243        input_domain: WildExprDomain,
244        input_metric: L01InfDistance<MI>,
245    ) -> Fallible<Transformation<WildExprDomain, L01InfDistance<MI>, ExprDomain, LpDistance<P, f64>>>
246    {
247        use Expr::*;
248        match self {
249            #[cfg(feature = "contrib")]
250            Agg(AggExpr::Count(_, _) | AggExpr::NUnique(_))
251            | Function {
252                function: FunctionExpr::NullCount,
253                ..
254            } => expr_count::make_expr_count(input_domain, input_metric, self),
255
256            #[cfg(feature = "contrib")]
257            Agg(AggExpr::Sum(_)) => expr_sum::make_expr_sum(input_domain, input_metric, self),
258
259            #[cfg(feature = "contrib")]
260            Len => expr_len::make_expr_len(input_domain, input_metric, self),
261
262            expr => fallible!(
263                MakeTransformation,
264                "Expr is not recognized at this time: {:?}. {}If you would like to see this supported, please file an issue.",
265                expr,
266                get_disabled_features_message()
267            ),
268        }
269    }
270}
271
272impl<MI> StableExpr<L01InfDistance<MI>, L0InfDistance<LInfDistance<f64>>> for Expr
273where
274    MI: 'static + UnboundedMetric,
275{
276    fn make_stable(
277        self,
278        input_domain: WildExprDomain,
279        input_metric: L01InfDistance<MI>,
280    ) -> Fallible<
281        Transformation<
282            WildExprDomain,
283            L01InfDistance<MI>,
284            ExprDomain,
285            L0InfDistance<LInfDistance<f64>>,
286        >,
287    > {
288        if expr_discrete_quantile_score::match_discrete_quantile_score(&self)?.is_some() {
289            return expr_discrete_quantile_score::make_expr_discrete_quantile_score(
290                input_domain,
291                input_metric,
292                self,
293            );
294        }
295        match self {
296            expr => fallible!(
297                MakeTransformation,
298                "Expr is not recognized at this time: {:?}. {}If you would like to see this supported, please file an issue.",
299                expr,
300                get_disabled_features_message()
301            ),
302        }
303    }
304}
305
306#[cfg(test)]
307pub mod test_helper;