use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::expr::{Case as DFCase, Cast};
use datafusion::logical_expr::{lit, Expr as DFExpr, Operator};
use datafusion_functions::math::expr_fn as math_fn;
use hamelin_lib::func::defs::{
Abs, Cbrt, Ceil, Degrees, Euler, Exp, Floor, Ln, Log, Log10, Log2, Pi, Pow, Radians, Round1,
Round2, Sign, Sqrt, Truncate, WidthBucket2, WidthBucket4,
};
use super::DataFusionTranslationRegistry;
pub fn register(registry: &mut DataFusionTranslationRegistry) {
registry.register::<Abs>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::abs(x))
});
registry.register::<Cbrt>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::cbrt(x))
});
registry.register::<Ceil>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::ceil(x))
});
registry.register::<Degrees>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::degrees(x))
});
registry.register::<Euler>(|_params| Ok(datafusion::logical_expr::lit(std::f64::consts::E)));
registry.register::<Exp>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::exp(x))
});
registry.register::<Floor>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::floor(x))
});
registry.register::<Ln>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::ln(x))
});
registry.register::<Log>(|mut params| {
let base = params.take()?.expr;
let x = params.take()?.expr;
Ok(math_fn::log(base, x))
});
registry.register::<Log10>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::log10(x))
});
registry.register::<Log2>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::log2(x))
});
registry.register::<Pi>(|_params| Ok(math_fn::pi()));
registry.register::<Pow>(|mut params| {
let x = params.take()?.expr;
let p = params.take()?.expr;
Ok(math_fn::power(x, p))
});
registry.register::<Radians>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::radians(x))
});
registry.register::<Round1>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::round(vec![x]))
});
registry.register::<Round2>(|mut params| {
let x = params.take()?.expr;
let d = params.take()?.expr;
Ok(math_fn::round(vec![x, d]))
});
registry.register::<Sign>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::signum(x))
});
registry.register::<Sqrt>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::sqrt(x))
});
registry.register::<Truncate>(|mut params| {
let x = params.take()?.expr;
Ok(math_fn::trunc(vec![x]))
});
registry.register::<WidthBucket4>(|mut params| {
let x = params.take()?.expr;
let bound1 = params.take()?.expr;
let bound2 = params.take()?.expr;
let n = params.take()?.expr;
let cond_below = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(x.clone()),
op: Operator::Lt,
right: Box::new(bound1.clone()),
});
let cond_above = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(x.clone()),
op: Operator::GtEq,
right: Box::new(bound2.clone()),
});
let n_plus_one = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(n.clone()),
op: Operator::Plus,
right: Box::new(lit(1i64)),
});
let x_minus_bound1 = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(x),
op: Operator::Minus,
right: Box::new(bound1.clone()),
});
let range = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(bound2),
op: Operator::Minus,
right: Box::new(bound1),
});
let x_minus_bound1_f64 =
DFExpr::Cast(Cast::new(Box::new(x_minus_bound1), DataType::Float64));
let normalized = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(x_minus_bound1_f64),
op: Operator::Divide,
right: Box::new(range),
});
let scaled = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(normalized),
op: Operator::Multiply,
right: Box::new(n),
});
let bucket = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(math_fn::floor(scaled)),
op: Operator::Plus,
right: Box::new(lit(1i64)),
});
Ok(DFExpr::Case(DFCase {
expr: None,
when_then_expr: vec![
(Box::new(cond_below), Box::new(lit(0i64))),
(Box::new(cond_above), Box::new(n_plus_one)),
],
else_expr: Some(Box::new(bucket)),
}))
});
registry.register::<WidthBucket2>(|mut params| {
let x = params.take()?.expr;
let bins = params.take()?.expr;
Ok(crate::udf::width_bucket_array_udf().call(vec![x, bins]))
});
}