datafusion_comet_spark_expr/math_funcs/
modulo_expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{create_comet_physical_fun, IfExpr};
19use crate::{divide_by_zero_error, Cast, EvalMode, SparkCastOptions};
20use arrow::compute::kernels::numeric::rem;
21use arrow::datatypes::*;
22use datafusion::common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
23use datafusion::execution::FunctionRegistry;
24use datafusion::physical_expr::expressions::{lit, BinaryExpr};
25use datafusion::physical_expr::ScalarFunctionExpr;
26use datafusion::physical_expr_common::datum::{apply, apply_cmp_for_nested};
27use datafusion::{
28    logical_expr::{ColumnarValue, Operator},
29    physical_expr::PhysicalExpr,
30};
31use std::cmp::max;
32use std::sync::Arc;
33
34/// Spark-compliant modulo function. If `fail_on_error` is true, then this function computes modulo
35/// in ANSI mode and returns an error on division by zero, otherwise it returns `NULL` for such
36/// cases.
37pub fn spark_modulo(args: &[ColumnarValue], fail_on_error: bool) -> Result<ColumnarValue> {
38    if args.len() != 2 {
39        return exec_err!("modulo expects exactly two arguments");
40    }
41
42    let lhs = &args[0];
43    let rhs = &args[1];
44
45    let left_data_type = lhs.data_type();
46    let right_data_type = rhs.data_type();
47
48    if left_data_type.is_nested() {
49        if right_data_type != left_data_type {
50            return internal_err!("Type mismatch for spark modulo operation");
51        }
52        return apply_cmp_for_nested(Operator::Modulo, lhs, rhs);
53    }
54
55    match apply(lhs, rhs, rem) {
56        Ok(result) => Ok(result),
57        Err(e) if e.to_string().contains("Divide by zero") && fail_on_error => {
58            // Return Spark-compliant divide by zero error.
59            Err(divide_by_zero_error().into())
60        }
61        Err(e) => Err(e),
62    }
63}
64
65pub fn create_modulo_expr(
66    left: Arc<dyn PhysicalExpr>,
67    right: Arc<dyn PhysicalExpr>,
68    data_type: DataType,
69    input_schema: SchemaRef,
70    fail_on_error: bool,
71    registry: &dyn FunctionRegistry,
72) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
73    // For non-ANSI mode, wrap the right expression such that any zero value is replaced with `NULL`
74    // to prevent divide by zero error.
75    let right_non_ansi_safe = if !fail_on_error {
76        null_if_zero_primitive(right, &input_schema)?
77    } else {
78        right
79    };
80
81    // If the data type is `Decimal128` and the (scale + integral part) exceeds the maximum allowed
82    // for `Decimal128`, then cast both operands to `Decimal256` before creating the modulo scalar
83    // expression, otherwise, create the modulo scalar expression directly.
84    match (
85        left.data_type(&input_schema),
86        right_non_ansi_safe.data_type(&input_schema),
87    ) {
88        (Ok(DataType::Decimal128(p1, s1)), Ok(DataType::Decimal128(p2, s2)))
89            if max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) > DECIMAL128_MAX_PRECISION =>
90        {
91            let left_256 = Arc::new(Cast::new(
92                left,
93                DataType::Decimal256(p1, s1),
94                SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
95            ));
96            let right_256 = Arc::new(Cast::new(
97                right_non_ansi_safe,
98                DataType::Decimal256(p2, s2),
99                SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
100            ));
101
102            let modulo_scalar_func = create_modulo_scalar_function(
103                left_256,
104                right_256,
105                &data_type,
106                registry,
107                fail_on_error,
108            )?;
109
110            Ok(Arc::new(Cast::new(
111                modulo_scalar_func,
112                data_type,
113                SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
114            )))
115        }
116        _ => create_modulo_scalar_function(
117            left,
118            right_non_ansi_safe,
119            &data_type,
120            registry,
121            fail_on_error,
122        ),
123    }
124}
125
126fn null_if_zero_primitive(
127    expression: Arc<dyn PhysicalExpr>,
128    input_schema: &Schema,
129) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
130    let expr_data_type = expression.data_type(input_schema)?;
131
132    if is_primitive_datatype(&expr_data_type) {
133        let zero = match expr_data_type {
134            DataType::Int8 => ScalarValue::Int8(Some(0)),
135            DataType::Int16 => ScalarValue::Int16(Some(0)),
136            DataType::Int32 => ScalarValue::Int32(Some(0)),
137            DataType::Int64 => ScalarValue::Int64(Some(0)),
138            DataType::UInt8 => ScalarValue::UInt8(Some(0)),
139            DataType::UInt16 => ScalarValue::UInt16(Some(0)),
140            DataType::UInt32 => ScalarValue::UInt32(Some(0)),
141            DataType::UInt64 => ScalarValue::UInt64(Some(0)),
142            DataType::Float32 => ScalarValue::Float32(Some(0.0)),
143            DataType::Float64 => ScalarValue::Float64(Some(0.0)),
144            DataType::Decimal128(s, p) => ScalarValue::Decimal128(Some(0), s, p),
145            DataType::Decimal256(s, p) => ScalarValue::Decimal256(Some(i256::from(0)), s, p),
146            _ => return Ok(expression),
147        };
148
149        // Create an expression like - `if (eval(expr) == Literal(0)) then NULL else eval(expr)`.
150        // This expression evaluates to null for rows with zero values to prevent divide by zero
151        // error.
152        let eq_expr = Arc::new(BinaryExpr::new(
153            Arc::<dyn PhysicalExpr>::clone(&expression),
154            Operator::Eq,
155            lit(zero),
156        ));
157        let null_literal = lit(ScalarValue::try_new_null(&expr_data_type)?);
158        let if_expr = Arc::new(IfExpr::new(eq_expr, null_literal, expression));
159        Ok(if_expr)
160    } else {
161        Ok(expression)
162    }
163}
164
165fn is_primitive_datatype(dt: &DataType) -> bool {
166    matches!(
167        dt,
168        DataType::Int8
169            | DataType::Int16
170            | DataType::Int32
171            | DataType::Int64
172            | DataType::UInt8
173            | DataType::UInt16
174            | DataType::UInt32
175            | DataType::UInt64
176            | DataType::Float32
177            | DataType::Float64
178            | DataType::Decimal128(_, _)
179            | DataType::Decimal256(_, _)
180    )
181}
182
183fn create_modulo_scalar_function(
184    left: Arc<dyn PhysicalExpr>,
185    right: Arc<dyn PhysicalExpr>,
186    data_type: &DataType,
187    registry: &dyn FunctionRegistry,
188    fail_on_error: bool,
189) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
190    let func_name = "spark_modulo";
191    let modulo_expr =
192        create_comet_physical_fun(func_name, data_type.clone(), registry, Some(fail_on_error))?;
193    Ok(Arc::new(ScalarFunctionExpr::new(
194        func_name,
195        modulo_expr,
196        vec![left, right],
197        Arc::new(Field::new(func_name, data_type.clone(), true)),
198    )))
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use arrow::array::{
205        Array, ArrayRef, Decimal128Array, Decimal128Builder, Int32Array, PrimitiveArray,
206        RecordBatch,
207    };
208    use datafusion::logical_expr::ColumnarValue;
209    use datafusion::physical_expr::expressions::{Column, Literal};
210    use datafusion::prelude::SessionContext;
211
212    fn with_fail_on_error<F: Fn(bool)>(test_fn: F) {
213        for fail_on_error in [true, false] {
214            test_fn(fail_on_error);
215        }
216    }
217
218    pub fn verify_result<T>(
219        expr: Arc<dyn PhysicalExpr>,
220        batch: RecordBatch,
221        should_fail: bool,
222        expected_result: Option<Arc<PrimitiveArray<T>>>,
223    ) where
224        T: ArrowPrimitiveType,
225    {
226        let actual_result = expr.evaluate(&batch);
227
228        if should_fail {
229            match actual_result {
230                Err(error) => {
231                    assert!(
232                        error
233                            .to_string()
234                            .contains("[DIVIDE_BY_ZERO] Division by zero"),
235                        "Error message did not match. Actual message: {error}"
236                    );
237                }
238                Ok(value) => {
239                    panic!("Expected error, but got: {value:?}");
240                }
241            }
242        } else {
243            match (actual_result, expected_result) {
244                (Ok(ColumnarValue::Array(ref actual)), Some(expected)) => {
245                    assert_eq!(actual.len(), expected.len(), "Array length mismatch");
246
247                    let actual_arr = actual.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
248                    let expected_arr = expected
249                        .as_any()
250                        .downcast_ref::<PrimitiveArray<T>>()
251                        .unwrap();
252
253                    for i in 0..actual_arr.len() {
254                        assert_eq!(
255                            actual_arr.is_null(i),
256                            expected_arr.is_null(i),
257                            "Nullity mismatch at index {i}"
258                        );
259                        if !actual_arr.is_null(i) {
260                            let actual_value = actual_arr.value(i);
261                            let expected_value = expected_arr.value(i);
262                            assert_eq!(
263                                actual_value, expected_value,
264                                "Mismatch at index {i}, actual {actual_value:?}, expected {expected_value:?}"
265                            );
266                        }
267                    }
268                }
269                (actual, expected) => {
270                    panic!("Actual: {actual:?}, expected: {expected:?}");
271                }
272            }
273        }
274    }
275
276    #[test]
277    fn test_modulo_basic_int() {
278        with_fail_on_error(|fail_on_error| {
279            let schema = Arc::new(Schema::new(vec![
280                Field::new("a", DataType::Int32, false),
281                Field::new("b", DataType::Int32, false),
282            ]));
283
284            let a_array = Arc::new(Int32Array::from(vec![3, 2, i32::MIN]));
285            let b_array = Arc::new(Int32Array::from(vec![1, 5, -1]));
286            let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a_array, b_array]).unwrap();
287
288            let left_expr = Arc::new(Column::new("a", 0));
289            let right_expr = Arc::new(Column::new("b", 1));
290
291            let session_ctx = SessionContext::new();
292            let modulo_expr = create_modulo_expr(
293                left_expr,
294                right_expr,
295                DataType::Int32,
296                schema,
297                fail_on_error,
298                &session_ctx.state(),
299            )
300            .unwrap();
301
302            // This test case should not fail as there is no division by zero.
303            let should_fail = false;
304            let expected_result = Arc::new(Int32Array::from(vec![0, 2, 0]));
305            verify_result(modulo_expr, batch, should_fail, Some(expected_result));
306        })
307    }
308
309    #[test]
310    fn test_modulo_basic_decimal() {
311        with_fail_on_error(|fail_on_error| {
312            let schema = Arc::new(Schema::new(vec![
313                Field::new("a", DataType::Decimal128(18, 4), false),
314                Field::new("b", DataType::Decimal128(18, 4), false),
315            ]));
316
317            let mut a_builder =
318                Decimal128Builder::with_capacity(2).with_data_type(DataType::Decimal128(18, 4));
319            a_builder.append_value(3000000000000000000);
320            a_builder.append_value(2000000000000000000);
321            let a_array: ArrayRef = Arc::new(a_builder.finish());
322
323            let mut b_builder =
324                Decimal128Builder::with_capacity(2).with_data_type(DataType::Decimal128(18, 4));
325            b_builder.append_value(1000000000000000000);
326            b_builder.append_value(5000000000000000000);
327            let b_array: ArrayRef = Arc::new(b_builder.finish());
328
329            let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a_array, b_array]).unwrap();
330
331            let left_expr = Arc::new(Column::new("a", 0));
332            let right_expr = Arc::new(Column::new("b", 1));
333
334            let session_ctx = SessionContext::new();
335            let modulo_expr = create_modulo_expr(
336                left_expr,
337                right_expr,
338                DataType::Decimal128(18, 4),
339                schema,
340                fail_on_error,
341                &session_ctx.state(),
342            )
343            .unwrap();
344
345            // This test case should not fail as there is no division by zero.
346            let should_fail = false;
347            let expected_result = Arc::new(Decimal128Array::from(vec![
348                Some(0),
349                Some(2000000000000000000),
350            ]));
351            verify_result(modulo_expr, batch, should_fail, Some(expected_result));
352        })
353    }
354
355    #[test]
356    fn test_modulo_divide_by_zero_int() {
357        with_fail_on_error(|fail_on_error| {
358            let schema = Arc::new(Schema::new(vec![
359                Field::new("a", DataType::Int32, false),
360                Field::new("b", DataType::Int32, false),
361            ]));
362
363            let a_array = Arc::new(Int32Array::from(vec![3]));
364            let b_array = Arc::new(Int32Array::from(vec![0]));
365            let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a_array, b_array]).unwrap();
366
367            let left_expr = Arc::new(Column::new("a", 0));
368            let right_expr = Arc::new(Column::new("b", 1));
369
370            let session_ctx = SessionContext::new();
371            let modulo_expr = create_modulo_expr(
372                left_expr,
373                right_expr,
374                DataType::Int32,
375                schema,
376                fail_on_error,
377                &session_ctx.state(),
378            )
379            .unwrap();
380
381            // Expected result in non-ANSI mode.
382            let expected_result = Arc::new(Int32Array::from(vec![None]));
383            verify_result(modulo_expr, batch, fail_on_error, Some(expected_result));
384        })
385    }
386
387    #[test]
388    fn test_division_by_zero_with_complex_int_expr() {
389        with_fail_on_error(|fail_on_error| {
390            let schema = Arc::new(Schema::new(vec![
391                Field::new("a", DataType::Int32, false),
392                Field::new("b", DataType::Int32, false),
393                Field::new("c", DataType::Int32, false),
394            ]));
395
396            let a_array = Arc::new(Int32Array::from(vec![3, 0]));
397            let b_array = Arc::new(Int32Array::from(vec![2, 4]));
398            let c_array = Arc::new(Int32Array::from(vec![4, 5]));
399            let batch =
400                RecordBatch::try_new(Arc::clone(&schema), vec![a_array, b_array, c_array]).unwrap();
401
402            let left_expr = Arc::new(BinaryExpr::new(
403                Arc::new(Column::new("a", 0)),
404                Operator::Divide,
405                Arc::new(Column::new("b", 1)),
406            ));
407            let right_expr = Arc::new(BinaryExpr::new(
408                Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
409                Operator::Divide,
410                Arc::new(Column::new("c", 2)),
411            ));
412
413            // Computes modulo of (a / b) % (0 / c).
414            let session_ctx = SessionContext::new();
415            let modulo_expr = create_modulo_expr(
416                left_expr,
417                right_expr,
418                DataType::Int32,
419                schema,
420                fail_on_error,
421                &session_ctx.state(),
422            )
423            .unwrap();
424
425            // Expected result in non-ANSI mode.
426            let expected_result = Arc::new(Int32Array::from(vec![None, None]));
427            verify_result(modulo_expr, batch, fail_on_error, Some(expected_result));
428        })
429    }
430}