1use arrow::compute::kernels::numeric::add;
19use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip};
20use arrow::datatypes::DataType;
21use datafusion_common::{internal_err, Result, ScalarValue};
22use datafusion_expr::{
23 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use std::any::Any;
26
27pub fn spark_mod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
30 if args.len() != 2 {
31 return internal_err!("mod expects exactly two arguments");
32 }
33 let args = ColumnarValue::values_to_arrays(args)?;
34 let result = rem(&args[0], &args[1])?;
35 Ok(ColumnarValue::Array(result))
36}
37
38pub fn spark_pmod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
41 if args.len() != 2 {
42 return internal_err!("pmod expects exactly two arguments");
43 }
44 let args = ColumnarValue::values_to_arrays(args)?;
45 let left = &args[0];
46 let right = &args[1];
47 let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
48 let result = rem(left, right)?;
49 let neg = lt(&result, &zero)?;
50 let plus = zip(&neg, right, &zero)?;
51 let result = add(&plus, &result)?;
52 let result = rem(&result, right)?;
53 Ok(ColumnarValue::Array(result))
54}
55
56#[derive(Debug, PartialEq, Eq, Hash)]
58pub struct SparkMod {
59 signature: Signature,
60}
61
62impl Default for SparkMod {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl SparkMod {
69 pub fn new() -> Self {
70 Self {
71 signature: Signature::numeric(2, Volatility::Immutable),
72 }
73 }
74}
75
76impl ScalarUDFImpl for SparkMod {
77 fn as_any(&self) -> &dyn Any {
78 self
79 }
80
81 fn name(&self) -> &str {
82 "mod"
83 }
84
85 fn signature(&self) -> &Signature {
86 &self.signature
87 }
88
89 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
90 if arg_types.len() != 2 {
91 return internal_err!("mod expects exactly two arguments");
92 }
93
94 Ok(arg_types[0].clone())
97 }
98
99 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100 spark_mod(&args.args)
101 }
102}
103
104#[derive(Debug, PartialEq, Eq, Hash)]
106pub struct SparkPmod {
107 signature: Signature,
108}
109
110impl Default for SparkPmod {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116impl SparkPmod {
117 pub fn new() -> Self {
118 Self {
119 signature: Signature::numeric(2, Volatility::Immutable),
120 }
121 }
122}
123
124impl ScalarUDFImpl for SparkPmod {
125 fn as_any(&self) -> &dyn Any {
126 self
127 }
128
129 fn name(&self) -> &str {
130 "pmod"
131 }
132
133 fn signature(&self) -> &Signature {
134 &self.signature
135 }
136
137 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
138 if arg_types.len() != 2 {
139 return internal_err!("pmod expects exactly two arguments");
140 }
141
142 Ok(arg_types[0].clone())
145 }
146
147 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
148 spark_pmod(&args.args)
149 }
150}
151
152#[cfg(test)]
153mod test {
154 use std::sync::Arc;
155
156 use super::*;
157 use arrow::array::*;
158 use datafusion_common::ScalarValue;
159
160 #[test]
161 fn test_mod_int32() {
162 let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
163 let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
164
165 let left_value = ColumnarValue::Array(Arc::new(left));
166 let right_value = ColumnarValue::Array(Arc::new(right));
167
168 let result = spark_mod(&[left_value, right_value]).unwrap();
169
170 if let ColumnarValue::Array(result_array) = result {
171 let result_int32 =
172 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
173 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); assert!(result_int32.is_null(3)); } else {
178 panic!("Expected array result");
179 }
180 }
181
182 #[test]
183 fn test_mod_int64() {
184 let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
185 let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
186
187 let left_value = ColumnarValue::Array(Arc::new(left));
188 let right_value = ColumnarValue::Array(Arc::new(right));
189
190 let result = spark_mod(&[left_value, right_value]).unwrap();
191
192 if let ColumnarValue::Array(result_array) = result {
193 let result_int64 =
194 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
195 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 0); assert_eq!(result_int64.value(2), 20); } else {
199 panic!("Expected array result");
200 }
201 }
202
203 #[test]
204 fn test_mod_float64() {
205 let left = Float64Array::from(vec![
206 Some(10.5),
207 Some(7.2),
208 Some(15.8),
209 Some(f64::NAN),
210 Some(f64::INFINITY),
211 Some(5.0),
212 Some(5.0),
213 Some(f64::NAN),
214 Some(f64::INFINITY),
215 ]);
216 let right = Float64Array::from(vec![
217 Some(3.0),
218 Some(2.5),
219 Some(4.2),
220 Some(2.0),
221 Some(2.0),
222 Some(f64::NAN),
223 Some(f64::INFINITY),
224 Some(f64::INFINITY),
225 Some(f64::NAN),
226 ]);
227
228 let left_value = ColumnarValue::Array(Arc::new(left));
229 let right_value = ColumnarValue::Array(Arc::new(right));
230
231 let result = spark_mod(&[left_value, right_value]).unwrap();
232
233 if let ColumnarValue::Array(result_array) = result {
234 let result_float64 = result_array
235 .as_any()
236 .downcast_ref::<Float64Array>()
237 .unwrap();
238 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); assert!(result_float64.value(3).is_nan());
244 assert!(result_float64.value(4).is_nan());
246 assert!(result_float64.value(5).is_nan());
248 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
250 assert!(result_float64.value(7).is_nan());
252 assert!(result_float64.value(8).is_nan());
254 } else {
255 panic!("Expected array result");
256 }
257 }
258
259 #[test]
260 fn test_mod_float32() {
261 let left = Float32Array::from(vec![
262 Some(10.5),
263 Some(7.2),
264 Some(15.8),
265 Some(f32::NAN),
266 Some(f32::INFINITY),
267 Some(5.0),
268 Some(5.0),
269 Some(f32::NAN),
270 Some(f32::INFINITY),
271 ]);
272 let right = Float32Array::from(vec![
273 Some(3.0),
274 Some(2.5),
275 Some(4.2),
276 Some(2.0),
277 Some(2.0),
278 Some(f32::NAN),
279 Some(f32::INFINITY),
280 Some(f32::INFINITY),
281 Some(f32::NAN),
282 ]);
283
284 let left_value = ColumnarValue::Array(Arc::new(left));
285 let right_value = ColumnarValue::Array(Arc::new(right));
286
287 let result = spark_mod(&[left_value, right_value]).unwrap();
288
289 if let ColumnarValue::Array(result_array) = result {
290 let result_float32 = result_array
291 .as_any()
292 .downcast_ref::<Float32Array>()
293 .unwrap();
294 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(3).is_nan());
300 assert!(result_float32.value(4).is_nan());
302 assert!(result_float32.value(5).is_nan());
304 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
306 assert!(result_float32.value(7).is_nan());
308 assert!(result_float32.value(8).is_nan());
310 } else {
311 panic!("Expected array result");
312 }
313 }
314
315 #[test]
316 fn test_mod_scalar() {
317 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
318 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
319
320 let left_value = ColumnarValue::Array(Arc::new(left));
321
322 let result = spark_mod(&[left_value, right_value]).unwrap();
323
324 if let ColumnarValue::Array(result_array) = result {
325 let result_int32 =
326 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
327 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 0); } else {
331 panic!("Expected array result");
332 }
333 }
334
335 #[test]
336 fn test_mod_wrong_arg_count() {
337 let left = Int32Array::from(vec![Some(10)]);
338 let left_value = ColumnarValue::Array(Arc::new(left));
339
340 let result = spark_mod(&[left_value]);
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn test_mod_zero_division() {
346 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
347 let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
348
349 let left_value = ColumnarValue::Array(Arc::new(left));
350 let right_value = ColumnarValue::Array(Arc::new(right));
351
352 let result = spark_mod(&[left_value, right_value]);
353 assert!(result.is_err()); }
355
356 #[test]
358 fn test_pmod_int32() {
359 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
360 let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
361
362 let left_value = ColumnarValue::Array(Arc::new(left));
363 let right_value = ColumnarValue::Array(Arc::new(right));
364
365 let result = spark_pmod(&[left_value, right_value]).unwrap();
366
367 if let ColumnarValue::Array(result_array) = result {
368 let result_int32 =
369 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
370 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 3); assert_eq!(result_int32.value(3), 1); assert!(result_int32.is_null(4)); } else {
376 panic!("Expected array result");
377 }
378 }
379
380 #[test]
381 fn test_pmod_int64() {
382 let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
383 let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
384
385 let left_value = ColumnarValue::Array(Arc::new(left));
386 let right_value = ColumnarValue::Array(Arc::new(right));
387
388 let result = spark_pmod(&[left_value, right_value]).unwrap();
389
390 if let ColumnarValue::Array(result_array) = result {
391 let result_int64 =
392 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
393 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 10); assert_eq!(result_int64.value(2), 20); assert_eq!(result_int64.value(3), 40); } else {
398 panic!("Expected array result");
399 }
400 }
401
402 #[test]
403 fn test_pmod_float64() {
404 let left = Float64Array::from(vec![
405 Some(10.5),
406 Some(-7.2),
407 Some(15.8),
408 Some(-15.8),
409 Some(f64::NAN),
410 Some(f64::INFINITY),
411 Some(5.0),
412 Some(-5.0),
413 ]);
414 let right = Float64Array::from(vec![
415 Some(3.0),
416 Some(3.0),
417 Some(4.2),
418 Some(4.2),
419 Some(2.0),
420 Some(2.0),
421 Some(f64::INFINITY),
422 Some(f64::INFINITY),
423 ]);
424
425 let left_value = ColumnarValue::Array(Arc::new(left));
426 let right_value = ColumnarValue::Array(Arc::new(right));
427
428 let result = spark_pmod(&[left_value, right_value]).unwrap();
429
430 if let ColumnarValue::Array(result_array) = result {
431 let result_float64 = result_array
432 .as_any()
433 .downcast_ref::<Float64Array>()
434 .unwrap();
435 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); assert!(result_float64.value(4).is_nan());
442 assert!(result_float64.value(5).is_nan());
444 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
446 assert!(result_float64.value(7).is_nan());
448 } else {
449 panic!("Expected array result");
450 }
451 }
452
453 #[test]
454 fn test_pmod_float32() {
455 let left = Float32Array::from(vec![
456 Some(10.5),
457 Some(-7.2),
458 Some(15.8),
459 Some(-15.8),
460 Some(f32::NAN),
461 Some(f32::INFINITY),
462 Some(5.0),
463 Some(-5.0),
464 ]);
465 let right = Float32Array::from(vec![
466 Some(3.0),
467 Some(3.0),
468 Some(4.2),
469 Some(4.2),
470 Some(2.0),
471 Some(2.0),
472 Some(f32::INFINITY),
473 Some(f32::INFINITY),
474 ]);
475
476 let left_value = ColumnarValue::Array(Arc::new(left));
477 let right_value = ColumnarValue::Array(Arc::new(right));
478
479 let result = spark_pmod(&[left_value, right_value]).unwrap();
480
481 if let ColumnarValue::Array(result_array) = result {
482 let result_float32 = result_array
483 .as_any()
484 .downcast_ref::<Float32Array>()
485 .unwrap();
486 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(4).is_nan());
493 assert!(result_float32.value(5).is_nan());
495 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
497 assert!(result_float32.value(7).is_nan());
499 } else {
500 panic!("Expected array result");
501 }
502 }
503
504 #[test]
505 fn test_pmod_scalar() {
506 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
507 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
508
509 let left_value = ColumnarValue::Array(Arc::new(left));
510
511 let result = spark_pmod(&[left_value, right_value]).unwrap();
512
513 if let ColumnarValue::Array(result_array) = result {
514 let result_int32 =
515 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
516 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 0); assert_eq!(result_int32.value(3), 0); } else {
521 panic!("Expected array result");
522 }
523 }
524
525 #[test]
526 fn test_pmod_wrong_arg_count() {
527 let left = Int32Array::from(vec![Some(10)]);
528 let left_value = ColumnarValue::Array(Arc::new(left));
529
530 let result = spark_pmod(&[left_value]);
531 assert!(result.is_err());
532 }
533
534 #[test]
535 fn test_pmod_zero_division() {
536 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
537 let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
538
539 let left_value = ColumnarValue::Array(Arc::new(left));
540 let right_value = ColumnarValue::Array(Arc::new(right));
541
542 let result = spark_pmod(&[left_value, right_value]);
543 assert!(result.is_err()); }
545
546 #[test]
547 fn test_pmod_negative_divisor() {
548 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
550 let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
551
552 let left_value = ColumnarValue::Array(Arc::new(left));
553 let right_value = ColumnarValue::Array(Arc::new(right));
554
555 let result = spark_pmod(&[left_value, right_value]).unwrap();
556
557 if let ColumnarValue::Array(result_array) = result {
558 let result_int32 =
559 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
560 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), -1); assert_eq!(result_int32.value(2), 3); } else {
564 panic!("Expected array result");
565 }
566 }
567
568 #[test]
569 fn test_pmod_edge_cases() {
570 let left = Int32Array::from(vec![
572 Some(0), Some(-1), Some(1), Some(-5), Some(5), Some(-6), Some(6), ]);
580 let right = Int32Array::from(vec![
581 Some(5),
582 Some(5),
583 Some(5),
584 Some(5),
585 Some(5),
586 Some(5),
587 Some(5),
588 ]);
589
590 let left_value = ColumnarValue::Array(Arc::new(left));
591 let right_value = ColumnarValue::Array(Arc::new(right));
592
593 let result = spark_pmod(&[left_value, right_value]).unwrap();
594
595 if let ColumnarValue::Array(result_array) = result {
596 let result_int32 =
597 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
598 assert_eq!(result_int32.value(0), 0); assert_eq!(result_int32.value(1), 4); assert_eq!(result_int32.value(2), 1); assert_eq!(result_int32.value(3), 0); assert_eq!(result_int32.value(4), 0); assert_eq!(result_int32.value(5), 4); assert_eq!(result_int32.value(6), 1); } else {
606 panic!("Expected array result");
607 }
608 }
609}