Skip to main content

datafusion_spark/function/math/
modulus.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Scalar, new_null_array};
19use arrow::compute::kernels::numeric::add;
20use arrow::compute::kernels::{
21    cmp::{eq, lt},
22    numeric::rem,
23    zip::zip,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
27use datafusion_expr::{
28    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
29};
30
31/// Computes `rem(left, right)` with divide-by-zero handling.
32/// In ANSI mode, any zero divisor causes an error.
33/// In legacy mode (ANSI off), zero divisors are replaced with NULL before
34/// computing the remainder, so those positions return NULL while others
35/// compute normally.
36fn try_rem(
37    left: &arrow::array::ArrayRef,
38    right: &arrow::array::ArrayRef,
39    enable_ansi_mode: bool,
40) -> Result<arrow::array::ArrayRef> {
41    if enable_ansi_mode {
42        Ok(rem(left, right)?)
43    } else {
44        // In legacy mode, null out zero divisors so that division by zero
45        // returns NULL instead of erroring (integers) or returning NaN (floats).
46        let zero = ScalarValue::new_zero(right.data_type())?.to_array()?;
47        let zero = Scalar::new(zero);
48        let null = Scalar::new(new_null_array(right.data_type(), 1));
49        let is_zero = eq(right, &zero)?;
50        let safe_right = zip(&is_zero, &null, right)?;
51        Ok(rem(left, &safe_right)?)
52    }
53}
54
55/// Spark-compatible `mod` function
56/// In ANSI mode, division by zero throws an error.
57/// In legacy mode, division by zero returns NULL (Spark behavior).
58pub fn spark_mod(
59    args: &[ColumnarValue],
60    enable_ansi_mode: bool,
61) -> Result<ColumnarValue> {
62    assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments");
63    let args = ColumnarValue::values_to_arrays(args)?;
64    let result = try_rem(&args[0], &args[1], enable_ansi_mode)?;
65    Ok(ColumnarValue::Array(result))
66}
67
68/// Spark-compatible `pmod` function
69/// In ANSI mode, division by zero throws an error.
70/// In legacy mode, division by zero returns NULL (Spark behavior).
71pub fn spark_pmod(
72    args: &[ColumnarValue],
73    enable_ansi_mode: bool,
74) -> Result<ColumnarValue> {
75    assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments");
76    let args = ColumnarValue::values_to_arrays(args)?;
77    let left = &args[0];
78    let right = &args[1];
79    let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
80    let result = try_rem(left, right, enable_ansi_mode)?;
81    let neg = lt(&result, &zero)?;
82    let plus = zip(&neg, right, &zero)?;
83    let result = add(&plus, &result)?;
84    let result = try_rem(&result, right, enable_ansi_mode)?;
85    Ok(ColumnarValue::Array(result))
86}
87
88/// SparkMod implements the Spark-compatible modulo function
89#[derive(Debug, PartialEq, Eq, Hash)]
90pub struct SparkMod {
91    signature: Signature,
92}
93
94impl Default for SparkMod {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl SparkMod {
101    pub fn new() -> Self {
102        Self {
103            signature: Signature::numeric(2, Volatility::Immutable),
104        }
105    }
106}
107
108impl ScalarUDFImpl for SparkMod {
109    fn name(&self) -> &str {
110        "mod"
111    }
112
113    fn signature(&self) -> &Signature {
114        &self.signature
115    }
116
117    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
118        assert_eq_or_internal_err!(
119            arg_types.len(),
120            2,
121            "mod expects exactly two arguments"
122        );
123
124        // Return the same type as the first argument for simplicity
125        // Arrow's rem function handles type promotion internally
126        Ok(arg_types[0].clone())
127    }
128
129    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
130        spark_mod(&args.args, args.config_options.execution.enable_ansi_mode)
131    }
132}
133
134/// SparkMod implements the Spark-compatible modulo function
135#[derive(Debug, PartialEq, Eq, Hash)]
136pub struct SparkPmod {
137    signature: Signature,
138}
139
140impl Default for SparkPmod {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl SparkPmod {
147    pub fn new() -> Self {
148        Self {
149            signature: Signature::numeric(2, Volatility::Immutable),
150        }
151    }
152}
153
154impl ScalarUDFImpl for SparkPmod {
155    fn name(&self) -> &str {
156        "pmod"
157    }
158
159    fn signature(&self) -> &Signature {
160        &self.signature
161    }
162
163    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
164        assert_eq_or_internal_err!(
165            arg_types.len(),
166            2,
167            "pmod expects exactly two arguments"
168        );
169
170        // Return the same type as the first argument for simplicity
171        // Arrow's rem function handles type promotion internally
172        Ok(arg_types[0].clone())
173    }
174
175    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
176        spark_pmod(&args.args, args.config_options.execution.enable_ansi_mode)
177    }
178}
179
180#[cfg(test)]
181mod test {
182    use std::sync::Arc;
183
184    use super::*;
185    use arrow::array::*;
186    use datafusion_common::ScalarValue;
187
188    #[test]
189    fn test_mod_int32() {
190        let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
191        let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
192
193        let left_value = ColumnarValue::Array(Arc::new(left));
194        let right_value = ColumnarValue::Array(Arc::new(right));
195
196        let result = spark_mod(&[left_value, right_value], false).unwrap();
197
198        if let ColumnarValue::Array(result_array) = result {
199            let result_int32 =
200                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
201            assert_eq!(result_int32.value(0), 1); // 10 % 3 = 1
202            assert_eq!(result_int32.value(1), 1); // 7 % 2 = 1
203            assert_eq!(result_int32.value(2), 3); // 15 % 4 = 3
204            assert!(result_int32.is_null(3)); // None % 5 = None
205        } else {
206            panic!("Expected array result");
207        }
208    }
209
210    #[test]
211    fn test_mod_int64() {
212        let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
213        let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
214
215        let left_value = ColumnarValue::Array(Arc::new(left));
216        let right_value = ColumnarValue::Array(Arc::new(right));
217
218        let result = spark_mod(&[left_value, right_value], false).unwrap();
219
220        if let ColumnarValue::Array(result_array) = result {
221            let result_int64 =
222                result_array.as_any().downcast_ref::<Int64Array>().unwrap();
223            assert_eq!(result_int64.value(0), 10); // 100 % 30 = 10
224            assert_eq!(result_int64.value(1), 0); // 50 % 25 = 0
225            assert_eq!(result_int64.value(2), 20); // 200 % 60 = 20
226        } else {
227            panic!("Expected array result");
228        }
229    }
230
231    #[test]
232    fn test_mod_float64() {
233        let left = Float64Array::from(vec![
234            Some(10.5),
235            Some(7.2),
236            Some(15.8),
237            Some(f64::NAN),
238            Some(f64::INFINITY),
239            Some(5.0),
240            Some(5.0),
241            Some(f64::NAN),
242            Some(f64::INFINITY),
243            Some(10.5),
244            Some(15.8),
245        ]);
246        let right = Float64Array::from(vec![
247            Some(3.0),
248            Some(2.5),
249            Some(4.2),
250            Some(2.0),
251            Some(2.0),
252            Some(f64::NAN),
253            Some(f64::INFINITY),
254            Some(f64::INFINITY),
255            Some(f64::NAN),
256            Some(0.0),
257            Some(0.0),
258        ]);
259
260        let left_value = ColumnarValue::Array(Arc::new(left));
261        let right_value = ColumnarValue::Array(Arc::new(right));
262
263        let result = spark_mod(&[left_value, right_value], false).unwrap();
264
265        if let ColumnarValue::Array(result_array) = result {
266            let result_float64 = result_array
267                .as_any()
268                .downcast_ref::<Float64Array>()
269                .unwrap();
270            // Regular cases
271            assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 % 3.0 = 1.5
272            assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); // 7.2 % 2.5 = 2.2
273            assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); // 15.8 % 4.2 = 3.2
274            // nan % 2.0 = nan
275            assert!(result_float64.value(3).is_nan());
276            // inf % 2.0 = nan (IEEE 754)
277            assert!(result_float64.value(4).is_nan());
278            // 5.0 % nan = nan
279            assert!(result_float64.value(5).is_nan());
280            // 5.0 % inf = 5.0
281            assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
282            // nan % inf = nan
283            assert!(result_float64.value(7).is_nan());
284            // inf % nan = nan
285            assert!(result_float64.value(8).is_nan());
286            // Division by zero returns NULL
287            assert!(result_float64.is_null(9)); // 10.5 % 0.0 = NULL
288            assert!(result_float64.is_null(10)); // 15.8 % 0.0 = NULL
289        } else {
290            panic!("Expected array result");
291        }
292    }
293
294    #[test]
295    fn test_mod_float32() {
296        let left = Float32Array::from(vec![
297            Some(10.5),
298            Some(7.2),
299            Some(15.8),
300            Some(f32::NAN),
301            Some(f32::INFINITY),
302            Some(5.0),
303            Some(5.0),
304            Some(f32::NAN),
305            Some(f32::INFINITY),
306            Some(10.5),
307            Some(15.8),
308        ]);
309        let right = Float32Array::from(vec![
310            Some(3.0),
311            Some(2.5),
312            Some(4.2),
313            Some(2.0),
314            Some(2.0),
315            Some(f32::NAN),
316            Some(f32::INFINITY),
317            Some(f32::INFINITY),
318            Some(f32::NAN),
319            Some(0.0),
320            Some(0.0),
321        ]);
322
323        let left_value = ColumnarValue::Array(Arc::new(left));
324        let right_value = ColumnarValue::Array(Arc::new(right));
325
326        let result = spark_mod(&[left_value, right_value], false).unwrap();
327
328        if let ColumnarValue::Array(result_array) = result {
329            let result_float32 = result_array
330                .as_any()
331                .downcast_ref::<Float32Array>()
332                .unwrap();
333            // Regular cases
334            assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 % 3.0 = 1.5
335            assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); // 7.2 % 2.5 = 2.2
336            assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 % 4.2 = 3.2
337            // nan % 2.0 = nan
338            assert!(result_float32.value(3).is_nan());
339            // inf % 2.0 = nan (IEEE 754)
340            assert!(result_float32.value(4).is_nan());
341            // 5.0 % nan = nan
342            assert!(result_float32.value(5).is_nan());
343            // 5.0 % inf = 5.0
344            assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
345            // nan % inf = nan
346            assert!(result_float32.value(7).is_nan());
347            // inf % nan = nan
348            assert!(result_float32.value(8).is_nan());
349            // Division by zero returns NULL
350            assert!(result_float32.is_null(9)); // 10.5 % 0.0 = NULL
351            assert!(result_float32.is_null(10)); // 15.8 % 0.0 = NULL
352        } else {
353            panic!("Expected array result");
354        }
355    }
356
357    #[test]
358    fn test_mod_scalar() {
359        let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
360        let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
361
362        let left_value = ColumnarValue::Array(Arc::new(left));
363
364        let result = spark_mod(&[left_value, right_value], false).unwrap();
365
366        if let ColumnarValue::Array(result_array) = result {
367            let result_int32 =
368                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
369            assert_eq!(result_int32.value(0), 1); // 10 % 3 = 1
370            assert_eq!(result_int32.value(1), 1); // 7 % 3 = 1
371            assert_eq!(result_int32.value(2), 0); // 15 % 3 = 0
372        } else {
373            panic!("Expected array result");
374        }
375    }
376
377    #[test]
378    fn test_mod_wrong_arg_count() {
379        let left = Int32Array::from(vec![Some(10)]);
380        let left_value = ColumnarValue::Array(Arc::new(left));
381
382        let result = spark_mod(&[left_value], false);
383        assert!(result.is_err());
384    }
385
386    #[test]
387    fn test_mod_zero_division_legacy() {
388        // In legacy mode (ANSI off), division by zero returns NULL per-element
389        let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
390        let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
391
392        let left_value = ColumnarValue::Array(Arc::new(left));
393        let right_value = ColumnarValue::Array(Arc::new(right));
394
395        let result = spark_mod(&[left_value, right_value], false).unwrap();
396
397        if let ColumnarValue::Array(result_array) = result {
398            let result_int32 =
399                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
400            assert!(result_int32.is_null(0)); // 10 % 0 = NULL
401            assert_eq!(result_int32.value(1), 1); // 7 % 2 = 1
402            assert_eq!(result_int32.value(2), 3); // 15 % 4 = 3
403        } else {
404            panic!("Expected array result");
405        }
406    }
407
408    #[test]
409    fn test_mod_zero_division_ansi() {
410        // In ANSI mode, division by zero should error
411        let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
412        let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
413
414        let left_value = ColumnarValue::Array(Arc::new(left));
415        let right_value = ColumnarValue::Array(Arc::new(right));
416
417        let result = spark_mod(&[left_value, right_value], true);
418        assert!(result.is_err());
419    }
420
421    // PMOD tests
422    #[test]
423    fn test_pmod_int32() {
424        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
425        let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
426
427        let left_value = ColumnarValue::Array(Arc::new(left));
428        let right_value = ColumnarValue::Array(Arc::new(right));
429
430        let result = spark_pmod(&[left_value, right_value], false).unwrap();
431
432        if let ColumnarValue::Array(result_array) = result {
433            let result_int32 =
434                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
435            assert_eq!(result_int32.value(0), 1); // 10 pmod 3 = 1
436            assert_eq!(result_int32.value(1), 2); // -7 pmod 3 = 2 (positive remainder)
437            assert_eq!(result_int32.value(2), 3); // 15 pmod 4 = 3
438            assert_eq!(result_int32.value(3), 1); // -15 pmod 4 = 1 (positive remainder)
439            assert!(result_int32.is_null(4)); // None pmod 5 = None
440        } else {
441            panic!("Expected array result");
442        }
443    }
444
445    #[test]
446    fn test_pmod_int64() {
447        let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
448        let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
449
450        let left_value = ColumnarValue::Array(Arc::new(left));
451        let right_value = ColumnarValue::Array(Arc::new(right));
452
453        let result = spark_pmod(&[left_value, right_value], false).unwrap();
454
455        if let ColumnarValue::Array(result_array) = result {
456            let result_int64 =
457                result_array.as_any().downcast_ref::<Int64Array>().unwrap();
458            assert_eq!(result_int64.value(0), 10); // 100 pmod 30 = 10
459            assert_eq!(result_int64.value(1), 10); // -50 pmod 30 = 10 (positive remainder)
460            assert_eq!(result_int64.value(2), 20); // 200 pmod 60 = 20
461            assert_eq!(result_int64.value(3), 40); // -200 pmod 60 = 40 (positive remainder)
462        } else {
463            panic!("Expected array result");
464        }
465    }
466
467    #[test]
468    fn test_pmod_float64() {
469        let left = Float64Array::from(vec![
470            Some(10.5),
471            Some(-7.2),
472            Some(15.8),
473            Some(-15.8),
474            Some(f64::NAN),
475            Some(f64::INFINITY),
476            Some(5.0),
477            Some(-5.0),
478            Some(10.5),
479            Some(-7.2),
480        ]);
481        let right = Float64Array::from(vec![
482            Some(3.0),
483            Some(3.0),
484            Some(4.2),
485            Some(4.2),
486            Some(2.0),
487            Some(2.0),
488            Some(f64::INFINITY),
489            Some(f64::INFINITY),
490            Some(0.0),
491            Some(0.0),
492        ]);
493
494        let left_value = ColumnarValue::Array(Arc::new(left));
495        let right_value = ColumnarValue::Array(Arc::new(right));
496
497        let result = spark_pmod(&[left_value, right_value], false).unwrap();
498
499        if let ColumnarValue::Array(result_array) = result {
500            let result_float64 = result_array
501                .as_any()
502                .downcast_ref::<Float64Array>()
503                .unwrap();
504            // Regular cases
505            assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 pmod 3.0 = 1.5
506            assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive)
507            assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); // 15.8 pmod 4.2 = 3.2
508            assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); // -15.8 pmod 4.2 = 1.0 (positive)
509            // nan pmod 2.0 = nan
510            assert!(result_float64.value(4).is_nan());
511            // inf pmod 2.0 = nan (IEEE 754)
512            assert!(result_float64.value(5).is_nan());
513            // 5.0 pmod inf = 5.0
514            assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
515            // -5.0 pmod inf = NaN
516            assert!(result_float64.value(7).is_nan());
517            // Division by zero returns NULL
518            assert!(result_float64.is_null(8)); // 10.5 pmod 0.0 = NULL
519            assert!(result_float64.is_null(9)); // -7.2 pmod 0.0 = NULL
520        } else {
521            panic!("Expected array result");
522        }
523    }
524
525    #[test]
526    fn test_pmod_float32() {
527        let left = Float32Array::from(vec![
528            Some(10.5),
529            Some(-7.2),
530            Some(15.8),
531            Some(-15.8),
532            Some(f32::NAN),
533            Some(f32::INFINITY),
534            Some(5.0),
535            Some(-5.0),
536            Some(10.5),
537            Some(-7.2),
538        ]);
539        let right = Float32Array::from(vec![
540            Some(3.0),
541            Some(3.0),
542            Some(4.2),
543            Some(4.2),
544            Some(2.0),
545            Some(2.0),
546            Some(f32::INFINITY),
547            Some(f32::INFINITY),
548            Some(0.0),
549            Some(0.0),
550        ]);
551
552        let left_value = ColumnarValue::Array(Arc::new(left));
553        let right_value = ColumnarValue::Array(Arc::new(right));
554
555        let result = spark_pmod(&[left_value, right_value], false).unwrap();
556
557        if let ColumnarValue::Array(result_array) = result {
558            let result_float32 = result_array
559                .as_any()
560                .downcast_ref::<Float32Array>()
561                .unwrap();
562            // Regular cases
563            assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 pmod 3.0 = 1.5
564            assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive)
565            assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 pmod 4.2 = 3.2
566            assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); // -15.8 pmod 4.2 = 1.0 (positive)
567            // nan pmod 2.0 = nan
568            assert!(result_float32.value(4).is_nan());
569            // inf pmod 2.0 = nan (IEEE 754)
570            assert!(result_float32.value(5).is_nan());
571            // 5.0 pmod inf = 5.0
572            assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
573            // -5.0 pmod inf = NaN
574            assert!(result_float32.value(7).is_nan());
575            // Division by zero returns NULL
576            assert!(result_float32.is_null(8)); // 10.5 pmod 0.0 = NULL
577            assert!(result_float32.is_null(9)); // -7.2 pmod 0.0 = NULL
578        } else {
579            panic!("Expected array result");
580        }
581    }
582
583    #[test]
584    fn test_pmod_scalar() {
585        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
586        let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
587
588        let left_value = ColumnarValue::Array(Arc::new(left));
589
590        let result = spark_pmod(&[left_value, right_value], false).unwrap();
591
592        if let ColumnarValue::Array(result_array) = result {
593            let result_int32 =
594                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
595            assert_eq!(result_int32.value(0), 1); // 10 pmod 3 = 1
596            assert_eq!(result_int32.value(1), 2); // -7 pmod 3 = 2 (positive remainder)
597            assert_eq!(result_int32.value(2), 0); // 15 pmod 3 = 0
598            assert_eq!(result_int32.value(3), 0); // -15 pmod 3 = 0 (positive remainder)
599        } else {
600            panic!("Expected array result");
601        }
602    }
603
604    #[test]
605    fn test_pmod_wrong_arg_count() {
606        let left = Int32Array::from(vec![Some(10)]);
607        let left_value = ColumnarValue::Array(Arc::new(left));
608
609        let result = spark_pmod(&[left_value], false);
610        assert!(result.is_err());
611    }
612
613    #[test]
614    fn test_pmod_zero_division_legacy() {
615        // In legacy mode (ANSI off), division by zero returns NULL per-element
616        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
617        let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
618
619        let left_value = ColumnarValue::Array(Arc::new(left));
620        let right_value = ColumnarValue::Array(Arc::new(right));
621
622        let result = spark_pmod(&[left_value, right_value], false).unwrap();
623
624        if let ColumnarValue::Array(result_array) = result {
625            let result_int32 =
626                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
627            assert!(result_int32.is_null(0)); // 10 pmod 0 = NULL
628            assert!(result_int32.is_null(1)); // -7 pmod 0 = NULL
629            assert_eq!(result_int32.value(2), 3); // 15 pmod 4 = 3
630        } else {
631            panic!("Expected array result");
632        }
633    }
634
635    #[test]
636    fn test_pmod_zero_division_ansi() {
637        // In ANSI mode, division by zero should error
638        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
639        let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
640
641        let left_value = ColumnarValue::Array(Arc::new(left));
642        let right_value = ColumnarValue::Array(Arc::new(right));
643
644        let result = spark_pmod(&[left_value, right_value], true);
645        assert!(result.is_err());
646    }
647
648    #[test]
649    fn test_pmod_negative_divisor() {
650        // PMOD with negative divisor should still work like regular mod
651        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
652        let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
653
654        let left_value = ColumnarValue::Array(Arc::new(left));
655        let right_value = ColumnarValue::Array(Arc::new(right));
656
657        let result = spark_pmod(&[left_value, right_value], false).unwrap();
658
659        if let ColumnarValue::Array(result_array) = result {
660            let result_int32 =
661                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
662            assert_eq!(result_int32.value(0), 1); // 10 pmod -3 = 1
663            assert_eq!(result_int32.value(1), -1); // -7 pmod -3 = -1
664            assert_eq!(result_int32.value(2), 3); // 15 pmod -4 = 3
665        } else {
666            panic!("Expected array result");
667        }
668    }
669
670    #[test]
671    fn test_pmod_edge_cases() {
672        // Test edge cases for PMOD
673        let left = Int32Array::from(vec![
674            Some(0),  // 0 pmod 5 = 0
675            Some(-1), // -1 pmod 5 = 4
676            Some(1),  // 1 pmod 5 = 1
677            Some(-5), // -5 pmod 5 = 0
678            Some(5),  // 5 pmod 5 = 0
679            Some(-6), // -6 pmod 5 = 4
680            Some(6),  // 6 pmod 5 = 1
681        ]);
682        let right = Int32Array::from(vec![
683            Some(5),
684            Some(5),
685            Some(5),
686            Some(5),
687            Some(5),
688            Some(5),
689            Some(5),
690        ]);
691
692        let left_value = ColumnarValue::Array(Arc::new(left));
693        let right_value = ColumnarValue::Array(Arc::new(right));
694
695        let result = spark_pmod(&[left_value, right_value], false).unwrap();
696
697        if let ColumnarValue::Array(result_array) = result {
698            let result_int32 =
699                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
700            assert_eq!(result_int32.value(0), 0); // 0 pmod 5 = 0
701            assert_eq!(result_int32.value(1), 4); // -1 pmod 5 = 4
702            assert_eq!(result_int32.value(2), 1); // 1 pmod 5 = 1
703            assert_eq!(result_int32.value(3), 0); // -5 pmod 5 = 0
704            assert_eq!(result_int32.value(4), 0); // 5 pmod 5 = 0
705            assert_eq!(result_int32.value(5), 4); // -6 pmod 5 = 4
706            assert_eq!(result_int32.value(6), 1); // 6 pmod 5 = 1
707        } else {
708            panic!("Expected array result");
709        }
710    }
711}