1use arrow::array::{Scalar, new_null_array};
19use arrow::compute::kernels::numeric::add;
20use arrow::compute::kernels::{
21 cmp::{eq, lt},
22 numeric::rem,
23 zip::zip,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
27use datafusion_expr::{
28 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
29};
30use std::any::Any;
31
32fn try_rem(
37 left: &arrow::array::ArrayRef,
38 right: &arrow::array::ArrayRef,
39 enable_ansi_mode: bool,
40) -> Result<arrow::array::ArrayRef> {
41 match rem(left, right) {
42 Ok(result) => Ok(result),
43 Err(arrow::error::ArrowError::DivideByZero) if !enable_ansi_mode => {
44 let zero = ScalarValue::new_zero(right.data_type())?.to_array()?;
47 let zero = Scalar::new(zero);
48 let null = Scalar::new(new_null_array(right.data_type(), 1));
49 let is_zero = eq(right, &zero)?;
50 let safe_right = zip(&is_zero, &null, right)?;
51 Ok(rem(left, &safe_right)?)
52 }
53 Err(e) => Err(e.into()),
54 }
55}
56
57pub fn spark_mod(
61 args: &[ColumnarValue],
62 enable_ansi_mode: bool,
63) -> Result<ColumnarValue> {
64 assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments");
65 let args = ColumnarValue::values_to_arrays(args)?;
66 let result = try_rem(&args[0], &args[1], enable_ansi_mode)?;
67 Ok(ColumnarValue::Array(result))
68}
69
70pub fn spark_pmod(
74 args: &[ColumnarValue],
75 enable_ansi_mode: bool,
76) -> Result<ColumnarValue> {
77 assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments");
78 let args = ColumnarValue::values_to_arrays(args)?;
79 let left = &args[0];
80 let right = &args[1];
81 let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
82 let result = try_rem(left, right, enable_ansi_mode)?;
83 let neg = lt(&result, &zero)?;
84 let plus = zip(&neg, right, &zero)?;
85 let result = add(&plus, &result)?;
86 let result = try_rem(&result, right, enable_ansi_mode)?;
87 Ok(ColumnarValue::Array(result))
88}
89
90#[derive(Debug, PartialEq, Eq, Hash)]
92pub struct SparkMod {
93 signature: Signature,
94}
95
96impl Default for SparkMod {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102impl SparkMod {
103 pub fn new() -> Self {
104 Self {
105 signature: Signature::numeric(2, Volatility::Immutable),
106 }
107 }
108}
109
110impl ScalarUDFImpl for SparkMod {
111 fn as_any(&self) -> &dyn Any {
112 self
113 }
114
115 fn name(&self) -> &str {
116 "mod"
117 }
118
119 fn signature(&self) -> &Signature {
120 &self.signature
121 }
122
123 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
124 assert_eq_or_internal_err!(
125 arg_types.len(),
126 2,
127 "mod expects exactly two arguments"
128 );
129
130 Ok(arg_types[0].clone())
133 }
134
135 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
136 spark_mod(&args.args, args.config_options.execution.enable_ansi_mode)
137 }
138}
139
140#[derive(Debug, PartialEq, Eq, Hash)]
142pub struct SparkPmod {
143 signature: Signature,
144}
145
146impl Default for SparkPmod {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152impl SparkPmod {
153 pub fn new() -> Self {
154 Self {
155 signature: Signature::numeric(2, Volatility::Immutable),
156 }
157 }
158}
159
160impl ScalarUDFImpl for SparkPmod {
161 fn as_any(&self) -> &dyn Any {
162 self
163 }
164
165 fn name(&self) -> &str {
166 "pmod"
167 }
168
169 fn signature(&self) -> &Signature {
170 &self.signature
171 }
172
173 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
174 assert_eq_or_internal_err!(
175 arg_types.len(),
176 2,
177 "pmod expects exactly two arguments"
178 );
179
180 Ok(arg_types[0].clone())
183 }
184
185 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
186 spark_pmod(&args.args, args.config_options.execution.enable_ansi_mode)
187 }
188}
189
190#[cfg(test)]
191mod test {
192 use std::sync::Arc;
193
194 use super::*;
195 use arrow::array::*;
196 use datafusion_common::ScalarValue;
197
198 #[test]
199 fn test_mod_int32() {
200 let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
201 let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
202
203 let left_value = ColumnarValue::Array(Arc::new(left));
204 let right_value = ColumnarValue::Array(Arc::new(right));
205
206 let result = spark_mod(&[left_value, right_value], false).unwrap();
207
208 if let ColumnarValue::Array(result_array) = result {
209 let result_int32 =
210 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
211 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); assert!(result_int32.is_null(3)); } else {
216 panic!("Expected array result");
217 }
218 }
219
220 #[test]
221 fn test_mod_int64() {
222 let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
223 let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
224
225 let left_value = ColumnarValue::Array(Arc::new(left));
226 let right_value = ColumnarValue::Array(Arc::new(right));
227
228 let result = spark_mod(&[left_value, right_value], false).unwrap();
229
230 if let ColumnarValue::Array(result_array) = result {
231 let result_int64 =
232 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
233 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 0); assert_eq!(result_int64.value(2), 20); } else {
237 panic!("Expected array result");
238 }
239 }
240
241 #[test]
242 fn test_mod_float64() {
243 let left = Float64Array::from(vec![
244 Some(10.5),
245 Some(7.2),
246 Some(15.8),
247 Some(f64::NAN),
248 Some(f64::INFINITY),
249 Some(5.0),
250 Some(5.0),
251 Some(f64::NAN),
252 Some(f64::INFINITY),
253 ]);
254 let right = Float64Array::from(vec![
255 Some(3.0),
256 Some(2.5),
257 Some(4.2),
258 Some(2.0),
259 Some(2.0),
260 Some(f64::NAN),
261 Some(f64::INFINITY),
262 Some(f64::INFINITY),
263 Some(f64::NAN),
264 ]);
265
266 let left_value = ColumnarValue::Array(Arc::new(left));
267 let right_value = ColumnarValue::Array(Arc::new(right));
268
269 let result = spark_mod(&[left_value, right_value], false).unwrap();
270
271 if let ColumnarValue::Array(result_array) = result {
272 let result_float64 = result_array
273 .as_any()
274 .downcast_ref::<Float64Array>()
275 .unwrap();
276 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); assert!(result_float64.value(3).is_nan());
282 assert!(result_float64.value(4).is_nan());
284 assert!(result_float64.value(5).is_nan());
286 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
288 assert!(result_float64.value(7).is_nan());
290 assert!(result_float64.value(8).is_nan());
292 } else {
293 panic!("Expected array result");
294 }
295 }
296
297 #[test]
298 fn test_mod_float32() {
299 let left = Float32Array::from(vec![
300 Some(10.5),
301 Some(7.2),
302 Some(15.8),
303 Some(f32::NAN),
304 Some(f32::INFINITY),
305 Some(5.0),
306 Some(5.0),
307 Some(f32::NAN),
308 Some(f32::INFINITY),
309 ]);
310 let right = Float32Array::from(vec![
311 Some(3.0),
312 Some(2.5),
313 Some(4.2),
314 Some(2.0),
315 Some(2.0),
316 Some(f32::NAN),
317 Some(f32::INFINITY),
318 Some(f32::INFINITY),
319 Some(f32::NAN),
320 ]);
321
322 let left_value = ColumnarValue::Array(Arc::new(left));
323 let right_value = ColumnarValue::Array(Arc::new(right));
324
325 let result = spark_mod(&[left_value, right_value], false).unwrap();
326
327 if let ColumnarValue::Array(result_array) = result {
328 let result_float32 = result_array
329 .as_any()
330 .downcast_ref::<Float32Array>()
331 .unwrap();
332 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(3).is_nan());
338 assert!(result_float32.value(4).is_nan());
340 assert!(result_float32.value(5).is_nan());
342 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
344 assert!(result_float32.value(7).is_nan());
346 assert!(result_float32.value(8).is_nan());
348 } else {
349 panic!("Expected array result");
350 }
351 }
352
353 #[test]
354 fn test_mod_scalar() {
355 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
356 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
357
358 let left_value = ColumnarValue::Array(Arc::new(left));
359
360 let result = spark_mod(&[left_value, right_value], false).unwrap();
361
362 if let ColumnarValue::Array(result_array) = result {
363 let result_int32 =
364 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
365 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 0); } else {
369 panic!("Expected array result");
370 }
371 }
372
373 #[test]
374 fn test_mod_wrong_arg_count() {
375 let left = Int32Array::from(vec![Some(10)]);
376 let left_value = ColumnarValue::Array(Arc::new(left));
377
378 let result = spark_mod(&[left_value], false);
379 assert!(result.is_err());
380 }
381
382 #[test]
383 fn test_mod_zero_division_legacy() {
384 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
386 let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
387
388 let left_value = ColumnarValue::Array(Arc::new(left));
389 let right_value = ColumnarValue::Array(Arc::new(right));
390
391 let result = spark_mod(&[left_value, right_value], false).unwrap();
392
393 if let ColumnarValue::Array(result_array) = result {
394 let result_int32 =
395 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
396 assert!(result_int32.is_null(0)); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); } else {
400 panic!("Expected array result");
401 }
402 }
403
404 #[test]
405 fn test_mod_zero_division_ansi() {
406 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
408 let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
409
410 let left_value = ColumnarValue::Array(Arc::new(left));
411 let right_value = ColumnarValue::Array(Arc::new(right));
412
413 let result = spark_mod(&[left_value, right_value], true);
414 assert!(result.is_err());
415 }
416
417 #[test]
419 fn test_pmod_int32() {
420 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
421 let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
422
423 let left_value = ColumnarValue::Array(Arc::new(left));
424 let right_value = ColumnarValue::Array(Arc::new(right));
425
426 let result = spark_pmod(&[left_value, right_value], false).unwrap();
427
428 if let ColumnarValue::Array(result_array) = result {
429 let result_int32 =
430 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
431 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 3); assert_eq!(result_int32.value(3), 1); assert!(result_int32.is_null(4)); } else {
437 panic!("Expected array result");
438 }
439 }
440
441 #[test]
442 fn test_pmod_int64() {
443 let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
444 let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
445
446 let left_value = ColumnarValue::Array(Arc::new(left));
447 let right_value = ColumnarValue::Array(Arc::new(right));
448
449 let result = spark_pmod(&[left_value, right_value], false).unwrap();
450
451 if let ColumnarValue::Array(result_array) = result {
452 let result_int64 =
453 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
454 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 10); assert_eq!(result_int64.value(2), 20); assert_eq!(result_int64.value(3), 40); } else {
459 panic!("Expected array result");
460 }
461 }
462
463 #[test]
464 fn test_pmod_float64() {
465 let left = Float64Array::from(vec![
466 Some(10.5),
467 Some(-7.2),
468 Some(15.8),
469 Some(-15.8),
470 Some(f64::NAN),
471 Some(f64::INFINITY),
472 Some(5.0),
473 Some(-5.0),
474 ]);
475 let right = Float64Array::from(vec![
476 Some(3.0),
477 Some(3.0),
478 Some(4.2),
479 Some(4.2),
480 Some(2.0),
481 Some(2.0),
482 Some(f64::INFINITY),
483 Some(f64::INFINITY),
484 ]);
485
486 let left_value = ColumnarValue::Array(Arc::new(left));
487 let right_value = ColumnarValue::Array(Arc::new(right));
488
489 let result = spark_pmod(&[left_value, right_value], false).unwrap();
490
491 if let ColumnarValue::Array(result_array) = result {
492 let result_float64 = result_array
493 .as_any()
494 .downcast_ref::<Float64Array>()
495 .unwrap();
496 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); assert!(result_float64.value(4).is_nan());
503 assert!(result_float64.value(5).is_nan());
505 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
507 assert!(result_float64.value(7).is_nan());
509 } else {
510 panic!("Expected array result");
511 }
512 }
513
514 #[test]
515 fn test_pmod_float32() {
516 let left = Float32Array::from(vec![
517 Some(10.5),
518 Some(-7.2),
519 Some(15.8),
520 Some(-15.8),
521 Some(f32::NAN),
522 Some(f32::INFINITY),
523 Some(5.0),
524 Some(-5.0),
525 ]);
526 let right = Float32Array::from(vec![
527 Some(3.0),
528 Some(3.0),
529 Some(4.2),
530 Some(4.2),
531 Some(2.0),
532 Some(2.0),
533 Some(f32::INFINITY),
534 Some(f32::INFINITY),
535 ]);
536
537 let left_value = ColumnarValue::Array(Arc::new(left));
538 let right_value = ColumnarValue::Array(Arc::new(right));
539
540 let result = spark_pmod(&[left_value, right_value], false).unwrap();
541
542 if let ColumnarValue::Array(result_array) = result {
543 let result_float32 = result_array
544 .as_any()
545 .downcast_ref::<Float32Array>()
546 .unwrap();
547 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(4).is_nan());
554 assert!(result_float32.value(5).is_nan());
556 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
558 assert!(result_float32.value(7).is_nan());
560 } else {
561 panic!("Expected array result");
562 }
563 }
564
565 #[test]
566 fn test_pmod_scalar() {
567 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
568 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
569
570 let left_value = ColumnarValue::Array(Arc::new(left));
571
572 let result = spark_pmod(&[left_value, right_value], false).unwrap();
573
574 if let ColumnarValue::Array(result_array) = result {
575 let result_int32 =
576 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
577 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 0); assert_eq!(result_int32.value(3), 0); } else {
582 panic!("Expected array result");
583 }
584 }
585
586 #[test]
587 fn test_pmod_wrong_arg_count() {
588 let left = Int32Array::from(vec![Some(10)]);
589 let left_value = ColumnarValue::Array(Arc::new(left));
590
591 let result = spark_pmod(&[left_value], false);
592 assert!(result.is_err());
593 }
594
595 #[test]
596 fn test_pmod_zero_division_legacy() {
597 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
599 let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
600
601 let left_value = ColumnarValue::Array(Arc::new(left));
602 let right_value = ColumnarValue::Array(Arc::new(right));
603
604 let result = spark_pmod(&[left_value, right_value], false).unwrap();
605
606 if let ColumnarValue::Array(result_array) = result {
607 let result_int32 =
608 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
609 assert!(result_int32.is_null(0)); assert!(result_int32.is_null(1)); assert_eq!(result_int32.value(2), 3); } else {
613 panic!("Expected array result");
614 }
615 }
616
617 #[test]
618 fn test_pmod_zero_division_ansi() {
619 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
621 let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
622
623 let left_value = ColumnarValue::Array(Arc::new(left));
624 let right_value = ColumnarValue::Array(Arc::new(right));
625
626 let result = spark_pmod(&[left_value, right_value], true);
627 assert!(result.is_err());
628 }
629
630 #[test]
631 fn test_pmod_negative_divisor() {
632 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
634 let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
635
636 let left_value = ColumnarValue::Array(Arc::new(left));
637 let right_value = ColumnarValue::Array(Arc::new(right));
638
639 let result = spark_pmod(&[left_value, right_value], false).unwrap();
640
641 if let ColumnarValue::Array(result_array) = result {
642 let result_int32 =
643 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
644 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), -1); assert_eq!(result_int32.value(2), 3); } else {
648 panic!("Expected array result");
649 }
650 }
651
652 #[test]
653 fn test_pmod_edge_cases() {
654 let left = Int32Array::from(vec![
656 Some(0), Some(-1), Some(1), Some(-5), Some(5), Some(-6), Some(6), ]);
664 let right = Int32Array::from(vec![
665 Some(5),
666 Some(5),
667 Some(5),
668 Some(5),
669 Some(5),
670 Some(5),
671 Some(5),
672 ]);
673
674 let left_value = ColumnarValue::Array(Arc::new(left));
675 let right_value = ColumnarValue::Array(Arc::new(right));
676
677 let result = spark_pmod(&[left_value, right_value], false).unwrap();
678
679 if let ColumnarValue::Array(result_array) = result {
680 let result_int32 =
681 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
682 assert_eq!(result_int32.value(0), 0); assert_eq!(result_int32.value(1), 4); assert_eq!(result_int32.value(2), 1); assert_eq!(result_int32.value(3), 0); assert_eq!(result_int32.value(4), 0); assert_eq!(result_int32.value(5), 4); assert_eq!(result_int32.value(6), 1); } else {
690 panic!("Expected array result");
691 }
692 }
693}