1use arrow::array::{Scalar, new_null_array};
19use arrow::compute::kernels::numeric::add;
20use arrow::compute::kernels::{
21 cmp::{eq, lt},
22 numeric::rem,
23 zip::zip,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
27use datafusion_expr::{
28 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
29};
30
31fn try_rem(
37 left: &arrow::array::ArrayRef,
38 right: &arrow::array::ArrayRef,
39 enable_ansi_mode: bool,
40) -> Result<arrow::array::ArrayRef> {
41 if enable_ansi_mode {
42 Ok(rem(left, right)?)
43 } else {
44 let zero = ScalarValue::new_zero(right.data_type())?.to_array()?;
47 let zero = Scalar::new(zero);
48 let null = Scalar::new(new_null_array(right.data_type(), 1));
49 let is_zero = eq(right, &zero)?;
50 let safe_right = zip(&is_zero, &null, right)?;
51 Ok(rem(left, &safe_right)?)
52 }
53}
54
55pub fn spark_mod(
59 args: &[ColumnarValue],
60 enable_ansi_mode: bool,
61) -> Result<ColumnarValue> {
62 assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments");
63 let args = ColumnarValue::values_to_arrays(args)?;
64 let result = try_rem(&args[0], &args[1], enable_ansi_mode)?;
65 Ok(ColumnarValue::Array(result))
66}
67
68pub fn spark_pmod(
72 args: &[ColumnarValue],
73 enable_ansi_mode: bool,
74) -> Result<ColumnarValue> {
75 assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments");
76 let args = ColumnarValue::values_to_arrays(args)?;
77 let left = &args[0];
78 let right = &args[1];
79 let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
80 let result = try_rem(left, right, enable_ansi_mode)?;
81 let neg = lt(&result, &zero)?;
82 let plus = zip(&neg, right, &zero)?;
83 let result = add(&plus, &result)?;
84 let result = try_rem(&result, right, enable_ansi_mode)?;
85 Ok(ColumnarValue::Array(result))
86}
87
88#[derive(Debug, PartialEq, Eq, Hash)]
90pub struct SparkMod {
91 signature: Signature,
92}
93
94impl Default for SparkMod {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl SparkMod {
101 pub fn new() -> Self {
102 Self {
103 signature: Signature::numeric(2, Volatility::Immutable),
104 }
105 }
106}
107
108impl ScalarUDFImpl for SparkMod {
109 fn name(&self) -> &str {
110 "mod"
111 }
112
113 fn signature(&self) -> &Signature {
114 &self.signature
115 }
116
117 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
118 assert_eq_or_internal_err!(
119 arg_types.len(),
120 2,
121 "mod expects exactly two arguments"
122 );
123
124 Ok(arg_types[0].clone())
127 }
128
129 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
130 spark_mod(&args.args, args.config_options.execution.enable_ansi_mode)
131 }
132}
133
134#[derive(Debug, PartialEq, Eq, Hash)]
136pub struct SparkPmod {
137 signature: Signature,
138}
139
140impl Default for SparkPmod {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl SparkPmod {
147 pub fn new() -> Self {
148 Self {
149 signature: Signature::numeric(2, Volatility::Immutable),
150 }
151 }
152}
153
154impl ScalarUDFImpl for SparkPmod {
155 fn name(&self) -> &str {
156 "pmod"
157 }
158
159 fn signature(&self) -> &Signature {
160 &self.signature
161 }
162
163 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
164 assert_eq_or_internal_err!(
165 arg_types.len(),
166 2,
167 "pmod expects exactly two arguments"
168 );
169
170 Ok(arg_types[0].clone())
173 }
174
175 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
176 spark_pmod(&args.args, args.config_options.execution.enable_ansi_mode)
177 }
178}
179
180#[cfg(test)]
181mod test {
182 use std::sync::Arc;
183
184 use super::*;
185 use arrow::array::*;
186 use datafusion_common::ScalarValue;
187
188 #[test]
189 fn test_mod_int32() {
190 let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
191 let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
192
193 let left_value = ColumnarValue::Array(Arc::new(left));
194 let right_value = ColumnarValue::Array(Arc::new(right));
195
196 let result = spark_mod(&[left_value, right_value], false).unwrap();
197
198 if let ColumnarValue::Array(result_array) = result {
199 let result_int32 =
200 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
201 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); assert!(result_int32.is_null(3)); } else {
206 panic!("Expected array result");
207 }
208 }
209
210 #[test]
211 fn test_mod_int64() {
212 let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
213 let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
214
215 let left_value = ColumnarValue::Array(Arc::new(left));
216 let right_value = ColumnarValue::Array(Arc::new(right));
217
218 let result = spark_mod(&[left_value, right_value], false).unwrap();
219
220 if let ColumnarValue::Array(result_array) = result {
221 let result_int64 =
222 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
223 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 0); assert_eq!(result_int64.value(2), 20); } else {
227 panic!("Expected array result");
228 }
229 }
230
231 #[test]
232 fn test_mod_float64() {
233 let left = Float64Array::from(vec![
234 Some(10.5),
235 Some(7.2),
236 Some(15.8),
237 Some(f64::NAN),
238 Some(f64::INFINITY),
239 Some(5.0),
240 Some(5.0),
241 Some(f64::NAN),
242 Some(f64::INFINITY),
243 Some(10.5),
244 Some(15.8),
245 ]);
246 let right = Float64Array::from(vec![
247 Some(3.0),
248 Some(2.5),
249 Some(4.2),
250 Some(2.0),
251 Some(2.0),
252 Some(f64::NAN),
253 Some(f64::INFINITY),
254 Some(f64::INFINITY),
255 Some(f64::NAN),
256 Some(0.0),
257 Some(0.0),
258 ]);
259
260 let left_value = ColumnarValue::Array(Arc::new(left));
261 let right_value = ColumnarValue::Array(Arc::new(right));
262
263 let result = spark_mod(&[left_value, right_value], false).unwrap();
264
265 if let ColumnarValue::Array(result_array) = result {
266 let result_float64 = result_array
267 .as_any()
268 .downcast_ref::<Float64Array>()
269 .unwrap();
270 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); assert!(result_float64.value(3).is_nan());
276 assert!(result_float64.value(4).is_nan());
278 assert!(result_float64.value(5).is_nan());
280 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
282 assert!(result_float64.value(7).is_nan());
284 assert!(result_float64.value(8).is_nan());
286 assert!(result_float64.is_null(9)); assert!(result_float64.is_null(10)); } else {
290 panic!("Expected array result");
291 }
292 }
293
294 #[test]
295 fn test_mod_float32() {
296 let left = Float32Array::from(vec![
297 Some(10.5),
298 Some(7.2),
299 Some(15.8),
300 Some(f32::NAN),
301 Some(f32::INFINITY),
302 Some(5.0),
303 Some(5.0),
304 Some(f32::NAN),
305 Some(f32::INFINITY),
306 Some(10.5),
307 Some(15.8),
308 ]);
309 let right = Float32Array::from(vec![
310 Some(3.0),
311 Some(2.5),
312 Some(4.2),
313 Some(2.0),
314 Some(2.0),
315 Some(f32::NAN),
316 Some(f32::INFINITY),
317 Some(f32::INFINITY),
318 Some(f32::NAN),
319 Some(0.0),
320 Some(0.0),
321 ]);
322
323 let left_value = ColumnarValue::Array(Arc::new(left));
324 let right_value = ColumnarValue::Array(Arc::new(right));
325
326 let result = spark_mod(&[left_value, right_value], false).unwrap();
327
328 if let ColumnarValue::Array(result_array) = result {
329 let result_float32 = result_array
330 .as_any()
331 .downcast_ref::<Float32Array>()
332 .unwrap();
333 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(3).is_nan());
339 assert!(result_float32.value(4).is_nan());
341 assert!(result_float32.value(5).is_nan());
343 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
345 assert!(result_float32.value(7).is_nan());
347 assert!(result_float32.value(8).is_nan());
349 assert!(result_float32.is_null(9)); assert!(result_float32.is_null(10)); } else {
353 panic!("Expected array result");
354 }
355 }
356
357 #[test]
358 fn test_mod_scalar() {
359 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
360 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
361
362 let left_value = ColumnarValue::Array(Arc::new(left));
363
364 let result = spark_mod(&[left_value, right_value], false).unwrap();
365
366 if let ColumnarValue::Array(result_array) = result {
367 let result_int32 =
368 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
369 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 0); } else {
373 panic!("Expected array result");
374 }
375 }
376
377 #[test]
378 fn test_mod_wrong_arg_count() {
379 let left = Int32Array::from(vec![Some(10)]);
380 let left_value = ColumnarValue::Array(Arc::new(left));
381
382 let result = spark_mod(&[left_value], false);
383 assert!(result.is_err());
384 }
385
386 #[test]
387 fn test_mod_zero_division_legacy() {
388 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
390 let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
391
392 let left_value = ColumnarValue::Array(Arc::new(left));
393 let right_value = ColumnarValue::Array(Arc::new(right));
394
395 let result = spark_mod(&[left_value, right_value], false).unwrap();
396
397 if let ColumnarValue::Array(result_array) = result {
398 let result_int32 =
399 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
400 assert!(result_int32.is_null(0)); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); } else {
404 panic!("Expected array result");
405 }
406 }
407
408 #[test]
409 fn test_mod_zero_division_ansi() {
410 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
412 let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
413
414 let left_value = ColumnarValue::Array(Arc::new(left));
415 let right_value = ColumnarValue::Array(Arc::new(right));
416
417 let result = spark_mod(&[left_value, right_value], true);
418 assert!(result.is_err());
419 }
420
421 #[test]
423 fn test_pmod_int32() {
424 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
425 let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
426
427 let left_value = ColumnarValue::Array(Arc::new(left));
428 let right_value = ColumnarValue::Array(Arc::new(right));
429
430 let result = spark_pmod(&[left_value, right_value], false).unwrap();
431
432 if let ColumnarValue::Array(result_array) = result {
433 let result_int32 =
434 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
435 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 3); assert_eq!(result_int32.value(3), 1); assert!(result_int32.is_null(4)); } else {
441 panic!("Expected array result");
442 }
443 }
444
445 #[test]
446 fn test_pmod_int64() {
447 let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
448 let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
449
450 let left_value = ColumnarValue::Array(Arc::new(left));
451 let right_value = ColumnarValue::Array(Arc::new(right));
452
453 let result = spark_pmod(&[left_value, right_value], false).unwrap();
454
455 if let ColumnarValue::Array(result_array) = result {
456 let result_int64 =
457 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
458 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 10); assert_eq!(result_int64.value(2), 20); assert_eq!(result_int64.value(3), 40); } else {
463 panic!("Expected array result");
464 }
465 }
466
467 #[test]
468 fn test_pmod_float64() {
469 let left = Float64Array::from(vec![
470 Some(10.5),
471 Some(-7.2),
472 Some(15.8),
473 Some(-15.8),
474 Some(f64::NAN),
475 Some(f64::INFINITY),
476 Some(5.0),
477 Some(-5.0),
478 Some(10.5),
479 Some(-7.2),
480 ]);
481 let right = Float64Array::from(vec![
482 Some(3.0),
483 Some(3.0),
484 Some(4.2),
485 Some(4.2),
486 Some(2.0),
487 Some(2.0),
488 Some(f64::INFINITY),
489 Some(f64::INFINITY),
490 Some(0.0),
491 Some(0.0),
492 ]);
493
494 let left_value = ColumnarValue::Array(Arc::new(left));
495 let right_value = ColumnarValue::Array(Arc::new(right));
496
497 let result = spark_pmod(&[left_value, right_value], false).unwrap();
498
499 if let ColumnarValue::Array(result_array) = result {
500 let result_float64 = result_array
501 .as_any()
502 .downcast_ref::<Float64Array>()
503 .unwrap();
504 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); assert!(result_float64.value(4).is_nan());
511 assert!(result_float64.value(5).is_nan());
513 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
515 assert!(result_float64.value(7).is_nan());
517 assert!(result_float64.is_null(8)); assert!(result_float64.is_null(9)); } else {
521 panic!("Expected array result");
522 }
523 }
524
525 #[test]
526 fn test_pmod_float32() {
527 let left = Float32Array::from(vec![
528 Some(10.5),
529 Some(-7.2),
530 Some(15.8),
531 Some(-15.8),
532 Some(f32::NAN),
533 Some(f32::INFINITY),
534 Some(5.0),
535 Some(-5.0),
536 Some(10.5),
537 Some(-7.2),
538 ]);
539 let right = Float32Array::from(vec![
540 Some(3.0),
541 Some(3.0),
542 Some(4.2),
543 Some(4.2),
544 Some(2.0),
545 Some(2.0),
546 Some(f32::INFINITY),
547 Some(f32::INFINITY),
548 Some(0.0),
549 Some(0.0),
550 ]);
551
552 let left_value = ColumnarValue::Array(Arc::new(left));
553 let right_value = ColumnarValue::Array(Arc::new(right));
554
555 let result = spark_pmod(&[left_value, right_value], false).unwrap();
556
557 if let ColumnarValue::Array(result_array) = result {
558 let result_float32 = result_array
559 .as_any()
560 .downcast_ref::<Float32Array>()
561 .unwrap();
562 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(4).is_nan());
569 assert!(result_float32.value(5).is_nan());
571 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
573 assert!(result_float32.value(7).is_nan());
575 assert!(result_float32.is_null(8)); assert!(result_float32.is_null(9)); } else {
579 panic!("Expected array result");
580 }
581 }
582
583 #[test]
584 fn test_pmod_scalar() {
585 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
586 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
587
588 let left_value = ColumnarValue::Array(Arc::new(left));
589
590 let result = spark_pmod(&[left_value, right_value], false).unwrap();
591
592 if let ColumnarValue::Array(result_array) = result {
593 let result_int32 =
594 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
595 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 0); assert_eq!(result_int32.value(3), 0); } else {
600 panic!("Expected array result");
601 }
602 }
603
604 #[test]
605 fn test_pmod_wrong_arg_count() {
606 let left = Int32Array::from(vec![Some(10)]);
607 let left_value = ColumnarValue::Array(Arc::new(left));
608
609 let result = spark_pmod(&[left_value], false);
610 assert!(result.is_err());
611 }
612
613 #[test]
614 fn test_pmod_zero_division_legacy() {
615 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
617 let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
618
619 let left_value = ColumnarValue::Array(Arc::new(left));
620 let right_value = ColumnarValue::Array(Arc::new(right));
621
622 let result = spark_pmod(&[left_value, right_value], false).unwrap();
623
624 if let ColumnarValue::Array(result_array) = result {
625 let result_int32 =
626 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
627 assert!(result_int32.is_null(0)); assert!(result_int32.is_null(1)); assert_eq!(result_int32.value(2), 3); } else {
631 panic!("Expected array result");
632 }
633 }
634
635 #[test]
636 fn test_pmod_zero_division_ansi() {
637 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
639 let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
640
641 let left_value = ColumnarValue::Array(Arc::new(left));
642 let right_value = ColumnarValue::Array(Arc::new(right));
643
644 let result = spark_pmod(&[left_value, right_value], true);
645 assert!(result.is_err());
646 }
647
648 #[test]
649 fn test_pmod_negative_divisor() {
650 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
652 let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
653
654 let left_value = ColumnarValue::Array(Arc::new(left));
655 let right_value = ColumnarValue::Array(Arc::new(right));
656
657 let result = spark_pmod(&[left_value, right_value], false).unwrap();
658
659 if let ColumnarValue::Array(result_array) = result {
660 let result_int32 =
661 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
662 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), -1); assert_eq!(result_int32.value(2), 3); } else {
666 panic!("Expected array result");
667 }
668 }
669
670 #[test]
671 fn test_pmod_edge_cases() {
672 let left = Int32Array::from(vec![
674 Some(0), Some(-1), Some(1), Some(-5), Some(5), Some(-6), Some(6), ]);
682 let right = Int32Array::from(vec![
683 Some(5),
684 Some(5),
685 Some(5),
686 Some(5),
687 Some(5),
688 Some(5),
689 Some(5),
690 ]);
691
692 let left_value = ColumnarValue::Array(Arc::new(left));
693 let right_value = ColumnarValue::Array(Arc::new(right));
694
695 let result = spark_pmod(&[left_value, right_value], false).unwrap();
696
697 if let ColumnarValue::Array(result_array) = result {
698 let result_int32 =
699 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
700 assert_eq!(result_int32.value(0), 0); assert_eq!(result_int32.value(1), 4); assert_eq!(result_int32.value(2), 1); assert_eq!(result_int32.value(3), 0); assert_eq!(result_int32.value(4), 0); assert_eq!(result_int32.value(5), 4); assert_eq!(result_int32.value(6), 1); } else {
708 panic!("Expected array result");
709 }
710 }
711}