use std::any::Any;
use std::sync::Arc;
use datafusion::arrow::array::{Array, ArrayRef, ListArray};
use datafusion::arrow::buffer::OffsetBuffer;
use datafusion::arrow::compute::concat;
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{Result, ScalarValue};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SlidingArrayAggUdaf {
signature: Signature,
}
impl Default for SlidingArrayAggUdaf {
fn default() -> Self {
Self::new()
}
}
impl SlidingArrayAggUdaf {
pub fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl AggregateUDFImpl for SlidingArrayAggUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_sliding_array_agg"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let element_type = arg_types[0].clone();
Ok(DataType::List(Arc::new(Field::new(
"item",
element_type,
true,
))))
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Arc<Field>>> {
let element_type = args.input_fields[0].data_type().clone();
Ok(vec![Arc::new(Field::new_list(
"values",
Field::new("item", element_type, true),
true,
))])
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let element_type = acc_args.exprs[0]
.data_type(acc_args.schema)
.map_err(|e| datafusion::common::DataFusionError::External(Box::new(e)))?;
Ok(Box::new(SlidingArrayAggAccumulator::new(element_type)))
}
}
pub fn sliding_array_agg_udaf() -> AggregateUDF {
AggregateUDF::from(SlidingArrayAggUdaf::new())
}
#[derive(Debug)]
struct SlidingArrayAggAccumulator {
chunks: Vec<ArrayRef>,
total_len: usize,
element_type: DataType,
}
impl SlidingArrayAggAccumulator {
fn new(element_type: DataType) -> Self {
Self {
chunks: Vec::new(),
total_len: 0,
element_type,
}
}
fn empty_array(&self) -> Result<ArrayRef> {
Ok(datafusion::arrow::array::new_empty_array(
&self.element_type,
))
}
fn consolidate(&self) -> Result<ArrayRef> {
match self.chunks.len() {
0 => self.empty_array(),
1 => Ok(self.chunks[0].clone()),
_ => {
let refs: Vec<&dyn Array> = self.chunks.iter().map(|a| a.as_ref()).collect();
Ok(concat(&refs)?)
}
}
}
}
impl Accumulator for SlidingArrayAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() || values[0].len() == 0 {
return Ok(());
}
self.total_len += values[0].len();
self.chunks.push(values[0].clone());
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() || values[0].len() == 0 {
return Ok(());
}
let retract_count = values[0].len();
if retract_count > self.total_len {
return Err(datafusion::common::DataFusionError::Internal(format!(
"retract_batch asked to remove {} values but only {} exist",
retract_count, self.total_len
)));
}
if retract_count == self.total_len {
self.chunks.clear();
self.total_len = 0;
return Ok(());
}
let consolidated = self.consolidate()?;
let remaining = consolidated.slice(retract_count, self.total_len - retract_count);
self.chunks = vec![remaining];
self.total_len -= retract_count;
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let values = self.consolidate()?;
let field = Arc::new(Field::new("item", self.element_type.clone(), true));
let offsets = OffsetBuffer::from_lengths([values.len()]);
let list_array = ListArray::new(field, offsets, values, None);
Ok(ScalarValue::List(Arc::new(list_array)))
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.evaluate().map(|v| vec![v])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() || states[0].len() == 0 {
return Ok(());
}
let list_array = states[0]
.as_any()
.downcast_ref::<ListArray>()
.ok_or_else(|| {
datafusion::common::DataFusionError::Internal(
"Expected ListArray for state".to_string(),
)
})?;
for i in 0..list_array.len() {
if !list_array.is_null(i) {
let value = list_array.value(i);
if value.len() > 0 {
self.update_batch(&[value])?;
}
}
}
Ok(())
}
fn size(&self) -> usize {
std::mem::size_of::<Self>()
+ self
.chunks
.iter()
.map(|a| a.get_array_memory_size())
.sum::<usize>()
}
}