hamelin_datafusion 0.6.12

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! Width bucket UDF for DataFusion.
//!
//! Implements `width_bucket(x, bins)` which finds the bucket index for a value
//! given an array of bin boundaries.

use std::any::Any;
use std::sync::Arc;

use datafusion::arrow::array::{
    Array, ArrayRef, AsArray, Float32Array, Float64Array, GenericListArray, Int32Array, Int64Array,
    OffsetSizeTrait, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
    Volatility,
};

// ============================================================================
// width_bucket(x, bins): Returns the bucket index for value x given bin boundaries.
//
// For bins = [b0, b1, b2, ..., bn], returns:
//   0 if x < b0
//   1 if b0 <= x < b1
//   2 if b1 <= x < b2
//   ...
//   n if b(n-1) <= x < bn
//   n+1 if x >= bn
// ============================================================================

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct WidthBucketArrayUdf {
    signature: Signature,
}

impl Default for WidthBucketArrayUdf {
    fn default() -> Self {
        Self::new()
    }
}

impl WidthBucketArrayUdf {
    pub fn new() -> Self {
        Self {
            // Two arguments: x (numeric) and bins (array of numeric)
            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
        }
    }
}

impl ScalarUDFImpl for WidthBucketArrayUdf {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn name(&self) -> &str {
        "hamelin_width_bucket"
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(DataType::Int64)
    }

    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
        let args = args.args;
        if args.len() != 2 {
            return exec_err!(
                "width_bucket expects exactly 2 arguments, got {}",
                args.len()
            );
        }

        match (&args[0], &args[1]) {
            // Both scalar
            (ColumnarValue::Scalar(x_scalar), ColumnarValue::Scalar(bins_scalar)) => {
                let x = scalar_to_f64(x_scalar)?;
                let bins = scalar_to_bins(bins_scalar)?;
                let bucket = match bins {
                    Some(ref b) => compute_bucket(x, b),
                    None => None,
                };
                Ok(ColumnarValue::Scalar(ScalarValue::Int64(bucket)))
            }
            // x is array, bins is scalar (common case: same bins for all rows)
            (ColumnarValue::Array(x_array), ColumnarValue::Scalar(bins_scalar)) => {
                let bins = scalar_to_bins(bins_scalar)?;
                let results = compute_bucket_array(x_array, &bins)?;
                Ok(ColumnarValue::Array(results))
            }
            // x is scalar, bins is array (unusual but supported)
            (ColumnarValue::Scalar(x_scalar), ColumnarValue::Array(bins_array)) => {
                let x = scalar_to_f64(x_scalar)?;
                let results =
                    bucket_from_list_column(bins_array, |_i, bins| Ok(compute_bucket(x, bins)))?;
                Ok(ColumnarValue::Array(Arc::new(results) as ArrayRef))
            }
            // Both arrays
            (ColumnarValue::Array(x_array), ColumnarValue::Array(bins_array)) => {
                if x_array.len() != bins_array.len() {
                    return exec_err!("width_bucket x and bins arrays must have same length");
                }

                let results = bucket_from_list_column(bins_array, |i, bins| {
                    let x = get_x_f64(x_array, i)?;
                    Ok(compute_bucket(x, bins))
                })?;
                Ok(ColumnarValue::Array(Arc::new(results) as ArrayRef))
            }
        }
    }
}

/// Result of converting an array to bin boundaries.
enum BinsResult {
    /// Valid bin boundaries extracted successfully.
    Valid,
    /// The array contained a NULL element, so the result should be NULL.
    Null,
}

/// Iterate over a list column (List or LargeList), extracting bins for each row.
fn bucket_from_list_column(
    bins_array: &ArrayRef,
    f: impl FnMut(usize, &[f64]) -> Result<Option<i64>>,
) -> Result<Int64Array> {
    match bins_array.data_type() {
        DataType::List(_) => bucket_from_generic_list(bins_array.as_list::<i32>(), f),
        DataType::LargeList(_) => bucket_from_generic_list(bins_array.as_list::<i64>(), f),
        dt => exec_err!("width_bucket bins must be an array, got {:?}", dt),
    }
}

fn bucket_from_generic_list<O: OffsetSizeTrait>(
    list_array: &GenericListArray<O>,
    mut f: impl FnMut(usize, &[f64]) -> Result<Option<i64>>,
) -> Result<Int64Array> {
    let mut buf = Vec::new();
    let mut results: Vec<Option<i64>> = Vec::with_capacity(list_array.len());
    for i in 0..list_array.len() {
        if list_array.is_null(i) {
            results.push(None);
            continue;
        }
        buf.clear();
        match array_to_bins_into(&list_array.value(i), &mut buf)? {
            BinsResult::Null => results.push(None),
            BinsResult::Valid => results.push(f(i, &buf)?),
        }
    }
    Ok(results.into_iter().collect())
}

/// Convert a scalar value to f64
fn scalar_to_f64(scalar: &ScalarValue) -> Result<Option<f64>> {
    match scalar {
        ScalarValue::Int8(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::Int16(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::Int32(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::Int64(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::UInt8(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::UInt16(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::UInt32(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::UInt64(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::Float32(v) => Ok(v.map(|x| x as f64)),
        ScalarValue::Float64(v) => Ok(*v),
        ScalarValue::Null => Ok(None),
        _ => exec_err!("width_bucket x must be numeric, got {:?}", scalar),
    }
}

/// Convert a scalar list to a vector of f64 bin boundaries.
/// Returns an error if bins contain NaN or are not sorted ascending.
fn scalar_to_bins(scalar: &ScalarValue) -> Result<Option<Vec<f64>>> {
    match scalar {
        ScalarValue::List(arr) => {
            if arr.is_empty() || arr.is_null(0) {
                return Ok(None);
            }
            array_to_bins(&arr.value(0))
        }
        ScalarValue::LargeList(arr) => {
            if arr.is_empty() || arr.is_null(0) {
                return Ok(None);
            }
            array_to_bins(&arr.value(0))
        }
        ScalarValue::Null => Ok(None),
        _ => exec_err!("width_bucket bins must be an array, got {:?}", scalar),
    }
}

/// Convert an array to a vector of f64 bin boundaries (convenience wrapper).
/// Returns `Ok(None)` if any element is NULL, errors on NaN or unsorted.
fn array_to_bins(array: &dyn Array) -> Result<Option<Vec<f64>>> {
    let mut buf = Vec::new();
    match array_to_bins_into(array, &mut buf)? {
        BinsResult::Null => Ok(None),
        BinsResult::Valid => Ok(Some(buf)),
    }
}

/// Extract f64 bin boundaries from an array into `buf`, with validation.
///
/// - If any element is NULL → returns `Ok(BinsResult::Null)`
/// - If any element is NaN → returns an error
/// - If bins are not sorted ascending (equal is OK) → returns an error
fn array_to_bins_into(array: &dyn Array, buf: &mut Vec<f64>) -> Result<BinsResult> {
    macro_rules! extract_bins {
        ($arr:expr) => {{
            buf.reserve($arr.len());
            for i in 0..$arr.len() {
                if $arr.is_null(i) {
                    return Ok(BinsResult::Null);
                }
                let v = $arr.value(i) as f64;
                if v.is_nan() {
                    return exec_err!("width_bucket bin boundaries must not contain NaN");
                }
                buf.push(v);
            }
        }};
    }

    if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
        extract_bins!(arr);
    } else if let Some(arr) = array.as_any().downcast_ref::<Float32Array>() {
        extract_bins!(arr);
    } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
        extract_bins!(arr);
    } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
        extract_bins!(arr);
    } else if let Some(arr) = array.as_any().downcast_ref::<UInt64Array>() {
        extract_bins!(arr);
    } else if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
        extract_bins!(arr);
    } else if let Some(arr) = array.as_any().downcast_ref::<UInt16Array>() {
        extract_bins!(arr);
    } else if let Some(arr) = array.as_any().downcast_ref::<UInt8Array>() {
        extract_bins!(arr);
    } else {
        return exec_err!(
            "width_bucket bins must be a numeric array, got {:?}",
            array.data_type()
        );
    }

    // Validate sorted ascending (equal is OK)
    for w in buf.windows(2) {
        if w[0] > w[1] {
            return exec_err!(
                "width_bucket bin boundaries must be sorted in ascending order, \
                 found {} followed by {}",
                w[0],
                w[1]
            );
        }
    }

    Ok(BinsResult::Valid)
}

/// Get a single f64 value from an array at the given index.
fn get_x_f64(array: &ArrayRef, i: usize) -> Result<Option<f64>> {
    if array.is_null(i) {
        return Ok(None);
    }
    if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
        return Ok(Some(arr.value(i)));
    }
    if let Some(arr) = array.as_any().downcast_ref::<Float32Array>() {
        return Ok(Some(arr.value(i) as f64));
    }
    if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
        return Ok(Some(arr.value(i) as f64));
    }
    if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
        return Ok(Some(arr.value(i) as f64));
    }
    if let Some(arr) = array.as_any().downcast_ref::<UInt64Array>() {
        return Ok(Some(arr.value(i) as f64));
    }
    if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
        return Ok(Some(arr.value(i) as f64));
    }
    if let Some(arr) = array.as_any().downcast_ref::<UInt16Array>() {
        return Ok(Some(arr.value(i) as f64));
    }
    if let Some(arr) = array.as_any().downcast_ref::<UInt8Array>() {
        return Ok(Some(arr.value(i) as f64));
    }
    exec_err!(
        "width_bucket x must be numeric array, got {:?}",
        array.data_type()
    )
}

/// Compute bucket array from x array with shared bins
fn compute_bucket_array(x_array: &ArrayRef, bins: &Option<Vec<f64>>) -> Result<ArrayRef> {
    let bins = match bins {
        Some(b) => b,
        None => {
            let results: Int64Array = (0..x_array.len()).map(|_| None::<i64>).collect();
            return Ok(Arc::new(results) as ArrayRef);
        }
    };

    macro_rules! map_bucket {
        ($arr:expr) => {{
            let results: Int64Array = $arr
                .iter()
                .map(|v| compute_bucket(v.map(|x| x as f64), bins))
                .collect();
            return Ok(Arc::new(results) as ArrayRef);
        }};
    }

    if let Some(arr) = x_array.as_any().downcast_ref::<Int64Array>() {
        map_bucket!(arr);
    }
    if let Some(arr) = x_array.as_any().downcast_ref::<Float64Array>() {
        let results: Int64Array = arr.iter().map(|v| compute_bucket(v, bins)).collect();
        return Ok(Arc::new(results) as ArrayRef);
    }
    if let Some(arr) = x_array.as_any().downcast_ref::<Int32Array>() {
        map_bucket!(arr);
    }
    if let Some(arr) = x_array.as_any().downcast_ref::<Float32Array>() {
        map_bucket!(arr);
    }
    if let Some(arr) = x_array.as_any().downcast_ref::<UInt64Array>() {
        map_bucket!(arr);
    }
    if let Some(arr) = x_array.as_any().downcast_ref::<UInt32Array>() {
        map_bucket!(arr);
    }
    if let Some(arr) = x_array.as_any().downcast_ref::<UInt16Array>() {
        map_bucket!(arr);
    }
    if let Some(arr) = x_array.as_any().downcast_ref::<UInt8Array>() {
        map_bucket!(arr);
    }

    exec_err!(
        "width_bucket x must be numeric array, got {:?}",
        x_array.data_type()
    )
}

/// Compute the bucket index for a value given bin boundaries.
///
/// Returns:
///   0 if x < bins[0]
///   i if bins[i-1] <= x < bins[i]
///   len(bins) if x >= bins[last]
///   len(bins) if x is NaN (overflow bucket, PostgreSQL semantics)
fn compute_bucket(x: Option<f64>, bins: &[f64]) -> Option<i64> {
    let x = x?;
    if bins.is_empty() {
        return Some(0);
    }

    // NaN goes to overflow bucket (PostgreSQL: NaN > all non-NaN values)
    if x.is_nan() {
        return Some(bins.len() as i64);
    }

    // Binary search to find the bucket
    // We want the largest i such that bins[i] <= x
    match bins.binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal))
    {
        Ok(i) => Some((i + 1) as i64), // Exact match, belongs to bucket i+1
        Err(i) => Some(i as i64),      // x falls between bins[i-1] and bins[i], bucket is i
    }
}

pub fn width_bucket_array_udf() -> ScalarUDF {
    ScalarUDF::new_from_impl(WidthBucketArrayUdf::new())
}