1use arrow::compute::kernels::numeric::add;
19use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip};
20use arrow::datatypes::DataType;
21use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
22use datafusion_expr::{
23 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use std::any::Any;
26
27pub fn spark_mod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
30 assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments");
31 let args = ColumnarValue::values_to_arrays(args)?;
32 let result = rem(&args[0], &args[1])?;
33 Ok(ColumnarValue::Array(result))
34}
35
36pub fn spark_pmod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
39 assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments");
40 let args = ColumnarValue::values_to_arrays(args)?;
41 let left = &args[0];
42 let right = &args[1];
43 let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
44 let result = rem(left, right)?;
45 let neg = lt(&result, &zero)?;
46 let plus = zip(&neg, right, &zero)?;
47 let result = add(&plus, &result)?;
48 let result = rem(&result, right)?;
49 Ok(ColumnarValue::Array(result))
50}
51
52#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct SparkMod {
55 signature: Signature,
56}
57
58impl Default for SparkMod {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl SparkMod {
65 pub fn new() -> Self {
66 Self {
67 signature: Signature::numeric(2, Volatility::Immutable),
68 }
69 }
70}
71
72impl ScalarUDFImpl for SparkMod {
73 fn as_any(&self) -> &dyn Any {
74 self
75 }
76
77 fn name(&self) -> &str {
78 "mod"
79 }
80
81 fn signature(&self) -> &Signature {
82 &self.signature
83 }
84
85 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86 assert_eq_or_internal_err!(
87 arg_types.len(),
88 2,
89 "mod expects exactly two arguments"
90 );
91
92 Ok(arg_types[0].clone())
95 }
96
97 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98 spark_mod(&args.args)
99 }
100}
101
102#[derive(Debug, PartialEq, Eq, Hash)]
104pub struct SparkPmod {
105 signature: Signature,
106}
107
108impl Default for SparkPmod {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl SparkPmod {
115 pub fn new() -> Self {
116 Self {
117 signature: Signature::numeric(2, Volatility::Immutable),
118 }
119 }
120}
121
122impl ScalarUDFImpl for SparkPmod {
123 fn as_any(&self) -> &dyn Any {
124 self
125 }
126
127 fn name(&self) -> &str {
128 "pmod"
129 }
130
131 fn signature(&self) -> &Signature {
132 &self.signature
133 }
134
135 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
136 assert_eq_or_internal_err!(
137 arg_types.len(),
138 2,
139 "pmod expects exactly two arguments"
140 );
141
142 Ok(arg_types[0].clone())
145 }
146
147 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
148 spark_pmod(&args.args)
149 }
150}
151
152#[cfg(test)]
153mod test {
154 use std::sync::Arc;
155
156 use super::*;
157 use arrow::array::*;
158 use datafusion_common::ScalarValue;
159
160 #[test]
161 fn test_mod_int32() {
162 let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]);
163 let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]);
164
165 let left_value = ColumnarValue::Array(Arc::new(left));
166 let right_value = ColumnarValue::Array(Arc::new(right));
167
168 let result = spark_mod(&[left_value, right_value]).unwrap();
169
170 if let ColumnarValue::Array(result_array) = result {
171 let result_int32 =
172 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
173 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 3); assert!(result_int32.is_null(3)); } else {
178 panic!("Expected array result");
179 }
180 }
181
182 #[test]
183 fn test_mod_int64() {
184 let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]);
185 let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]);
186
187 let left_value = ColumnarValue::Array(Arc::new(left));
188 let right_value = ColumnarValue::Array(Arc::new(right));
189
190 let result = spark_mod(&[left_value, right_value]).unwrap();
191
192 if let ColumnarValue::Array(result_array) = result {
193 let result_int64 =
194 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
195 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 0); assert_eq!(result_int64.value(2), 20); } else {
199 panic!("Expected array result");
200 }
201 }
202
203 #[test]
204 fn test_mod_float64() {
205 let left = Float64Array::from(vec![
206 Some(10.5),
207 Some(7.2),
208 Some(15.8),
209 Some(f64::NAN),
210 Some(f64::INFINITY),
211 Some(5.0),
212 Some(5.0),
213 Some(f64::NAN),
214 Some(f64::INFINITY),
215 ]);
216 let right = Float64Array::from(vec![
217 Some(3.0),
218 Some(2.5),
219 Some(4.2),
220 Some(2.0),
221 Some(2.0),
222 Some(f64::NAN),
223 Some(f64::INFINITY),
224 Some(f64::INFINITY),
225 Some(f64::NAN),
226 ]);
227
228 let left_value = ColumnarValue::Array(Arc::new(left));
229 let right_value = ColumnarValue::Array(Arc::new(right));
230
231 let result = spark_mod(&[left_value, right_value]).unwrap();
232
233 if let ColumnarValue::Array(result_array) = result {
234 let result_float64 = result_array
235 .as_any()
236 .downcast_ref::<Float64Array>()
237 .unwrap();
238 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); assert!(result_float64.value(3).is_nan());
244 assert!(result_float64.value(4).is_nan());
246 assert!(result_float64.value(5).is_nan());
248 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
250 assert!(result_float64.value(7).is_nan());
252 assert!(result_float64.value(8).is_nan());
254 } else {
255 panic!("Expected array result");
256 }
257 }
258
259 #[test]
260 fn test_mod_float32() {
261 let left = Float32Array::from(vec![
262 Some(10.5),
263 Some(7.2),
264 Some(15.8),
265 Some(f32::NAN),
266 Some(f32::INFINITY),
267 Some(5.0),
268 Some(5.0),
269 Some(f32::NAN),
270 Some(f32::INFINITY),
271 ]);
272 let right = Float32Array::from(vec![
273 Some(3.0),
274 Some(2.5),
275 Some(4.2),
276 Some(2.0),
277 Some(2.0),
278 Some(f32::NAN),
279 Some(f32::INFINITY),
280 Some(f32::INFINITY),
281 Some(f32::NAN),
282 ]);
283
284 let left_value = ColumnarValue::Array(Arc::new(left));
285 let right_value = ColumnarValue::Array(Arc::new(right));
286
287 let result = spark_mod(&[left_value, right_value]).unwrap();
288
289 if let ColumnarValue::Array(result_array) = result {
290 let result_float32 = result_array
291 .as_any()
292 .downcast_ref::<Float32Array>()
293 .unwrap();
294 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(3).is_nan());
300 assert!(result_float32.value(4).is_nan());
302 assert!(result_float32.value(5).is_nan());
304 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON);
306 assert!(result_float32.value(7).is_nan());
308 assert!(result_float32.value(8).is_nan());
310 } else {
311 panic!("Expected array result");
312 }
313 }
314
315 #[test]
316 fn test_mod_scalar() {
317 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
318 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
319
320 let left_value = ColumnarValue::Array(Arc::new(left));
321
322 let result = spark_mod(&[left_value, right_value]).unwrap();
323
324 if let ColumnarValue::Array(result_array) = result {
325 let result_int32 =
326 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
327 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 1); assert_eq!(result_int32.value(2), 0); } else {
331 panic!("Expected array result");
332 }
333 }
334
335 #[test]
336 fn test_mod_wrong_arg_count() {
337 let left = Int32Array::from(vec![Some(10)]);
338 let left_value = ColumnarValue::Array(Arc::new(left));
339
340 let result = spark_mod(&[left_value]);
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn test_mod_zero_division() {
346 let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
347 let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
348
349 let left_value = ColumnarValue::Array(Arc::new(left));
350 let right_value = ColumnarValue::Array(Arc::new(right));
351
352 let result = spark_mod(&[left_value, right_value]);
353 assert!(result.is_err()); }
355
356 #[test]
358 fn test_pmod_int32() {
359 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]);
360 let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]);
361
362 let left_value = ColumnarValue::Array(Arc::new(left));
363 let right_value = ColumnarValue::Array(Arc::new(right));
364
365 let result = spark_pmod(&[left_value, right_value]).unwrap();
366
367 if let ColumnarValue::Array(result_array) = result {
368 let result_int32 =
369 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
370 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 3); assert_eq!(result_int32.value(3), 1); assert!(result_int32.is_null(4)); } else {
376 panic!("Expected array result");
377 }
378 }
379
380 #[test]
381 fn test_pmod_int64() {
382 let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]);
383 let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]);
384
385 let left_value = ColumnarValue::Array(Arc::new(left));
386 let right_value = ColumnarValue::Array(Arc::new(right));
387
388 let result = spark_pmod(&[left_value, right_value]).unwrap();
389
390 if let ColumnarValue::Array(result_array) = result {
391 let result_int64 =
392 result_array.as_any().downcast_ref::<Int64Array>().unwrap();
393 assert_eq!(result_int64.value(0), 10); assert_eq!(result_int64.value(1), 10); assert_eq!(result_int64.value(2), 20); assert_eq!(result_int64.value(3), 40); } else {
398 panic!("Expected array result");
399 }
400 }
401
402 #[test]
403 fn test_pmod_float64() {
404 let left = Float64Array::from(vec![
405 Some(10.5),
406 Some(-7.2),
407 Some(15.8),
408 Some(-15.8),
409 Some(f64::NAN),
410 Some(f64::INFINITY),
411 Some(5.0),
412 Some(-5.0),
413 ]);
414 let right = Float64Array::from(vec![
415 Some(3.0),
416 Some(3.0),
417 Some(4.2),
418 Some(4.2),
419 Some(2.0),
420 Some(2.0),
421 Some(f64::INFINITY),
422 Some(f64::INFINITY),
423 ]);
424
425 let left_value = ColumnarValue::Array(Arc::new(left));
426 let right_value = ColumnarValue::Array(Arc::new(right));
427
428 let result = spark_pmod(&[left_value, right_value]).unwrap();
429
430 if let ColumnarValue::Array(result_array) = result {
431 let result_float64 = result_array
432 .as_any()
433 .downcast_ref::<Float64Array>()
434 .unwrap();
435 assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); assert!(result_float64.value(4).is_nan());
442 assert!(result_float64.value(5).is_nan());
444 assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
446 assert!(result_float64.value(7).is_nan());
448 } else {
449 panic!("Expected array result");
450 }
451 }
452
453 #[test]
454 fn test_pmod_float32() {
455 let left = Float32Array::from(vec![
456 Some(10.5),
457 Some(-7.2),
458 Some(15.8),
459 Some(-15.8),
460 Some(f32::NAN),
461 Some(f32::INFINITY),
462 Some(5.0),
463 Some(-5.0),
464 ]);
465 let right = Float32Array::from(vec![
466 Some(3.0),
467 Some(3.0),
468 Some(4.2),
469 Some(4.2),
470 Some(2.0),
471 Some(2.0),
472 Some(f32::INFINITY),
473 Some(f32::INFINITY),
474 ]);
475
476 let left_value = ColumnarValue::Array(Arc::new(left));
477 let right_value = ColumnarValue::Array(Arc::new(right));
478
479 let result = spark_pmod(&[left_value, right_value]).unwrap();
480
481 if let ColumnarValue::Array(result_array) = result {
482 let result_float32 = result_array
483 .as_any()
484 .downcast_ref::<Float32Array>()
485 .unwrap();
486 assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); assert!(result_float32.value(4).is_nan());
493 assert!(result_float32.value(5).is_nan());
495 assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
497 assert!(result_float32.value(7).is_nan());
499 } else {
500 panic!("Expected array result");
501 }
502 }
503
504 #[test]
505 fn test_pmod_scalar() {
506 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]);
507 let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3)));
508
509 let left_value = ColumnarValue::Array(Arc::new(left));
510
511 let result = spark_pmod(&[left_value, right_value]).unwrap();
512
513 if let ColumnarValue::Array(result_array) = result {
514 let result_int32 =
515 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
516 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), 2); assert_eq!(result_int32.value(2), 0); assert_eq!(result_int32.value(3), 0); } else {
521 panic!("Expected array result");
522 }
523 }
524
525 #[test]
526 fn test_pmod_wrong_arg_count() {
527 let left = Int32Array::from(vec![Some(10)]);
528 let left_value = ColumnarValue::Array(Arc::new(left));
529
530 let result = spark_pmod(&[left_value]);
531 assert!(result.is_err());
532 }
533
534 #[test]
535 fn test_pmod_zero_division() {
536 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
537 let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
538
539 let left_value = ColumnarValue::Array(Arc::new(left));
540 let right_value = ColumnarValue::Array(Arc::new(right));
541
542 let result = spark_pmod(&[left_value, right_value]);
543 assert!(result.is_err()); }
545
546 #[test]
547 fn test_pmod_negative_divisor() {
548 let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
550 let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]);
551
552 let left_value = ColumnarValue::Array(Arc::new(left));
553 let right_value = ColumnarValue::Array(Arc::new(right));
554
555 let result = spark_pmod(&[left_value, right_value]).unwrap();
556
557 if let ColumnarValue::Array(result_array) = result {
558 let result_int32 =
559 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
560 assert_eq!(result_int32.value(0), 1); assert_eq!(result_int32.value(1), -1); assert_eq!(result_int32.value(2), 3); } else {
564 panic!("Expected array result");
565 }
566 }
567
568 #[test]
569 fn test_pmod_edge_cases() {
570 let left = Int32Array::from(vec![
572 Some(0), Some(-1), Some(1), Some(-5), Some(5), Some(-6), Some(6), ]);
580 let right = Int32Array::from(vec![
581 Some(5),
582 Some(5),
583 Some(5),
584 Some(5),
585 Some(5),
586 Some(5),
587 Some(5),
588 ]);
589
590 let left_value = ColumnarValue::Array(Arc::new(left));
591 let right_value = ColumnarValue::Array(Arc::new(right));
592
593 let result = spark_pmod(&[left_value, right_value]).unwrap();
594
595 if let ColumnarValue::Array(result_array) = result {
596 let result_int32 =
597 result_array.as_any().downcast_ref::<Int32Array>().unwrap();
598 assert_eq!(result_int32.value(0), 0); assert_eq!(result_int32.value(1), 4); assert_eq!(result_int32.value(2), 1); assert_eq!(result_int32.value(3), 0); assert_eq!(result_int32.value(4), 0); assert_eq!(result_int32.value(5), 4); assert_eq!(result_int32.value(6), 1); } else {
606 panic!("Expected array result");
607 }
608 }
609}