use polars::chunked_array::cast::CastOptions;
use polars::prelude::*;
use polars_plan::dsl::Expr;
use crate::core::{Function, MetricSpace, StabilityMap, Transformation};
use crate::domains::{ExprDomain, OuterMetric, WildExprDomain};
use crate::error::*;
use crate::metrics::MicrodataMetric;
use super::StableExpr;
#[cfg(test)]
mod test;
pub fn make_expr_cast<M: OuterMetric>(
input_domain: WildExprDomain,
input_metric: M,
expr: Expr,
) -> Fallible<Transformation<WildExprDomain, M, ExprDomain, M>>
where
M::InnerMetric: MicrodataMetric,
M::Distance: Clone,
(WildExprDomain, M): MetricSpace,
(ExprDomain, M): MetricSpace,
Expr: StableExpr<M, M>,
{
let Expr::Cast {
expr: input,
dtype: to_type,
mut options,
} = expr
else {
return fallible!(MakeTransformation, "expected cast expression");
};
if matches!(options, CastOptions::Strict) {
options = CastOptions::NonStrict;
}
let t_prior = input
.as_ref()
.clone()
.make_stable(input_domain.clone(), input_metric.clone())?;
let (middle_domain, middle_metric) = t_prior.output_space();
let mut output_domain = middle_domain.clone();
let data_column = &mut output_domain.column;
let to_type_dtype = to_type
.as_literal()
.ok_or_else(|| {
err!(
MakeTransformation,
"cast expression only supports literal dtype"
)
})?
.clone();
data_column.set_dtype(to_type_dtype.clone())?;
t_prior
>> Transformation::new(
middle_domain.clone(),
middle_metric.clone(),
output_domain,
middle_metric,
Function::then_expr(move |expr| Expr::Cast {
expr: Arc::new(expr),
dtype: to_type.clone(),
options,
}),
StabilityMap::new(Clone::clone),
)?
}