hamelin_datafusion 0.6.13

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! DataFusion translations for window functions.
//!
//! Note: Window functions in DataFusion are typically created via BuiltInWindowFunction
//! and applied in the context of a window definition. These translations create the
//! function expressions that will be used within window contexts.

use datafusion::common::ScalarValue;
use datafusion::logical_expr::expr::NullTreatment;
use datafusion::logical_expr::expr_fn::ExprFunctionExt;
use datafusion::logical_expr::Expr as DFExpr;
use datafusion_functions_window::expr_fn as window_fn;
use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf};
use datafusion_functions_window::nth_value::nth_value_udwf;

use hamelin_lib::func::defs::{
    CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, PercentRank, Rank, RowNumber,
};

use super::DataFusionTranslationRegistry;

/// Helper to extract a boolean literal from an expression.
/// Returns None if the expression is not a boolean literal (expression-based ignore_nulls).
fn extract_bool_literal(expr: &DFExpr) -> Option<bool> {
    match expr {
        DFExpr::Literal(ScalarValue::Boolean(Some(v)), _) => Some(*v),
        _ => None,
    }
}

/// Convert ignore_nulls boolean to NullTreatment.
/// If ignore_nulls is true, we want to IGNORE NULLS.
/// If ignore_nulls is false or not a literal, we use RESPECT NULLS (default behavior).
fn get_null_treatment(ignore_nulls: &DFExpr) -> Option<NullTreatment> {
    extract_bool_literal(ignore_nulls).map(|ignore| {
        if ignore {
            NullTreatment::IgnoreNulls
        } else {
            NullTreatment::RespectNulls
        }
    })
}

pub fn register(registry: &mut DataFusionTranslationRegistry) {
    // row_number() -> row_number()
    registry.register::<RowNumber>(|_params| Ok(window_fn::row_number()));

    // rank() -> rank()
    registry.register::<Rank>(|_params| Ok(window_fn::rank()));

    // dense_rank() -> dense_rank()
    registry.register::<DenseRank>(|_params| Ok(window_fn::dense_rank()));

    // lag(expression, offset, ignore_nulls) -> lag_udwf().call([expression, offset, NULL])
    // with null_treatment. Uses UDWF directly to support expression-based offsets.
    registry.register::<Lag>(|mut params| {
        let expression = params.take()?.expr;
        let offset = params.take()?.expr;
        let ignore_nulls = params.take()?.expr;
        let default = datafusion::logical_expr::lit(ScalarValue::Null);
        let base = lag_udwf().call(vec![expression, offset, default]);
        match get_null_treatment(&ignore_nulls) {
            Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
            None => Ok(base),
        }
    });

    // lead(expression, offset, ignore_nulls) -> lead_udwf().call([expression, offset, NULL])
    // with null_treatment. Uses UDWF directly to support expression-based offsets.
    registry.register::<Lead>(|mut params| {
        let expression = params.take()?.expr;
        let offset = params.take()?.expr;
        let ignore_nulls = params.take()?.expr;
        let default = datafusion::logical_expr::lit(ScalarValue::Null);
        let base = lead_udwf().call(vec![expression, offset, default]);
        match get_null_treatment(&ignore_nulls) {
            Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
            None => Ok(base),
        }
    });

    // first_value(expression, ignore_nulls) -> first_value(expression) with null_treatment
    registry.register::<FirstValue>(|mut params| {
        let expression = params.take()?.expr;
        let ignore_nulls = params.take()?.expr;
        let base = window_fn::first_value(expression);
        match get_null_treatment(&ignore_nulls) {
            Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
            None => Ok(base),
        }
    });

    // last_value(expression, ignore_nulls) -> last_value(expression) with null_treatment
    registry.register::<LastValue>(|mut params| {
        let expression = params.take()?.expr;
        let ignore_nulls = params.take()?.expr;
        let base = window_fn::last_value(expression);
        match get_null_treatment(&ignore_nulls) {
            Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
            None => Ok(base),
        }
    });

    // nth_value(expression, n, ignore_nulls) -> nth_value_udwf().call([expression, n])
    // with null_treatment. Uses UDWF directly to support expression-based n.
    registry.register::<NthValue>(|mut params| {
        let expression = params.take()?.expr;
        let n = params.take()?.expr;
        let ignore_nulls = params.take()?.expr;
        let base = nth_value_udwf().call(vec![expression, n]);
        match get_null_treatment(&ignore_nulls) {
            Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
            None => Ok(base),
        }
    });

    // cume_dist() -> cume_dist()
    registry.register::<CumeDist>(|_params| Ok(window_fn::cume_dist()));

    // percent_rank() -> percent_rank()
    registry.register::<PercentRank>(|_params| Ok(window_fn::percent_rank()));
}