datafusion_comet_spark_expr/math_funcs/internal/
checkoverflow.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Schema};
19use arrow::{
20    array::{as_primitive_array, Array, ArrayRef, Decimal128Array},
21    datatypes::{Decimal128Type, DecimalType},
22    record_batch::RecordBatch,
23};
24use datafusion::common::{DataFusionError, ScalarValue};
25use datafusion::logical_expr::ColumnarValue;
26use datafusion::physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29    any::Any,
30    fmt::{Display, Formatter},
31    sync::Arc,
32};
33
34/// This is from Spark `CheckOverflow` expression. Spark `CheckOverflow` expression rounds decimals
35/// to given scale and check if the decimals can fit in given precision. As `cast` kernel rounds
36/// decimals already, Comet `CheckOverflow` expression only checks if the decimals can fit in the
37/// precision.
38#[derive(Debug, Eq)]
39pub struct CheckOverflow {
40    pub child: Arc<dyn PhysicalExpr>,
41    pub data_type: DataType,
42    pub fail_on_error: bool,
43}
44
45impl Hash for CheckOverflow {
46    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47        self.child.hash(state);
48        self.data_type.hash(state);
49        self.fail_on_error.hash(state);
50    }
51}
52
53impl PartialEq for CheckOverflow {
54    fn eq(&self, other: &Self) -> bool {
55        self.child.eq(&other.child)
56            && self.data_type.eq(&other.data_type)
57            && self.fail_on_error.eq(&other.fail_on_error)
58    }
59}
60
61impl CheckOverflow {
62    pub fn new(child: Arc<dyn PhysicalExpr>, data_type: DataType, fail_on_error: bool) -> Self {
63        Self {
64            child,
65            data_type,
66            fail_on_error,
67        }
68    }
69}
70
71impl Display for CheckOverflow {
72    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73        write!(
74            f,
75            "CheckOverflow [datatype: {}, fail_on_error: {}, child: {}]",
76            self.data_type, self.fail_on_error, self.child
77        )
78    }
79}
80
81impl PhysicalExpr for CheckOverflow {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
87        unimplemented!()
88    }
89
90    fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
91        Ok(self.data_type.clone())
92    }
93
94    fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
95        Ok(true)
96    }
97
98    fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
99        let arg = self.child.evaluate(batch)?;
100        match arg {
101            ColumnarValue::Array(array)
102                if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
103            {
104                let (precision, scale) = match &self.data_type {
105                    DataType::Decimal128(p, s) => (p, s),
106                    dt => {
107                        return Err(DataFusionError::Execution(format!(
108                            "CheckOverflow expects only Decimal128, but got {dt:?}"
109                        )))
110                    }
111                };
112
113                let decimal_array = as_primitive_array::<Decimal128Type>(&array);
114
115                let casted_array = if self.fail_on_error {
116                    // Returning error if overflow
117                    decimal_array.validate_decimal_precision(*precision)?;
118                    decimal_array
119                } else {
120                    // Overflowing gets null value
121                    &decimal_array.null_if_overflow_precision(*precision)
122                };
123
124                let new_array = Decimal128Array::from(casted_array.into_data())
125                    .with_precision_and_scale(*precision, *scale)
126                    .map(|a| Arc::new(a) as ArrayRef)?;
127
128                Ok(ColumnarValue::Array(new_array))
129            }
130            ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
131                // `fail_on_error` is only true when ANSI is enabled, which we don't support yet
132                // (Java side will simply fallback to Spark when it is enabled)
133                assert!(
134                    !self.fail_on_error,
135                    "fail_on_error (ANSI mode) is not supported yet"
136                );
137
138                let new_v: Option<i128> = v.and_then(|v| {
139                    Decimal128Type::validate_decimal_precision(v, precision)
140                        .map(|_| v)
141                        .ok()
142                });
143
144                Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
145                    new_v, precision, scale,
146                )))
147            }
148            v => Err(DataFusionError::Execution(format!(
149                "CheckOverflow's child expression should be decimal array, but found {v:?}"
150            ))),
151        }
152    }
153
154    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
155        vec![&self.child]
156    }
157
158    fn with_new_children(
159        self: Arc<Self>,
160        children: Vec<Arc<dyn PhysicalExpr>>,
161    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
162        Ok(Arc::new(CheckOverflow::new(
163            Arc::clone(&children[0]),
164            self.data_type.clone(),
165            self.fail_on_error,
166        )))
167    }
168}