opendp/transformations/make_stable_lazyframe/filter/
mod.rs1use crate::core::{Function, StabilityMap, Transformation};
2use crate::domains::{Context, DslPlanDomain, WildExprDomain};
3use crate::error::*;
4use crate::metrics::FrameDistance;
5use crate::transformations::StableExpr;
6use crate::transformations::traits::UnboundedMetric;
7use polars::prelude::*;
8
9use super::StableDslPlan;
10
11#[cfg(test)]
12mod test;
13
14pub fn make_stable_filter<MI: UnboundedMetric, MO: UnboundedMetric>(
21 input_domain: DslPlanDomain,
22 input_metric: FrameDistance<MI>,
23 plan: DslPlan,
24) -> Fallible<Transformation<DslPlanDomain, FrameDistance<MI>, DslPlanDomain, FrameDistance<MO>>>
25where
26 DslPlan: StableDslPlan<FrameDistance<MI>, FrameDistance<MO>>,
27{
28 let DslPlan::Filter { input, predicate } = plan else {
29 return fallible!(MakeTransformation, "Expected filter in logical plan");
30 };
31
32 let t_prior = input
33 .as_ref()
34 .clone()
35 .make_stable(input_domain, input_metric)?;
36 let (middle_domain, middle_metric) = t_prior.output_space();
37
38 let expr_domain = WildExprDomain {
39 columns: middle_domain.series_domains.clone(),
40 context: Context::RowByRow,
41 };
42
43 let mut output_domain = middle_domain.clone();
44 let t_pred = predicate
45 .clone()
46 .make_stable(expr_domain, middle_metric.clone())?;
47
48 let pred_dtype = t_pred.output_domain.column.dtype();
49
50 if !pred_dtype.is_bool() {
51 return fallible!(
52 MakeTransformation,
53 "Expected predicate to return a boolean value, got: {:?}",
54 pred_dtype
55 );
56 }
57 let function = t_pred.function.clone();
58
59 output_domain.margins.iter_mut().for_each(|m| {
60 m.invariant = None;
62 });
63
64 t_prior
65 >> Transformation::new(
66 middle_domain,
67 middle_metric.clone(),
68 output_domain,
69 middle_metric,
70 Function::new_fallible(move |plan: &DslPlan| {
71 Ok(DslPlan::Filter {
72 input: Arc::new(plan.clone()),
73 predicate: function.eval(plan)?.expr,
74 })
75 }),
76 StabilityMap::new(Clone::clone),
77 )?
78}