datafusion_comet_spark_expr/math_funcs/internal/
checkoverflow.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Schema};
19use arrow::{
20    array::{as_primitive_array, Array, ArrayRef, Decimal128Array},
21    datatypes::{Decimal128Type, DecimalType},
22    record_batch::RecordBatch,
23};
24use datafusion::common::{DataFusionError, ScalarValue};
25use datafusion::logical_expr::ColumnarValue;
26use datafusion::physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29    any::Any,
30    fmt::{Display, Formatter},
31    sync::Arc,
32};
33
34/// This is from Spark `CheckOverflow` expression. Spark `CheckOverflow` expression rounds decimals
35/// to given scale and check if the decimals can fit in given precision. As `cast` kernel rounds
36/// decimals already, Comet `CheckOverflow` expression only checks if the decimals can fit in the
37/// precision.
38#[derive(Debug, Eq)]
39pub struct CheckOverflow {
40    pub child: Arc<dyn PhysicalExpr>,
41    pub data_type: DataType,
42    pub fail_on_error: bool,
43}
44
45impl Hash for CheckOverflow {
46    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47        self.child.hash(state);
48        self.data_type.hash(state);
49        self.fail_on_error.hash(state);
50    }
51}
52
53impl PartialEq for CheckOverflow {
54    fn eq(&self, other: &Self) -> bool {
55        self.child.eq(&other.child)
56            && self.data_type.eq(&other.data_type)
57            && self.fail_on_error.eq(&other.fail_on_error)
58    }
59}
60
61impl CheckOverflow {
62    pub fn new(child: Arc<dyn PhysicalExpr>, data_type: DataType, fail_on_error: bool) -> Self {
63        Self {
64            child,
65            data_type,
66            fail_on_error,
67        }
68    }
69}
70
71impl Display for CheckOverflow {
72    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73        write!(
74            f,
75            "CheckOverflow [datatype: {}, fail_on_error: {}, child: {}]",
76            self.data_type, self.fail_on_error, self.child
77        )
78    }
79}
80
81impl PhysicalExpr for CheckOverflow {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
87        unimplemented!()
88    }
89
90    fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
91        Ok(self.data_type.clone())
92    }
93
94    fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
95        Ok(true)
96    }
97
98    fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
99        let arg = self.child.evaluate(batch)?;
100        match arg {
101            ColumnarValue::Array(array)
102                if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
103            {
104                let (precision, scale) = match &self.data_type {
105                    DataType::Decimal128(p, s) => (p, s),
106                    dt => {
107                        return Err(DataFusionError::Execution(format!(
108                            "CheckOverflow expects only Decimal128, but got {:?}",
109                            dt
110                        )))
111                    }
112                };
113
114                let decimal_array = as_primitive_array::<Decimal128Type>(&array);
115
116                let casted_array = if self.fail_on_error {
117                    // Returning error if overflow
118                    decimal_array.validate_decimal_precision(*precision)?;
119                    decimal_array
120                } else {
121                    // Overflowing gets null value
122                    &decimal_array.null_if_overflow_precision(*precision)
123                };
124
125                let new_array = Decimal128Array::from(casted_array.into_data())
126                    .with_precision_and_scale(*precision, *scale)
127                    .map(|a| Arc::new(a) as ArrayRef)?;
128
129                Ok(ColumnarValue::Array(new_array))
130            }
131            ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
132                // `fail_on_error` is only true when ANSI is enabled, which we don't support yet
133                // (Java side will simply fallback to Spark when it is enabled)
134                assert!(
135                    !self.fail_on_error,
136                    "fail_on_error (ANSI mode) is not supported yet"
137                );
138
139                let new_v: Option<i128> = v.and_then(|v| {
140                    Decimal128Type::validate_decimal_precision(v, precision)
141                        .map(|_| v)
142                        .ok()
143                });
144
145                Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
146                    new_v, precision, scale,
147                )))
148            }
149            v => Err(DataFusionError::Execution(format!(
150                "CheckOverflow's child expression should be decimal array, but found {:?}",
151                v
152            ))),
153        }
154    }
155
156    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
157        vec![&self.child]
158    }
159
160    fn with_new_children(
161        self: Arc<Self>,
162        children: Vec<Arc<dyn PhysicalExpr>>,
163    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
164        Ok(Arc::new(CheckOverflow::new(
165            Arc::clone(&children[0]),
166            self.data_type.clone(),
167            self.fail_on_error,
168        )))
169    }
170}