datafusion_comet_spark_expr/math_funcs/
div.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::math_funcs::utils::get_precision_scale;
19use arrow::array::{Array, Decimal128Array};
20use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION};
21use arrow::{
22    array::{ArrayRef, AsArray},
23    datatypes::Decimal128Type,
24};
25use datafusion::common::DataFusionError;
26use datafusion::physical_plan::ColumnarValue;
27use num::{BigInt, Signed, ToPrimitive};
28use std::sync::Arc;
29
30pub fn spark_decimal_div(
31    args: &[ColumnarValue],
32    data_type: &DataType,
33) -> Result<ColumnarValue, DataFusionError> {
34    spark_decimal_div_internal(args, data_type, false)
35}
36
37pub fn spark_decimal_integral_div(
38    args: &[ColumnarValue],
39    data_type: &DataType,
40) -> Result<ColumnarValue, DataFusionError> {
41    spark_decimal_div_internal(args, data_type, true)
42}
43
44// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
45// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to
46// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since
47// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale >
48// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt.
49fn spark_decimal_div_internal(
50    args: &[ColumnarValue],
51    data_type: &DataType,
52    is_integral_div: bool,
53) -> Result<ColumnarValue, DataFusionError> {
54    let left = &args[0];
55    let right = &args[1];
56    let (p3, s3) = get_precision_scale(data_type);
57
58    let (left, right): (ArrayRef, ArrayRef) = match (left, right) {
59        (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)),
60        (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
61            (l.to_array_of_size(r.len())?, Arc::clone(r))
62        }
63        (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
64            (Arc::clone(l), r.to_array_of_size(l.len())?)
65        }
66        (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?),
67    };
68    let left = left.as_primitive::<Decimal128Type>();
69    let right = right.as_primitive::<Decimal128Type>();
70    let (p1, s1) = get_precision_scale(left.data_type());
71    let (p2, s2) = get_precision_scale(right.data_type());
72
73    let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
74    let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
75    let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32
76        || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
77    {
78        let ten = BigInt::from(10);
79        let l_mul = ten.pow(l_exp);
80        let r_mul = ten.pow(r_exp);
81        let five = BigInt::from(5);
82        let zero = BigInt::from(0);
83        arrow::compute::kernels::arity::binary(left, right, |l, r| {
84            let l = BigInt::from(l) * &l_mul;
85            let r = BigInt::from(r) * &r_mul;
86            let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
87            let res = if is_integral_div {
88                div
89            } else if div.is_negative() {
90                div - &five
91            } else {
92                div + &five
93            } / &ten;
94            res.to_i128().unwrap_or(i128::MAX)
95        })?
96    } else {
97        let l_mul = 10_i128.pow(l_exp);
98        let r_mul = 10_i128.pow(r_exp);
99        arrow::compute::kernels::arity::binary(left, right, |l, r| {
100            let l = l * l_mul;
101            let r = r * r_mul;
102            let div = if r == 0 { 0 } else { l / r };
103            let res = if is_integral_div {
104                div
105            } else if div.is_negative() {
106                div - 5
107            } else {
108                div + 5
109            } / 10;
110            res.to_i128().unwrap_or(i128::MAX)
111        })?
112    };
113    let result = result.with_data_type(DataType::Decimal128(p3, s3));
114    Ok(ColumnarValue::Array(Arc::new(result)))
115}