1use crate::{create_comet_physical_fun, IfExpr};
19use crate::{divide_by_zero_error, Cast, EvalMode, SparkCastOptions};
20use arrow::compute::kernels::numeric::rem;
21use arrow::datatypes::*;
22use datafusion::common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
23use datafusion::execution::FunctionRegistry;
24use datafusion::physical_expr::expressions::{lit, BinaryExpr};
25use datafusion::physical_expr::ScalarFunctionExpr;
26use datafusion::physical_expr_common::datum::{apply, apply_cmp_for_nested};
27use datafusion::{
28 logical_expr::{ColumnarValue, Operator},
29 physical_expr::PhysicalExpr,
30};
31use std::cmp::max;
32use std::sync::Arc;
33
34pub fn spark_modulo(args: &[ColumnarValue], fail_on_error: bool) -> Result<ColumnarValue> {
38 if args.len() != 2 {
39 return exec_err!("modulo expects exactly two arguments");
40 }
41
42 let lhs = &args[0];
43 let rhs = &args[1];
44
45 let left_data_type = lhs.data_type();
46 let right_data_type = rhs.data_type();
47
48 if left_data_type.is_nested() {
49 if right_data_type != left_data_type {
50 return internal_err!("Type mismatch for spark modulo operation");
51 }
52 return apply_cmp_for_nested(Operator::Modulo, lhs, rhs);
53 }
54
55 match apply(lhs, rhs, rem) {
56 Ok(result) => Ok(result),
57 Err(e) if e.to_string().contains("Divide by zero") && fail_on_error => {
58 Err(divide_by_zero_error().into())
60 }
61 Err(e) => Err(e),
62 }
63}
64
65pub fn create_modulo_expr(
66 left: Arc<dyn PhysicalExpr>,
67 right: Arc<dyn PhysicalExpr>,
68 data_type: DataType,
69 input_schema: SchemaRef,
70 fail_on_error: bool,
71 registry: &dyn FunctionRegistry,
72) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
73 let right_non_ansi_safe = if !fail_on_error {
76 null_if_zero_primitive(right, &input_schema)?
77 } else {
78 right
79 };
80
81 match (
85 left.data_type(&input_schema),
86 right_non_ansi_safe.data_type(&input_schema),
87 ) {
88 (Ok(DataType::Decimal128(p1, s1)), Ok(DataType::Decimal128(p2, s2)))
89 if max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) > DECIMAL128_MAX_PRECISION =>
90 {
91 let left_256 = Arc::new(Cast::new(
92 left,
93 DataType::Decimal256(p1, s1),
94 SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
95 ));
96 let right_256 = Arc::new(Cast::new(
97 right_non_ansi_safe,
98 DataType::Decimal256(p2, s2),
99 SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
100 ));
101
102 let modulo_scalar_func = create_modulo_scalar_function(
103 left_256,
104 right_256,
105 &data_type,
106 registry,
107 fail_on_error,
108 )?;
109
110 Ok(Arc::new(Cast::new(
111 modulo_scalar_func,
112 data_type,
113 SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
114 )))
115 }
116 _ => create_modulo_scalar_function(
117 left,
118 right_non_ansi_safe,
119 &data_type,
120 registry,
121 fail_on_error,
122 ),
123 }
124}
125
126fn null_if_zero_primitive(
127 expression: Arc<dyn PhysicalExpr>,
128 input_schema: &Schema,
129) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
130 let expr_data_type = expression.data_type(input_schema)?;
131
132 if is_primitive_datatype(&expr_data_type) {
133 let zero = match expr_data_type {
134 DataType::Int8 => ScalarValue::Int8(Some(0)),
135 DataType::Int16 => ScalarValue::Int16(Some(0)),
136 DataType::Int32 => ScalarValue::Int32(Some(0)),
137 DataType::Int64 => ScalarValue::Int64(Some(0)),
138 DataType::UInt8 => ScalarValue::UInt8(Some(0)),
139 DataType::UInt16 => ScalarValue::UInt16(Some(0)),
140 DataType::UInt32 => ScalarValue::UInt32(Some(0)),
141 DataType::UInt64 => ScalarValue::UInt64(Some(0)),
142 DataType::Float32 => ScalarValue::Float32(Some(0.0)),
143 DataType::Float64 => ScalarValue::Float64(Some(0.0)),
144 DataType::Decimal128(s, p) => ScalarValue::Decimal128(Some(0), s, p),
145 DataType::Decimal256(s, p) => ScalarValue::Decimal256(Some(i256::from(0)), s, p),
146 _ => return Ok(expression),
147 };
148
149 let eq_expr = Arc::new(BinaryExpr::new(
153 Arc::<dyn PhysicalExpr>::clone(&expression),
154 Operator::Eq,
155 lit(zero),
156 ));
157 let null_literal = lit(ScalarValue::try_new_null(&expr_data_type)?);
158 let if_expr = Arc::new(IfExpr::new(eq_expr, null_literal, expression));
159 Ok(if_expr)
160 } else {
161 Ok(expression)
162 }
163}
164
165fn is_primitive_datatype(dt: &DataType) -> bool {
166 matches!(
167 dt,
168 DataType::Int8
169 | DataType::Int16
170 | DataType::Int32
171 | DataType::Int64
172 | DataType::UInt8
173 | DataType::UInt16
174 | DataType::UInt32
175 | DataType::UInt64
176 | DataType::Float32
177 | DataType::Float64
178 | DataType::Decimal128(_, _)
179 | DataType::Decimal256(_, _)
180 )
181}
182
183fn create_modulo_scalar_function(
184 left: Arc<dyn PhysicalExpr>,
185 right: Arc<dyn PhysicalExpr>,
186 data_type: &DataType,
187 registry: &dyn FunctionRegistry,
188 fail_on_error: bool,
189) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
190 let func_name = "spark_modulo";
191 let modulo_expr =
192 create_comet_physical_fun(func_name, data_type.clone(), registry, Some(fail_on_error))?;
193 Ok(Arc::new(ScalarFunctionExpr::new(
194 func_name,
195 modulo_expr,
196 vec![left, right],
197 Arc::new(Field::new(func_name, data_type.clone(), true)),
198 )))
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use arrow::array::{
205 Array, ArrayRef, Decimal128Array, Decimal128Builder, Int32Array, PrimitiveArray,
206 RecordBatch,
207 };
208 use datafusion::logical_expr::ColumnarValue;
209 use datafusion::physical_expr::expressions::{Column, Literal};
210 use datafusion::prelude::SessionContext;
211
212 fn with_fail_on_error<F: Fn(bool)>(test_fn: F) {
213 for fail_on_error in [true, false] {
214 test_fn(fail_on_error);
215 }
216 }
217
218 pub fn verify_result<T>(
219 expr: Arc<dyn PhysicalExpr>,
220 batch: RecordBatch,
221 should_fail: bool,
222 expected_result: Option<Arc<PrimitiveArray<T>>>,
223 ) where
224 T: ArrowPrimitiveType,
225 {
226 let actual_result = expr.evaluate(&batch);
227
228 if should_fail {
229 match actual_result {
230 Err(error) => {
231 assert!(
232 error
233 .to_string()
234 .contains("[DIVIDE_BY_ZERO] Division by zero"),
235 "Error message did not match. Actual message: {error}"
236 );
237 }
238 Ok(value) => {
239 panic!("Expected error, but got: {value:?}");
240 }
241 }
242 } else {
243 match (actual_result, expected_result) {
244 (Ok(ColumnarValue::Array(ref actual)), Some(expected)) => {
245 assert_eq!(actual.len(), expected.len(), "Array length mismatch");
246
247 let actual_arr = actual.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
248 let expected_arr = expected
249 .as_any()
250 .downcast_ref::<PrimitiveArray<T>>()
251 .unwrap();
252
253 for i in 0..actual_arr.len() {
254 assert_eq!(
255 actual_arr.is_null(i),
256 expected_arr.is_null(i),
257 "Nullity mismatch at index {i}"
258 );
259 if !actual_arr.is_null(i) {
260 let actual_value = actual_arr.value(i);
261 let expected_value = expected_arr.value(i);
262 assert_eq!(
263 actual_value, expected_value,
264 "Mismatch at index {i}, actual {actual_value:?}, expected {expected_value:?}"
265 );
266 }
267 }
268 }
269 (actual, expected) => {
270 panic!("Actual: {actual:?}, expected: {expected:?}");
271 }
272 }
273 }
274 }
275
276 #[test]
277 fn test_modulo_basic_int() {
278 with_fail_on_error(|fail_on_error| {
279 let schema = Arc::new(Schema::new(vec![
280 Field::new("a", DataType::Int32, false),
281 Field::new("b", DataType::Int32, false),
282 ]));
283
284 let a_array = Arc::new(Int32Array::from(vec![3, 2, i32::MIN]));
285 let b_array = Arc::new(Int32Array::from(vec![1, 5, -1]));
286 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a_array, b_array]).unwrap();
287
288 let left_expr = Arc::new(Column::new("a", 0));
289 let right_expr = Arc::new(Column::new("b", 1));
290
291 let session_ctx = SessionContext::new();
292 let modulo_expr = create_modulo_expr(
293 left_expr,
294 right_expr,
295 DataType::Int32,
296 schema,
297 fail_on_error,
298 &session_ctx.state(),
299 )
300 .unwrap();
301
302 let should_fail = false;
304 let expected_result = Arc::new(Int32Array::from(vec![0, 2, 0]));
305 verify_result(modulo_expr, batch, should_fail, Some(expected_result));
306 })
307 }
308
309 #[test]
310 fn test_modulo_basic_decimal() {
311 with_fail_on_error(|fail_on_error| {
312 let schema = Arc::new(Schema::new(vec![
313 Field::new("a", DataType::Decimal128(18, 4), false),
314 Field::new("b", DataType::Decimal128(18, 4), false),
315 ]));
316
317 let mut a_builder =
318 Decimal128Builder::with_capacity(2).with_data_type(DataType::Decimal128(18, 4));
319 a_builder.append_value(3000000000000000000);
320 a_builder.append_value(2000000000000000000);
321 let a_array: ArrayRef = Arc::new(a_builder.finish());
322
323 let mut b_builder =
324 Decimal128Builder::with_capacity(2).with_data_type(DataType::Decimal128(18, 4));
325 b_builder.append_value(1000000000000000000);
326 b_builder.append_value(5000000000000000000);
327 let b_array: ArrayRef = Arc::new(b_builder.finish());
328
329 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a_array, b_array]).unwrap();
330
331 let left_expr = Arc::new(Column::new("a", 0));
332 let right_expr = Arc::new(Column::new("b", 1));
333
334 let session_ctx = SessionContext::new();
335 let modulo_expr = create_modulo_expr(
336 left_expr,
337 right_expr,
338 DataType::Decimal128(18, 4),
339 schema,
340 fail_on_error,
341 &session_ctx.state(),
342 )
343 .unwrap();
344
345 let should_fail = false;
347 let expected_result = Arc::new(Decimal128Array::from(vec![
348 Some(0),
349 Some(2000000000000000000),
350 ]));
351 verify_result(modulo_expr, batch, should_fail, Some(expected_result));
352 })
353 }
354
355 #[test]
356 fn test_modulo_divide_by_zero_int() {
357 with_fail_on_error(|fail_on_error| {
358 let schema = Arc::new(Schema::new(vec![
359 Field::new("a", DataType::Int32, false),
360 Field::new("b", DataType::Int32, false),
361 ]));
362
363 let a_array = Arc::new(Int32Array::from(vec![3]));
364 let b_array = Arc::new(Int32Array::from(vec![0]));
365 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a_array, b_array]).unwrap();
366
367 let left_expr = Arc::new(Column::new("a", 0));
368 let right_expr = Arc::new(Column::new("b", 1));
369
370 let session_ctx = SessionContext::new();
371 let modulo_expr = create_modulo_expr(
372 left_expr,
373 right_expr,
374 DataType::Int32,
375 schema,
376 fail_on_error,
377 &session_ctx.state(),
378 )
379 .unwrap();
380
381 let expected_result = Arc::new(Int32Array::from(vec![None]));
383 verify_result(modulo_expr, batch, fail_on_error, Some(expected_result));
384 })
385 }
386
387 #[test]
388 fn test_division_by_zero_with_complex_int_expr() {
389 with_fail_on_error(|fail_on_error| {
390 let schema = Arc::new(Schema::new(vec![
391 Field::new("a", DataType::Int32, false),
392 Field::new("b", DataType::Int32, false),
393 Field::new("c", DataType::Int32, false),
394 ]));
395
396 let a_array = Arc::new(Int32Array::from(vec![3, 0]));
397 let b_array = Arc::new(Int32Array::from(vec![2, 4]));
398 let c_array = Arc::new(Int32Array::from(vec![4, 5]));
399 let batch =
400 RecordBatch::try_new(Arc::clone(&schema), vec![a_array, b_array, c_array]).unwrap();
401
402 let left_expr = Arc::new(BinaryExpr::new(
403 Arc::new(Column::new("a", 0)),
404 Operator::Divide,
405 Arc::new(Column::new("b", 1)),
406 ));
407 let right_expr = Arc::new(BinaryExpr::new(
408 Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
409 Operator::Divide,
410 Arc::new(Column::new("c", 2)),
411 ));
412
413 let session_ctx = SessionContext::new();
415 let modulo_expr = create_modulo_expr(
416 left_expr,
417 right_expr,
418 DataType::Int32,
419 schema,
420 fail_on_error,
421 &session_ctx.state(),
422 )
423 .unwrap();
424
425 let expected_result = Arc::new(Int32Array::from(vec![None, None]));
427 verify_result(modulo_expr, batch, fail_on_error, Some(expected_result));
428 })
429 }
430}