datafusion_comet_spark_expr/math_funcs/
div.rs1use crate::math_funcs::utils::get_precision_scale;
19use arrow::array::{Array, Decimal128Array};
20use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION};
21use arrow::{
22 array::{ArrayRef, AsArray},
23 datatypes::Decimal128Type,
24};
25use datafusion::common::DataFusionError;
26use datafusion::physical_plan::ColumnarValue;
27use num::{BigInt, Signed, ToPrimitive};
28use std::sync::Arc;
29
30pub fn spark_decimal_div(
31 args: &[ColumnarValue],
32 data_type: &DataType,
33) -> Result<ColumnarValue, DataFusionError> {
34 spark_decimal_div_internal(args, data_type, false)
35}
36
37pub fn spark_decimal_integral_div(
38 args: &[ColumnarValue],
39 data_type: &DataType,
40) -> Result<ColumnarValue, DataFusionError> {
41 spark_decimal_div_internal(args, data_type, true)
42}
43
44fn spark_decimal_div_internal(
50 args: &[ColumnarValue],
51 data_type: &DataType,
52 is_integral_div: bool,
53) -> Result<ColumnarValue, DataFusionError> {
54 let left = &args[0];
55 let right = &args[1];
56 let (p3, s3) = get_precision_scale(data_type);
57
58 let (left, right): (ArrayRef, ArrayRef) = match (left, right) {
59 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)),
60 (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
61 (l.to_array_of_size(r.len())?, Arc::clone(r))
62 }
63 (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
64 (Arc::clone(l), r.to_array_of_size(l.len())?)
65 }
66 (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?),
67 };
68 let left = left.as_primitive::<Decimal128Type>();
69 let right = right.as_primitive::<Decimal128Type>();
70 let (p1, s1) = get_precision_scale(left.data_type());
71 let (p2, s2) = get_precision_scale(right.data_type());
72
73 let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
74 let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
75 let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32
76 || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
77 {
78 let ten = BigInt::from(10);
79 let l_mul = ten.pow(l_exp);
80 let r_mul = ten.pow(r_exp);
81 let five = BigInt::from(5);
82 let zero = BigInt::from(0);
83 arrow::compute::kernels::arity::binary(left, right, |l, r| {
84 let l = BigInt::from(l) * &l_mul;
85 let r = BigInt::from(r) * &r_mul;
86 let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
87 let res = if is_integral_div {
88 div
89 } else if div.is_negative() {
90 div - &five
91 } else {
92 div + &five
93 } / &ten;
94 res.to_i128().unwrap_or(i128::MAX)
95 })?
96 } else {
97 let l_mul = 10_i128.pow(l_exp);
98 let r_mul = 10_i128.pow(r_exp);
99 arrow::compute::kernels::arity::binary(left, right, |l, r| {
100 let l = l * l_mul;
101 let r = r * r_mul;
102 let div = if r == 0 { 0 } else { l / r };
103 let res = if is_integral_div {
104 div
105 } else if div.is_negative() {
106 div - 5
107 } else {
108 div + 5
109 } / 10;
110 res.to_i128().unwrap_or(i128::MAX)
111 })?
112 };
113 let result = result.with_data_type(DataType::Decimal128(p3, s3));
114 Ok(ColumnarValue::Array(Arc::new(result)))
115}