hamelin_datafusion 0.7.8

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! Sliding window array_agg UDAF for DataFusion.
//!
//! Implements `array_agg` with `retract_batch` support for sliding window frames.
//! DataFusion's built-in `array_agg` doesn't support sliding windows because it
//! doesn't implement `retract_batch`.

use std::any::Any;
use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef, ListArray};
use datafusion::arrow::buffer::OffsetBuffer;
use datafusion::arrow::compute::concat;
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{Result, ScalarValue};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::{
    Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
};

// ============================================================================
// sliding_array_agg(value) -> Array<value_type>
// Like array_agg but supports sliding windows via retract_batch.
// ============================================================================

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SlidingArrayAggUdaf {
    signature: Signature,
}

impl Default for SlidingArrayAggUdaf {
    fn default() -> Self {
        Self::new()
    }
}

impl SlidingArrayAggUdaf {
    pub fn new() -> Self {
        Self {
            signature: Signature::any(1, Volatility::Immutable),
        }
    }
}

impl AggregateUDFImpl for SlidingArrayAggUdaf {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn name(&self) -> &str {
        "hamelin_sliding_array_agg"
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
        let element_type = arg_types[0].clone();
        Ok(DataType::List(Arc::new(Field::new(
            "item",
            element_type,
            true,
        ))))
    }

    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Arc<Field>>> {
        // State is a list of the element type
        let element_type = args.input_fields[0].data_type().clone();
        Ok(vec![Arc::new(Field::new_list(
            "values",
            Field::new("item", element_type, true),
            true,
        ))])
    }

    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
        let element_type = acc_args.exprs[0]
            .data_type(acc_args.schema)
            .map_err(|e| datafusion::common::DataFusionError::External(Box::new(e)))?;

        Ok(Box::new(SlidingArrayAggAccumulator::new(element_type)))
    }
}

/// Create the sliding_array_agg UDAF.
pub fn sliding_array_agg_udaf() -> AggregateUDF {
    AggregateUDF::from(SlidingArrayAggUdaf::new())
}

// ============================================================================
// Accumulator implementation
// ============================================================================

#[derive(Debug)]
struct SlidingArrayAggAccumulator {
    /// Buffered array chunks, consolidated lazily to avoid O(n²) concat.
    chunks: Vec<ArrayRef>,
    /// Sum of all chunk lengths (avoids recomputing).
    total_len: usize,
    /// Element type for the output array.
    element_type: DataType,
}

impl SlidingArrayAggAccumulator {
    fn new(element_type: DataType) -> Self {
        Self {
            chunks: Vec::new(),
            total_len: 0,
            element_type,
        }
    }

    /// Helper to create an empty array of the element type.
    fn empty_array(&self) -> Result<ArrayRef> {
        Ok(datafusion::arrow::array::new_empty_array(
            &self.element_type,
        ))
    }

    /// Consolidate all buffered chunks into a single array.
    fn consolidate(&self) -> Result<ArrayRef> {
        match self.chunks.len() {
            0 => self.empty_array(),
            1 => Ok(self.chunks[0].clone()),
            _ => {
                let refs: Vec<&dyn Array> = self.chunks.iter().map(|a| a.as_ref()).collect();
                Ok(concat(&refs)?)
            }
        }
    }
}

impl Accumulator for SlidingArrayAggAccumulator {
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() || values[0].len() == 0 {
            return Ok(());
        }

        self.total_len += values[0].len();
        self.chunks.push(values[0].clone());

        Ok(())
    }

    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() || values[0].len() == 0 {
            return Ok(());
        }

        let retract_count = values[0].len();

        if retract_count > self.total_len {
            return Err(datafusion::common::DataFusionError::Internal(format!(
                "retract_batch asked to remove {} values but only {} exist",
                retract_count, self.total_len
            )));
        }

        if retract_count == self.total_len {
            self.chunks.clear();
            self.total_len = 0;
            return Ok(());
        }

        // Consolidate chunks into one array, then slice off the front
        let consolidated = self.consolidate()?;
        let remaining = consolidated.slice(retract_count, self.total_len - retract_count);
        self.chunks = vec![remaining];
        self.total_len -= retract_count;

        Ok(())
    }

    fn supports_retract_batch(&self) -> bool {
        true
    }

    fn evaluate(&mut self) -> Result<ScalarValue> {
        let values = self.consolidate()?;

        // Convert to ListArray with a single element (our result)
        let field = Arc::new(Field::new("item", self.element_type.clone(), true));
        let offsets = OffsetBuffer::from_lengths([values.len()]);
        let list_array = ListArray::new(field, offsets, values, None);

        Ok(ScalarValue::List(Arc::new(list_array)))
    }

    fn state(&mut self) -> Result<Vec<ScalarValue>> {
        // For state serialization, return current values as a list
        self.evaluate().map(|v| vec![v])
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        // Merge state from another accumulator
        if states.is_empty() || states[0].len() == 0 {
            return Ok(());
        }

        // State is a ListArray - extract and concatenate each list's values
        let list_array = states[0]
            .as_any()
            .downcast_ref::<ListArray>()
            .ok_or_else(|| {
                datafusion::common::DataFusionError::Internal(
                    "Expected ListArray for state".to_string(),
                )
            })?;

        for i in 0..list_array.len() {
            if !list_array.is_null(i) {
                let value = list_array.value(i);
                if value.len() > 0 {
                    // Use update_batch logic to append
                    self.update_batch(&[value])?;
                }
            }
        }

        Ok(())
    }

    fn size(&self) -> usize {
        std::mem::size_of::<Self>()
            + self
                .chunks
                .iter()
                .map(|a| a.get_array_memory_size())
                .sum::<usize>()
    }
}