use std::any::Any;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::mem::size_of_val;
use std::sync::Arc;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, ListArray, MapArray, StructArray};
use datafusion::arrow::buffer::OffsetBuffer;
use datafusion::arrow::compute::SortOptions;
use datafusion::arrow::datatypes::{DataType, Field, Fields};
use datafusion::common::utils::compare_rows;
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::utils::AggregateOrderSensitivity;
use datafusion::logical_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
};
use crate::struct_expansion::map_data_type;
pub fn multimap_agg_udaf() -> AggregateUDF {
AggregateUDF::new_from_impl(MultimapAggUdaf::new())
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct MultimapAggUdaf {
signature: Signature,
}
impl Default for MultimapAggUdaf {
fn default() -> Self {
Self::new()
}
}
impl MultimapAggUdaf {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl AggregateUDFImpl for MultimapAggUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_multimap_agg"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let value_list_type =
DataType::List(Arc::new(Field::new("item", arg_types[1].clone(), true)));
Ok(map_data_type(arg_types[0].clone(), value_list_type))
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Arc<Field>>> {
let key_type = args.input_fields[0].data_type().clone();
let value_type = args.input_fields[1].data_type().clone();
let mut fields = vec![
Arc::new(Field::new_list(
"keys",
Field::new("item", key_type, true),
true,
)),
Arc::new(Field::new_list(
"values",
Field::new("item", value_type, true),
true,
)),
];
for (i, ord_field) in args.ordering_fields.iter().enumerate() {
fields.push(Arc::new(Field::new_list(
format!("ordering_{}", i),
Field::new("item", ord_field.data_type().clone(), true),
true,
)));
}
Ok(fields)
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let key_type = acc_args.exprs[0]
.data_type(acc_args.schema)
.map_err(|e| datafusion::common::DataFusionError::External(Box::new(e)))?;
let value_type = acc_args.exprs[1]
.data_type(acc_args.schema)
.map_err(|e| datafusion::common::DataFusionError::External(Box::new(e)))?;
if acc_args.order_bys.is_empty() {
Ok(Box::new(MultimapAggAccumulator::new(key_type, value_type)))
} else {
let sort_options: Vec<SortOptions> =
acc_args.order_bys.iter().map(|e| e.options).collect();
let ordering_types: Vec<DataType> = acc_args
.order_bys
.iter()
.map(|e| {
e.expr
.data_type(acc_args.schema)
.map_err(|e| datafusion::common::DataFusionError::External(Box::new(e)))
})
.collect::<Result<_>>()?;
Ok(Box::new(OrderSensitiveMultimapAggAccumulator::new(
key_type,
value_type,
ordering_types,
sort_options,
)))
}
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
AggregateOrderSensitivity::SoftRequirement
}
fn with_beneficial_ordering(
self: Arc<Self>,
_beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
Ok(Some(Arc::new(Self {
signature: self.signature.clone(),
})))
}
}
#[derive(Debug)]
struct MultimapAggAccumulator {
keys: Vec<ScalarValue>,
values: Vec<Vec<ScalarValue>>,
key_indices: HashMap<ScalarValue, usize>,
key_type: DataType,
value_type: DataType,
}
impl MultimapAggAccumulator {
fn new(key_type: DataType, value_type: DataType) -> Self {
Self {
keys: Vec::new(),
values: Vec::new(),
key_indices: HashMap::new(),
key_type,
value_type,
}
}
}
impl Accumulator for MultimapAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.len() < 2 {
return exec_err!(
"multimap_agg expects at least 2 arguments, got {}",
values.len()
);
}
let keys_array = &values[0];
let values_array = &values[1];
for i in 0..keys_array.len() {
let key = ScalarValue::try_from_array(keys_array, i)?;
if key.is_null() {
continue;
}
let value = ScalarValue::try_from_array(values_array, i)?;
match self.key_indices.get(&key) {
Some(&idx) => {
self.values[idx].push(value);
}
None => {
let idx = self.keys.len();
self.keys.push(key.clone());
self.values.push(vec![value]);
self.key_indices.insert(key, idx);
}
}
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.len() != 2 {
return exec_err!("multimap_agg merge expects 2 state arrays");
}
let keys_list = states[0].as_list::<i32>();
let values_list = states[1].as_list::<i32>();
for row in 0..keys_list.len() {
if keys_list.is_null(row) || values_list.is_null(row) {
continue;
}
let keys_array = keys_list.value(row);
let values_array = values_list.value(row);
for i in 0..keys_array.len() {
let key = ScalarValue::try_from_array(&keys_array, i)?;
if key.is_null() {
continue;
}
let value = ScalarValue::try_from_array(&values_array, i)?;
match self.key_indices.get(&key) {
Some(&idx) => {
self.values[idx].push(value);
}
None => {
let idx = self.keys.len();
self.keys.push(key.clone());
self.values.push(vec![value]);
self.key_indices.insert(key, idx);
}
}
}
}
Ok(())
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let key_field = Arc::new(Field::new("item", self.key_type.clone(), true));
let value_field = Arc::new(Field::new("item", self.value_type.clone(), true));
if self.keys.is_empty() {
let empty_keys = datafusion::arrow::array::new_empty_array(&self.key_type);
let empty_values = datafusion::arrow::array::new_empty_array(&self.value_type);
return Ok(vec![
ScalarValue::List(Arc::new(ListArray::new(
key_field,
OffsetBuffer::from_lengths([0]),
empty_keys,
None,
))),
ScalarValue::List(Arc::new(ListArray::new(
value_field,
OffsetBuffer::from_lengths([0]),
empty_values,
None,
))),
]);
}
let mut flat_keys: Vec<ScalarValue> = Vec::new();
let mut flat_values: Vec<ScalarValue> = Vec::new();
for (key, vals) in self.keys.iter().zip(self.values.iter()) {
for val in vals {
flat_keys.push(key.clone());
flat_values.push(val.clone());
}
}
let num_entries = flat_keys.len();
let keys_array = ScalarValue::iter_to_array(flat_keys.into_iter())?;
let values_array = ScalarValue::iter_to_array(flat_values.into_iter())?;
Ok(vec![
ScalarValue::List(Arc::new(ListArray::new(
key_field,
OffsetBuffer::from_lengths([num_entries]),
keys_array,
None,
))),
ScalarValue::List(Arc::new(ListArray::new(
value_field,
OffsetBuffer::from_lengths([num_entries]),
values_array,
None,
))),
])
}
fn evaluate(&mut self) -> Result<ScalarValue> {
build_multimap_scalar(&self.keys, &self.values, &self.key_type, &self.value_type)
}
fn size(&self) -> usize {
size_of_val(self)
+ self.keys.capacity() * std::mem::size_of::<ScalarValue>()
+ self.values.capacity() * std::mem::size_of::<Vec<ScalarValue>>()
+ self
.values
.iter()
.map(|v| v.capacity() * std::mem::size_of::<ScalarValue>())
.sum::<usize>()
+ self.key_indices.capacity()
* (std::mem::size_of::<ScalarValue>() + std::mem::size_of::<usize>())
}
}
#[derive(Debug)]
struct OrderSensitiveMultimapAggAccumulator {
keys: Vec<ScalarValue>,
values: Vec<ScalarValue>,
ordering_values: Vec<Vec<ScalarValue>>,
key_type: DataType,
value_type: DataType,
ordering_types: Vec<DataType>,
sort_options: Vec<SortOptions>,
}
impl OrderSensitiveMultimapAggAccumulator {
fn new(
key_type: DataType,
value_type: DataType,
ordering_types: Vec<DataType>,
sort_options: Vec<SortOptions>,
) -> Self {
Self {
keys: Vec::new(),
values: Vec::new(),
ordering_values: Vec::new(),
key_type,
value_type,
ordering_types,
sort_options,
}
}
fn sort_and_group(&mut self) -> Result<(Vec<ScalarValue>, Vec<Vec<ScalarValue>>)> {
if self.keys.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let mut indices: Vec<usize> = (0..self.keys.len()).collect();
let sort_options = &self.sort_options;
let ordering_values = &self.ordering_values;
let mut sort_error: Option<datafusion::common::DataFusionError> = None;
indices.sort_by(|&a, &b| {
if sort_error.is_some() {
return Ordering::Equal;
}
match compare_rows(&ordering_values[a], &ordering_values[b], sort_options) {
Ok(ord) => ord,
Err(e) => {
sort_error = Some(e);
Ordering::Equal
}
}
});
if let Some(e) = sort_error {
return Err(e);
}
let mut key_to_values: HashMap<ScalarValue, Vec<ScalarValue>> = HashMap::new();
let mut key_order: Vec<ScalarValue> = Vec::new();
for idx in indices {
let key = &self.keys[idx];
if key.is_null() {
continue;
}
if !key_to_values.contains_key(key) {
key_order.push(key.clone());
key_to_values.insert(key.clone(), Vec::new());
}
key_to_values
.get_mut(key)
.map(|v| v.push(self.values[idx].clone()));
}
let result_values: Vec<Vec<ScalarValue>> = key_order
.iter()
.map(|k| key_to_values.remove(k).unwrap_or_default())
.collect();
Ok((key_order, result_values))
}
}
impl Accumulator for OrderSensitiveMultimapAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.len() < 2 {
return exec_err!(
"multimap_agg expects at least 2 arguments, got {}",
values.len()
);
}
let keys_array = &values[0];
let values_array = &values[1];
let ordering_arrays = &values[2..];
for i in 0..keys_array.len() {
let key = ScalarValue::try_from_array(keys_array, i)?;
let value = ScalarValue::try_from_array(values_array, i)?;
let mut row_ordering = Vec::with_capacity(ordering_arrays.len());
for ord_arr in ordering_arrays {
row_ordering.push(ScalarValue::try_from_array(ord_arr, i)?);
}
self.keys.push(key);
self.values.push(value);
self.ordering_values.push(row_ordering);
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let num_ordering_cols = self.ordering_types.len();
let expected_state_len = 2 + num_ordering_cols;
if states.len() != expected_state_len {
return exec_err!(
"multimap_agg merge expects {} state arrays (2 + {} ordering), got {}",
expected_state_len,
num_ordering_cols,
states.len()
);
}
let keys_list = states[0].as_list::<i32>();
let values_list = states[1].as_list::<i32>();
let ordering_lists: Vec<_> = states[2..].iter().map(|s| s.as_list::<i32>()).collect();
for row in 0..keys_list.len() {
if keys_list.is_null(row) || values_list.is_null(row) {
continue;
}
let keys_array = keys_list.value(row);
let values_array = values_list.value(row);
let ordering_arrays: Vec<_> =
ordering_lists.iter().map(|list| list.value(row)).collect();
for i in 0..keys_array.len() {
let key = ScalarValue::try_from_array(&keys_array, i)?;
let value = ScalarValue::try_from_array(&values_array, i)?;
let mut row_ordering = Vec::with_capacity(num_ordering_cols);
for ord_arr in &ordering_arrays {
row_ordering.push(ScalarValue::try_from_array(ord_arr, i)?);
}
self.keys.push(key);
self.values.push(value);
self.ordering_values.push(row_ordering);
}
}
Ok(())
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let num_entries = self.keys.len();
let key_field = Arc::new(Field::new("item", self.key_type.clone(), true));
let value_field = Arc::new(Field::new("item", self.value_type.clone(), true));
if self.keys.is_empty() {
let empty_keys = datafusion::arrow::array::new_empty_array(&self.key_type);
let empty_values = datafusion::arrow::array::new_empty_array(&self.value_type);
let mut state = vec![
ScalarValue::List(Arc::new(ListArray::new(
key_field,
OffsetBuffer::from_lengths([0]),
empty_keys,
None,
))),
ScalarValue::List(Arc::new(ListArray::new(
value_field,
OffsetBuffer::from_lengths([0]),
empty_values,
None,
))),
];
for ord_type in &self.ordering_types {
let ord_field = Arc::new(Field::new("item", ord_type.clone(), true));
let empty_ord = datafusion::arrow::array::new_empty_array(ord_type);
state.push(ScalarValue::List(Arc::new(ListArray::new(
ord_field,
OffsetBuffer::from_lengths([0]),
empty_ord,
None,
))));
}
return Ok(state);
}
let keys_array = ScalarValue::iter_to_array(self.keys.iter().cloned())?;
let values_array = ScalarValue::iter_to_array(self.values.iter().cloned())?;
let mut state = vec![
ScalarValue::List(Arc::new(ListArray::new(
key_field,
OffsetBuffer::from_lengths([keys_array.len()]),
keys_array,
None,
))),
ScalarValue::List(Arc::new(ListArray::new(
value_field,
OffsetBuffer::from_lengths([values_array.len()]),
values_array,
None,
))),
];
for (col_idx, ord_type) in self.ordering_types.iter().enumerate() {
let ord_field = Arc::new(Field::new("item", ord_type.clone(), true));
let col_values: Vec<ScalarValue> = self
.ordering_values
.iter()
.map(|row| row[col_idx].clone())
.collect();
let ord_array = ScalarValue::iter_to_array(col_values.into_iter())?;
state.push(ScalarValue::List(Arc::new(ListArray::new(
ord_field,
OffsetBuffer::from_lengths([num_entries]),
ord_array,
None,
))));
}
Ok(state)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let (keys, grouped_values) = self.sort_and_group()?;
build_multimap_scalar(&keys, &grouped_values, &self.key_type, &self.value_type)
}
fn size(&self) -> usize {
size_of_val(self)
+ self.keys.capacity() * std::mem::size_of::<ScalarValue>()
+ self.values.capacity() * std::mem::size_of::<ScalarValue>()
+ self
.ordering_values
.iter()
.map(|v| v.capacity() * std::mem::size_of::<ScalarValue>())
.sum::<usize>()
}
}
fn build_multimap_scalar(
keys: &[ScalarValue],
grouped_values: &[Vec<ScalarValue>],
key_type: &DataType,
value_type: &DataType,
) -> Result<ScalarValue> {
let value_list_type = DataType::List(Arc::new(Field::new("item", value_type.clone(), true)));
if keys.is_empty() {
return ScalarValue::try_new_null(&map_data_type(key_type.clone(), value_list_type));
}
let keys_array = ScalarValue::iter_to_array(keys.iter().cloned())?;
let inner_field = Arc::new(Field::new("item", value_type.clone(), true));
let mut offsets = vec![0i32];
let mut all_values: Vec<ScalarValue> = Vec::new();
for vals in grouped_values {
all_values.extend(vals.iter().cloned());
offsets.push(all_values.len() as i32);
}
let flat_values = if all_values.is_empty() {
datafusion::arrow::array::new_empty_array(value_type)
} else {
ScalarValue::iter_to_array(all_values.into_iter())?
};
let values_list_array = ListArray::new(
inner_field,
OffsetBuffer::new(offsets.into()),
flat_values,
None,
);
let struct_fields = Fields::from(vec![
Field::new("key", key_type.clone(), false),
Field::new("value", value_list_type, true),
]);
let struct_array = StructArray::new(
struct_fields,
vec![keys_array, Arc::new(values_list_array)],
None,
);
let map_array = MapArray::new(
Arc::new(Field::new(
"entries",
struct_array.data_type().clone(),
false,
)),
OffsetBuffer::from_lengths([keys.len()]),
struct_array,
None,
false,
);
ScalarValue::try_from_array(&map_array, 0)
}