datafusion_spark/function/math/
modulus.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::compute::kernels::numeric::add;
19use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip};
20use arrow::datatypes::DataType;
21use datafusion_common::{DataFusionError, Result, ScalarValue};
22use datafusion_expr::{
23    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use std::any::Any;
26
27/// Spark-compatible `mod` function
28/// This function directly uses Arrow's arithmetic_op function for modulo operations
29pub fn spark_mod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
30    if args.len() != 2 {
31        return Err(DataFusionError::Internal(
32            "mod expects exactly two arguments".to_string(),
33        ));
34    }
35    let args = ColumnarValue::values_to_arrays(args)?;
36    let result = rem(&args[0], &args[1])?;
37    Ok(ColumnarValue::Array(result))
38}
39
40/// Spark-compatible `pmod` function
41/// This function directly uses Arrow's arithmetic_op function for modulo operations
42pub fn spark_pmod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
43    if args.len() != 2 {
44        return Err(DataFusionError::Internal(
45            "pmod expects exactly two arguments".to_string(),
46        ));
47    }
48    let args = ColumnarValue::values_to_arrays(args)?;
49    let left = &args[0];
50    let right = &args[1];
51    let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
52    let result = rem(left, right)?;
53    let neg = lt(&result, &zero)?;
54    let plus = zip(&neg, right, &zero)?;
55    let result = add(&plus, &result)?;
56    let result = rem(&result, right)?;
57    Ok(ColumnarValue::Array(result))
58}
59
60/// SparkMod implements the Spark-compatible modulo function
61#[derive(Debug, PartialEq, Eq, Hash)]
62pub struct SparkMod {
63    signature: Signature,
64}
65
66impl Default for SparkMod {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl SparkMod {
73    pub fn new() -> Self {
74        Self {
75            signature: Signature::numeric(2, Volatility::Immutable),
76        }
77    }
78}
79
80impl ScalarUDFImpl for SparkMod {
81    fn as_any(&self) -> &dyn Any {
82        self
83    }
84
85    fn name(&self) -> &str {
86        "mod"
87    }
88
89    fn signature(&self) -> &Signature {
90        &self.signature
91    }
92
93    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94        if arg_types.len() != 2 {
95            return Err(DataFusionError::Internal(
96                "mod expects exactly two arguments".to_string(),
97            ));
98        }
99
100        // Return the same type as the first argument for simplicity
101        // Arrow's rem function handles type promotion internally
102        Ok(arg_types[0].clone())
103    }
104
105    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106        spark_mod(&args.args)
107    }
108}
109
110/// SparkMod implements the Spark-compatible modulo function
111#[derive(Debug, PartialEq, Eq, Hash)]
112pub struct SparkPmod {
113    signature: Signature,
114}
115
116impl Default for SparkPmod {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl SparkPmod {
123    pub fn new() -> Self {
124        Self {
125            signature: Signature::numeric(2, Volatility::Immutable),
126        }
127    }
128}
129
130impl ScalarUDFImpl for SparkPmod {
131    fn as_any(&self) -> &dyn Any {
132        self
133    }
134
135    fn name(&self) -> &str {
136        "pmod"
137    }
138
139    fn signature(&self) -> &Signature {
140        &self.signature
141    }
142
143    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
144        if arg_types.len() != 2 {
145            return Err(DataFusionError::Internal(
146                "pmod expects exactly two arguments".to_string(),
147            ));
148        }
149
150        // Return the same type as the first argument for simplicity
151        // Arrow's rem function handles type promotion internally
152        Ok(arg_types[0].clone())
153    }
154
155    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
156        spark_pmod(&args.args)
157    }
158}
159
160#[cfg(test)]
161mod test {
162    use std::sync::Arc;
163
164    use super::*;
165    use arrow::array::*;
166    use datafusion_common::ScalarValue;
167
168    #[test]
169    fn test_mod_int32() {
170        let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
171        let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
172
173        let left_value = ColumnarValue::Array(Arc::new(left));
174        let right_value = ColumnarValue::Array(Arc::new(right));
175
176        let result = spark_mod(&[left_value, right_value]).unwrap();
177
178        if let ColumnarValue::Array(result_array) = result {
179            let result_int32 =
180                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
181            assert_eq!(result_int32.value(0), 1); // 10 % 3 = 1
182            assert_eq!(result_int32.value(1), 1); // 7 % 2 = 1
183            assert_eq!(result_int32.value(2), 3); // 15 % 4 = 3
184            assert!(result_int32.is_null(3)); // None % 5 = None
185        } else {
186            panic!("Expected array result");
187        }
188    }
189
190    #[test]
191    fn test_mod_int64() {
192        let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
193        let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
194
195        let left_value = ColumnarValue::Array(Arc::new(left));
196        let right_value = ColumnarValue::Array(Arc::new(right));
197
198        let result = spark_mod(&[left_value, right_value]).unwrap();
199
200        if let ColumnarValue::Array(result_array) = result {
201            let result_int64 =
202                result_array.as_any().downcast_ref::<Int64Array>().unwrap();
203            assert_eq!(result_int64.value(0), 10); // 100 % 30 = 10
204            assert_eq!(result_int64.value(1), 0); // 50 % 25 = 0
205            assert_eq!(result_int64.value(2), 20); // 200 % 60 = 20
206        } else {
207            panic!("Expected array result");
208        }
209    }
210
211    #[test]
212    fn test_mod_float64() {
213        let left = Float64Array::from(vec![
214            Some(10.5),
215            Some(7.2),
216            Some(15.8),
217            Some(f64::NAN),
218            Some(f64::INFINITY),
219            Some(5.0),
220            Some(5.0),
221            Some(f64::NAN),
222            Some(f64::INFINITY),
223        ]);
224        let right = Float64Array::from(vec![
225            Some(3.0),
226            Some(2.5),
227            Some(4.2),
228            Some(2.0),
229            Some(2.0),
230            Some(f64::NAN),
231            Some(f64::INFINITY),
232            Some(f64::INFINITY),
233            Some(f64::NAN),
234        ]);
235
236        let left_value = ColumnarValue::Array(Arc::new(left));
237        let right_value = ColumnarValue::Array(Arc::new(right));
238
239        let result = spark_mod(&[left_value, right_value]).unwrap();
240
241        if let ColumnarValue::Array(result_array) = result {
242            let result_float64 = result_array
243                .as_any()
244                .downcast_ref::<Float64Array>()
245                .unwrap();
246            // Regular cases
247            assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 % 3.0 = 1.5
248            assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); // 7.2 % 2.5 = 2.2
249            assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); // 15.8 % 4.2 = 3.2
250                                                                           // nan % 2.0 = nan
251            assert!(result_float64.value(3).is_nan());
252            // inf % 2.0 = nan (IEEE 754)
253            assert!(result_float64.value(4).is_nan());
254            // 5.0 % nan = nan
255            assert!(result_float64.value(5).is_nan());
256            // 5.0 % inf = 5.0
257            assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
258            // nan % inf = nan
259            assert!(result_float64.value(7).is_nan());
260            // inf % nan = nan
261            assert!(result_float64.value(8).is_nan());
262        } else {
263            panic!("Expected array result");
264        }
265    }
266
267    #[test]
268    fn test_mod_float32() {
269        let left = Float32Array::from(vec![
270            Some(10.5),
271            Some(7.2),
272            Some(15.8),
273            Some(f32::NAN),
274            Some(f32::INFINITY),
275            Some(5.0),
276            Some(5.0),
277            Some(f32::NAN),
278            Some(f32::INFINITY),
279        ]);
280        let right = Float32Array::from(vec![
281            Some(3.0),
282            Some(2.5),
283            Some(4.2),
284            Some(2.0),
285            Some(2.0),
286            Some(f32::NAN),
287            Some(f32::INFINITY),
288            Some(f32::INFINITY),
289            Some(f32::NAN),
290        ]);
291
292        let left_value = ColumnarValue::Array(Arc::new(left));
293        let right_value = ColumnarValue::Array(Arc::new(right));
294
295        let result = spark_mod(&[left_value, right_value]).unwrap();
296
297        if let ColumnarValue::Array(result_array) = result {
298            let result_float32 = result_array
299                .as_any()
300                .downcast_ref::<Float32Array>()
301                .unwrap();
302            // Regular cases
303            assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 % 3.0 = 1.5
304            assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); // 7.2 % 2.5 = 2.2
305            assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 % 4.2 = 3.2
306                                                                                  // nan % 2.0 = nan
307            assert!(result_float32.value(3).is_nan());
308            // inf % 2.0 = nan (IEEE 754)
309            assert!(result_float32.value(4).is_nan());
310            // 5.0 % nan = nan
311            assert!(result_float32.value(5).is_nan());
312            // 5.0 % inf = 5.0
313            assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
314            // nan % inf = nan
315            assert!(result_float32.value(7).is_nan());
316            // inf % nan = nan
317            assert!(result_float32.value(8).is_nan());
318        } else {
319            panic!("Expected array result");
320        }
321    }
322
323    #[test]
324    fn test_mod_scalar() {
325        let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
326        let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
327
328        let left_value = ColumnarValue::Array(Arc::new(left));
329
330        let result = spark_mod(&[left_value, right_value]).unwrap();
331
332        if let ColumnarValue::Array(result_array) = result {
333            let result_int32 =
334                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
335            assert_eq!(result_int32.value(0), 1); // 10 % 3 = 1
336            assert_eq!(result_int32.value(1), 1); // 7 % 3 = 1
337            assert_eq!(result_int32.value(2), 0); // 15 % 3 = 0
338        } else {
339            panic!("Expected array result");
340        }
341    }
342
343    #[test]
344    fn test_mod_wrong_arg_count() {
345        let left = Int32Array::from(vec![Some(10)]);
346        let left_value = ColumnarValue::Array(Arc::new(left));
347
348        let result = spark_mod(&[left_value]);
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn test_mod_zero_division() {
354        let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
355        let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
356
357        let left_value = ColumnarValue::Array(Arc::new(left));
358        let right_value = ColumnarValue::Array(Arc::new(right));
359
360        let result = spark_mod(&[left_value, right_value]);
361        assert!(result.is_err()); // Division by zero should error
362    }
363
364    // PMOD tests
365    #[test]
366    fn test_pmod_int32() {
367        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
368        let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
369
370        let left_value = ColumnarValue::Array(Arc::new(left));
371        let right_value = ColumnarValue::Array(Arc::new(right));
372
373        let result = spark_pmod(&[left_value, right_value]).unwrap();
374
375        if let ColumnarValue::Array(result_array) = result {
376            let result_int32 =
377                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
378            assert_eq!(result_int32.value(0), 1); // 10 pmod 3 = 1
379            assert_eq!(result_int32.value(1), 2); // -7 pmod 3 = 2 (positive remainder)
380            assert_eq!(result_int32.value(2), 3); // 15 pmod 4 = 3
381            assert_eq!(result_int32.value(3), 1); // -15 pmod 4 = 1 (positive remainder)
382            assert!(result_int32.is_null(4)); // None pmod 5 = None
383        } else {
384            panic!("Expected array result");
385        }
386    }
387
388    #[test]
389    fn test_pmod_int64() {
390        let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
391        let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
392
393        let left_value = ColumnarValue::Array(Arc::new(left));
394        let right_value = ColumnarValue::Array(Arc::new(right));
395
396        let result = spark_pmod(&[left_value, right_value]).unwrap();
397
398        if let ColumnarValue::Array(result_array) = result {
399            let result_int64 =
400                result_array.as_any().downcast_ref::<Int64Array>().unwrap();
401            assert_eq!(result_int64.value(0), 10); // 100 pmod 30 = 10
402            assert_eq!(result_int64.value(1), 10); // -50 pmod 30 = 10 (positive remainder)
403            assert_eq!(result_int64.value(2), 20); // 200 pmod 60 = 20
404            assert_eq!(result_int64.value(3), 40); // -200 pmod 60 = 40 (positive remainder)
405        } else {
406            panic!("Expected array result");
407        }
408    }
409
410    #[test]
411    fn test_pmod_float64() {
412        let left = Float64Array::from(vec![
413            Some(10.5),
414            Some(-7.2),
415            Some(15.8),
416            Some(-15.8),
417            Some(f64::NAN),
418            Some(f64::INFINITY),
419            Some(5.0),
420            Some(-5.0),
421        ]);
422        let right = Float64Array::from(vec![
423            Some(3.0),
424            Some(3.0),
425            Some(4.2),
426            Some(4.2),
427            Some(2.0),
428            Some(2.0),
429            Some(f64::INFINITY),
430            Some(f64::INFINITY),
431        ]);
432
433        let left_value = ColumnarValue::Array(Arc::new(left));
434        let right_value = ColumnarValue::Array(Arc::new(right));
435
436        let result = spark_pmod(&[left_value, right_value]).unwrap();
437
438        if let ColumnarValue::Array(result_array) = result {
439            let result_float64 = result_array
440                .as_any()
441                .downcast_ref::<Float64Array>()
442                .unwrap();
443            // Regular cases
444            assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 pmod 3.0 = 1.5
445            assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive)
446            assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); // 15.8 pmod 4.2 = 3.2
447            assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); // -15.8 pmod 4.2 = 1.0 (positive)
448                                                                                 // nan pmod 2.0 = nan
449            assert!(result_float64.value(4).is_nan());
450            // inf pmod 2.0 = nan (IEEE 754)
451            assert!(result_float64.value(5).is_nan());
452            // 5.0 pmod inf = 5.0
453            assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
454            // -5.0 pmod inf = NaN
455            assert!(result_float64.value(7).is_nan());
456        } else {
457            panic!("Expected array result");
458        }
459    }
460
461    #[test]
462    fn test_pmod_float32() {
463        let left = Float32Array::from(vec![
464            Some(10.5),
465            Some(-7.2),
466            Some(15.8),
467            Some(-15.8),
468            Some(f32::NAN),
469            Some(f32::INFINITY),
470            Some(5.0),
471            Some(-5.0),
472        ]);
473        let right = Float32Array::from(vec![
474            Some(3.0),
475            Some(3.0),
476            Some(4.2),
477            Some(4.2),
478            Some(2.0),
479            Some(2.0),
480            Some(f32::INFINITY),
481            Some(f32::INFINITY),
482        ]);
483
484        let left_value = ColumnarValue::Array(Arc::new(left));
485        let right_value = ColumnarValue::Array(Arc::new(right));
486
487        let result = spark_pmod(&[left_value, right_value]).unwrap();
488
489        if let ColumnarValue::Array(result_array) = result {
490            let result_float32 = result_array
491                .as_any()
492                .downcast_ref::<Float32Array>()
493                .unwrap();
494            // Regular cases
495            assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 pmod 3.0 = 1.5
496            assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive)
497            assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 pmod 4.2 = 3.2
498            assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); // -15.8 pmod 4.2 = 1.0 (positive)
499                                                                                  // nan pmod 2.0 = nan
500            assert!(result_float32.value(4).is_nan());
501            // inf pmod 2.0 = nan (IEEE 754)
502            assert!(result_float32.value(5).is_nan());
503            // 5.0 pmod inf = 5.0
504            assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
505            // -5.0 pmod inf = NaN
506            assert!(result_float32.value(7).is_nan());
507        } else {
508            panic!("Expected array result");
509        }
510    }
511
512    #[test]
513    fn test_pmod_scalar() {
514        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
515        let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
516
517        let left_value = ColumnarValue::Array(Arc::new(left));
518
519        let result = spark_pmod(&[left_value, right_value]).unwrap();
520
521        if let ColumnarValue::Array(result_array) = result {
522            let result_int32 =
523                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
524            assert_eq!(result_int32.value(0), 1); // 10 pmod 3 = 1
525            assert_eq!(result_int32.value(1), 2); // -7 pmod 3 = 2 (positive remainder)
526            assert_eq!(result_int32.value(2), 0); // 15 pmod 3 = 0
527            assert_eq!(result_int32.value(3), 0); // -15 pmod 3 = 0 (positive remainder)
528        } else {
529            panic!("Expected array result");
530        }
531    }
532
533    #[test]
534    fn test_pmod_wrong_arg_count() {
535        let left = Int32Array::from(vec![Some(10)]);
536        let left_value = ColumnarValue::Array(Arc::new(left));
537
538        let result = spark_pmod(&[left_value]);
539        assert!(result.is_err());
540    }
541
542    #[test]
543    fn test_pmod_zero_division() {
544        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
545        let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
546
547        let left_value = ColumnarValue::Array(Arc::new(left));
548        let right_value = ColumnarValue::Array(Arc::new(right));
549
550        let result = spark_pmod(&[left_value, right_value]);
551        assert!(result.is_err()); // Division by zero should error
552    }
553
554    #[test]
555    fn test_pmod_negative_divisor() {
556        // PMOD with negative divisor should still work like regular mod
557        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
558        let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
559
560        let left_value = ColumnarValue::Array(Arc::new(left));
561        let right_value = ColumnarValue::Array(Arc::new(right));
562
563        let result = spark_pmod(&[left_value, right_value]).unwrap();
564
565        if let ColumnarValue::Array(result_array) = result {
566            let result_int32 =
567                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
568            assert_eq!(result_int32.value(0), 1); // 10 pmod -3 = 1
569            assert_eq!(result_int32.value(1), -1); // -7 pmod -3 = -1
570            assert_eq!(result_int32.value(2), 3); // 15 pmod -4 = 3
571        } else {
572            panic!("Expected array result");
573        }
574    }
575
576    #[test]
577    fn test_pmod_edge_cases() {
578        // Test edge cases for PMOD
579        let left = Int32Array::from(vec![
580            Some(0),  // 0 pmod 5 = 0
581            Some(-1), // -1 pmod 5 = 4
582            Some(1),  // 1 pmod 5 = 1
583            Some(-5), // -5 pmod 5 = 0
584            Some(5),  // 5 pmod 5 = 0
585            Some(-6), // -6 pmod 5 = 4
586            Some(6),  // 6 pmod 5 = 1
587        ]);
588        let right = Int32Array::from(vec![
589            Some(5),
590            Some(5),
591            Some(5),
592            Some(5),
593            Some(5),
594            Some(5),
595            Some(5),
596        ]);
597
598        let left_value = ColumnarValue::Array(Arc::new(left));
599        let right_value = ColumnarValue::Array(Arc::new(right));
600
601        let result = spark_pmod(&[left_value, right_value]).unwrap();
602
603        if let ColumnarValue::Array(result_array) = result {
604            let result_int32 =
605                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
606            assert_eq!(result_int32.value(0), 0); // 0 pmod 5 = 0
607            assert_eq!(result_int32.value(1), 4); // -1 pmod 5 = 4
608            assert_eq!(result_int32.value(2), 1); // 1 pmod 5 = 1
609            assert_eq!(result_int32.value(3), 0); // -5 pmod 5 = 0
610            assert_eq!(result_int32.value(4), 0); // 5 pmod 5 = 0
611            assert_eq!(result_int32.value(5), 4); // -6 pmod 5 = 4
612            assert_eq!(result_int32.value(6), 1); // 6 pmod 5 = 1
613        } else {
614            panic!("Expected array result");
615        }
616    }
617}