opendp/transformations/make_stable_expr/expr_replace/
mod.rs

1use polars::prelude::*;
2use polars_plan::dsl::Expr;
3
4use crate::core::{Function, MetricSpace, StabilityMap, Transformation};
5use crate::domains::{ExprDomain, OuterMetric, WildExprDomain};
6use crate::error::*;
7use crate::transformations::DatasetMetric;
8
9use super::StableExpr;
10
11#[cfg(test)]
12mod test;
13
14/// Make a Transformation that returns a `replace(old, new)` expression for a LazyFrame.
15///
16/// # Arguments
17/// * `input_domain` - Expr domain
18/// * `input_metric` - The metric under which neighboring LazyFrames are compared
19/// * `expr` - The replace expression
20pub fn make_expr_replace<M: OuterMetric>(
21    input_domain: WildExprDomain,
22    input_metric: M,
23    expr: Expr,
24) -> Fallible<Transformation<WildExprDomain, ExprDomain, M, M>>
25where
26    M::InnerMetric: DatasetMetric,
27    M::Distance: Clone,
28    (WildExprDomain, M): MetricSpace,
29    (ExprDomain, M): MetricSpace,
30    Expr: StableExpr<M, M>,
31{
32    let Expr::Function {
33        input,
34        function: FunctionExpr::Replace,
35        options,
36    } = expr
37    else {
38        return fallible!(MakeTransformation, "expected replace expression");
39    };
40
41    let [input, old, new] = <[Expr; 3]>::try_from(input)
42        .map_err(|_| err!(MakeTransformation, "replace takes an input, old and new"))?;
43
44    let t_prior = input.make_stable(input_domain, input_metric)?;
45    let (middle_domain, middle_metric) = t_prior.output_space();
46
47    let (Expr::Literal(old_lit), Expr::Literal(new_lit)) = (&old, &new) else {
48        return fallible!(
49            MakeTransformation,
50            "replace: old and new must be literals, but found {:?} and {:?}",
51            old,
52            new
53        );
54    };
55
56    let (old_len, new_len) = (literal_len(old_lit)?, literal_len(new_lit)?);
57    if ![old_len, 1].contains(&new_len) {
58        return fallible!(
59            MakeTransformation,
60            "length of `new` ({}) must match length of `old` ({}) or be broadcastable (1)",
61            new_len,
62            old_len
63        );
64    }
65
66    let dtype = middle_domain.column.dtype();
67    if matches!(dtype, DataType::Categorical(_, _)) {
68        return fallible!(
69            MakeTransformation,
70            "replace cannot be applied to categorical data, because it may trigger a data-dependent CategoricalRemappingWarning in Polars"
71        );
72    }
73
74    let (old_dtype, new_dtype) = (old_lit.get_datatype(), new_lit.get_datatype());
75    if is_cast_fallible(&old_dtype, &dtype) || is_cast_fallible(&new_dtype, &dtype) {
76        return fallible!(
77            MakeTransformation,
78            "replace: old datatype ({}) and new datatype ({}) must be consistent with the input datatype ({})",
79            old_dtype,
80            new_dtype,
81            dtype
82        );
83    }
84
85    let mut output_domain = middle_domain.clone();
86
87    // reset domain descriptors
88    output_domain.column.set_dtype(dtype)?;
89
90    // if replacement can introduce nulls, then set nullable
91    output_domain.column.nullable |= literal_is_nullable(new_lit);
92
93    // if old has null and new does not, then there is a non-null null replacement
94    if literal_is_nullable(old_lit) && !literal_is_nullable(new_lit) {
95        output_domain.column.nullable = false;
96    }
97
98    t_prior
99        >> Transformation::new(
100            middle_domain.clone(),
101            output_domain,
102            Function::then_expr(move |expr| Expr::Function {
103                input: vec![expr.clone(), old.clone(), new.clone()],
104                function: FunctionExpr::Replace,
105                options: options.clone(),
106            }),
107            middle_metric.clone(),
108            middle_metric,
109            StabilityMap::new(Clone::clone),
110        )?
111}
112
113/// # Proof Definition
114/// Returns the length of a literal value.
115pub(crate) fn literal_len(literal: &LiteralValue) -> Fallible<i64> {
116    Ok(match literal {
117        LiteralValue::Range { low, high, .. } => high.saturating_sub(*low),
118        LiteralValue::Series(s) => s.len() as i64,
119        l if l.is_scalar() => 1,
120        l => {
121            return fallible!(
122                MakeTransformation,
123                "unrecognized literal when determining length: {l:?}"
124            )
125        }
126    })
127}
128
129/// # Proof Definition
130/// Returns whether a literal value contains null.
131pub(crate) fn literal_is_nullable(literal: &LiteralValue) -> bool {
132    match literal {
133        LiteralValue::Series(new_series) => new_series.has_nulls(),
134        LiteralValue::Null => true,
135        _ => false,
136    }
137}
138
139/// # Proof Definition
140/// Returns whether casting is fallible between two data types.
141pub(crate) fn is_cast_fallible(from: &DataType, to: &DataType) -> bool {
142    if let DataType::Null = from {
143        return false;
144    }
145    if let DataType::Unknown(child) = from {
146        return match child {
147            UnknownKind::Int(v) => {
148                return if let Ok(v) = i64::try_from(*v) {
149                    AnyValue::Int64(v).cast(&to).is_null()
150                } else {
151                    to != &DataType::UInt64
152                }
153            }
154            UnknownKind::Float => !to.is_float(),
155            UnknownKind::Str => !to.is_string(),
156            UnknownKind::Any => true,
157        };
158    }
159    from != to
160}