1use arrow::compute::kernels::numeric::add;
19use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip};
20use arrow::datatypes::DataType;
21use datafusion_common::{DataFusionError, Result, ScalarValue};
22use datafusion_expr::{
23 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use std::any::Any;
26
27pub fn spark_mod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
30 if args.len() != 2 {
31 return Err(DataFusionError::Internal(
32 "mod expects exactly two arguments".to_string(),
33 ));
34 }
35 let args = ColumnarValue::values_to_arrays(args)?;
36 let result = rem(&args[0], &args[1])?;
37 Ok(ColumnarValue::Array(result))
38}
39
40pub fn spark_pmod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
43 if args.len() != 2 {
44 return Err(DataFusionError::Internal(
45 "pmod expects exactly two arguments".to_string(),
46 ));
47 }
48 let args = ColumnarValue::values_to_arrays(args)?;
49 let left = &args[0];
50 let right = &args[1];
51 let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
52 let result = rem(left, right)?;
53 let neg = lt(&result, &zero)?;
54 let plus = zip(&neg, right, &zero)?;
55 let result = add(&plus, &result)?;
56 let result = rem(&result, right)?;
57 Ok(ColumnarValue::Array(result))
58}
59
60#[derive(Debug, PartialEq, Eq, Hash)]
62pub struct SparkMod {
63 signature: Signature,
64}
65
66impl Default for SparkMod {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl SparkMod {
73 pub fn new() -> Self {
74 Self {
75 signature: Signature::numeric(2, Volatility::Immutable),
76 }
77 }
78}
79
80impl ScalarUDFImpl for SparkMod {
81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84
85 fn name(&self) -> &str {
86 "mod"
87 }
88
89 fn signature(&self) -> &Signature {
90 &self.signature
91 }
92
93 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94 if arg_types.len() != 2 {
95 return Err(DataFusionError::Internal(
96 "mod expects exactly two arguments".to_string(),
97 ));
98 }
99
100 Ok(arg_types[0].clone())
103 }
104
105 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106 spark_mod(&args.args)
107 }
108}
109
110#[derive(Debug, PartialEq, Eq, Hash)]
112pub struct SparkPmod {
113 signature: Signature,
114}
115
116impl Default for SparkPmod {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl SparkPmod {
123 pub fn new() -> Self {
124 Self {
125 signature: Signature::numeric(2, Volatility::Immutable),
126 }
127 }
128}
129
130impl ScalarUDFImpl for SparkPmod {
131 fn as_any(&self) -> &dyn Any {
132 self
133 }
134
135 fn name(&self) -> &str {
136 "pmod"
137 }
138
139 fn signature(&self) -> &Signature {
140 &self.signature
141 }
142
143 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
144 if arg_types.len() != 2 {
145 return Err(DataFusionError::Internal(
146 "pmod expects exactly two arguments".to_string(),
147 ));
148 }
149
150 Ok(arg_types[0].clone())
153 }
154
155 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
156 spark_pmod(&args.args)
157 }
158}
159
160#[cfg(test)]
161mod test {
162 use std::sync::Arc;
163
164 use super::*;
165 use arrow::array::*;
166 use datafusion_common::ScalarValue;
167
168 #[test]
169 fn test_mod_int32() {
170 let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
171 let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
172
173 let left_value = ColumnarValue::Array(Arc::new(left));
174 let right_value = ColumnarValue::Array(Arc::new(right));
175
176 let result = spark_mod(&[left_value, right_value]).unwrap();
177
178 if let ColumnarValue::Array(result_array) = result {
179 let result_int32 =
180 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
181 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); assert!(result_int32.is_null(3)); } else {
186 panic!("Expected array result");
187 }
188 }
189
190 #[test]
191 fn test_mod_int64() {
192 let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
193 let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
194
195 let left_value = ColumnarValue::Array(Arc::new(left));
196 let right_value = ColumnarValue::Array(Arc::new(right));
197
198 let result = spark_mod(&[left_value, right_value]).unwrap();
199
200 if let ColumnarValue::Array(result_array) = result {
201 let result_int64 =
202 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
203 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 0); assert_eq!(result_int64.value(2), 20); } else {
207 panic!("Expected array result");
208 }
209 }
210
211 #[test]
212 fn test_mod_float64() {
213 let left = Float64Array::from(vec![
214 Some(10.5),
215 Some(7.2),
216 Some(15.8),
217 Some(f64::NAN),
218 Some(f64::INFINITY),
219 Some(5.0),
220 Some(5.0),
221 Some(f64::NAN),
222 Some(f64::INFINITY),
223 ]);
224 let right = Float64Array::from(vec![
225 Some(3.0),
226 Some(2.5),
227 Some(4.2),
228 Some(2.0),
229 Some(2.0),
230 Some(f64::NAN),
231 Some(f64::INFINITY),
232 Some(f64::INFINITY),
233 Some(f64::NAN),
234 ]);
235
236 let left_value = ColumnarValue::Array(Arc::new(left));
237 let right_value = ColumnarValue::Array(Arc::new(right));
238
239 let result = spark_mod(&[left_value, right_value]).unwrap();
240
241 if let ColumnarValue::Array(result_array) = result {
242 let result_float64 = result_array
243 .as_any()
244 .downcast_ref::<Float64Array>()
245 .unwrap();
246 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); assert!(result_float64.value(3).is_nan());
252 assert!(result_float64.value(4).is_nan());
254 assert!(result_float64.value(5).is_nan());
256 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
258 assert!(result_float64.value(7).is_nan());
260 assert!(result_float64.value(8).is_nan());
262 } else {
263 panic!("Expected array result");
264 }
265 }
266
267 #[test]
268 fn test_mod_float32() {
269 let left = Float32Array::from(vec![
270 Some(10.5),
271 Some(7.2),
272 Some(15.8),
273 Some(f32::NAN),
274 Some(f32::INFINITY),
275 Some(5.0),
276 Some(5.0),
277 Some(f32::NAN),
278 Some(f32::INFINITY),
279 ]);
280 let right = Float32Array::from(vec![
281 Some(3.0),
282 Some(2.5),
283 Some(4.2),
284 Some(2.0),
285 Some(2.0),
286 Some(f32::NAN),
287 Some(f32::INFINITY),
288 Some(f32::INFINITY),
289 Some(f32::NAN),
290 ]);
291
292 let left_value = ColumnarValue::Array(Arc::new(left));
293 let right_value = ColumnarValue::Array(Arc::new(right));
294
295 let result = spark_mod(&[left_value, right_value]).unwrap();
296
297 if let ColumnarValue::Array(result_array) = result {
298 let result_float32 = result_array
299 .as_any()
300 .downcast_ref::<Float32Array>()
301 .unwrap();
302 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(3).is_nan());
308 assert!(result_float32.value(4).is_nan());
310 assert!(result_float32.value(5).is_nan());
312 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
314 assert!(result_float32.value(7).is_nan());
316 assert!(result_float32.value(8).is_nan());
318 } else {
319 panic!("Expected array result");
320 }
321 }
322
323 #[test]
324 fn test_mod_scalar() {
325 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
326 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
327
328 let left_value = ColumnarValue::Array(Arc::new(left));
329
330 let result = spark_mod(&[left_value, right_value]).unwrap();
331
332 if let ColumnarValue::Array(result_array) = result {
333 let result_int32 =
334 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
335 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 0); } else {
339 panic!("Expected array result");
340 }
341 }
342
343 #[test]
344 fn test_mod_wrong_arg_count() {
345 let left = Int32Array::from(vec![Some(10)]);
346 let left_value = ColumnarValue::Array(Arc::new(left));
347
348 let result = spark_mod(&[left_value]);
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn test_mod_zero_division() {
354 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
355 let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
356
357 let left_value = ColumnarValue::Array(Arc::new(left));
358 let right_value = ColumnarValue::Array(Arc::new(right));
359
360 let result = spark_mod(&[left_value, right_value]);
361 assert!(result.is_err()); }
363
364 #[test]
366 fn test_pmod_int32() {
367 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
368 let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
369
370 let left_value = ColumnarValue::Array(Arc::new(left));
371 let right_value = ColumnarValue::Array(Arc::new(right));
372
373 let result = spark_pmod(&[left_value, right_value]).unwrap();
374
375 if let ColumnarValue::Array(result_array) = result {
376 let result_int32 =
377 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
378 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 3); assert_eq!(result_int32.value(3), 1); assert!(result_int32.is_null(4)); } else {
384 panic!("Expected array result");
385 }
386 }
387
388 #[test]
389 fn test_pmod_int64() {
390 let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
391 let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
392
393 let left_value = ColumnarValue::Array(Arc::new(left));
394 let right_value = ColumnarValue::Array(Arc::new(right));
395
396 let result = spark_pmod(&[left_value, right_value]).unwrap();
397
398 if let ColumnarValue::Array(result_array) = result {
399 let result_int64 =
400 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
401 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 10); assert_eq!(result_int64.value(2), 20); assert_eq!(result_int64.value(3), 40); } else {
406 panic!("Expected array result");
407 }
408 }
409
410 #[test]
411 fn test_pmod_float64() {
412 let left = Float64Array::from(vec![
413 Some(10.5),
414 Some(-7.2),
415 Some(15.8),
416 Some(-15.8),
417 Some(f64::NAN),
418 Some(f64::INFINITY),
419 Some(5.0),
420 Some(-5.0),
421 ]);
422 let right = Float64Array::from(vec![
423 Some(3.0),
424 Some(3.0),
425 Some(4.2),
426 Some(4.2),
427 Some(2.0),
428 Some(2.0),
429 Some(f64::INFINITY),
430 Some(f64::INFINITY),
431 ]);
432
433 let left_value = ColumnarValue::Array(Arc::new(left));
434 let right_value = ColumnarValue::Array(Arc::new(right));
435
436 let result = spark_pmod(&[left_value, right_value]).unwrap();
437
438 if let ColumnarValue::Array(result_array) = result {
439 let result_float64 = result_array
440 .as_any()
441 .downcast_ref::<Float64Array>()
442 .unwrap();
443 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); assert!(result_float64.value(4).is_nan());
450 assert!(result_float64.value(5).is_nan());
452 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
454 assert!(result_float64.value(7).is_nan());
456 } else {
457 panic!("Expected array result");
458 }
459 }
460
461 #[test]
462 fn test_pmod_float32() {
463 let left = Float32Array::from(vec![
464 Some(10.5),
465 Some(-7.2),
466 Some(15.8),
467 Some(-15.8),
468 Some(f32::NAN),
469 Some(f32::INFINITY),
470 Some(5.0),
471 Some(-5.0),
472 ]);
473 let right = Float32Array::from(vec![
474 Some(3.0),
475 Some(3.0),
476 Some(4.2),
477 Some(4.2),
478 Some(2.0),
479 Some(2.0),
480 Some(f32::INFINITY),
481 Some(f32::INFINITY),
482 ]);
483
484 let left_value = ColumnarValue::Array(Arc::new(left));
485 let right_value = ColumnarValue::Array(Arc::new(right));
486
487 let result = spark_pmod(&[left_value, right_value]).unwrap();
488
489 if let ColumnarValue::Array(result_array) = result {
490 let result_float32 = result_array
491 .as_any()
492 .downcast_ref::<Float32Array>()
493 .unwrap();
494 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(4).is_nan());
501 assert!(result_float32.value(5).is_nan());
503 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
505 assert!(result_float32.value(7).is_nan());
507 } else {
508 panic!("Expected array result");
509 }
510 }
511
512 #[test]
513 fn test_pmod_scalar() {
514 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
515 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
516
517 let left_value = ColumnarValue::Array(Arc::new(left));
518
519 let result = spark_pmod(&[left_value, right_value]).unwrap();
520
521 if let ColumnarValue::Array(result_array) = result {
522 let result_int32 =
523 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
524 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 0); assert_eq!(result_int32.value(3), 0); } else {
529 panic!("Expected array result");
530 }
531 }
532
533 #[test]
534 fn test_pmod_wrong_arg_count() {
535 let left = Int32Array::from(vec![Some(10)]);
536 let left_value = ColumnarValue::Array(Arc::new(left));
537
538 let result = spark_pmod(&[left_value]);
539 assert!(result.is_err());
540 }
541
542 #[test]
543 fn test_pmod_zero_division() {
544 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
545 let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
546
547 let left_value = ColumnarValue::Array(Arc::new(left));
548 let right_value = ColumnarValue::Array(Arc::new(right));
549
550 let result = spark_pmod(&[left_value, right_value]);
551 assert!(result.is_err()); }
553
554 #[test]
555 fn test_pmod_negative_divisor() {
556 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
558 let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
559
560 let left_value = ColumnarValue::Array(Arc::new(left));
561 let right_value = ColumnarValue::Array(Arc::new(right));
562
563 let result = spark_pmod(&[left_value, right_value]).unwrap();
564
565 if let ColumnarValue::Array(result_array) = result {
566 let result_int32 =
567 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
568 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), -1); assert_eq!(result_int32.value(2), 3); } else {
572 panic!("Expected array result");
573 }
574 }
575
576 #[test]
577 fn test_pmod_edge_cases() {
578 let left = Int32Array::from(vec![
580 Some(0), Some(-1), Some(1), Some(-5), Some(5), Some(-6), Some(6), ]);
588 let right = Int32Array::from(vec![
589 Some(5),
590 Some(5),
591 Some(5),
592 Some(5),
593 Some(5),
594 Some(5),
595 Some(5),
596 ]);
597
598 let left_value = ColumnarValue::Array(Arc::new(left));
599 let right_value = ColumnarValue::Array(Arc::new(right));
600
601 let result = spark_pmod(&[left_value, right_value]).unwrap();
602
603 if let ColumnarValue::Array(result_array) = result {
604 let result_int32 =
605 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
606 assert_eq!(result_int32.value(0), 0); assert_eq!(result_int32.value(1), 4); assert_eq!(result_int32.value(2), 1); assert_eq!(result_int32.value(3), 0); assert_eq!(result_int32.value(4), 0); assert_eq!(result_int32.value(5), 4); assert_eq!(result_int32.value(6), 1); } else {
614 panic!("Expected array result");
615 }
616 }
617}