opendp/transformations/make_stable_expr/expr_replace/
mod.rs1use polars::prelude::*;
2use polars_plan::dsl::Expr;
3
4use crate::core::{Function, MetricSpace, StabilityMap, Transformation};
5use crate::domains::{ExprDomain, OuterMetric, WildExprDomain};
6use crate::error::*;
7use crate::transformations::DatasetMetric;
8
9use super::StableExpr;
10
11#[cfg(test)]
12mod test;
13
14pub fn make_expr_replace<M: OuterMetric>(
21 input_domain: WildExprDomain,
22 input_metric: M,
23 expr: Expr,
24) -> Fallible<Transformation<WildExprDomain, ExprDomain, M, M>>
25where
26 M::InnerMetric: DatasetMetric,
27 M::Distance: Clone,
28 (WildExprDomain, M): MetricSpace,
29 (ExprDomain, M): MetricSpace,
30 Expr: StableExpr<M, M>,
31{
32 let Expr::Function {
33 input,
34 function: FunctionExpr::Replace,
35 options,
36 } = expr
37 else {
38 return fallible!(MakeTransformation, "expected replace expression");
39 };
40
41 let [input, old, new] = <[Expr; 3]>::try_from(input)
42 .map_err(|_| err!(MakeTransformation, "replace takes an input, old and new"))?;
43
44 let t_prior = input.make_stable(input_domain, input_metric)?;
45 let (middle_domain, middle_metric) = t_prior.output_space();
46
47 let (Expr::Literal(old_lit), Expr::Literal(new_lit)) = (&old, &new) else {
48 return fallible!(
49 MakeTransformation,
50 "replace: old and new must be literals, but found {:?} and {:?}",
51 old,
52 new
53 );
54 };
55
56 let (old_len, new_len) = (literal_len(old_lit)?, literal_len(new_lit)?);
57 if ![old_len, 1].contains(&new_len) {
58 return fallible!(
59 MakeTransformation,
60 "length of `new` ({}) must match length of `old` ({}) or be broadcastable (1)",
61 new_len,
62 old_len
63 );
64 }
65
66 let dtype = middle_domain.column.dtype();
67 if matches!(dtype, DataType::Categorical(_, _)) {
68 return fallible!(
69 MakeTransformation,
70 "replace cannot be applied to categorical data, because it may trigger a data-dependent CategoricalRemappingWarning in Polars"
71 );
72 }
73
74 let (old_dtype, new_dtype) = (old_lit.get_datatype(), new_lit.get_datatype());
75 if is_cast_fallible(&old_dtype, &dtype) || is_cast_fallible(&new_dtype, &dtype) {
76 return fallible!(
77 MakeTransformation,
78 "replace: old datatype ({}) and new datatype ({}) must be consistent with the input datatype ({})",
79 old_dtype,
80 new_dtype,
81 dtype
82 );
83 }
84
85 let mut output_domain = middle_domain.clone();
86
87 output_domain.column.set_dtype(dtype)?;
89
90 output_domain.column.nullable |= literal_is_nullable(new_lit);
92
93 if literal_is_nullable(old_lit) && !literal_is_nullable(new_lit) {
95 output_domain.column.nullable = false;
96 }
97
98 t_prior
99 >> Transformation::new(
100 middle_domain.clone(),
101 output_domain,
102 Function::then_expr(move |expr| Expr::Function {
103 input: vec![expr.clone(), old.clone(), new.clone()],
104 function: FunctionExpr::Replace,
105 options: options.clone(),
106 }),
107 middle_metric.clone(),
108 middle_metric,
109 StabilityMap::new(Clone::clone),
110 )?
111}
112
113pub(crate) fn literal_len(literal: &LiteralValue) -> Fallible<i64> {
116 Ok(match literal {
117 LiteralValue::Range { low, high, .. } => high.saturating_sub(*low),
118 LiteralValue::Series(s) => s.len() as i64,
119 l if l.is_scalar() => 1,
120 l => {
121 return fallible!(
122 MakeTransformation,
123 "unrecognized literal when determining length: {l:?}"
124 )
125 }
126 })
127}
128
129pub(crate) fn literal_is_nullable(literal: &LiteralValue) -> bool {
132 match literal {
133 LiteralValue::Series(new_series) => new_series.has_nulls(),
134 LiteralValue::Null => true,
135 _ => false,
136 }
137}
138
139pub(crate) fn is_cast_fallible(from: &DataType, to: &DataType) -> bool {
142 if let DataType::Null = from {
143 return false;
144 }
145 if let DataType::Unknown(child) = from {
146 return match child {
147 UnknownKind::Int(v) => {
148 return if let Ok(v) = i64::try_from(*v) {
149 AnyValue::Int64(v).cast(&to).is_null()
150 } else {
151 to != &DataType::UInt64
152 }
153 }
154 UnknownKind::Float => !to.is_float(),
155 UnknownKind::Str => !to.is_string(),
156 UnknownKind::Any => true,
157 };
158 }
159 from != to
160}