Skip to main content

opendp/transformations/make_stable_lazyframe/filter/
mod.rs

1use crate::core::{Function, StabilityMap, Transformation};
2use crate::domains::{Context, DslPlanDomain, WildExprDomain};
3use crate::error::*;
4use crate::metrics::FrameDistance;
5use crate::transformations::StableExpr;
6use crate::transformations::traits::UnboundedMetric;
7use polars::prelude::*;
8
9use super::StableDslPlan;
10
11#[cfg(test)]
12mod test;
13
14/// Transformation for creating a stable LazyFrame filter.
15///
16/// # Arguments
17/// * `input_domain` - The domain of the input LazyFrame.
18/// * `input_metric` - The metric of the input LazyFrame.
19/// * `plan` - The LazyFrame to transform.
20pub fn make_stable_filter<MI: UnboundedMetric, MO: UnboundedMetric>(
21    input_domain: DslPlanDomain,
22    input_metric: FrameDistance<MI>,
23    plan: DslPlan,
24) -> Fallible<Transformation<DslPlanDomain, FrameDistance<MI>, DslPlanDomain, FrameDistance<MO>>>
25where
26    DslPlan: StableDslPlan<FrameDistance<MI>, FrameDistance<MO>>,
27{
28    let DslPlan::Filter { input, predicate } = plan else {
29        return fallible!(MakeTransformation, "Expected filter in logical plan");
30    };
31
32    let t_prior = input
33        .as_ref()
34        .clone()
35        .make_stable(input_domain, input_metric)?;
36    let (middle_domain, middle_metric) = t_prior.output_space();
37
38    let expr_domain = WildExprDomain {
39        columns: middle_domain.series_domains.clone(),
40        context: Context::RowByRow,
41    };
42
43    let mut output_domain = middle_domain.clone();
44    let t_pred = predicate
45        .clone()
46        .make_stable(expr_domain, middle_metric.clone())?;
47
48    let pred_dtype = t_pred.output_domain.column.dtype();
49
50    if !pred_dtype.is_bool() {
51        return fallible!(
52            MakeTransformation,
53            "Expected predicate to return a boolean value, got: {:?}",
54            pred_dtype
55        );
56    }
57    let function = t_pred.function.clone();
58
59    output_domain.margins.iter_mut().for_each(|m| {
60        // After filtering you no longer know partition lengths or keys.
61        m.invariant = None;
62    });
63
64    t_prior
65        >> Transformation::new(
66            middle_domain,
67            middle_metric.clone(),
68            output_domain,
69            middle_metric,
70            Function::new_fallible(move |plan: &DslPlan| {
71                Ok(DslPlan::Filter {
72                    input: Arc::new(plan.clone()),
73                    predicate: function.eval(plan)?.expr,
74                })
75            }),
76            StabilityMap::new(Clone::clone),
77        )?
78}