1use opendp_derive::bootstrap;
2use polars_plan::dsl::{AggExpr, Expr, FunctionExpr};
3
4use crate::{
5 core::{Metric, MetricSpace, Transformation},
6 domains::{ExprDomain, OuterMetric, WildExprDomain},
7 error::Fallible,
8 metrics::{L0InfDistance, L01InfDistance, LInfDistance, LpDistance},
9 polars::get_disabled_features_message,
10};
11
12use super::traits::UnboundedMetric;
13
14#[cfg(feature = "ffi")]
15mod ffi;
16
17#[cfg(feature = "contrib")]
18mod expr_alias;
19
20#[cfg(feature = "contrib")]
21mod expr_binary;
22
23#[cfg(feature = "contrib")]
24mod expr_boolean_function;
25
26#[cfg(feature = "contrib")]
27mod expr_cast;
28
29#[cfg(feature = "contrib")]
30mod expr_clip;
31
32#[cfg(feature = "contrib")]
33mod expr_col;
34
35#[cfg(feature = "contrib")]
36mod expr_count;
37
38#[cfg(feature = "contrib")]
39mod expr_cut;
40
41#[cfg(feature = "contrib")]
42pub(crate) mod expr_discrete_quantile_score;
43
44#[cfg(feature = "contrib")]
45pub(crate) mod expr_drop_nan_or_null;
46
47#[cfg(feature = "contrib")]
48mod expr_fill_nan;
49
50#[cfg(feature = "contrib")]
51mod expr_fill_null;
52
53#[cfg(feature = "contrib")]
54mod expr_filter;
55
56#[cfg(feature = "contrib")]
57mod expr_len;
58
59#[cfg(feature = "contrib")]
60mod expr_lit;
61
62#[cfg(feature = "contrib")]
63pub(crate) mod expr_replace;
64
65#[cfg(feature = "contrib")]
66mod expr_replace_strict;
67
68#[cfg(feature = "contrib")]
69mod expr_sum;
70
71#[cfg(feature = "contrib")]
72mod expr_to_physical;
73
74#[cfg(feature = "contrib")]
75mod namespace_arr;
76
77#[cfg(feature = "contrib")]
78mod namespace_dt;
79
80#[cfg(feature = "contrib")]
81mod namespace_str;
82
83#[bootstrap(
84 features("contrib"),
85 arguments(output_metric(c_type = "AnyMetric *", rust_type = b"null")),
86 generics(MI(suppress), MO(suppress))
87)]
88pub fn make_stable_expr<MI: 'static + Metric, MO: 'static + Metric>(
95 input_domain: WildExprDomain,
96 input_metric: MI,
97 expr: Expr,
98) -> Fallible<Transformation<WildExprDomain, MI, ExprDomain, MO>>
99where
100 Expr: StableExpr<MI, MO>,
101 (WildExprDomain, MI): MetricSpace,
102 (ExprDomain, MO): MetricSpace,
103{
104 expr.make_stable(input_domain, input_metric)
105}
106
107pub trait StableExpr<MI: Metric, MO: Metric> {
108 fn make_stable(
109 self,
110 input_domain: WildExprDomain,
111 input_metric: MI,
112 ) -> Fallible<Transformation<WildExprDomain, MI, ExprDomain, MO>>;
113}
114
115impl<M: OuterMetric> StableExpr<M, M> for Expr
116where
117 M::InnerMetric: UnboundedMetric,
118 M::Distance: Clone,
119 (WildExprDomain, M): MetricSpace,
120 (ExprDomain, M): MetricSpace,
121{
122 fn make_stable(
123 self,
124 input_domain: WildExprDomain,
125 input_metric: M,
126 ) -> Fallible<Transformation<WildExprDomain, M, ExprDomain, M>> {
127 if expr_fill_nan::match_fill_nan(&self).is_some() {
128 return expr_fill_nan::make_expr_fill_nan(input_domain, input_metric, self);
129 }
130
131 use Expr::*;
132 use FunctionExpr::*;
133 match self {
134 #[cfg(feature = "contrib")]
135 Alias(_, _) => expr_alias::make_expr_alias(input_domain, input_metric, self),
136
137 #[cfg(feature = "contrib")]
138 Expr::BinaryExpr { .. } => {
139 expr_binary::make_expr_binary(input_domain, input_metric, self)
140 }
141
142 #[cfg(feature = "contrib")]
143 Function {
144 function: Boolean(_),
145 ..
146 } => {
147 return expr_boolean_function::make_expr_boolean_function(
148 input_domain,
149 input_metric,
150 self,
151 );
152 }
153
154 #[cfg(feature = "contrib")]
155 Cast { .. } => expr_cast::make_expr_cast(input_domain, input_metric, self),
156
157 #[cfg(feature = "contrib")]
158 Function {
159 function: Clip { .. },
160 ..
161 } => expr_clip::make_expr_clip(input_domain, input_metric, self),
162
163 #[cfg(feature = "contrib")]
164 Function {
165 function: DropNans | DropNulls,
166 ..
167 } => {
168 expr_drop_nan_or_null::make_expr_drop_nan_or_null(input_domain, input_metric, self)
169 }
170
171 #[cfg(feature = "contrib")]
172 Function {
173 function: FillNull { .. },
174 ..
175 } => expr_fill_null::make_expr_fill_null(input_domain, input_metric, self),
176
177 #[cfg(feature = "contrib")]
178 Filter { .. } => expr_filter::make_expr_filter(input_domain, input_metric, self),
179
180 #[cfg(feature = "contrib")]
181 Column(_) => expr_col::make_expr_col(input_domain, input_metric, self),
182
183 #[cfg(feature = "contrib")]
184 Function {
185 function: Cut { .. },
186 ..
187 } => expr_cut::make_expr_cut(input_domain, input_metric, self),
188
189 #[cfg(feature = "contrib")]
190 Literal(_) => expr_lit::make_expr_lit(input_domain, input_metric, self),
191
192 #[cfg(feature = "contrib")]
193 Function {
194 function: ToPhysical,
195 ..
196 } => expr_to_physical::make_expr_to_physical(input_domain, input_metric, self),
197
198 #[cfg(feature = "contrib")]
199 Function {
200 function: Replace, ..
201 } => expr_replace::make_expr_replace(input_domain, input_metric, self),
202
203 #[cfg(feature = "contrib")]
204 Function {
205 function: ReplaceStrict { .. },
206 ..
207 } => expr_replace_strict::make_expr_replace_strict(input_domain, input_metric, self),
208
209 #[cfg(feature = "contrib")]
210 Function {
211 function: FunctionExpr::ArrayExpr(_),
212 ..
213 } => namespace_arr::make_namespace_arr(input_domain, input_metric, self),
214
215 #[cfg(feature = "contrib")]
216 Function {
217 function: FunctionExpr::TemporalExpr(_),
218 ..
219 } => namespace_dt::make_namespace_dt(input_domain, input_metric, self),
220
221 #[cfg(feature = "contrib")]
222 Function {
223 function: FunctionExpr::StringExpr(_),
224 ..
225 } => namespace_str::make_namespace_str(input_domain, input_metric, self),
226
227 expr => fallible!(
228 MakeTransformation,
229 "Expr is not recognized at this time: {:?}. {}If you would like to see this supported, please file an issue.",
230 expr,
231 get_disabled_features_message()
232 ),
233 }
234 }
235}
236
237impl<MI, const P: usize> StableExpr<L01InfDistance<MI>, LpDistance<P, f64>> for Expr
238where
239 MI: 'static + UnboundedMetric,
240{
241 fn make_stable(
242 self,
243 input_domain: WildExprDomain,
244 input_metric: L01InfDistance<MI>,
245 ) -> Fallible<Transformation<WildExprDomain, L01InfDistance<MI>, ExprDomain, LpDistance<P, f64>>>
246 {
247 use Expr::*;
248 match self {
249 #[cfg(feature = "contrib")]
250 Agg(AggExpr::Count(_, _) | AggExpr::NUnique(_))
251 | Function {
252 function: FunctionExpr::NullCount,
253 ..
254 } => expr_count::make_expr_count(input_domain, input_metric, self),
255
256 #[cfg(feature = "contrib")]
257 Agg(AggExpr::Sum(_)) => expr_sum::make_expr_sum(input_domain, input_metric, self),
258
259 #[cfg(feature = "contrib")]
260 Len => expr_len::make_expr_len(input_domain, input_metric, self),
261
262 expr => fallible!(
263 MakeTransformation,
264 "Expr is not recognized at this time: {:?}. {}If you would like to see this supported, please file an issue.",
265 expr,
266 get_disabled_features_message()
267 ),
268 }
269 }
270}
271
272impl<MI> StableExpr<L01InfDistance<MI>, L0InfDistance<LInfDistance<f64>>> for Expr
273where
274 MI: 'static + UnboundedMetric,
275{
276 fn make_stable(
277 self,
278 input_domain: WildExprDomain,
279 input_metric: L01InfDistance<MI>,
280 ) -> Fallible<
281 Transformation<
282 WildExprDomain,
283 L01InfDistance<MI>,
284 ExprDomain,
285 L0InfDistance<LInfDistance<f64>>,
286 >,
287 > {
288 if expr_discrete_quantile_score::match_discrete_quantile_score(&self)?.is_some() {
289 return expr_discrete_quantile_score::make_expr_discrete_quantile_score(
290 input_domain,
291 input_metric,
292 self,
293 );
294 }
295 match self {
296 expr => fallible!(
297 MakeTransformation,
298 "Expr is not recognized at this time: {:?}. {}If you would like to see this supported, please file an issue.",
299 expr,
300 get_disabled_features_message()
301 ),
302 }
303 }
304}
305
306#[cfg(test)]
307pub mod test_helper;