use std::{iter::zip, sync::Arc};
use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
use arrow_schema::{DataType, FieldRef, Schema};
use datafusion_common::{
arrow::compute::kernels::concat::concat, config::ConfigOptions, Result, ScalarValue,
};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDF, ColumnarValue, EmitTo, Expr, GroupsAccumulator, Literal,
ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
use sedona_common::{sedona_internal_err, SedonaOptions};
use sedona_schema::datatypes::SedonaType;
use crate::{
compare::assert_scalar_equal,
create::{create_array, create_scalar},
};
pub struct AggregateUdfTester {
udf: AggregateUDF,
arg_types: Vec<SedonaType>,
mock_schema: Schema,
mock_exprs: Vec<Arc<dyn PhysicalExpr>>,
}
impl AggregateUdfTester {
pub fn new(udf: AggregateUDF, arg_types: Vec<SedonaType>) -> Self {
let arg_fields = arg_types
.iter()
.map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
.collect::<Result<Vec<_>>>()
.unwrap();
let mock_schema = Schema::new(arg_fields);
let mock_exprs = (0..arg_types.len())
.map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col", i)) })
.collect::<Vec<_>>();
Self {
udf,
arg_types,
mock_schema,
mock_exprs,
}
}
pub fn return_type(&self) -> Result<SedonaType> {
let out_field = self.udf.return_field(&self.mock_schema.fields)?;
SedonaType::from_storage_field(&out_field)
}
pub fn aggregate_wkt(&self, batches: Vec<Vec<Option<&str>>>) -> Result<ScalarValue> {
let batches_array = batches
.into_iter()
.map(|batch| create_array(&batch, &self.arg_types[0]))
.collect::<Vec<_>>();
self.aggregate(&batches_array)
}
pub fn aggregate(&self, batches: &Vec<ArrayRef>) -> Result<ScalarValue> {
let state_schema = Arc::new(Schema::new(self.state_fields()?));
let mut state_accumulator = self.new_accumulator()?;
for batch in batches {
let mut batch_accumulator = self.new_accumulator()?;
batch_accumulator.update_batch(std::slice::from_ref(batch))?;
let state_batch_of_one = RecordBatch::try_new(
state_schema.clone(),
batch_accumulator
.state()?
.into_iter()
.map(|v| v.to_array())
.collect::<Result<Vec<_>>>()?,
)?;
state_accumulator.merge_batch(state_batch_of_one.columns())?;
}
state_accumulator.evaluate()
}
pub fn aggregate_groups(
&self,
batches: &Vec<ArrayRef>,
group_indices: Vec<usize>,
opt_filter: Option<&Vec<bool>>,
emit_sizes: Vec<usize>,
) -> Result<ArrayRef> {
let state_schema = Arc::new(Schema::new(self.state_fields()?));
let mut state_accumulator = self.new_groups_accumulator()?;
let total_num_groups = group_indices.iter().max().unwrap_or(&0) + 1;
let total_input_rows: usize = batches.iter().map(|a| a.len()).sum();
assert_eq!(total_input_rows, group_indices.len());
if let Some(filter) = opt_filter {
assert_eq!(total_input_rows, filter.len());
}
if !emit_sizes.is_empty() {
assert_eq!(emit_sizes.iter().sum::<usize>(), total_num_groups);
}
let mut offset = 0;
for batch in batches {
let mut batch_accumulator = self.new_groups_accumulator()?;
let opt_filter_array = opt_filter.map(|filter_vec| {
filter_vec[offset..(offset + batch.len())]
.iter()
.collect::<BooleanArray>()
});
batch_accumulator.update_batch(
std::slice::from_ref(batch),
&group_indices[offset..(offset + batch.len())],
opt_filter_array.as_ref(),
total_num_groups,
)?;
offset += batch.len();
let state_batch = RecordBatch::try_new(
state_schema.clone(),
batch_accumulator.state(datafusion_expr::EmitTo::All)?,
)?;
state_accumulator.merge_batch(
state_batch.columns(),
&(0..total_num_groups).collect::<Vec<_>>(),
None,
total_num_groups,
)?;
}
if emit_sizes.is_empty() {
state_accumulator.evaluate(datafusion_expr::EmitTo::All)
} else {
let arrays = emit_sizes
.iter()
.map(|emit_size| state_accumulator.evaluate(EmitTo::First(*emit_size)))
.collect::<Result<Vec<_>>>()?;
let arrays_ref = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
Ok(concat(&arrays_ref)?)
}
}
fn new_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let accumulator_args = self.accumulator_args()?;
self.udf.accumulator(accumulator_args)
}
fn new_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
assert!(self
.udf
.groups_accumulator_supported(self.accumulator_args()?));
self.udf.create_groups_accumulator(self.accumulator_args()?)
}
fn accumulator_args(&self) -> Result<AccumulatorArgs<'_>> {
Ok(AccumulatorArgs {
return_field: self.udf.return_field(self.mock_schema.fields())?,
schema: &self.mock_schema,
ignore_nulls: true,
order_bys: &[],
is_reversed: false,
name: "",
is_distinct: false,
exprs: &self.mock_exprs,
expr_fields: &[],
})
}
fn state_fields(&self) -> Result<Vec<FieldRef>> {
let state_field_args = StateFieldsArgs {
name: "",
input_fields: self.mock_schema.fields(),
return_field: self.udf.return_field(self.mock_schema.fields())?,
ordering_fields: &[],
is_distinct: false,
};
self.udf.state_fields(state_field_args)
}
}
pub struct ScalarUdfTester {
udf: ScalarUDF,
arg_types: Vec<SedonaType>,
config_options: Arc<ConfigOptions>,
}
impl ScalarUdfTester {
pub fn new(udf: ScalarUDF, arg_types: Vec<SedonaType>) -> Self {
let mut config_options = ConfigOptions::default();
let sedona_options = SedonaOptions::default();
config_options.extensions.insert(sedona_options);
Self {
udf,
arg_types,
config_options: Arc::new(config_options),
}
}
pub fn config_options(&self) -> &ConfigOptions {
&self.config_options
}
pub fn config_options_mut(&mut self) -> &mut ConfigOptions {
Arc::get_mut(&mut self.config_options).expect("ConfigOptions is shared")
}
pub fn sedona_options(&self) -> &SedonaOptions {
self.config_options
.extensions
.get::<SedonaOptions>()
.expect("SedonaOptions does not exist")
}
pub fn sedona_options_mut(&mut self) -> &mut SedonaOptions {
self.config_options_mut()
.extensions
.get_mut::<SedonaOptions>()
.expect("SedonaOptions does not exist")
}
pub fn assert_return_type(&self, data_type: impl TryInto<SedonaType>) {
let expected = match data_type.try_into() {
Ok(t) => t,
Err(_) => panic!("Failed to convert to SedonaType"),
};
assert_eq!(self.return_type().unwrap(), expected)
}
pub fn assert_scalar_result_equals(&self, actual: impl Literal, expected: impl Literal) {
self.assert_scalar_result_equals_inner(actual, expected, None);
}
pub fn assert_scalar_result_equals_with_return_type(
&self,
actual: impl Literal,
expected: impl Literal,
return_type: SedonaType,
) {
self.assert_scalar_result_equals_inner(actual, expected, Some(return_type));
}
fn assert_scalar_result_equals_inner(
&self,
actual: impl Literal,
expected: impl Literal,
return_type: Option<SedonaType>,
) {
let return_type = return_type.unwrap_or_else(|| self.return_type().unwrap());
let actual = Self::scalar_lit(actual, &return_type).unwrap();
let expected = Self::scalar_lit(expected, &return_type).unwrap();
assert_scalar_equal(&actual, &expected);
}
pub fn return_type(&self) -> Result<SedonaType> {
let scalar_arguments = vec![None; self.arg_types.len()];
self.return_type_with_scalars_inner(&scalar_arguments)
}
pub fn return_type_with_scalar(&self, arg0: Option<impl Literal>) -> Result<SedonaType> {
let scalar_arguments = vec![arg0
.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
.transpose()?];
self.return_type_with_scalars_inner(&scalar_arguments)
}
pub fn return_type_with_scalar_scalar(
&self,
arg0: Option<impl Literal>,
arg1: Option<impl Literal>,
) -> Result<SedonaType> {
let scalar_arguments = vec![
arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
.transpose()?,
arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
.transpose()?,
];
self.return_type_with_scalars_inner(&scalar_arguments)
}
pub fn return_type_with_scalar_scalar_scalar(
&self,
arg0: Option<impl Literal>,
arg1: Option<impl Literal>,
arg2: Option<impl Literal>,
) -> Result<SedonaType> {
let scalar_arguments = vec![
arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
.transpose()?,
arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
.transpose()?,
arg2.map(|x| Self::scalar_lit(x, &self.arg_types[2]))
.transpose()?,
];
self.return_type_with_scalars_inner(&scalar_arguments)
}
fn return_type_with_scalars_inner(
&self,
scalar_arguments: &[Option<ScalarValue>],
) -> Result<SedonaType> {
let arg_fields = self
.arg_types
.iter()
.map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
.collect::<Result<Vec<_>>>()?;
let scalar_arguments_ref: Vec<Option<&ScalarValue>> =
scalar_arguments.iter().map(|x| x.as_ref()).collect();
let args = ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &scalar_arguments_ref,
};
let return_field = self.udf.return_field_from_args(args)?;
SedonaType::from_storage_field(&return_field)
}
pub fn invoke_scalar(&self, arg: impl Literal) -> Result<ScalarValue> {
let scalar_arg = Self::scalar_lit(arg, &self.arg_types[0])?;
let return_type = self
.return_type_with_scalars_inner(&[Some(scalar_arg.clone())])
.ok();
let args = vec![ColumnarValue::Scalar(scalar_arg)];
if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
Ok(scalar)
} else {
sedona_internal_err!("Expected scalar result from scalar invoke")
}
}
pub fn invoke_wkb_scalar(&self, wkt_value: Option<&str>) -> Result<ScalarValue> {
self.invoke_scalar(create_scalar(wkt_value, &self.arg_types[0]))
}
pub fn invoke_scalar_scalar<T0: Literal, T1: Literal>(
&self,
arg0: T0,
arg1: T1,
) -> Result<ScalarValue> {
let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
let return_type = self
.return_type_with_scalars_inner(&[Some(scalar_arg0.clone()), Some(scalar_arg1.clone())])
.ok();
let args = vec![
ColumnarValue::Scalar(scalar_arg0),
ColumnarValue::Scalar(scalar_arg1),
];
if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
Ok(scalar)
} else {
sedona_internal_err!("Expected scalar result from binary scalar invoke")
}
}
pub fn invoke_scalar_scalar_scalar<T0: Literal, T1: Literal, T2: Literal>(
&self,
arg0: T0,
arg1: T1,
arg2: T2,
) -> Result<ScalarValue> {
let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
let scalar_arg2 = Self::scalar_lit(arg2, &self.arg_types[2])?;
let return_type = self
.return_type_with_scalars_inner(&[
Some(scalar_arg0.clone()),
Some(scalar_arg1.clone()),
Some(scalar_arg2.clone()),
])
.ok();
let args = vec![
ColumnarValue::Scalar(scalar_arg0),
ColumnarValue::Scalar(scalar_arg1),
ColumnarValue::Scalar(scalar_arg2),
];
if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
Ok(scalar)
} else {
sedona_internal_err!("Expected scalar result from binary scalar invoke")
}
}
pub fn invoke_wkb_array(&self, wkb_values: Vec<Option<&str>>) -> Result<ArrayRef> {
self.invoke_array(create_array(&wkb_values, &self.arg_types[0]))
}
pub fn invoke_wkb_array_scalar(
&self,
wkb_values: Vec<Option<&str>>,
arg: impl Literal,
) -> Result<ArrayRef> {
let wkb_array = create_array(&wkb_values, &self.arg_types[0]);
self.invoke_arrays_scalar(vec![wkb_array], arg)
}
pub fn invoke_array(&self, array: ArrayRef) -> Result<ArrayRef> {
self.invoke_arrays(vec![array])
}
pub fn invoke_array_scalar(&self, array: ArrayRef, arg: impl Literal) -> Result<ArrayRef> {
self.invoke_arrays_scalar(vec![array], arg)
}
pub fn invoke_array_scalar_scalar(
&self,
array: ArrayRef,
arg0: impl Literal,
arg1: impl Literal,
) -> Result<ArrayRef> {
self.invoke_arrays_scalar_scalar(vec![array], arg0, arg1)
}
pub fn invoke_scalar_array(&self, arg: impl Literal, array: ArrayRef) -> Result<ArrayRef> {
self.invoke_scalar_arrays(arg, vec![array])
}
pub fn invoke_array_array(&self, array0: ArrayRef, array1: ArrayRef) -> Result<ArrayRef> {
self.invoke_arrays(vec![array0, array1])
}
pub fn invoke_array_array_scalar(
&self,
array0: ArrayRef,
array1: ArrayRef,
arg: impl Literal,
) -> Result<ArrayRef> {
self.invoke_arrays_scalar(vec![array0, array1], arg)
}
fn invoke_scalar_arrays(&self, arg: impl Literal, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
let mut args = zip(arrays, &self.arg_types)
.map(|(array, sedona_type)| {
ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
})
.collect::<Result<Vec<_>>>()?;
let index = args.len();
args.insert(0, Self::scalar_arg(arg, &self.arg_types[index])?);
if let ColumnarValue::Array(array) = self.invoke(args)? {
Ok(array)
} else {
sedona_internal_err!("Expected array result from scalar/array invoke")
}
}
fn invoke_arrays_scalar(&self, arrays: Vec<ArrayRef>, arg: impl Literal) -> Result<ArrayRef> {
let mut args = zip(arrays, &self.arg_types)
.map(|(array, sedona_type)| {
ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
})
.collect::<Result<Vec<_>>>()?;
let index = args.len();
args.push(Self::scalar_arg(arg, &self.arg_types[index])?);
if let ColumnarValue::Array(array) = self.invoke(args)? {
Ok(array)
} else {
sedona_internal_err!("Expected array result from array/scalar invoke")
}
}
fn invoke_arrays_scalar_scalar(
&self,
arrays: Vec<ArrayRef>,
arg0: impl Literal,
arg1: impl Literal,
) -> Result<ArrayRef> {
let mut args = zip(arrays, &self.arg_types)
.map(|(array, sedona_type)| {
ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
})
.collect::<Result<Vec<_>>>()?;
let index = args.len();
args.push(Self::scalar_arg(arg0, &self.arg_types[index])?);
args.push(Self::scalar_arg(arg1, &self.arg_types[index + 1])?);
if let ColumnarValue::Array(array) = self.invoke(args)? {
Ok(array)
} else {
sedona_internal_err!("Expected array result from array/scalar invoke")
}
}
pub fn invoke_arrays(&self, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
let args = zip(arrays, &self.arg_types)
.map(|(array, sedona_type)| {
ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
})
.collect::<Result<_>>()?;
if let ColumnarValue::Array(array) = self.invoke(args)? {
Ok(array)
} else {
sedona_internal_err!("Expected array result from array invoke")
}
}
pub fn invoke(&self, args: Vec<ColumnarValue>) -> Result<ColumnarValue> {
let scalar_args = args
.iter()
.map(|arg| match arg {
ColumnarValue::Array(_) => None,
ColumnarValue::Scalar(scalar_value) => Some(scalar_value.clone()),
})
.collect::<Vec<_>>();
let return_type = self.return_type_with_scalars_inner(&scalar_args)?;
self.invoke_with_return_type(args, Some(return_type))
}
pub fn invoke_with_return_type(
&self,
args: Vec<ColumnarValue>,
return_type: Option<SedonaType>,
) -> Result<ColumnarValue> {
assert_eq!(args.len(), self.arg_types.len(), "Unexpected arg length");
let mut number_rows = 1;
for arg in &args {
match arg {
ColumnarValue::Array(array) => {
number_rows = array.len();
break;
}
_ => continue,
}
}
let return_type = match return_type {
Some(return_type) => return_type,
None => self.return_type()?,
};
let args = ScalarFunctionArgs {
args,
arg_fields: self.arg_fields(),
number_rows,
return_field: return_type.to_storage_field("", true)?.into(),
config_options: Arc::clone(&self.config_options),
};
self.udf.invoke_with_args(args)
}
fn scalar_arg(arg: impl Literal, sedona_type: &SedonaType) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(Self::scalar_lit(arg, sedona_type)?))
}
fn scalar_lit(arg: impl Literal, sedona_type: &SedonaType) -> Result<ScalarValue> {
if let Expr::Literal(scalar, _) = arg.lit() {
let is_geometry_or_geography = match sedona_type {
SedonaType::Wkb(_, _) | SedonaType::WkbView(_, _) => true,
SedonaType::Arrow(DataType::Struct(fields))
if fields.iter().map(|f| f.name()).collect::<Vec<_>>()
== vec!["item", "crs"] =>
{
true
}
_ => false,
};
if is_geometry_or_geography {
if let ScalarValue::Utf8(expected_wkt) = scalar {
Ok(create_scalar(expected_wkt.as_deref(), sedona_type))
} else if &scalar.data_type() == sedona_type.storage_type() {
Ok(scalar)
} else if scalar.is_null() {
Ok(create_scalar(None, sedona_type))
} else {
sedona_internal_err!("Can't interpret scalar {scalar} as type {sedona_type}")
}
} else {
scalar.cast_to(sedona_type.storage_type())
}
} else {
sedona_internal_err!("Can't use test scalar invoke where .lit() returns non-literal")
}
}
fn arg_fields(&self) -> Vec<FieldRef> {
self.arg_types
.iter()
.map(|data_type| data_type.to_storage_field("", false).map(Arc::new))
.collect::<Result<Vec<_>>>()
.unwrap()
}
}