1use std::{iter::zip, sync::Arc};
18
19use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
20use arrow_schema::{DataType, FieldRef, Schema};
21use datafusion_common::{
22 arrow::compute::kernels::concat::concat, config::ConfigOptions, Result, ScalarValue,
23};
24use datafusion_expr::{
25 function::{AccumulatorArgs, StateFieldsArgs},
26 Accumulator, AggregateUDF, ColumnarValue, EmitTo, Expr, GroupsAccumulator, Literal,
27 ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
28};
29use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
30use sedona_common::{sedona_internal_err, SedonaOptions};
31use sedona_schema::datatypes::SedonaType;
32
33use crate::{
34 compare::assert_scalar_equal,
35 create::{create_array, create_scalar},
36};
37
38pub struct AggregateUdfTester {
52 udf: AggregateUDF,
53 arg_types: Vec<SedonaType>,
54 mock_schema: Schema,
55 mock_exprs: Vec<Arc<dyn PhysicalExpr>>,
56}
57
58impl AggregateUdfTester {
59 pub fn new(udf: AggregateUDF, arg_types: Vec<SedonaType>) -> Self {
61 let arg_fields = arg_types
62 .iter()
63 .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
64 .collect::<Result<Vec<_>>>()
65 .unwrap();
66 let mock_schema = Schema::new(arg_fields);
67
68 let mock_exprs = (0..arg_types.len())
69 .map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col", i)) })
70 .collect::<Vec<_>>();
71 Self {
72 udf,
73 arg_types,
74 mock_schema,
75 mock_exprs,
76 }
77 }
78
79 pub fn return_type(&self) -> Result<SedonaType> {
81 let out_field = self.udf.return_field(&self.mock_schema.fields)?;
82 SedonaType::from_storage_field(&out_field)
83 }
84
85 pub fn aggregate_wkt(&self, batches: Vec<Vec<Option<&str>>>) -> Result<ScalarValue> {
87 let batches_array = batches
88 .into_iter()
89 .map(|batch| create_array(&batch, &self.arg_types[0]))
90 .collect::<Vec<_>>();
91 self.aggregate(&batches_array)
92 }
93
94 pub fn aggregate(&self, batches: &Vec<ArrayRef>) -> Result<ScalarValue> {
101 let state_schema = Arc::new(Schema::new(self.state_fields()?));
102 let mut state_accumulator = self.new_accumulator()?;
103
104 for batch in batches {
105 let mut batch_accumulator = self.new_accumulator()?;
106 batch_accumulator.update_batch(std::slice::from_ref(batch))?;
107 let state_batch_of_one = RecordBatch::try_new(
108 state_schema.clone(),
109 batch_accumulator
110 .state()?
111 .into_iter()
112 .map(|v| v.to_array())
113 .collect::<Result<Vec<_>>>()?,
114 )?;
115 state_accumulator.merge_batch(state_batch_of_one.columns())?;
116 }
117
118 state_accumulator.evaluate()
119 }
120
121 pub fn aggregate_groups(
129 &self,
130 batches: &Vec<ArrayRef>,
131 group_indices: Vec<usize>,
132 opt_filter: Option<&Vec<bool>>,
133 emit_sizes: Vec<usize>,
134 ) -> Result<ArrayRef> {
135 let state_schema = Arc::new(Schema::new(self.state_fields()?));
136 let mut state_accumulator = self.new_groups_accumulator()?;
137 let total_num_groups = group_indices.iter().max().unwrap_or(&0) + 1;
138
139 let total_input_rows: usize = batches.iter().map(|a| a.len()).sum();
141 assert_eq!(total_input_rows, group_indices.len());
142 if let Some(filter) = opt_filter {
143 assert_eq!(total_input_rows, filter.len());
144 }
145 if !emit_sizes.is_empty() {
146 assert_eq!(emit_sizes.iter().sum::<usize>(), total_num_groups);
147 }
148
149 let mut offset = 0;
150 for batch in batches {
151 let mut batch_accumulator = self.new_groups_accumulator()?;
152 let opt_filter_array = opt_filter.map(|filter_vec| {
153 filter_vec[offset..(offset + batch.len())]
154 .iter()
155 .collect::<BooleanArray>()
156 });
157 batch_accumulator.update_batch(
158 std::slice::from_ref(batch),
159 &group_indices[offset..(offset + batch.len())],
160 opt_filter_array.as_ref(),
161 total_num_groups,
162 )?;
163 offset += batch.len();
164
165 let state_batch = RecordBatch::try_new(
168 state_schema.clone(),
169 batch_accumulator.state(datafusion_expr::EmitTo::All)?,
170 )?;
171 state_accumulator.merge_batch(
172 state_batch.columns(),
173 &(0..total_num_groups).collect::<Vec<_>>(),
174 None,
175 total_num_groups,
176 )?;
177 }
178
179 if emit_sizes.is_empty() {
180 state_accumulator.evaluate(datafusion_expr::EmitTo::All)
181 } else {
182 let arrays = emit_sizes
183 .iter()
184 .map(|emit_size| state_accumulator.evaluate(EmitTo::First(*emit_size)))
185 .collect::<Result<Vec<_>>>()?;
186 let arrays_ref = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
187 Ok(concat(&arrays_ref)?)
188 }
189 }
190
191 fn new_accumulator(&self) -> Result<Box<dyn Accumulator>> {
192 let accumulator_args = self.accumulator_args()?;
193 self.udf.accumulator(accumulator_args)
194 }
195
196 fn new_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
197 assert!(self
198 .udf
199 .groups_accumulator_supported(self.accumulator_args()?));
200 self.udf.create_groups_accumulator(self.accumulator_args()?)
201 }
202
203 fn accumulator_args(&self) -> Result<AccumulatorArgs<'_>> {
204 Ok(AccumulatorArgs {
205 return_field: self.udf.return_field(self.mock_schema.fields())?,
206 schema: &self.mock_schema,
207 ignore_nulls: true,
208 order_bys: &[],
209 is_reversed: false,
210 name: "",
211 is_distinct: false,
212 exprs: &self.mock_exprs,
213 expr_fields: &[],
214 })
215 }
216
217 fn state_fields(&self) -> Result<Vec<FieldRef>> {
218 let state_field_args = StateFieldsArgs {
219 name: "",
220 input_fields: self.mock_schema.fields(),
221 return_field: self.udf.return_field(self.mock_schema.fields())?,
222 ordering_fields: &[],
223 is_distinct: false,
224 };
225 self.udf.state_fields(state_field_args)
226 }
227}
228
229pub struct ScalarUdfTester {
241 udf: ScalarUDF,
242 arg_types: Vec<SedonaType>,
243 config_options: Arc<ConfigOptions>,
244}
245
246impl ScalarUdfTester {
247 pub fn new(udf: ScalarUDF, arg_types: Vec<SedonaType>) -> Self {
249 let mut config_options = ConfigOptions::default();
250 let sedona_options = SedonaOptions::default();
251 config_options.extensions.insert(sedona_options);
252 Self {
253 udf,
254 arg_types,
255 config_options: Arc::new(config_options),
256 }
257 }
258
259 pub fn config_options(&self) -> &ConfigOptions {
264 &self.config_options
265 }
266
267 pub fn config_options_mut(&mut self) -> &mut ConfigOptions {
272 Arc::get_mut(&mut self.config_options).expect("ConfigOptions is shared")
274 }
275
276 pub fn sedona_options(&self) -> &SedonaOptions {
278 self.config_options
279 .extensions
280 .get::<SedonaOptions>()
281 .expect("SedonaOptions does not exist")
282 }
283
284 pub fn sedona_options_mut(&mut self) -> &mut SedonaOptions {
286 self.config_options_mut()
287 .extensions
288 .get_mut::<SedonaOptions>()
289 .expect("SedonaOptions does not exist")
290 }
291
292 pub fn assert_return_type(&self, data_type: impl TryInto<SedonaType>) {
298 let expected = match data_type.try_into() {
299 Ok(t) => t,
300 Err(_) => panic!("Failed to convert to SedonaType"),
301 };
302 assert_eq!(self.return_type().unwrap(), expected)
303 }
304
305 pub fn assert_scalar_result_equals(&self, actual: impl Literal, expected: impl Literal) {
310 self.assert_scalar_result_equals_inner(actual, expected, None);
311 }
312
313 pub fn assert_scalar_result_equals_with_return_type(
317 &self,
318 actual: impl Literal,
319 expected: impl Literal,
320 return_type: SedonaType,
321 ) {
322 self.assert_scalar_result_equals_inner(actual, expected, Some(return_type));
323 }
324
325 fn assert_scalar_result_equals_inner(
326 &self,
327 actual: impl Literal,
328 expected: impl Literal,
329 return_type: Option<SedonaType>,
330 ) {
331 let return_type = return_type.unwrap_or_else(|| self.return_type().unwrap());
332 let actual = Self::scalar_lit(actual, &return_type).unwrap();
333 let expected = Self::scalar_lit(expected, &return_type).unwrap();
334 assert_scalar_equal(&actual, &expected);
335 }
336
337 pub fn return_type(&self) -> Result<SedonaType> {
339 let scalar_arguments = vec![None; self.arg_types.len()];
340 self.return_type_with_scalars_inner(&scalar_arguments)
341 }
342
343 pub fn return_type_with_scalar(&self, arg0: Option<impl Literal>) -> Result<SedonaType> {
347 let scalar_arguments = vec![arg0
348 .map(|x| Self::scalar_lit(x, &self.arg_types[0]))
349 .transpose()?];
350 self.return_type_with_scalars_inner(&scalar_arguments)
351 }
352
353 pub fn return_type_with_scalar_scalar(
357 &self,
358 arg0: Option<impl Literal>,
359 arg1: Option<impl Literal>,
360 ) -> Result<SedonaType> {
361 let scalar_arguments = vec![
362 arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
363 .transpose()?,
364 arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
365 .transpose()?,
366 ];
367 self.return_type_with_scalars_inner(&scalar_arguments)
368 }
369
370 pub fn return_type_with_scalar_scalar_scalar(
374 &self,
375 arg0: Option<impl Literal>,
376 arg1: Option<impl Literal>,
377 arg2: Option<impl Literal>,
378 ) -> Result<SedonaType> {
379 let scalar_arguments = vec![
380 arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
381 .transpose()?,
382 arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
383 .transpose()?,
384 arg2.map(|x| Self::scalar_lit(x, &self.arg_types[2]))
385 .transpose()?,
386 ];
387 self.return_type_with_scalars_inner(&scalar_arguments)
388 }
389
390 fn return_type_with_scalars_inner(
391 &self,
392 scalar_arguments: &[Option<ScalarValue>],
393 ) -> Result<SedonaType> {
394 let arg_fields = self
395 .arg_types
396 .iter()
397 .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
398 .collect::<Result<Vec<_>>>()?;
399
400 let scalar_arguments_ref: Vec<Option<&ScalarValue>> =
401 scalar_arguments.iter().map(|x| x.as_ref()).collect();
402 let args = ReturnFieldArgs {
403 arg_fields: &arg_fields,
404 scalar_arguments: &scalar_arguments_ref,
405 };
406 let return_field = self.udf.return_field_from_args(args)?;
407 SedonaType::from_storage_field(&return_field)
408 }
409
410 pub fn invoke_scalar(&self, arg: impl Literal) -> Result<ScalarValue> {
412 let scalar_arg = Self::scalar_lit(arg, &self.arg_types[0])?;
413
414 let return_type = self
416 .return_type_with_scalars_inner(&[Some(scalar_arg.clone())])
417 .ok();
418
419 let args = vec![ColumnarValue::Scalar(scalar_arg)];
420 if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
421 Ok(scalar)
422 } else {
423 sedona_internal_err!("Expected scalar result from scalar invoke")
424 }
425 }
426
427 pub fn invoke_wkb_scalar(&self, wkt_value: Option<&str>) -> Result<ScalarValue> {
429 self.invoke_scalar(create_scalar(wkt_value, &self.arg_types[0]))
430 }
431
432 pub fn invoke_scalar_scalar<T0: Literal, T1: Literal>(
434 &self,
435 arg0: T0,
436 arg1: T1,
437 ) -> Result<ScalarValue> {
438 let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
439 let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
440
441 let return_type = self
443 .return_type_with_scalars_inner(&[Some(scalar_arg0.clone()), Some(scalar_arg1.clone())])
444 .ok();
445
446 let args = vec![
447 ColumnarValue::Scalar(scalar_arg0),
448 ColumnarValue::Scalar(scalar_arg1),
449 ];
450 if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
451 Ok(scalar)
452 } else {
453 sedona_internal_err!("Expected scalar result from binary scalar invoke")
454 }
455 }
456
457 pub fn invoke_scalar_scalar_scalar<T0: Literal, T1: Literal, T2: Literal>(
459 &self,
460 arg0: T0,
461 arg1: T1,
462 arg2: T2,
463 ) -> Result<ScalarValue> {
464 let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
465 let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
466 let scalar_arg2 = Self::scalar_lit(arg2, &self.arg_types[2])?;
467
468 let return_type = self
470 .return_type_with_scalars_inner(&[
471 Some(scalar_arg0.clone()),
472 Some(scalar_arg1.clone()),
473 Some(scalar_arg2.clone()),
474 ])
475 .ok();
476
477 let args = vec![
478 ColumnarValue::Scalar(scalar_arg0),
479 ColumnarValue::Scalar(scalar_arg1),
480 ColumnarValue::Scalar(scalar_arg2),
481 ];
482 if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
483 Ok(scalar)
484 } else {
485 sedona_internal_err!("Expected scalar result from binary scalar invoke")
486 }
487 }
488
489 pub fn invoke_wkb_array(&self, wkb_values: Vec<Option<&str>>) -> Result<ArrayRef> {
491 self.invoke_array(create_array(&wkb_values, &self.arg_types[0]))
492 }
493
494 pub fn invoke_wkb_array_scalar(
496 &self,
497 wkb_values: Vec<Option<&str>>,
498 arg: impl Literal,
499 ) -> Result<ArrayRef> {
500 let wkb_array = create_array(&wkb_values, &self.arg_types[0]);
501 self.invoke_arrays_scalar(vec![wkb_array], arg)
502 }
503
504 pub fn invoke_array(&self, array: ArrayRef) -> Result<ArrayRef> {
506 self.invoke_arrays(vec![array])
507 }
508
509 pub fn invoke_array_scalar(&self, array: ArrayRef, arg: impl Literal) -> Result<ArrayRef> {
511 self.invoke_arrays_scalar(vec![array], arg)
512 }
513
514 pub fn invoke_array_scalar_scalar(
516 &self,
517 array: ArrayRef,
518 arg0: impl Literal,
519 arg1: impl Literal,
520 ) -> Result<ArrayRef> {
521 self.invoke_arrays_scalar_scalar(vec![array], arg0, arg1)
522 }
523
524 pub fn invoke_scalar_array(&self, arg: impl Literal, array: ArrayRef) -> Result<ArrayRef> {
526 self.invoke_scalar_arrays(arg, vec![array])
527 }
528
529 pub fn invoke_array_array(&self, array0: ArrayRef, array1: ArrayRef) -> Result<ArrayRef> {
531 self.invoke_arrays(vec![array0, array1])
532 }
533
534 pub fn invoke_array_array_scalar(
536 &self,
537 array0: ArrayRef,
538 array1: ArrayRef,
539 arg: impl Literal,
540 ) -> Result<ArrayRef> {
541 self.invoke_arrays_scalar(vec![array0, array1], arg)
542 }
543
544 fn invoke_scalar_arrays(&self, arg: impl Literal, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
545 let mut args = zip(arrays, &self.arg_types)
546 .map(|(array, sedona_type)| {
547 ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
548 })
549 .collect::<Result<Vec<_>>>()?;
550 let index = args.len();
551 args.insert(0, Self::scalar_arg(arg, &self.arg_types[index])?);
552
553 if let ColumnarValue::Array(array) = self.invoke(args)? {
554 Ok(array)
555 } else {
556 sedona_internal_err!("Expected array result from scalar/array invoke")
557 }
558 }
559
560 fn invoke_arrays_scalar(&self, arrays: Vec<ArrayRef>, arg: impl Literal) -> Result<ArrayRef> {
561 let mut args = zip(arrays, &self.arg_types)
562 .map(|(array, sedona_type)| {
563 ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
564 })
565 .collect::<Result<Vec<_>>>()?;
566 let index = args.len();
567 args.push(Self::scalar_arg(arg, &self.arg_types[index])?);
568
569 if let ColumnarValue::Array(array) = self.invoke(args)? {
570 Ok(array)
571 } else {
572 sedona_internal_err!("Expected array result from array/scalar invoke")
573 }
574 }
575
576 fn invoke_arrays_scalar_scalar(
577 &self,
578 arrays: Vec<ArrayRef>,
579 arg0: impl Literal,
580 arg1: impl Literal,
581 ) -> Result<ArrayRef> {
582 let mut args = zip(arrays, &self.arg_types)
583 .map(|(array, sedona_type)| {
584 ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
585 })
586 .collect::<Result<Vec<_>>>()?;
587 let index = args.len();
588 args.push(Self::scalar_arg(arg0, &self.arg_types[index])?);
589 args.push(Self::scalar_arg(arg1, &self.arg_types[index + 1])?);
590
591 if let ColumnarValue::Array(array) = self.invoke(args)? {
592 Ok(array)
593 } else {
594 sedona_internal_err!("Expected array result from array/scalar invoke")
595 }
596 }
597
598 pub fn invoke_arrays(&self, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
600 let args = zip(arrays, &self.arg_types)
601 .map(|(array, sedona_type)| {
602 ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
603 })
604 .collect::<Result<_>>()?;
605
606 if let ColumnarValue::Array(array) = self.invoke(args)? {
607 Ok(array)
608 } else {
609 sedona_internal_err!("Expected array result from array invoke")
610 }
611 }
612
613 pub fn invoke(&self, args: Vec<ColumnarValue>) -> Result<ColumnarValue> {
614 let scalar_args = args
615 .iter()
616 .map(|arg| match arg {
617 ColumnarValue::Array(_) => None,
618 ColumnarValue::Scalar(scalar_value) => Some(scalar_value.clone()),
619 })
620 .collect::<Vec<_>>();
621
622 let return_type = self.return_type_with_scalars_inner(&scalar_args)?;
623 self.invoke_with_return_type(args, Some(return_type))
624 }
625
626 pub fn invoke_with_return_type(
627 &self,
628 args: Vec<ColumnarValue>,
629 return_type: Option<SedonaType>,
630 ) -> Result<ColumnarValue> {
631 assert_eq!(args.len(), self.arg_types.len(), "Unexpected arg length");
632
633 let mut number_rows = 1;
634 for arg in &args {
635 match arg {
636 ColumnarValue::Array(array) => {
637 number_rows = array.len();
638 break;
639 }
640 _ => continue,
641 }
642 }
643
644 let return_type = match return_type {
645 Some(return_type) => return_type,
646 None => self.return_type()?,
647 };
648
649 let args = ScalarFunctionArgs {
650 args,
651 arg_fields: self.arg_fields(),
652 number_rows,
653 return_field: return_type.to_storage_field("", true)?.into(),
654 config_options: Arc::clone(&self.config_options),
655 };
656
657 self.udf.invoke_with_args(args)
658 }
659
660 fn scalar_arg(arg: impl Literal, sedona_type: &SedonaType) -> Result<ColumnarValue> {
661 Ok(ColumnarValue::Scalar(Self::scalar_lit(arg, sedona_type)?))
662 }
663
664 fn scalar_lit(arg: impl Literal, sedona_type: &SedonaType) -> Result<ScalarValue> {
665 if let Expr::Literal(scalar, _) = arg.lit() {
666 let is_geometry_or_geography = match sedona_type {
667 SedonaType::Wkb(_, _) | SedonaType::WkbView(_, _) => true,
668 SedonaType::Arrow(DataType::Struct(fields))
669 if fields.iter().map(|f| f.name()).collect::<Vec<_>>()
670 == vec!["item", "crs"] =>
671 {
672 true
673 }
674 _ => false,
675 };
676
677 if is_geometry_or_geography {
678 if let ScalarValue::Utf8(expected_wkt) = scalar {
679 Ok(create_scalar(expected_wkt.as_deref(), sedona_type))
680 } else if &scalar.data_type() == sedona_type.storage_type() {
681 Ok(scalar)
682 } else if scalar.is_null() {
683 Ok(create_scalar(None, sedona_type))
684 } else {
685 sedona_internal_err!("Can't interpret scalar {scalar} as type {sedona_type}")
686 }
687 } else {
688 scalar.cast_to(sedona_type.storage_type())
689 }
690 } else {
691 sedona_internal_err!("Can't use test scalar invoke where .lit() returns non-literal")
692 }
693 }
694
695 fn arg_fields(&self) -> Vec<FieldRef> {
696 self.arg_types
697 .iter()
698 .map(|data_type| data_type.to_storage_field("", false).map(Arc::new))
699 .collect::<Result<Vec<_>>>()
700 .unwrap()
701 }
702}