use std::any::Any;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::mem::size_of_val;
use std::sync::Arc;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, ListArray, MapArray, StructArray};
use datafusion::arrow::buffer::OffsetBuffer;
use datafusion::arrow::compute::SortOptions;
use datafusion::arrow::datatypes::{DataType, Field, Fields};
use datafusion::common::utils::compare_rows;
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::utils::AggregateOrderSensitivity;
use datafusion::logical_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
};
use crate::struct_expansion::map_data_type;
pub fn map_agg_udaf() -> AggregateUDF {
AggregateUDF::new_from_impl(MapAggUdaf::new())
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct MapAggUdaf {
signature: Signature,
}
impl Default for MapAggUdaf {
fn default() -> Self {
Self::new()
}
}
impl MapAggUdaf {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl AggregateUDFImpl for MapAggUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_map_agg"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(map_data_type(arg_types[0].clone(), arg_types[1].clone()))
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Arc<Field>>> {
let key_type = args.input_fields[0].data_type().clone();
let value_type = args.input_fields[1].data_type().clone();
let mut fields = vec![
Arc::new(Field::new_list(
"keys",
Field::new("item", key_type, true),
true,
)),
Arc::new(Field::new_list(
"values",
Field::new("item", value_type, true),
true,
)),
];
for (i, ord_field) in args.ordering_fields.iter().enumerate() {
fields.push(Arc::new(Field::new_list(
format!("ordering_{}", i),
Field::new("item", ord_field.data_type().clone(), true),
true,
)));
}
Ok(fields)
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let key_type = acc_args.exprs[0]
.data_type(acc_args.schema)
.map_err(|e| datafusion::common::DataFusionError::External(Box::new(e)))?;
let value_type = acc_args.exprs[1]
.data_type(acc_args.schema)
.map_err(|e| datafusion::common::DataFusionError::External(Box::new(e)))?;
if acc_args.order_bys.is_empty() {
Ok(Box::new(MapAggAccumulator::new(key_type, value_type)))
} else {
let sort_options: Vec<SortOptions> =
acc_args.order_bys.iter().map(|e| e.options).collect();
let ordering_types: Vec<DataType> = acc_args
.order_bys
.iter()
.map(|e| {
e.expr
.data_type(acc_args.schema)
.map_err(|e| datafusion::common::DataFusionError::External(Box::new(e)))
})
.collect::<Result<_>>()?;
Ok(Box::new(OrderSensitiveMapAggAccumulator::new(
key_type,
value_type,
ordering_types,
sort_options,
)))
}
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
AggregateOrderSensitivity::SoftRequirement
}
fn with_beneficial_ordering(
self: Arc<Self>,
_beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
Ok(Some(Arc::new(Self {
signature: self.signature.clone(),
})))
}
}
#[derive(Debug)]
struct MapAggAccumulator {
keys: Vec<ScalarValue>,
values: Vec<ScalarValue>,
seen_keys: HashMap<ScalarValue, usize>,
key_type: DataType,
value_type: DataType,
}
impl MapAggAccumulator {
fn new(key_type: DataType, value_type: DataType) -> Self {
Self {
keys: Vec::new(),
values: Vec::new(),
seen_keys: HashMap::new(),
key_type,
value_type,
}
}
}
impl Accumulator for MapAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.len() < 2 {
return exec_err!("map_agg expects at least 2 arguments, got {}", values.len());
}
let keys_array = &values[0];
let values_array = &values[1];
for i in 0..keys_array.len() {
let key = ScalarValue::try_from_array(keys_array, i)?;
if key.is_null() {
continue;
}
if !self.seen_keys.contains_key(&key) {
let value = ScalarValue::try_from_array(values_array, i)?;
let idx = self.keys.len();
self.keys.push(key.clone());
self.values.push(value);
self.seen_keys.insert(key, idx);
}
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.len() != 2 {
return exec_err!("map_agg merge expects 2 state arrays");
}
let keys_list = states[0].as_list::<i32>();
let values_list = states[1].as_list::<i32>();
for row in 0..keys_list.len() {
if keys_list.is_null(row) || values_list.is_null(row) {
continue;
}
let keys_array = keys_list.value(row);
let values_array = values_list.value(row);
for i in 0..keys_array.len() {
let key = ScalarValue::try_from_array(&keys_array, i)?;
if key.is_null() {
continue;
}
if !self.seen_keys.contains_key(&key) {
let value = ScalarValue::try_from_array(&values_array, i)?;
let idx = self.keys.len();
self.keys.push(key.clone());
self.values.push(value);
self.seen_keys.insert(key, idx);
}
}
}
Ok(())
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let key_field = Arc::new(Field::new("item", self.key_type.clone(), true));
let value_field = Arc::new(Field::new("item", self.value_type.clone(), true));
if self.keys.is_empty() {
let empty_keys = datafusion::arrow::array::new_empty_array(&self.key_type);
let empty_values = datafusion::arrow::array::new_empty_array(&self.value_type);
return Ok(vec![
ScalarValue::List(Arc::new(ListArray::new(
key_field,
OffsetBuffer::from_lengths([0]),
empty_keys,
None,
))),
ScalarValue::List(Arc::new(ListArray::new(
value_field,
OffsetBuffer::from_lengths([0]),
empty_values,
None,
))),
]);
}
let keys_array = ScalarValue::iter_to_array(self.keys.iter().cloned())?;
let values_array = ScalarValue::iter_to_array(self.values.iter().cloned())?;
Ok(vec![
ScalarValue::List(Arc::new(ListArray::new(
key_field,
OffsetBuffer::from_lengths([keys_array.len()]),
keys_array,
None,
))),
ScalarValue::List(Arc::new(ListArray::new(
value_field,
OffsetBuffer::from_lengths([values_array.len()]),
values_array,
None,
))),
])
}
fn evaluate(&mut self) -> Result<ScalarValue> {
if self.keys.is_empty() {
return ScalarValue::try_new_null(&map_data_type(
self.key_type.clone(),
self.value_type.clone(),
));
}
build_map_scalar(&self.keys, &self.values, &self.key_type, &self.value_type)
}
fn size(&self) -> usize {
size_of_val(self)
+ self.keys.capacity() * std::mem::size_of::<ScalarValue>()
+ self.values.capacity() * std::mem::size_of::<ScalarValue>()
+ self.seen_keys.capacity()
* (std::mem::size_of::<ScalarValue>() + std::mem::size_of::<usize>())
}
}
#[derive(Debug)]
struct OrderSensitiveMapAggAccumulator {
keys: Vec<ScalarValue>,
values: Vec<ScalarValue>,
ordering_values: Vec<Vec<ScalarValue>>,
key_type: DataType,
value_type: DataType,
ordering_types: Vec<DataType>,
sort_options: Vec<SortOptions>,
}
impl OrderSensitiveMapAggAccumulator {
fn new(
key_type: DataType,
value_type: DataType,
ordering_types: Vec<DataType>,
sort_options: Vec<SortOptions>,
) -> Self {
Self {
keys: Vec::new(),
values: Vec::new(),
ordering_values: Vec::new(),
key_type,
value_type,
ordering_types,
sort_options,
}
}
fn sort_and_deduplicate(
&mut self,
) -> Result<(Vec<ScalarValue>, Vec<ScalarValue>, Vec<Vec<ScalarValue>>)> {
if self.keys.is_empty() {
return Ok((Vec::new(), Vec::new(), Vec::new()));
}
let mut indices: Vec<usize> = (0..self.keys.len()).collect();
let sort_options = &self.sort_options;
let ordering_values = &self.ordering_values;
let mut sort_error: Option<datafusion::common::DataFusionError> = None;
indices.sort_by(|&a, &b| {
if sort_error.is_some() {
return Ordering::Equal;
}
match compare_rows(&ordering_values[a], &ordering_values[b], sort_options) {
Ok(ord) => ord,
Err(e) => {
sort_error = Some(e);
Ordering::Equal
}
}
});
if let Some(e) = sort_error {
return Err(e);
}
let mut seen: HashMap<ScalarValue, usize> = HashMap::new();
let mut result_keys = Vec::new();
let mut result_values = Vec::new();
let mut result_orderings = Vec::new();
for idx in indices {
let key = &self.keys[idx];
if key.is_null() {
continue;
}
if !seen.contains_key(key) {
seen.insert(key.clone(), result_keys.len());
result_keys.push(key.clone());
result_values.push(self.values[idx].clone());
result_orderings.push(self.ordering_values[idx].clone());
}
}
Ok((result_keys, result_values, result_orderings))
}
}
impl Accumulator for OrderSensitiveMapAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.len() < 2 {
return exec_err!("map_agg expects at least 2 arguments, got {}", values.len());
}
let keys_array = &values[0];
let values_array = &values[1];
let ordering_arrays = &values[2..];
for i in 0..keys_array.len() {
let key = ScalarValue::try_from_array(keys_array, i)?;
let value = ScalarValue::try_from_array(values_array, i)?;
let mut row_ordering = Vec::with_capacity(ordering_arrays.len());
for ord_arr in ordering_arrays {
row_ordering.push(ScalarValue::try_from_array(ord_arr, i)?);
}
self.keys.push(key);
self.values.push(value);
self.ordering_values.push(row_ordering);
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let num_ordering_cols = self.ordering_types.len();
let expected_state_len = 2 + num_ordering_cols;
if states.len() != expected_state_len {
return exec_err!(
"map_agg merge expects {} state arrays (2 + {} ordering), got {}",
expected_state_len,
num_ordering_cols,
states.len()
);
}
let keys_list = states[0].as_list::<i32>();
let values_list = states[1].as_list::<i32>();
let ordering_lists: Vec<_> = states[2..].iter().map(|s| s.as_list::<i32>()).collect();
for row in 0..keys_list.len() {
if keys_list.is_null(row) || values_list.is_null(row) {
continue;
}
let keys_array = keys_list.value(row);
let values_array = values_list.value(row);
let ordering_arrays: Vec<_> =
ordering_lists.iter().map(|list| list.value(row)).collect();
for i in 0..keys_array.len() {
let key = ScalarValue::try_from_array(&keys_array, i)?;
let value = ScalarValue::try_from_array(&values_array, i)?;
let mut row_ordering = Vec::with_capacity(num_ordering_cols);
for ord_arr in &ordering_arrays {
row_ordering.push(ScalarValue::try_from_array(ord_arr, i)?);
}
self.keys.push(key);
self.values.push(value);
self.ordering_values.push(row_ordering);
}
}
Ok(())
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let (keys, values, orderings) = self.sort_and_deduplicate()?;
let num_entries = keys.len();
let key_field = Arc::new(Field::new("item", self.key_type.clone(), true));
let value_field = Arc::new(Field::new("item", self.value_type.clone(), true));
if keys.is_empty() {
let empty_keys = datafusion::arrow::array::new_empty_array(&self.key_type);
let empty_values = datafusion::arrow::array::new_empty_array(&self.value_type);
let mut state = vec![
ScalarValue::List(Arc::new(ListArray::new(
key_field,
OffsetBuffer::from_lengths([0]),
empty_keys,
None,
))),
ScalarValue::List(Arc::new(ListArray::new(
value_field,
OffsetBuffer::from_lengths([0]),
empty_values,
None,
))),
];
for ord_type in &self.ordering_types {
let ord_field = Arc::new(Field::new("item", ord_type.clone(), true));
let empty_ord = datafusion::arrow::array::new_empty_array(ord_type);
state.push(ScalarValue::List(Arc::new(ListArray::new(
ord_field,
OffsetBuffer::from_lengths([0]),
empty_ord,
None,
))));
}
return Ok(state);
}
let keys_array = ScalarValue::iter_to_array(keys.into_iter())?;
let values_array = ScalarValue::iter_to_array(values.into_iter())?;
let mut state = vec![
ScalarValue::List(Arc::new(ListArray::new(
key_field,
OffsetBuffer::from_lengths([keys_array.len()]),
keys_array,
None,
))),
ScalarValue::List(Arc::new(ListArray::new(
value_field,
OffsetBuffer::from_lengths([values_array.len()]),
values_array,
None,
))),
];
for (col_idx, ord_type) in self.ordering_types.iter().enumerate() {
let ord_field = Arc::new(Field::new("item", ord_type.clone(), true));
let col_values: Vec<ScalarValue> =
orderings.iter().map(|row| row[col_idx].clone()).collect();
let ord_array = ScalarValue::iter_to_array(col_values.into_iter())?;
state.push(ScalarValue::List(Arc::new(ListArray::new(
ord_field,
OffsetBuffer::from_lengths([num_entries]),
ord_array,
None,
))));
}
Ok(state)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let (keys, values, _orderings) = self.sort_and_deduplicate()?;
if keys.is_empty() {
return ScalarValue::try_new_null(&map_data_type(
self.key_type.clone(),
self.value_type.clone(),
));
}
build_map_scalar(&keys, &values, &self.key_type, &self.value_type)
}
fn size(&self) -> usize {
size_of_val(self)
+ self.keys.capacity() * std::mem::size_of::<ScalarValue>()
+ self.values.capacity() * std::mem::size_of::<ScalarValue>()
+ self
.ordering_values
.iter()
.map(|v| v.capacity() * std::mem::size_of::<ScalarValue>())
.sum::<usize>()
}
}
fn build_map_scalar(
keys: &[ScalarValue],
values: &[ScalarValue],
key_type: &DataType,
value_type: &DataType,
) -> Result<ScalarValue> {
let keys_array = ScalarValue::iter_to_array(keys.iter().cloned())?;
let values_array = ScalarValue::iter_to_array(values.iter().cloned())?;
let struct_fields = Fields::from(vec![
Field::new("key", key_type.clone(), false),
Field::new("value", value_type.clone(), true),
]);
let struct_array = StructArray::new(struct_fields, vec![keys_array, values_array], None);
let map_array = MapArray::new(
Arc::new(Field::new(
"entries",
struct_array.data_type().clone(),
false,
)),
OffsetBuffer::from_lengths([keys.len()]),
struct_array,
None,
false,
);
ScalarValue::try_from_array(&map_array, 0)
}