datafusion_spark/function/math/
modulus.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::compute::kernels::numeric::add;
19use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip};
20use arrow::datatypes::DataType;
21use datafusion_common::{internal_err, Result, ScalarValue};
22use datafusion_expr::{
23    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use std::any::Any;
26
27/// Spark-compatible `mod` function
28/// This function directly uses Arrow's arithmetic_op function for modulo operations
29pub fn spark_mod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
30    if args.len() != 2 {
31        return internal_err!("mod expects exactly two arguments");
32    }
33    let args = ColumnarValue::values_to_arrays(args)?;
34    let result = rem(&args[0], &args[1])?;
35    Ok(ColumnarValue::Array(result))
36}
37
38/// Spark-compatible `pmod` function
39/// This function directly uses Arrow's arithmetic_op function for modulo operations
40pub fn spark_pmod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
41    if args.len() != 2 {
42        return internal_err!("pmod expects exactly two arguments");
43    }
44    let args = ColumnarValue::values_to_arrays(args)?;
45    let left = &args[0];
46    let right = &args[1];
47    let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
48    let result = rem(left, right)?;
49    let neg = lt(&result, &zero)?;
50    let plus = zip(&neg, right, &zero)?;
51    let result = add(&plus, &result)?;
52    let result = rem(&result, right)?;
53    Ok(ColumnarValue::Array(result))
54}
55
56/// SparkMod implements the Spark-compatible modulo function
57#[derive(Debug, PartialEq, Eq, Hash)]
58pub struct SparkMod {
59    signature: Signature,
60}
61
62impl Default for SparkMod {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl SparkMod {
69    pub fn new() -> Self {
70        Self {
71            signature: Signature::numeric(2, Volatility::Immutable),
72        }
73    }
74}
75
76impl ScalarUDFImpl for SparkMod {
77    fn as_any(&self) -> &dyn Any {
78        self
79    }
80
81    fn name(&self) -> &str {
82        "mod"
83    }
84
85    fn signature(&self) -> &Signature {
86        &self.signature
87    }
88
89    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
90        if arg_types.len() != 2 {
91            return internal_err!("mod expects exactly two arguments");
92        }
93
94        // Return the same type as the first argument for simplicity
95        // Arrow's rem function handles type promotion internally
96        Ok(arg_types[0].clone())
97    }
98
99    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100        spark_mod(&args.args)
101    }
102}
103
104/// SparkMod implements the Spark-compatible modulo function
105#[derive(Debug, PartialEq, Eq, Hash)]
106pub struct SparkPmod {
107    signature: Signature,
108}
109
110impl Default for SparkPmod {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl SparkPmod {
117    pub fn new() -> Self {
118        Self {
119            signature: Signature::numeric(2, Volatility::Immutable),
120        }
121    }
122}
123
124impl ScalarUDFImpl for SparkPmod {
125    fn as_any(&self) -> &dyn Any {
126        self
127    }
128
129    fn name(&self) -> &str {
130        "pmod"
131    }
132
133    fn signature(&self) -> &Signature {
134        &self.signature
135    }
136
137    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
138        if arg_types.len() != 2 {
139            return internal_err!("pmod expects exactly two arguments");
140        }
141
142        // Return the same type as the first argument for simplicity
143        // Arrow's rem function handles type promotion internally
144        Ok(arg_types[0].clone())
145    }
146
147    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
148        spark_pmod(&args.args)
149    }
150}
151
152#[cfg(test)]
153mod test {
154    use std::sync::Arc;
155
156    use super::*;
157    use arrow::array::*;
158    use datafusion_common::ScalarValue;
159
160    #[test]
161    fn test_mod_int32() {
162        let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
163        let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
164
165        let left_value = ColumnarValue::Array(Arc::new(left));
166        let right_value = ColumnarValue::Array(Arc::new(right));
167
168        let result = spark_mod(&[left_value, right_value]).unwrap();
169
170        if let ColumnarValue::Array(result_array) = result {
171            let result_int32 =
172                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
173            assert_eq!(result_int32.value(0), 1); // 10 % 3 = 1
174            assert_eq!(result_int32.value(1), 1); // 7 % 2 = 1
175            assert_eq!(result_int32.value(2), 3); // 15 % 4 = 3
176            assert!(result_int32.is_null(3)); // None % 5 = None
177        } else {
178            panic!("Expected array result");
179        }
180    }
181
182    #[test]
183    fn test_mod_int64() {
184        let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
185        let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
186
187        let left_value = ColumnarValue::Array(Arc::new(left));
188        let right_value = ColumnarValue::Array(Arc::new(right));
189
190        let result = spark_mod(&[left_value, right_value]).unwrap();
191
192        if let ColumnarValue::Array(result_array) = result {
193            let result_int64 =
194                result_array.as_any().downcast_ref::<Int64Array>().unwrap();
195            assert_eq!(result_int64.value(0), 10); // 100 % 30 = 10
196            assert_eq!(result_int64.value(1), 0); // 50 % 25 = 0
197            assert_eq!(result_int64.value(2), 20); // 200 % 60 = 20
198        } else {
199            panic!("Expected array result");
200        }
201    }
202
203    #[test]
204    fn test_mod_float64() {
205        let left = Float64Array::from(vec![
206            Some(10.5),
207            Some(7.2),
208            Some(15.8),
209            Some(f64::NAN),
210            Some(f64::INFINITY),
211            Some(5.0),
212            Some(5.0),
213            Some(f64::NAN),
214            Some(f64::INFINITY),
215        ]);
216        let right = Float64Array::from(vec![
217            Some(3.0),
218            Some(2.5),
219            Some(4.2),
220            Some(2.0),
221            Some(2.0),
222            Some(f64::NAN),
223            Some(f64::INFINITY),
224            Some(f64::INFINITY),
225            Some(f64::NAN),
226        ]);
227
228        let left_value = ColumnarValue::Array(Arc::new(left));
229        let right_value = ColumnarValue::Array(Arc::new(right));
230
231        let result = spark_mod(&[left_value, right_value]).unwrap();
232
233        if let ColumnarValue::Array(result_array) = result {
234            let result_float64 = result_array
235                .as_any()
236                .downcast_ref::<Float64Array>()
237                .unwrap();
238            // Regular cases
239            assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 % 3.0 = 1.5
240            assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); // 7.2 % 2.5 = 2.2
241            assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); // 15.8 % 4.2 = 3.2
242                                                                           // nan % 2.0 = nan
243            assert!(result_float64.value(3).is_nan());
244            // inf % 2.0 = nan (IEEE 754)
245            assert!(result_float64.value(4).is_nan());
246            // 5.0 % nan = nan
247            assert!(result_float64.value(5).is_nan());
248            // 5.0 % inf = 5.0
249            assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
250            // nan % inf = nan
251            assert!(result_float64.value(7).is_nan());
252            // inf % nan = nan
253            assert!(result_float64.value(8).is_nan());
254        } else {
255            panic!("Expected array result");
256        }
257    }
258
259    #[test]
260    fn test_mod_float32() {
261        let left = Float32Array::from(vec![
262            Some(10.5),
263            Some(7.2),
264            Some(15.8),
265            Some(f32::NAN),
266            Some(f32::INFINITY),
267            Some(5.0),
268            Some(5.0),
269            Some(f32::NAN),
270            Some(f32::INFINITY),
271        ]);
272        let right = Float32Array::from(vec![
273            Some(3.0),
274            Some(2.5),
275            Some(4.2),
276            Some(2.0),
277            Some(2.0),
278            Some(f32::NAN),
279            Some(f32::INFINITY),
280            Some(f32::INFINITY),
281            Some(f32::NAN),
282        ]);
283
284        let left_value = ColumnarValue::Array(Arc::new(left));
285        let right_value = ColumnarValue::Array(Arc::new(right));
286
287        let result = spark_mod(&[left_value, right_value]).unwrap();
288
289        if let ColumnarValue::Array(result_array) = result {
290            let result_float32 = result_array
291                .as_any()
292                .downcast_ref::<Float32Array>()
293                .unwrap();
294            // Regular cases
295            assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 % 3.0 = 1.5
296            assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); // 7.2 % 2.5 = 2.2
297            assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 % 4.2 = 3.2
298                                                                                  // nan % 2.0 = nan
299            assert!(result_float32.value(3).is_nan());
300            // inf % 2.0 = nan (IEEE 754)
301            assert!(result_float32.value(4).is_nan());
302            // 5.0 % nan = nan
303            assert!(result_float32.value(5).is_nan());
304            // 5.0 % inf = 5.0
305            assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
306            // nan % inf = nan
307            assert!(result_float32.value(7).is_nan());
308            // inf % nan = nan
309            assert!(result_float32.value(8).is_nan());
310        } else {
311            panic!("Expected array result");
312        }
313    }
314
315    #[test]
316    fn test_mod_scalar() {
317        let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
318        let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
319
320        let left_value = ColumnarValue::Array(Arc::new(left));
321
322        let result = spark_mod(&[left_value, right_value]).unwrap();
323
324        if let ColumnarValue::Array(result_array) = result {
325            let result_int32 =
326                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
327            assert_eq!(result_int32.value(0), 1); // 10 % 3 = 1
328            assert_eq!(result_int32.value(1), 1); // 7 % 3 = 1
329            assert_eq!(result_int32.value(2), 0); // 15 % 3 = 0
330        } else {
331            panic!("Expected array result");
332        }
333    }
334
335    #[test]
336    fn test_mod_wrong_arg_count() {
337        let left = Int32Array::from(vec![Some(10)]);
338        let left_value = ColumnarValue::Array(Arc::new(left));
339
340        let result = spark_mod(&[left_value]);
341        assert!(result.is_err());
342    }
343
344    #[test]
345    fn test_mod_zero_division() {
346        let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
347        let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
348
349        let left_value = ColumnarValue::Array(Arc::new(left));
350        let right_value = ColumnarValue::Array(Arc::new(right));
351
352        let result = spark_mod(&[left_value, right_value]);
353        assert!(result.is_err()); // Division by zero should error
354    }
355
356    // PMOD tests
357    #[test]
358    fn test_pmod_int32() {
359        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
360        let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
361
362        let left_value = ColumnarValue::Array(Arc::new(left));
363        let right_value = ColumnarValue::Array(Arc::new(right));
364
365        let result = spark_pmod(&[left_value, right_value]).unwrap();
366
367        if let ColumnarValue::Array(result_array) = result {
368            let result_int32 =
369                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
370            assert_eq!(result_int32.value(0), 1); // 10 pmod 3 = 1
371            assert_eq!(result_int32.value(1), 2); // -7 pmod 3 = 2 (positive remainder)
372            assert_eq!(result_int32.value(2), 3); // 15 pmod 4 = 3
373            assert_eq!(result_int32.value(3), 1); // -15 pmod 4 = 1 (positive remainder)
374            assert!(result_int32.is_null(4)); // None pmod 5 = None
375        } else {
376            panic!("Expected array result");
377        }
378    }
379
380    #[test]
381    fn test_pmod_int64() {
382        let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
383        let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
384
385        let left_value = ColumnarValue::Array(Arc::new(left));
386        let right_value = ColumnarValue::Array(Arc::new(right));
387
388        let result = spark_pmod(&[left_value, right_value]).unwrap();
389
390        if let ColumnarValue::Array(result_array) = result {
391            let result_int64 =
392                result_array.as_any().downcast_ref::<Int64Array>().unwrap();
393            assert_eq!(result_int64.value(0), 10); // 100 pmod 30 = 10
394            assert_eq!(result_int64.value(1), 10); // -50 pmod 30 = 10 (positive remainder)
395            assert_eq!(result_int64.value(2), 20); // 200 pmod 60 = 20
396            assert_eq!(result_int64.value(3), 40); // -200 pmod 60 = 40 (positive remainder)
397        } else {
398            panic!("Expected array result");
399        }
400    }
401
402    #[test]
403    fn test_pmod_float64() {
404        let left = Float64Array::from(vec![
405            Some(10.5),
406            Some(-7.2),
407            Some(15.8),
408            Some(-15.8),
409            Some(f64::NAN),
410            Some(f64::INFINITY),
411            Some(5.0),
412            Some(-5.0),
413        ]);
414        let right = Float64Array::from(vec![
415            Some(3.0),
416            Some(3.0),
417            Some(4.2),
418            Some(4.2),
419            Some(2.0),
420            Some(2.0),
421            Some(f64::INFINITY),
422            Some(f64::INFINITY),
423        ]);
424
425        let left_value = ColumnarValue::Array(Arc::new(left));
426        let right_value = ColumnarValue::Array(Arc::new(right));
427
428        let result = spark_pmod(&[left_value, right_value]).unwrap();
429
430        if let ColumnarValue::Array(result_array) = result {
431            let result_float64 = result_array
432                .as_any()
433                .downcast_ref::<Float64Array>()
434                .unwrap();
435            // Regular cases
436            assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 pmod 3.0 = 1.5
437            assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive)
438            assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); // 15.8 pmod 4.2 = 3.2
439            assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); // -15.8 pmod 4.2 = 1.0 (positive)
440                                                                                 // nan pmod 2.0 = nan
441            assert!(result_float64.value(4).is_nan());
442            // inf pmod 2.0 = nan (IEEE 754)
443            assert!(result_float64.value(5).is_nan());
444            // 5.0 pmod inf = 5.0
445            assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
446            // -5.0 pmod inf = NaN
447            assert!(result_float64.value(7).is_nan());
448        } else {
449            panic!("Expected array result");
450        }
451    }
452
453    #[test]
454    fn test_pmod_float32() {
455        let left = Float32Array::from(vec![
456            Some(10.5),
457            Some(-7.2),
458            Some(15.8),
459            Some(-15.8),
460            Some(f32::NAN),
461            Some(f32::INFINITY),
462            Some(5.0),
463            Some(-5.0),
464        ]);
465        let right = Float32Array::from(vec![
466            Some(3.0),
467            Some(3.0),
468            Some(4.2),
469            Some(4.2),
470            Some(2.0),
471            Some(2.0),
472            Some(f32::INFINITY),
473            Some(f32::INFINITY),
474        ]);
475
476        let left_value = ColumnarValue::Array(Arc::new(left));
477        let right_value = ColumnarValue::Array(Arc::new(right));
478
479        let result = spark_pmod(&[left_value, right_value]).unwrap();
480
481        if let ColumnarValue::Array(result_array) = result {
482            let result_float32 = result_array
483                .as_any()
484                .downcast_ref::<Float32Array>()
485                .unwrap();
486            // Regular cases
487            assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 pmod 3.0 = 1.5
488            assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive)
489            assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 pmod 4.2 = 3.2
490            assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); // -15.8 pmod 4.2 = 1.0 (positive)
491                                                                                  // nan pmod 2.0 = nan
492            assert!(result_float32.value(4).is_nan());
493            // inf pmod 2.0 = nan (IEEE 754)
494            assert!(result_float32.value(5).is_nan());
495            // 5.0 pmod inf = 5.0
496            assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
497            // -5.0 pmod inf = NaN
498            assert!(result_float32.value(7).is_nan());
499        } else {
500            panic!("Expected array result");
501        }
502    }
503
504    #[test]
505    fn test_pmod_scalar() {
506        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
507        let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
508
509        let left_value = ColumnarValue::Array(Arc::new(left));
510
511        let result = spark_pmod(&[left_value, right_value]).unwrap();
512
513        if let ColumnarValue::Array(result_array) = result {
514            let result_int32 =
515                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
516            assert_eq!(result_int32.value(0), 1); // 10 pmod 3 = 1
517            assert_eq!(result_int32.value(1), 2); // -7 pmod 3 = 2 (positive remainder)
518            assert_eq!(result_int32.value(2), 0); // 15 pmod 3 = 0
519            assert_eq!(result_int32.value(3), 0); // -15 pmod 3 = 0 (positive remainder)
520        } else {
521            panic!("Expected array result");
522        }
523    }
524
525    #[test]
526    fn test_pmod_wrong_arg_count() {
527        let left = Int32Array::from(vec![Some(10)]);
528        let left_value = ColumnarValue::Array(Arc::new(left));
529
530        let result = spark_pmod(&[left_value]);
531        assert!(result.is_err());
532    }
533
534    #[test]
535    fn test_pmod_zero_division() {
536        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
537        let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
538
539        let left_value = ColumnarValue::Array(Arc::new(left));
540        let right_value = ColumnarValue::Array(Arc::new(right));
541
542        let result = spark_pmod(&[left_value, right_value]);
543        assert!(result.is_err()); // Division by zero should error
544    }
545
546    #[test]
547    fn test_pmod_negative_divisor() {
548        // PMOD with negative divisor should still work like regular mod
549        let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
550        let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
551
552        let left_value = ColumnarValue::Array(Arc::new(left));
553        let right_value = ColumnarValue::Array(Arc::new(right));
554
555        let result = spark_pmod(&[left_value, right_value]).unwrap();
556
557        if let ColumnarValue::Array(result_array) = result {
558            let result_int32 =
559                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
560            assert_eq!(result_int32.value(0), 1); // 10 pmod -3 = 1
561            assert_eq!(result_int32.value(1), -1); // -7 pmod -3 = -1
562            assert_eq!(result_int32.value(2), 3); // 15 pmod -4 = 3
563        } else {
564            panic!("Expected array result");
565        }
566    }
567
568    #[test]
569    fn test_pmod_edge_cases() {
570        // Test edge cases for PMOD
571        let left = Int32Array::from(vec![
572            Some(0),  // 0 pmod 5 = 0
573            Some(-1), // -1 pmod 5 = 4
574            Some(1),  // 1 pmod 5 = 1
575            Some(-5), // -5 pmod 5 = 0
576            Some(5),  // 5 pmod 5 = 0
577            Some(-6), // -6 pmod 5 = 4
578            Some(6),  // 6 pmod 5 = 1
579        ]);
580        let right = Int32Array::from(vec![
581            Some(5),
582            Some(5),
583            Some(5),
584            Some(5),
585            Some(5),
586            Some(5),
587            Some(5),
588        ]);
589
590        let left_value = ColumnarValue::Array(Arc::new(left));
591        let right_value = ColumnarValue::Array(Arc::new(right));
592
593        let result = spark_pmod(&[left_value, right_value]).unwrap();
594
595        if let ColumnarValue::Array(result_array) = result {
596            let result_int32 =
597                result_array.as_any().downcast_ref::<Int32Array>().unwrap();
598            assert_eq!(result_int32.value(0), 0); // 0 pmod 5 = 0
599            assert_eq!(result_int32.value(1), 4); // -1 pmod 5 = 4
600            assert_eq!(result_int32.value(2), 1); // 1 pmod 5 = 1
601            assert_eq!(result_int32.value(3), 0); // -5 pmod 5 = 0
602            assert_eq!(result_int32.value(4), 0); // 5 pmod 5 = 0
603            assert_eq!(result_int32.value(5), 4); // -6 pmod 5 = 4
604            assert_eq!(result_int32.value(6), 1); // 6 pmod 5 = 1
605        } else {
606            panic!("Expected array result");
607        }
608    }
609}