Skip to main content

sedona_testing/
testers.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::{iter::zip, sync::Arc};
18
19use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
20use arrow_schema::{DataType, FieldRef, Schema};
21use datafusion_common::{
22    arrow::compute::kernels::concat::concat, config::ConfigOptions, Result, ScalarValue,
23};
24use datafusion_expr::{
25    function::{AccumulatorArgs, StateFieldsArgs},
26    Accumulator, AggregateUDF, ColumnarValue, EmitTo, Expr, GroupsAccumulator, Literal,
27    ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
28};
29use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
30use sedona_common::{sedona_internal_err, SedonaOptions};
31use sedona_schema::datatypes::SedonaType;
32
33use crate::{
34    compare::assert_scalar_equal,
35    create::{create_array, create_scalar},
36};
37
38/// Low-level tester for aggregate functions
39///
40/// This struct provides a means by which to run a simple check of an
41/// aggregate implementation by approximating one method DataFusion
42/// might use to perform the aggregation. Whereas DataFusion may arrange
43/// various calls to accumulate, state, and update_batch to optimize
44/// for different cases, this tester is always created by aggregating
45/// states that were in turn created from accumulating one batch.
46///
47/// This is not a replacement for testing at a higher level using
48/// DataFusion's actual aggregate implementation but provides
49/// a useful mechanism to ensure all the pieces of an accumulator
50/// are plugged in.
51pub struct AggregateUdfTester {
52    udf: AggregateUDF,
53    arg_types: Vec<SedonaType>,
54    mock_schema: Schema,
55    mock_exprs: Vec<Arc<dyn PhysicalExpr>>,
56}
57
58impl AggregateUdfTester {
59    /// Create a new tester
60    pub fn new(udf: AggregateUDF, arg_types: Vec<SedonaType>) -> Self {
61        let arg_fields = arg_types
62            .iter()
63            .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
64            .collect::<Result<Vec<_>>>()
65            .unwrap();
66        let mock_schema = Schema::new(arg_fields);
67
68        let mock_exprs = (0..arg_types.len())
69            .map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col", i)) })
70            .collect::<Vec<_>>();
71        Self {
72            udf,
73            arg_types,
74            mock_schema,
75            mock_exprs,
76        }
77    }
78
79    /// Compute the return type
80    pub fn return_type(&self) -> Result<SedonaType> {
81        let out_field = self.udf.return_field(&self.mock_schema.fields)?;
82        SedonaType::from_storage_field(&out_field)
83    }
84
85    /// Perform a simple aggregation using WKT as geometry input
86    pub fn aggregate_wkt(&self, batches: Vec<Vec<Option<&str>>>) -> Result<ScalarValue> {
87        let batches_array = batches
88            .into_iter()
89            .map(|batch| create_array(&batch, &self.arg_types[0]))
90            .collect::<Vec<_>>();
91        self.aggregate(&batches_array)
92    }
93
94    /// Perform a simple aggregation
95    ///
96    /// Each batch in batches is accumulated with its own accumulator
97    /// and serialized into its own state, after which the states are accumulated
98    /// in batches of one. This has the effect of testing all the pieces of
99    /// an aggregator in a somewhat configurable/predictable way.
100    pub fn aggregate(&self, batches: &Vec<ArrayRef>) -> Result<ScalarValue> {
101        let state_schema = Arc::new(Schema::new(self.state_fields()?));
102        let mut state_accumulator = self.new_accumulator()?;
103
104        for batch in batches {
105            let mut batch_accumulator = self.new_accumulator()?;
106            batch_accumulator.update_batch(std::slice::from_ref(batch))?;
107            let state_batch_of_one = RecordBatch::try_new(
108                state_schema.clone(),
109                batch_accumulator
110                    .state()?
111                    .into_iter()
112                    .map(|v| v.to_array())
113                    .collect::<Result<Vec<_>>>()?,
114            )?;
115            state_accumulator.merge_batch(state_batch_of_one.columns())?;
116        }
117
118        state_accumulator.evaluate()
119    }
120
121    /// Perform a simple grouped aggregation
122    ///
123    /// Each batch in batches is accumulated with its own groups accumulator
124    /// and serialized into its own state, after which the state resulting
125    /// from each batch is merged into the final groups accumulator. This
126    /// has the effect of testing the pieces of a groups accumulator in a
127    /// predictable/debug-friendly (if artificial) way.
128    pub fn aggregate_groups(
129        &self,
130        batches: &Vec<ArrayRef>,
131        group_indices: Vec<usize>,
132        opt_filter: Option<&Vec<bool>>,
133        emit_sizes: Vec<usize>,
134    ) -> Result<ArrayRef> {
135        let state_schema = Arc::new(Schema::new(self.state_fields()?));
136        let mut state_accumulator = self.new_groups_accumulator()?;
137        let total_num_groups = group_indices.iter().max().unwrap_or(&0) + 1;
138
139        // Check input
140        let total_input_rows: usize = batches.iter().map(|a| a.len()).sum();
141        assert_eq!(total_input_rows, group_indices.len());
142        if let Some(filter) = opt_filter {
143            assert_eq!(total_input_rows, filter.len());
144        }
145        if !emit_sizes.is_empty() {
146            assert_eq!(emit_sizes.iter().sum::<usize>(), total_num_groups);
147        }
148
149        let mut offset = 0;
150        for batch in batches {
151            let mut batch_accumulator = self.new_groups_accumulator()?;
152            let opt_filter_array = opt_filter.map(|filter_vec| {
153                filter_vec[offset..(offset + batch.len())]
154                    .iter()
155                    .collect::<BooleanArray>()
156            });
157            batch_accumulator.update_batch(
158                std::slice::from_ref(batch),
159                &group_indices[offset..(offset + batch.len())],
160                opt_filter_array.as_ref(),
161                total_num_groups,
162            )?;
163            offset += batch.len();
164
165            // For the state accumulator the input is ordered such that
166            // each row is group i for i in (0..total_num_groups)
167            let state_batch = RecordBatch::try_new(
168                state_schema.clone(),
169                batch_accumulator.state(datafusion_expr::EmitTo::All)?,
170            )?;
171            state_accumulator.merge_batch(
172                state_batch.columns(),
173                &(0..total_num_groups).collect::<Vec<_>>(),
174                None,
175                total_num_groups,
176            )?;
177        }
178
179        if emit_sizes.is_empty() {
180            state_accumulator.evaluate(datafusion_expr::EmitTo::All)
181        } else {
182            let arrays = emit_sizes
183                .iter()
184                .map(|emit_size| state_accumulator.evaluate(EmitTo::First(*emit_size)))
185                .collect::<Result<Vec<_>>>()?;
186            let arrays_ref = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
187            Ok(concat(&arrays_ref)?)
188        }
189    }
190
191    fn new_accumulator(&self) -> Result<Box<dyn Accumulator>> {
192        let accumulator_args = self.accumulator_args()?;
193        self.udf.accumulator(accumulator_args)
194    }
195
196    fn new_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
197        assert!(self
198            .udf
199            .groups_accumulator_supported(self.accumulator_args()?));
200        self.udf.create_groups_accumulator(self.accumulator_args()?)
201    }
202
203    fn accumulator_args(&self) -> Result<AccumulatorArgs<'_>> {
204        Ok(AccumulatorArgs {
205            return_field: self.udf.return_field(self.mock_schema.fields())?,
206            schema: &self.mock_schema,
207            ignore_nulls: true,
208            order_bys: &[],
209            is_reversed: false,
210            name: "",
211            is_distinct: false,
212            exprs: &self.mock_exprs,
213            expr_fields: &[],
214        })
215    }
216
217    fn state_fields(&self) -> Result<Vec<FieldRef>> {
218        let state_field_args = StateFieldsArgs {
219            name: "",
220            input_fields: self.mock_schema.fields(),
221            return_field: self.udf.return_field(self.mock_schema.fields())?,
222            ordering_fields: &[],
223            is_distinct: false,
224        };
225        self.udf.state_fields(state_field_args)
226    }
227}
228
229/// Low-level tester for scalar functions
230///
231/// This struct provides a means by which to run a simple check of an
232/// scalar UDF implementation by simulating how DataFusion might call it.
233///
234/// This is not a replacement for testing at a higher level using DataFusion's
235/// actual implementation but provides a useful mechanism to ensure all the
236/// pieces of an scalar UDF are plugged in.
237///
238/// Note that arguments are always cast to the values passed [Self::new]:
239/// to test different combinations of argument types, use a new tester.
240pub struct ScalarUdfTester {
241    udf: ScalarUDF,
242    arg_types: Vec<SedonaType>,
243    config_options: Arc<ConfigOptions>,
244}
245
246impl ScalarUdfTester {
247    /// Create a new tester
248    pub fn new(udf: ScalarUDF, arg_types: Vec<SedonaType>) -> Self {
249        let mut config_options = ConfigOptions::default();
250        let sedona_options = SedonaOptions::default();
251        config_options.extensions.insert(sedona_options);
252        Self {
253            udf,
254            arg_types,
255            config_options: Arc::new(config_options),
256        }
257    }
258
259    /// Returns the [`ConfigOptions`] used when invoking the UDF.
260    ///
261    /// This is the same structure DataFusion threads through [`ScalarFunctionArgs`].
262    /// Sedona-specific options are stored in `config_options.extensions`.
263    pub fn config_options(&self) -> &ConfigOptions {
264        &self.config_options
265    }
266
267    /// Returns a mutable reference to the [`ConfigOptions`] used when invoking the UDF.
268    ///
269    /// Use this to tweak DataFusion options or to insert/update Sedona options via
270    /// `config_options.extensions` before calling the tester's `invoke_*` helpers.
271    pub fn config_options_mut(&mut self) -> &mut ConfigOptions {
272        // config_options can only be owned by this tester, so it's safe to get a mutable reference.
273        Arc::get_mut(&mut self.config_options).expect("ConfigOptions is shared")
274    }
275
276    /// Returns the [`SedonaOptions`] stored in `config_options.extensions`, if present.
277    pub fn sedona_options(&self) -> &SedonaOptions {
278        self.config_options
279            .extensions
280            .get::<SedonaOptions>()
281            .expect("SedonaOptions does not exist")
282    }
283
284    /// Returns a mutable reference to the [`SedonaOptions`] stored in `config_options.extensions`, if present.
285    pub fn sedona_options_mut(&mut self) -> &mut SedonaOptions {
286        self.config_options_mut()
287            .extensions
288            .get_mut::<SedonaOptions>()
289            .expect("SedonaOptions does not exist")
290    }
291
292    /// Assert the return type of the function for the argument types used
293    /// to construct this tester
294    ///
295    /// Both [SedonaType] or [DataType] objects can be used as the expected
296    /// data type.
297    pub fn assert_return_type(&self, data_type: impl TryInto<SedonaType>) {
298        let expected = match data_type.try_into() {
299            Ok(t) => t,
300            Err(_) => panic!("Failed to convert to SedonaType"),
301        };
302        assert_eq!(self.return_type().unwrap(), expected)
303    }
304
305    /// Assert the result of invoking this function
306    ///
307    /// Both actual and expected are interpreted according to the calculated
308    /// return type (notably, WKT is interpreted as geometry or geography output).
309    pub fn assert_scalar_result_equals(&self, actual: impl Literal, expected: impl Literal) {
310        self.assert_scalar_result_equals_inner(actual, expected, None);
311    }
312
313    /// Assert the result of invoking this function with the return type specified
314    ///
315    /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`.
316    pub fn assert_scalar_result_equals_with_return_type(
317        &self,
318        actual: impl Literal,
319        expected: impl Literal,
320        return_type: SedonaType,
321    ) {
322        self.assert_scalar_result_equals_inner(actual, expected, Some(return_type));
323    }
324
325    fn assert_scalar_result_equals_inner(
326        &self,
327        actual: impl Literal,
328        expected: impl Literal,
329        return_type: Option<SedonaType>,
330    ) {
331        let return_type = return_type.unwrap_or_else(|| self.return_type().unwrap());
332        let actual = Self::scalar_lit(actual, &return_type).unwrap();
333        let expected = Self::scalar_lit(expected, &return_type).unwrap();
334        assert_scalar_equal(&actual, &expected);
335    }
336
337    /// Compute the return type
338    pub fn return_type(&self) -> Result<SedonaType> {
339        let scalar_arguments = vec![None; self.arg_types.len()];
340        self.return_type_with_scalars_inner(&scalar_arguments)
341    }
342
343    /// Compute the return type from one scalar argument
344    ///
345    /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`.
346    pub fn return_type_with_scalar(&self, arg0: Option<impl Literal>) -> Result<SedonaType> {
347        let scalar_arguments = vec![arg0
348            .map(|x| Self::scalar_lit(x, &self.arg_types[0]))
349            .transpose()?];
350        self.return_type_with_scalars_inner(&scalar_arguments)
351    }
352
353    /// Compute the return type from two scalar arguments
354    ///
355    /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`.
356    pub fn return_type_with_scalar_scalar(
357        &self,
358        arg0: Option<impl Literal>,
359        arg1: Option<impl Literal>,
360    ) -> Result<SedonaType> {
361        let scalar_arguments = vec![
362            arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
363                .transpose()?,
364            arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
365                .transpose()?,
366        ];
367        self.return_type_with_scalars_inner(&scalar_arguments)
368    }
369
370    /// Compute the return type from three scalar arguments
371    ///
372    /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`.
373    pub fn return_type_with_scalar_scalar_scalar(
374        &self,
375        arg0: Option<impl Literal>,
376        arg1: Option<impl Literal>,
377        arg2: Option<impl Literal>,
378    ) -> Result<SedonaType> {
379        let scalar_arguments = vec![
380            arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
381                .transpose()?,
382            arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
383                .transpose()?,
384            arg2.map(|x| Self::scalar_lit(x, &self.arg_types[2]))
385                .transpose()?,
386        ];
387        self.return_type_with_scalars_inner(&scalar_arguments)
388    }
389
390    fn return_type_with_scalars_inner(
391        &self,
392        scalar_arguments: &[Option<ScalarValue>],
393    ) -> Result<SedonaType> {
394        let arg_fields = self
395            .arg_types
396            .iter()
397            .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
398            .collect::<Result<Vec<_>>>()?;
399
400        let scalar_arguments_ref: Vec<Option<&ScalarValue>> =
401            scalar_arguments.iter().map(|x| x.as_ref()).collect();
402        let args = ReturnFieldArgs {
403            arg_fields: &arg_fields,
404            scalar_arguments: &scalar_arguments_ref,
405        };
406        let return_field = self.udf.return_field_from_args(args)?;
407        SedonaType::from_storage_field(&return_field)
408    }
409
410    /// Invoke this function with a scalar
411    pub fn invoke_scalar(&self, arg: impl Literal) -> Result<ScalarValue> {
412        let scalar_arg = Self::scalar_lit(arg, &self.arg_types[0])?;
413
414        // Some UDF calculate the return type from the input scalar arguments, so try it first.
415        let return_type = self
416            .return_type_with_scalars_inner(&[Some(scalar_arg.clone())])
417            .ok();
418
419        let args = vec![ColumnarValue::Scalar(scalar_arg)];
420        if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
421            Ok(scalar)
422        } else {
423            sedona_internal_err!("Expected scalar result from scalar invoke")
424        }
425    }
426
427    /// Invoke this function with a geometry scalar
428    pub fn invoke_wkb_scalar(&self, wkt_value: Option<&str>) -> Result<ScalarValue> {
429        self.invoke_scalar(create_scalar(wkt_value, &self.arg_types[0]))
430    }
431
432    /// Invoke this function with two scalars
433    pub fn invoke_scalar_scalar<T0: Literal, T1: Literal>(
434        &self,
435        arg0: T0,
436        arg1: T1,
437    ) -> Result<ScalarValue> {
438        let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
439        let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
440
441        // Some UDF calculate the return type from the input scalar arguments, so try it first.
442        let return_type = self
443            .return_type_with_scalars_inner(&[Some(scalar_arg0.clone()), Some(scalar_arg1.clone())])
444            .ok();
445
446        let args = vec![
447            ColumnarValue::Scalar(scalar_arg0),
448            ColumnarValue::Scalar(scalar_arg1),
449        ];
450        if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
451            Ok(scalar)
452        } else {
453            sedona_internal_err!("Expected scalar result from binary scalar invoke")
454        }
455    }
456
457    /// Invoke this function with three scalars
458    pub fn invoke_scalar_scalar_scalar<T0: Literal, T1: Literal, T2: Literal>(
459        &self,
460        arg0: T0,
461        arg1: T1,
462        arg2: T2,
463    ) -> Result<ScalarValue> {
464        let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
465        let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
466        let scalar_arg2 = Self::scalar_lit(arg2, &self.arg_types[2])?;
467
468        // Some UDF calculate the return type from the input scalar arguments, so try it first.
469        let return_type = self
470            .return_type_with_scalars_inner(&[
471                Some(scalar_arg0.clone()),
472                Some(scalar_arg1.clone()),
473                Some(scalar_arg2.clone()),
474            ])
475            .ok();
476
477        let args = vec![
478            ColumnarValue::Scalar(scalar_arg0),
479            ColumnarValue::Scalar(scalar_arg1),
480            ColumnarValue::Scalar(scalar_arg2),
481        ];
482        if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
483            Ok(scalar)
484        } else {
485            sedona_internal_err!("Expected scalar result from binary scalar invoke")
486        }
487    }
488
489    /// Invoke this function with a geometry array
490    pub fn invoke_wkb_array(&self, wkb_values: Vec<Option<&str>>) -> Result<ArrayRef> {
491        self.invoke_array(create_array(&wkb_values, &self.arg_types[0]))
492    }
493
494    /// Invoke this function with a geometry array and a scalar
495    pub fn invoke_wkb_array_scalar(
496        &self,
497        wkb_values: Vec<Option<&str>>,
498        arg: impl Literal,
499    ) -> Result<ArrayRef> {
500        let wkb_array = create_array(&wkb_values, &self.arg_types[0]);
501        self.invoke_arrays_scalar(vec![wkb_array], arg)
502    }
503
504    /// Invoke this function with an array
505    pub fn invoke_array(&self, array: ArrayRef) -> Result<ArrayRef> {
506        self.invoke_arrays(vec![array])
507    }
508
509    /// Invoke a binary function with an array and a scalar
510    pub fn invoke_array_scalar(&self, array: ArrayRef, arg: impl Literal) -> Result<ArrayRef> {
511        self.invoke_arrays_scalar(vec![array], arg)
512    }
513
514    /// Invoke a binary function with an array, and two scalars
515    pub fn invoke_array_scalar_scalar(
516        &self,
517        array: ArrayRef,
518        arg0: impl Literal,
519        arg1: impl Literal,
520    ) -> Result<ArrayRef> {
521        self.invoke_arrays_scalar_scalar(vec![array], arg0, arg1)
522    }
523
524    /// Invoke a binary function with a scalar and an array
525    pub fn invoke_scalar_array(&self, arg: impl Literal, array: ArrayRef) -> Result<ArrayRef> {
526        self.invoke_scalar_arrays(arg, vec![array])
527    }
528
529    /// Invoke a binary function with two arrays
530    pub fn invoke_array_array(&self, array0: ArrayRef, array1: ArrayRef) -> Result<ArrayRef> {
531        self.invoke_arrays(vec![array0, array1])
532    }
533
534    /// Invoke a binary function with two arrays and a scalar
535    pub fn invoke_array_array_scalar(
536        &self,
537        array0: ArrayRef,
538        array1: ArrayRef,
539        arg: impl Literal,
540    ) -> Result<ArrayRef> {
541        self.invoke_arrays_scalar(vec![array0, array1], arg)
542    }
543
544    fn invoke_scalar_arrays(&self, arg: impl Literal, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
545        let mut args = zip(arrays, &self.arg_types)
546            .map(|(array, sedona_type)| {
547                ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
548            })
549            .collect::<Result<Vec<_>>>()?;
550        let index = args.len();
551        args.insert(0, Self::scalar_arg(arg, &self.arg_types[index])?);
552
553        if let ColumnarValue::Array(array) = self.invoke(args)? {
554            Ok(array)
555        } else {
556            sedona_internal_err!("Expected array result from scalar/array invoke")
557        }
558    }
559
560    fn invoke_arrays_scalar(&self, arrays: Vec<ArrayRef>, arg: impl Literal) -> Result<ArrayRef> {
561        let mut args = zip(arrays, &self.arg_types)
562            .map(|(array, sedona_type)| {
563                ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
564            })
565            .collect::<Result<Vec<_>>>()?;
566        let index = args.len();
567        args.push(Self::scalar_arg(arg, &self.arg_types[index])?);
568
569        if let ColumnarValue::Array(array) = self.invoke(args)? {
570            Ok(array)
571        } else {
572            sedona_internal_err!("Expected array result from array/scalar invoke")
573        }
574    }
575
576    fn invoke_arrays_scalar_scalar(
577        &self,
578        arrays: Vec<ArrayRef>,
579        arg0: impl Literal,
580        arg1: impl Literal,
581    ) -> Result<ArrayRef> {
582        let mut args = zip(arrays, &self.arg_types)
583            .map(|(array, sedona_type)| {
584                ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
585            })
586            .collect::<Result<Vec<_>>>()?;
587        let index = args.len();
588        args.push(Self::scalar_arg(arg0, &self.arg_types[index])?);
589        args.push(Self::scalar_arg(arg1, &self.arg_types[index + 1])?);
590
591        if let ColumnarValue::Array(array) = self.invoke(args)? {
592            Ok(array)
593        } else {
594            sedona_internal_err!("Expected array result from array/scalar invoke")
595        }
596    }
597
598    // Invoke a function with a set of arrays
599    pub fn invoke_arrays(&self, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
600        let args = zip(arrays, &self.arg_types)
601            .map(|(array, sedona_type)| {
602                ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
603            })
604            .collect::<Result<_>>()?;
605
606        if let ColumnarValue::Array(array) = self.invoke(args)? {
607            Ok(array)
608        } else {
609            sedona_internal_err!("Expected array result from array invoke")
610        }
611    }
612
613    pub fn invoke(&self, args: Vec<ColumnarValue>) -> Result<ColumnarValue> {
614        let scalar_args = args
615            .iter()
616            .map(|arg| match arg {
617                ColumnarValue::Array(_) => None,
618                ColumnarValue::Scalar(scalar_value) => Some(scalar_value.clone()),
619            })
620            .collect::<Vec<_>>();
621
622        let return_type = self.return_type_with_scalars_inner(&scalar_args)?;
623        self.invoke_with_return_type(args, Some(return_type))
624    }
625
626    pub fn invoke_with_return_type(
627        &self,
628        args: Vec<ColumnarValue>,
629        return_type: Option<SedonaType>,
630    ) -> Result<ColumnarValue> {
631        assert_eq!(args.len(), self.arg_types.len(), "Unexpected arg length");
632
633        let mut number_rows = 1;
634        for arg in &args {
635            match arg {
636                ColumnarValue::Array(array) => {
637                    number_rows = array.len();
638                    break;
639                }
640                _ => continue,
641            }
642        }
643
644        let return_type = match return_type {
645            Some(return_type) => return_type,
646            None => self.return_type()?,
647        };
648
649        let args = ScalarFunctionArgs {
650            args,
651            arg_fields: self.arg_fields(),
652            number_rows,
653            return_field: return_type.to_storage_field("", true)?.into(),
654            config_options: Arc::clone(&self.config_options),
655        };
656
657        self.udf.invoke_with_args(args)
658    }
659
660    fn scalar_arg(arg: impl Literal, sedona_type: &SedonaType) -> Result<ColumnarValue> {
661        Ok(ColumnarValue::Scalar(Self::scalar_lit(arg, sedona_type)?))
662    }
663
664    fn scalar_lit(arg: impl Literal, sedona_type: &SedonaType) -> Result<ScalarValue> {
665        if let Expr::Literal(scalar, _) = arg.lit() {
666            let is_geometry_or_geography = match sedona_type {
667                SedonaType::Wkb(_, _) | SedonaType::WkbView(_, _) => true,
668                SedonaType::Arrow(DataType::Struct(fields))
669                    if fields.iter().map(|f| f.name()).collect::<Vec<_>>()
670                        == vec!["item", "crs"] =>
671                {
672                    true
673                }
674                _ => false,
675            };
676
677            if is_geometry_or_geography {
678                if let ScalarValue::Utf8(expected_wkt) = scalar {
679                    Ok(create_scalar(expected_wkt.as_deref(), sedona_type))
680                } else if &scalar.data_type() == sedona_type.storage_type() {
681                    Ok(scalar)
682                } else if scalar.is_null() {
683                    Ok(create_scalar(None, sedona_type))
684                } else {
685                    sedona_internal_err!("Can't interpret scalar {scalar} as type {sedona_type}")
686                }
687            } else {
688                scalar.cast_to(sedona_type.storage_type())
689            }
690        } else {
691            sedona_internal_err!("Can't use test scalar invoke where .lit() returns non-literal")
692        }
693    }
694
695    fn arg_fields(&self) -> Vec<FieldRef> {
696        self.arg_types
697            .iter()
698            .map(|data_type| data_type.to_storage_field("", false).map(Arc::new))
699            .collect::<Result<Vec<_>>>()
700            .unwrap()
701    }
702}