Skip to main content

sedona_testing/
benchmark_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::{fmt::Debug, sync::Arc, vec};
18
19use arrow_array::{ArrayRef, Float64Array, Int64Array};
20use arrow_schema::DataType;
21
22use datafusion_common::{exec_datafusion_err, Result, ScalarValue};
23use datafusion_expr::{AggregateUDF, ScalarUDF};
24use geo_types::Rect;
25use rand::{distr::Uniform, rngs::StdRng, Rng, RngExt, SeedableRng};
26
27use sedona_common::sedona_internal_err;
28use sedona_geometry::types::GeometryTypeId;
29use sedona_schema::datatypes::{SedonaType, RASTER, WKB_GEOMETRY};
30use sedona_schema::raster::BandDataType;
31
32use crate::{
33    datagen::RandomPartitionedDataBuilder,
34    rasters::generate_tiled_rasters,
35    testers::{AggregateUdfTester, ScalarUdfTester},
36};
37
38/// The default number of rows per batch (the same as the DataFusion default)
39pub const ROWS_PER_BATCH: usize = 8192;
40
41/// The number of rows per batch to use for tiny size benchmarks
42pub const ROWS_PER_BATCH_TINY: usize = 1024;
43
44/// The default number of batches to use for small size benchmarks
45///
46/// This was chosen to ensure that most benchmarks run nicely with criterion
47/// defaults (target 5s, 100 samples).
48pub const NUM_BATCHES_SMALL: usize = 16;
49
50/// The default number of batches to use for tiny size benchmarks
51///
52/// Just one batch for testing that benchmarks actually run.
53pub const NUM_BATCHES_TINY: usize = 1;
54
55#[cfg(feature = "criterion")]
56pub mod benchmark {
57    use super::*;
58    use criterion::Criterion;
59    use sedona_expr::function_set::FunctionSet;
60
61    /// Benchmark a [ScalarUDF] using [Criterion]
62    ///
63    /// When built with the criterion feature, provides utilities for running a
64    /// basic benchmark on a [ScalarUDF] given [BenchmarkArgs]. This
65    /// basic benchmark currently has a hard-coded data size of 16 batches by
66    /// 8192 rows (==131,072 rows), which was chosen to ensure that most benchmarks
67    /// run nicely with criterion defaults (target 5s, 100 samples).
68    pub fn scalar(
69        c: &mut Criterion,
70        functions: &FunctionSet,
71        lib: &str,
72        name: &str,
73        config: impl Into<BenchmarkArgs>,
74    ) {
75        let not_found_err = format!("{name} was not found in function set");
76        let udf: ScalarUDF = functions
77            .scalar_udf(name)
78            .expect(&not_found_err)
79            .clone()
80            .into();
81        let data = config
82            .into()
83            .build_data(
84                Config::default().num_batches(),
85                Config::default().rows_per_batch(),
86            )
87            .unwrap();
88        c.bench_function(&data.make_label(lib, name), |b| {
89            b.iter(|| data.invoke_scalar(&udf).unwrap())
90        });
91    }
92
93    /// Benchmark a [AggregateUDF] using [Criterion]
94    ///
95    /// When built with the criterion feature, provides utilities for running a
96    /// basic benchmark on a [AggregateUDF] given [BenchmarkArgs]. This
97    /// shares a the default batch configuration with [scalar]. Because
98    /// aggregate functions can be invoked with varying combinations of
99    /// accumulation and merging of states, they should also be benchmarked
100    /// at a higher level. This benchmark primarily checks the accumulator.
101    pub fn aggregate(
102        c: &mut Criterion,
103        functions: &FunctionSet,
104        lib: &str,
105        name: &str,
106        config: impl Into<BenchmarkArgs>,
107    ) {
108        let not_found_err = format!("{name} was not found in function set");
109        let udf: AggregateUDF = functions
110            .aggregate_udf(name)
111            .expect(&not_found_err)
112            .clone()
113            .into();
114        let data = config
115            .into()
116            .build_data(
117                Config::default().num_batches(),
118                Config::default().rows_per_batch(),
119            )
120            .unwrap();
121        c.bench_function(&data.make_label(lib, name), |b| {
122            b.iter(|| data.invoke_aggregate(&udf).unwrap())
123        });
124    }
125
126    pub enum Config {
127        Tiny,
128        Small,
129    }
130
131    impl Default for Config {
132        fn default() -> Self {
133            #[cfg(debug_assertions)]
134            return Self::Tiny;
135
136            #[cfg(not(debug_assertions))]
137            return Self::Small;
138        }
139    }
140
141    impl Config {
142        fn num_batches(&self) -> usize {
143            match self {
144                Config::Tiny => NUM_BATCHES_TINY,
145                Config::Small => NUM_BATCHES_SMALL,
146            }
147        }
148
149        fn rows_per_batch(&self) -> usize {
150            match self {
151                Config::Tiny => ROWS_PER_BATCH_TINY,
152                Config::Small => ROWS_PER_BATCH,
153            }
154        }
155    }
156}
157
158/// Specification for benchmark arguments
159///
160/// This provides a concise definition of function input based on a
161/// combination of scalar/array arguments each specified by a [BenchmarkArgSpec].
162#[derive(Debug, Clone)]
163pub enum BenchmarkArgs {
164    /// Invoke a unary function with array input
165    Array(BenchmarkArgSpec),
166    /// Invoke a binary function with scalar and array input
167    ScalarArray(BenchmarkArgSpec, BenchmarkArgSpec),
168    /// Invoke a binary function with array and scalar input
169    ArrayScalar(BenchmarkArgSpec, BenchmarkArgSpec),
170    /// Invoke a binary function with two arrays
171    ArrayArray(BenchmarkArgSpec, BenchmarkArgSpec),
172    /// Invoke a function with an array and two scalar inputs
173    ArrayScalarScalar(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
174    /// Invoke a ternary function with two arrays and a scalar
175    ArrayArrayScalar(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
176    /// Invoke a ternary function with three arrays
177    ArrayArrayArray(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
178    /// Invoke a quaternary function with four arrays
179    ArrayArrayArrayArray(
180        BenchmarkArgSpec,
181        BenchmarkArgSpec,
182        BenchmarkArgSpec,
183        BenchmarkArgSpec,
184    ),
185}
186
187impl From<BenchmarkArgSpec> for BenchmarkArgs {
188    fn from(value: BenchmarkArgSpec) -> Self {
189        BenchmarkArgs::Array(value)
190    }
191}
192
193impl BenchmarkArgs {
194    /// Calculate the [SedonaType]s of the input arguments
195    fn sedona_types(&self) -> Vec<SedonaType> {
196        self.specs().iter().map(|col| col.sedona_type()).collect()
197    }
198
199    /// Build [BenchmarkData] with the specified number of batches
200    pub fn build_data(&self, num_batches: usize, rows_per_batch: usize) -> Result<BenchmarkData> {
201        let array_configs = match self {
202            BenchmarkArgs::Array(_)
203            | BenchmarkArgs::ArrayArray(_, _)
204            | BenchmarkArgs::ArrayArrayScalar(_, _, _)
205            | BenchmarkArgs::ArrayArrayArray(_, _, _)
206            | BenchmarkArgs::ArrayArrayArrayArray(_, _, _, _) => self.specs(),
207            BenchmarkArgs::ScalarArray(_, col)
208            | BenchmarkArgs::ArrayScalar(col, _)
209            | BenchmarkArgs::ArrayScalarScalar(col, _, _) => {
210                vec![col.clone()]
211            }
212        };
213        let scalar_configs = match self {
214            BenchmarkArgs::ScalarArray(col, _)
215            | BenchmarkArgs::ArrayScalar(_, col)
216            | BenchmarkArgs::ArrayArrayScalar(_, _, col) => {
217                vec![col.clone()]
218            }
219            BenchmarkArgs::ArrayScalarScalar(_, col0, col1) => {
220                vec![col0.clone(), col1.clone()]
221            }
222            _ => vec![],
223        };
224
225        let arrays = array_configs
226            .iter()
227            .enumerate()
228            .map(|(i, col)| col.build_arrays(i, num_batches, rows_per_batch))
229            .collect::<Result<Vec<_>>>()?;
230
231        let scalars = scalar_configs
232            .iter()
233            .enumerate()
234            .map(|(i, col)| col.build_scalar(i))
235            .collect::<Result<Vec<_>>>()?;
236
237        Ok(BenchmarkData {
238            config: self.clone(),
239            num_batches,
240            arrays,
241            scalars,
242        })
243    }
244
245    fn specs(&self) -> Vec<BenchmarkArgSpec> {
246        match self {
247            BenchmarkArgs::Array(col) => vec![col.clone()],
248            BenchmarkArgs::ScalarArray(col0, col1)
249            | BenchmarkArgs::ArrayScalar(col0, col1)
250            | BenchmarkArgs::ArrayArray(col0, col1) => {
251                vec![col0.clone(), col1.clone()]
252            }
253            BenchmarkArgs::ArrayScalarScalar(col0, col1, col2)
254            | BenchmarkArgs::ArrayArrayScalar(col0, col1, col2)
255            | BenchmarkArgs::ArrayArrayArray(col0, col1, col2) => {
256                vec![col0.clone(), col1.clone(), col2.clone()]
257            }
258            BenchmarkArgs::ArrayArrayArrayArray(col0, col1, col2, col3) => {
259                vec![col0.clone(), col1.clone(), col2.clone(), col3.clone()]
260            }
261        }
262    }
263}
264
265/// Specification of a single argument to a function
266///
267/// Geometries are generated using the [RandomPartitionedDataBuilder], which offers
268/// more specific options for generating random geometries.
269#[derive(Clone)]
270pub enum BenchmarkArgSpec {
271    /// Randomly generated point input
272    Point,
273    /// Randomly generated linestring input with a specified number of vertices
274    LineString(usize),
275    /// Randomly generated polygon input with a specified number of vertices
276    Polygon(usize),
277    /// Randomly generated polygon with hole input with a specified number of vertices
278    PolygonWithHole(usize),
279    /// Randomly generated linestring input with a specified number of vertices
280    MultiPoint(usize),
281    /// Randomly generated integer input with a given range of values
282    Int64(i64, i64),
283    /// Randomly generated floating point input with a given range of values
284    Float64(f64, f64),
285    /// Randomly generated integer input with a given range of values
286    Int32(i32, i32),
287    /// A transformation of any of the above based on a [ScalarUDF] accepting
288    /// a single argument
289    Transformed(Box<BenchmarkArgSpec>, ScalarUDF),
290    /// A string that will be a constant
291    String(String),
292    /// Randomly generated raster input with a specified width, height
293    Raster(usize, usize),
294}
295
296// Custom implementation of Debug because otherwise the output of Transformed()
297// is excessively verbose
298impl Debug for BenchmarkArgSpec {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        match self {
301            Self::Point => write!(f, "Point"),
302            Self::LineString(arg0) => f.debug_tuple("LineString").field(arg0).finish(),
303            Self::Polygon(arg0) => f.debug_tuple("Polygon").field(arg0).finish(),
304            Self::PolygonWithHole(arg0) => f.debug_tuple("PolygonWithHole").field(arg0).finish(),
305            Self::MultiPoint(arg0) => f.debug_tuple("MultiPoint").field(arg0).finish(),
306            Self::Int64(arg0, arg1) => f.debug_tuple("Int64").field(arg0).field(arg1).finish(),
307            Self::Float64(arg0, arg1) => f.debug_tuple("Float64").field(arg0).field(arg1).finish(),
308            Self::Int32(arg0, arg1) => f.debug_tuple("Int32").field(arg0).field(arg1).finish(),
309            Self::Transformed(inner, t) => write!(f, "{}({:?})", t.name(), inner),
310            Self::String(s) => write!(f, "String({s})"),
311            Self::Raster(w, h) => f.debug_tuple("Raster").field(w).field(h).finish(),
312        }
313    }
314}
315
316impl BenchmarkArgSpec {
317    /// The [SedonaType] of this argument
318    pub fn sedona_type(&self) -> SedonaType {
319        match self {
320            BenchmarkArgSpec::Point
321            | BenchmarkArgSpec::Polygon(_)
322            | BenchmarkArgSpec::PolygonWithHole(_)
323            | BenchmarkArgSpec::LineString(_)
324            | BenchmarkArgSpec::MultiPoint(_) => WKB_GEOMETRY,
325            BenchmarkArgSpec::Int64(_, _) => SedonaType::Arrow(DataType::Int64),
326            BenchmarkArgSpec::Float64(_, _) => SedonaType::Arrow(DataType::Float64),
327            BenchmarkArgSpec::Int32(_, _) => SedonaType::Arrow(DataType::Int32),
328            BenchmarkArgSpec::Transformed(inner, t) => {
329                let tester = ScalarUdfTester::new(t.clone(), vec![inner.sedona_type()]);
330                tester.return_type().unwrap()
331            }
332            BenchmarkArgSpec::String(_) => SedonaType::Arrow(DataType::Utf8),
333            BenchmarkArgSpec::Raster(_, _) => RASTER,
334        }
335    }
336
337    /// Build a [ScalarValue] for this argument
338    ///
339    /// This currently builds the same non-null scalar for each unique value
340    /// of i (the argument number).
341    pub fn build_scalar(&self, i: usize) -> Result<ScalarValue> {
342        let array = self.build_arrays(i, 1, 1)?;
343        ScalarValue::try_from_array(&array[0], 0)
344    }
345
346    /// Build a column of num_batches arrays
347    ///
348    /// This currently builds the same column for each unique value of i (the argument
349    /// number). The batch size is currently fixed to 8192 (the DataFusion default).
350    pub fn build_arrays(
351        &self,
352        i: usize,
353        num_batches: usize,
354        rows_per_batch: usize,
355    ) -> Result<Vec<ArrayRef>> {
356        match self {
357            BenchmarkArgSpec::Point => self.build_geometry(
358                i,
359                GeometryTypeId::Point,
360                num_batches,
361                1,
362                1,
363                rows_per_batch,
364                None,
365            ),
366            BenchmarkArgSpec::LineString(vertex_count) => self.build_geometry(
367                i,
368                GeometryTypeId::LineString,
369                num_batches,
370                *vertex_count,
371                1,
372                rows_per_batch,
373                None,
374            ),
375            BenchmarkArgSpec::Polygon(vertex_count) => self.build_geometry(
376                i,
377                GeometryTypeId::Polygon,
378                num_batches,
379                *vertex_count,
380                1,
381                rows_per_batch,
382                None,
383            ),
384            BenchmarkArgSpec::PolygonWithHole(vertex_count) => self.build_geometry(
385                i,
386                GeometryTypeId::Polygon,
387                num_batches,
388                *vertex_count,
389                1,
390                rows_per_batch,
391                // Currently only a single interior ring is possible.
392                Some(1.0),
393            ),
394            BenchmarkArgSpec::MultiPoint(part_count) => self.build_geometry(
395                i,
396                GeometryTypeId::MultiPoint,
397                num_batches,
398                1,
399                *part_count,
400                rows_per_batch,
401                None,
402            ),
403            BenchmarkArgSpec::Int64(lo, hi) => {
404                let mut rng = self.rng(i);
405                let dist = Uniform::new(lo, hi)
406                    .map_err(|e| exec_datafusion_err!("Invalid Int64 range [{lo}, {hi}): {e}"))?;
407                (0..num_batches)
408                    .map(|_| -> Result<ArrayRef> {
409                        let int64_array: Int64Array =
410                            (0..rows_per_batch).map(|_| rng.sample(dist)).collect();
411                        Ok(Arc::new(int64_array))
412                    })
413                    .collect()
414            }
415            BenchmarkArgSpec::Float64(lo, hi) => {
416                let mut rng = self.rng(i);
417                let dist = Uniform::new(lo, hi)
418                    .map_err(|e| exec_datafusion_err!("Invalid Float64 range [{lo}, {hi}): {e}"))?;
419                (0..num_batches)
420                    .map(|_| -> Result<ArrayRef> {
421                        let float64_array: Float64Array =
422                            (0..rows_per_batch).map(|_| rng.sample(dist)).collect();
423                        Ok(Arc::new(float64_array))
424                    })
425                    .collect()
426            }
427            BenchmarkArgSpec::Int32(lo, hi) => {
428                let mut rng = self.rng(i);
429                let dist = Uniform::new(lo, hi)
430                    .map_err(|e| exec_datafusion_err!("Invalid Int32 range [{lo}, {hi}): {e}"))?;
431                (0..num_batches)
432                    .map(|_| -> Result<ArrayRef> {
433                        let int32_array: arrow_array::Int32Array =
434                            (0..rows_per_batch).map(|_| rng.sample(dist)).collect();
435                        Ok(Arc::new(int32_array))
436                    })
437                    .collect()
438            }
439            BenchmarkArgSpec::Transformed(inner, t) => {
440                let inner_type = inner.sedona_type();
441                let inner_arrays = inner.build_arrays(i, num_batches, rows_per_batch)?;
442                let tester = ScalarUdfTester::new(t.clone(), vec![inner_type]);
443                inner_arrays
444                    .into_iter()
445                    .map(|array| tester.invoke_array(array))
446                    .collect::<Result<Vec<_>>>()
447            }
448            BenchmarkArgSpec::String(s) => {
449                let string_array = (0..num_batches)
450                    .map(|_| {
451                        let array = arrow_array::StringArray::from_iter_values(
452                            std::iter::repeat_n(s, rows_per_batch),
453                        );
454                        Ok(Arc::new(array) as ArrayRef)
455                    })
456                    .collect::<Result<Vec<_>>>()?;
457                Ok(string_array)
458            }
459            BenchmarkArgSpec::Raster(width, height) => {
460                let mut arrays = vec![];
461                for _ in 0..num_batches {
462                    let tile_size = (*width, *height);
463                    let tile_count = (rows_per_batch, 1);
464                    let raster = generate_tiled_rasters(
465                        tile_size,
466                        tile_count,
467                        BandDataType::UInt8,
468                        Some(43),
469                    )?;
470                    arrays.push(Arc::new(raster) as ArrayRef);
471                }
472                Ok(arrays)
473            }
474        }
475    }
476
477    #[allow(clippy::too_many_arguments)]
478    fn build_geometry(
479        &self,
480        i: usize,
481        geom_type: GeometryTypeId,
482        num_batches: usize,
483        vertex_count: usize,
484        num_parts_count: usize,
485        rows_per_batch: usize,
486        polygon_hole_rate: Option<f64>,
487    ) -> Result<Vec<ArrayRef>> {
488        let builder = RandomPartitionedDataBuilder::new()
489            .num_partitions(1)
490            .rows_per_batch(rows_per_batch)
491            .batches_per_partition(num_batches)
492            // Use a random geometry range that is also not unrealistic for geography
493            .bounds(Rect::new((-10.0, -10.0), (10.0, 10.0)))
494            .size_range((0.1, 2.0))
495            .vertices_per_linestring_range((vertex_count, vertex_count))
496            .num_parts_range((num_parts_count, num_parts_count))
497            .geometry_type(geom_type)
498            .polygon_hole_rate(polygon_hole_rate.unwrap_or_default())
499            // Currently just use WKB_GEOMETRY (we can generate a view type with
500            // Transformed)
501            .sedona_type(WKB_GEOMETRY);
502
503        builder
504            .partition_reader(self.rng(i), 0)
505            .map(|batch| -> Result<ArrayRef> { Ok(batch?.column(2).clone()) })
506            .collect()
507    }
508
509    fn rng(&self, i: usize) -> impl Rng {
510        StdRng::seed_from_u64(42 + i as u64)
511    }
512}
513
514/// Fully resolved data ready for running a benchmark
515///
516/// This struct contains the fully built data (such that benchmarks do not
517/// measure the time required to build the data) and has methods for invoking
518/// functions on it.
519pub struct BenchmarkData {
520    config: BenchmarkArgs,
521    num_batches: usize,
522    arrays: Vec<Vec<ArrayRef>>,
523    scalars: Vec<ScalarValue>,
524}
525
526impl BenchmarkData {
527    /// Create a label based on the library, function name, and configuration
528    pub fn make_label(&self, lib: &str, name: &str) -> String {
529        format!("{lib}-{name}-{:?}", self.config)
530    }
531
532    /// Invoke a scalar function on this data
533    pub fn invoke_scalar(&self, udf: &ScalarUDF) -> Result<()> {
534        let tester = ScalarUdfTester::new(udf.clone(), self.config.sedona_types().clone());
535
536        match self.config {
537            BenchmarkArgs::Array(_) => {
538                for i in 0..self.num_batches {
539                    tester.invoke_array(self.arrays[0][i].clone())?;
540                }
541            }
542            BenchmarkArgs::ScalarArray(_, _) => {
543                let scalar = &self.scalars[0];
544                for i in 0..self.num_batches {
545                    tester.invoke_scalar_array(scalar.clone(), self.arrays[0][i].clone())?;
546                }
547            }
548            BenchmarkArgs::ArrayScalar(_, _) => {
549                let scalar = &self.scalars[0];
550                for i in 0..self.num_batches {
551                    tester.invoke_array_scalar(self.arrays[0][i].clone(), scalar.clone())?;
552                }
553            }
554            BenchmarkArgs::ArrayArray(_, _) => {
555                for i in 0..self.num_batches {
556                    tester
557                        .invoke_array_array(self.arrays[0][i].clone(), self.arrays[1][i].clone())?;
558                }
559            }
560            BenchmarkArgs::ArrayScalarScalar(_, _, _) => {
561                let scalar0 = &self.scalars[0];
562                let scalar1 = &self.scalars[1];
563                for i in 0..self.num_batches {
564                    tester.invoke_array_scalar_scalar(
565                        self.arrays[0][i].clone(),
566                        scalar0.clone(),
567                        scalar1.clone(),
568                    )?;
569                }
570            }
571            BenchmarkArgs::ArrayArrayScalar(_, _, _) => {
572                for i in 0..self.num_batches {
573                    tester.invoke_array_array_scalar(
574                        self.arrays[0][i].clone(),
575                        self.arrays[1][i].clone(),
576                        self.scalars[0].clone(),
577                    )?;
578                }
579            }
580            BenchmarkArgs::ArrayArrayArray(_, _, _) => {
581                for i in 0..self.num_batches {
582                    tester.invoke_arrays(vec![
583                        self.arrays[0][i].clone(),
584                        self.arrays[1][i].clone(),
585                        self.arrays[2][i].clone(),
586                    ])?;
587                }
588            }
589            BenchmarkArgs::ArrayArrayArrayArray(_, _, _, _) => {
590                for i in 0..self.num_batches {
591                    tester.invoke_arrays(vec![
592                        self.arrays[0][i].clone(),
593                        self.arrays[1][i].clone(),
594                        self.arrays[2][i].clone(),
595                        self.arrays[3][i].clone(),
596                    ])?;
597                }
598            }
599        }
600
601        Ok(())
602    }
603
604    /// Invoke an aggregate function on this data
605    pub fn invoke_aggregate(&self, udf: &AggregateUDF) -> Result<ScalarValue> {
606        if !matches!(self.config, BenchmarkArgs::Array(_)) {
607            return sedona_internal_err!(
608                "invoke_aggregate() not implemented for {:?}",
609                self.config
610            );
611        }
612
613        let tester = AggregateUdfTester::new(udf.clone(), self.config.sedona_types().clone());
614        tester.aggregate(&self.arrays[0])
615    }
616}
617
618#[cfg(test)]
619mod test {
620    use arrow_array::{Array, StructArray};
621    use datafusion_common::cast::as_binary_array;
622    use datafusion_expr::{ColumnarValue, SimpleScalarUDF};
623    use geo_traits::Dimensions;
624    use rstest::rstest;
625    use sedona_geometry::{analyze::analyze_geometry, types::GeometryTypeAndDimensions};
626
627    use super::*;
628
629    #[test]
630    fn arg_spec_scalar() {
631        let spec = BenchmarkArgSpec::Point;
632        assert_eq!(spec.sedona_type(), WKB_GEOMETRY);
633
634        let scalar = spec.build_scalar(0).unwrap();
635
636        // Make sure this is deterministic
637        assert_eq!(spec.build_scalar(0).unwrap(), scalar);
638
639        // Make sure we generate different scalars for different columns
640        assert_ne!(spec.build_scalar(1).unwrap(), scalar);
641
642        if let ScalarValue::Binary(Some(wkb_bytes)) = scalar {
643            let wkb = wkb::reader::read_wkb(&wkb_bytes).unwrap();
644            let analysis = analyze_geometry(&wkb).unwrap();
645            assert_eq!(analysis.point_count, 1);
646            assert_eq!(
647                analysis.geometry_type,
648                GeometryTypeAndDimensions::new(GeometryTypeId::Point, Dimensions::Xy)
649            )
650        } else {
651            unreachable!("Unexpected scalar output {scalar}")
652        }
653    }
654
655    #[rstest]
656    fn arg_spec_geometry(
657        #[values(
658            (BenchmarkArgSpec::Point, GeometryTypeId::Point, 1),
659            (BenchmarkArgSpec::LineString(10), GeometryTypeId::LineString, 10),
660            (BenchmarkArgSpec::Polygon(10), GeometryTypeId::Polygon, 11),
661            (BenchmarkArgSpec::MultiPoint(10), GeometryTypeId::MultiPoint, 10),
662        )]
663        config: (BenchmarkArgSpec, GeometryTypeId, i64),
664    ) {
665        let (spec, geometry_type, point_count) = config;
666        assert_eq!(spec.sedona_type(), WKB_GEOMETRY);
667
668        let arrays = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
669        assert_eq!(arrays.len(), 2);
670
671        // Make sure this is deterministic
672        assert_eq!(spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap(), arrays);
673
674        // Make sure we generate different arrays for different argument numbers
675        assert_ne!(spec.build_arrays(1, 2, ROWS_PER_BATCH).unwrap(), arrays);
676
677        for array in arrays {
678            assert_eq!(array.data_type(), WKB_GEOMETRY.storage_type());
679            assert_eq!(array.len(), ROWS_PER_BATCH);
680
681            let binary_array = as_binary_array(&array).unwrap();
682            assert_eq!(binary_array.null_count(), 0);
683
684            for wkb_bytes in binary_array {
685                let wkb = wkb::reader::read_wkb(wkb_bytes.unwrap()).unwrap();
686                let analysis = analyze_geometry(&wkb).unwrap();
687                assert_eq!(analysis.point_count, point_count);
688                assert_eq!(
689                    analysis.geometry_type,
690                    GeometryTypeAndDimensions::new(geometry_type, Dimensions::Xy)
691                )
692            }
693        }
694    }
695
696    #[test]
697    fn arg_spec_float() {
698        let spec = BenchmarkArgSpec::Float64(1.0, 2.0);
699        assert_eq!(spec.sedona_type(), SedonaType::Arrow(DataType::Float64));
700
701        let arrays = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
702        assert_eq!(arrays.len(), 2);
703
704        // Make sure this is deterministic
705        assert_eq!(spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap(), arrays);
706
707        // Make sure we generate different arrays for different argument numbers
708        assert_ne!(spec.build_arrays(1, 2, ROWS_PER_BATCH).unwrap(), arrays);
709
710        for array in arrays {
711            assert_eq!(array.data_type(), &DataType::Float64);
712            assert_eq!(array.len(), ROWS_PER_BATCH);
713            assert_eq!(array.null_count(), 0);
714        }
715    }
716
717    #[test]
718    fn arg_spec_int() {
719        let spec = BenchmarkArgSpec::Int32(1, 10);
720        assert_eq!(spec.sedona_type(), SedonaType::Arrow(DataType::Int32));
721        let arrays = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
722        assert_eq!(arrays.len(), 2);
723        // Make sure this is deterministic
724        assert_eq!(spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap(), arrays);
725        // Make sure we generate different arrays for different argument numbers
726        assert_ne!(spec.build_arrays(1, 2, ROWS_PER_BATCH).unwrap(), arrays);
727        for array in arrays {
728            assert_eq!(array.data_type(), &DataType::Int32);
729            assert_eq!(array.len(), ROWS_PER_BATCH);
730            assert_eq!(array.null_count(), 0);
731        }
732    }
733
734    #[test]
735    fn arg_spec_transformed() {
736        let udf = SimpleScalarUDF::new(
737            "float32",
738            vec![DataType::Float64],
739            DataType::Float32,
740            datafusion_expr::Volatility::Immutable,
741            Arc::new(|args| -> Result<ColumnarValue> { args[0].cast_to(&DataType::Float32, None) }),
742        );
743
744        let spec =
745            BenchmarkArgSpec::Transformed(BenchmarkArgSpec::Float64(1.0, 2.0).into(), udf.into());
746        assert_eq!(spec.sedona_type(), SedonaType::Arrow(DataType::Float32));
747
748        assert_eq!(format!("{spec:?}"), "float32(Float64(1.0, 2.0))");
749        let arrays = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
750        assert_eq!(arrays.len(), 2);
751
752        // Make sure this is deterministic
753        assert_eq!(spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap(), arrays);
754
755        // Make sure we generate different arrays for different argument numbers
756        assert_ne!(spec.build_arrays(1, 2, ROWS_PER_BATCH).unwrap(), arrays);
757
758        for array in arrays {
759            assert_eq!(array.data_type(), &DataType::Float32);
760            assert_eq!(array.len(), ROWS_PER_BATCH);
761            assert_eq!(array.null_count(), 0);
762        }
763    }
764
765    #[test]
766    fn args_array() {
767        let spec = BenchmarkArgs::Array(BenchmarkArgSpec::Point);
768        assert_eq!(spec.sedona_types(), [WKB_GEOMETRY]);
769
770        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
771        assert_eq!(data.num_batches, 2);
772        assert_eq!(data.arrays.len(), 1);
773        assert_eq!(data.scalars.len(), 0);
774
775        assert_eq!(data.arrays[0].len(), 2);
776        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
777    }
778
779    #[test]
780    fn args_array_scalar() {
781        let spec = BenchmarkArgs::ArrayScalar(
782            BenchmarkArgSpec::Point,
783            BenchmarkArgSpec::Float64(1.0, 2.0),
784        );
785        assert_eq!(
786            spec.sedona_types(),
787            [WKB_GEOMETRY, SedonaType::Arrow(DataType::Float64)]
788        );
789
790        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
791        assert_eq!(data.num_batches, 2);
792
793        assert_eq!(data.arrays.len(), 1);
794        assert_eq!(data.arrays[0].len(), 2);
795        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
796
797        assert_eq!(data.scalars.len(), 1);
798        assert_eq!(data.scalars[0].data_type(), DataType::Float64);
799    }
800
801    #[test]
802    fn args_scalar_array() {
803        let spec = BenchmarkArgs::ScalarArray(
804            BenchmarkArgSpec::Point,
805            BenchmarkArgSpec::Float64(1.0, 2.0),
806        );
807        assert_eq!(
808            spec.sedona_types(),
809            [WKB_GEOMETRY, SedonaType::Arrow(DataType::Float64)]
810        );
811
812        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
813        assert_eq!(data.num_batches, 2);
814
815        assert_eq!(data.scalars.len(), 1);
816        assert_eq!(WKB_GEOMETRY.storage_type(), &data.scalars[0].data_type());
817
818        assert_eq!(data.arrays.len(), 1);
819        assert_eq!(data.arrays[0].len(), 2);
820        assert_eq!(data.arrays[0][0].data_type(), &DataType::Float64);
821    }
822
823    #[test]
824    fn args_array_array() {
825        let spec =
826            BenchmarkArgs::ArrayArray(BenchmarkArgSpec::Point, BenchmarkArgSpec::Float64(1.0, 2.0));
827        assert_eq!(
828            spec.sedona_types(),
829            [WKB_GEOMETRY, SedonaType::Arrow(DataType::Float64)]
830        );
831
832        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
833        assert_eq!(data.num_batches, 2);
834        assert_eq!(data.arrays.len(), 2);
835        assert_eq!(data.scalars.len(), 0);
836
837        assert_eq!(data.arrays[0].len(), 2);
838        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
839
840        assert_eq!(data.arrays[1].len(), 2);
841        assert_eq!(data.arrays[1][0].data_type(), &DataType::Float64);
842    }
843
844    #[test]
845    fn args_array_scalar_scalar() {
846        let spec = BenchmarkArgs::ArrayScalarScalar(
847            BenchmarkArgSpec::Point,
848            BenchmarkArgSpec::Float64(1.0, 2.0),
849            BenchmarkArgSpec::String("test".to_string()),
850        );
851        assert_eq!(
852            spec.sedona_types(),
853            [
854                WKB_GEOMETRY,
855                SedonaType::Arrow(DataType::Float64),
856                SedonaType::Arrow(DataType::Utf8)
857            ]
858        );
859
860        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
861        assert_eq!(data.num_batches, 2);
862        assert_eq!(data.arrays.len(), 1);
863        assert_eq!(data.scalars.len(), 2);
864        assert_eq!(data.arrays[0].len(), 2);
865        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
866        assert_eq!(data.scalars[0].data_type(), DataType::Float64);
867        assert_eq!(data.scalars[1].data_type(), DataType::Utf8);
868    }
869
870    #[test]
871    fn args_array_array_scalar() {
872        let spec = BenchmarkArgs::ArrayArrayScalar(
873            BenchmarkArgSpec::Point,
874            BenchmarkArgSpec::Point,
875            BenchmarkArgSpec::Float64(1.0, 2.0),
876        );
877        assert_eq!(
878            spec.sedona_types(),
879            [
880                WKB_GEOMETRY,
881                WKB_GEOMETRY,
882                SedonaType::Arrow(DataType::Float64)
883            ]
884        );
885
886        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
887        assert_eq!(data.num_batches, 2);
888        assert_eq!(data.arrays.len(), 3);
889        assert_eq!(data.scalars.len(), 1);
890        assert_eq!(data.arrays[0].len(), 2);
891        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
892        assert_eq!(data.arrays[1].len(), 2);
893        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[1][0].data_type());
894
895        assert_eq!(data.scalars[0].data_type(), DataType::Float64);
896    }
897
898    #[test]
899    fn args_array_array_array() {
900        let spec = BenchmarkArgs::ArrayArrayArray(
901            BenchmarkArgSpec::Point,
902            BenchmarkArgSpec::Point,
903            BenchmarkArgSpec::Float64(1.0, 2.0),
904        );
905        assert_eq!(
906            spec.sedona_types(),
907            [
908                WKB_GEOMETRY,
909                WKB_GEOMETRY,
910                SedonaType::Arrow(DataType::Float64)
911            ]
912        );
913
914        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
915        assert_eq!(data.num_batches, 2);
916        assert_eq!(data.arrays.len(), 3);
917        assert_eq!(data.scalars.len(), 0);
918        assert_eq!(data.arrays[0].len(), 2);
919        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
920        assert_eq!(data.arrays[1].len(), 2);
921        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[1][0].data_type());
922        assert_eq!(data.arrays[2].len(), 2);
923        assert_eq!(data.arrays[2][0].data_type(), &DataType::Float64);
924    }
925
926    #[test]
927    fn args_array_array_array_array() {
928        let spec = BenchmarkArgs::ArrayArrayArrayArray(
929            BenchmarkArgSpec::Float64(1.0, 2.0),
930            BenchmarkArgSpec::Float64(3.0, 4.0),
931            BenchmarkArgSpec::Float64(5.0, 6.0),
932            BenchmarkArgSpec::Float64(7.0, 8.0),
933        );
934        assert_eq!(
935            spec.sedona_types(),
936            [
937                SedonaType::Arrow(DataType::Float64),
938                SedonaType::Arrow(DataType::Float64),
939                SedonaType::Arrow(DataType::Float64),
940                SedonaType::Arrow(DataType::Float64)
941            ]
942        );
943
944        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
945        assert_eq!(data.num_batches, 2);
946        assert_eq!(data.arrays.len(), 4);
947        assert_eq!(data.scalars.len(), 0);
948        assert_eq!(data.arrays[0].len(), 2);
949        assert_eq!(data.arrays[0][0].data_type(), &DataType::Float64);
950        assert_eq!(data.arrays[1].len(), 2);
951        assert_eq!(data.arrays[1][0].data_type(), &DataType::Float64);
952        assert_eq!(data.arrays[2].len(), 2);
953        assert_eq!(data.arrays[2][0].data_type(), &DataType::Float64);
954        assert_eq!(data.arrays[3].len(), 2);
955        assert_eq!(data.arrays[3][0].data_type(), &DataType::Float64);
956    }
957
958    #[test]
959    fn arg_spec_raster() {
960        use sedona_raster::array::RasterStructArray;
961        use sedona_raster::traits::RasterRef;
962
963        let spec = BenchmarkArgSpec::Raster(10, 5);
964        assert_eq!(spec.sedona_type(), RASTER);
965        let data = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
966        assert_eq!(data.len(), 2);
967        assert_eq!(data[0].data_type(), RASTER.storage_type());
968
969        let raster_array = data[0].as_any().downcast_ref::<StructArray>().unwrap();
970        let rasters = RasterStructArray::new(raster_array);
971        assert_eq!(rasters.len(), ROWS_PER_BATCH);
972        let raster = rasters.get(0).unwrap();
973        let metadata = raster.metadata();
974        assert_eq!(metadata.width(), 10);
975        assert_eq!(metadata.height(), 5);
976    }
977}