datafusion_comet_spark_expr/math_funcs/
negative.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::arithmetic_overflow_error;
19use crate::SparkError;
20use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType};
21use arrow_array::RecordBatch;
22use arrow_buffer::IntervalDayTime;
23use arrow_schema::{DataType, Schema};
24use datafusion::{
25    logical_expr::{interval_arithmetic::Interval, ColumnarValue},
26    physical_expr::PhysicalExpr,
27};
28use datafusion_common::{DataFusionError, Result, ScalarValue};
29use datafusion_expr::sort_properties::ExprProperties;
30use std::hash::Hash;
31use std::{any::Any, sync::Arc};
32
33pub fn create_negate_expr(
34    expr: Arc<dyn PhysicalExpr>,
35    fail_on_error: bool,
36) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
37    Ok(Arc::new(NegativeExpr::new(expr, fail_on_error)))
38}
39
40/// Negative expression
41#[derive(Debug, Eq)]
42pub struct NegativeExpr {
43    /// Input expression
44    arg: Arc<dyn PhysicalExpr>,
45    fail_on_error: bool,
46}
47
48impl Hash for NegativeExpr {
49    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50        self.arg.hash(state);
51        self.fail_on_error.hash(state);
52    }
53}
54
55impl PartialEq for NegativeExpr {
56    fn eq(&self, other: &Self) -> bool {
57        self.arg.eq(&other.arg) && self.fail_on_error.eq(&other.fail_on_error)
58    }
59}
60
61macro_rules! check_overflow {
62    ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
63        let typed_array = $array
64            .as_any()
65            .downcast_ref::<$array_type>()
66            .expect(concat!(stringify!($array_type), " expected"));
67        for i in 0..typed_array.len() {
68            if typed_array.value(i) == $min_val {
69                if $type_name == "byte" || $type_name == "short" {
70                    let value = format!("{:?} caused", typed_array.value(i));
71                    return Err(arithmetic_overflow_error(value.as_str()).into());
72                }
73                return Err(arithmetic_overflow_error($type_name).into());
74            }
75        }
76    }};
77}
78
79impl NegativeExpr {
80    /// Create new not expression
81    pub fn new(arg: Arc<dyn PhysicalExpr>, fail_on_error: bool) -> Self {
82        Self { arg, fail_on_error }
83    }
84
85    /// Get the input expression
86    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
87        &self.arg
88    }
89}
90
91impl std::fmt::Display for NegativeExpr {
92    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
93        write!(f, "(- {})", self.arg)
94    }
95}
96
97impl PhysicalExpr for NegativeExpr {
98    /// Return a reference to Any that can be used for downcasting
99    fn as_any(&self) -> &dyn Any {
100        self
101    }
102
103    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
104        self.arg.data_type(input_schema)
105    }
106
107    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
108        self.arg.nullable(input_schema)
109    }
110
111    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
112        let arg = self.arg.evaluate(batch)?;
113
114        // overflow checks only apply in ANSI mode
115        // datatypes supported are byte, short, integer, long, float, interval
116        match arg {
117            ColumnarValue::Array(array) => {
118                if self.fail_on_error {
119                    match array.data_type() {
120                        DataType::Int8 => {
121                            check_overflow!(array, arrow::array::Int8Array, i8::MIN, "byte")
122                        }
123                        DataType::Int16 => {
124                            check_overflow!(array, arrow::array::Int16Array, i16::MIN, "short")
125                        }
126                        DataType::Int32 => {
127                            check_overflow!(array, arrow::array::Int32Array, i32::MIN, "integer")
128                        }
129                        DataType::Int64 => {
130                            check_overflow!(array, arrow::array::Int64Array, i64::MIN, "long")
131                        }
132                        DataType::Interval(value) => match value {
133                            arrow::datatypes::IntervalUnit::YearMonth => check_overflow!(
134                                array,
135                                arrow::array::IntervalYearMonthArray,
136                                i32::MIN,
137                                "interval"
138                            ),
139                            arrow::datatypes::IntervalUnit::DayTime => check_overflow!(
140                                array,
141                                arrow::array::IntervalDayTimeArray,
142                                IntervalDayTime::MIN,
143                                "interval"
144                            ),
145                            arrow::datatypes::IntervalUnit::MonthDayNano => {
146                                // Overflow checks are not supported
147                            }
148                        },
149                        _ => {
150                            // Overflow checks are not supported for other datatypes
151                        }
152                    }
153                }
154                let result = neg_wrapping(array.as_ref())?;
155                Ok(ColumnarValue::Array(result))
156            }
157            ColumnarValue::Scalar(scalar) => {
158                if self.fail_on_error {
159                    match scalar {
160                        ScalarValue::Int8(value) => {
161                            if value == Some(i8::MIN) {
162                                return Err(arithmetic_overflow_error(" caused").into());
163                            }
164                        }
165                        ScalarValue::Int16(value) => {
166                            if value == Some(i16::MIN) {
167                                return Err(arithmetic_overflow_error(" caused").into());
168                            }
169                        }
170                        ScalarValue::Int32(value) => {
171                            if value == Some(i32::MIN) {
172                                return Err(arithmetic_overflow_error("integer").into());
173                            }
174                        }
175                        ScalarValue::Int64(value) => {
176                            if value == Some(i64::MIN) {
177                                return Err(arithmetic_overflow_error("long").into());
178                            }
179                        }
180                        ScalarValue::IntervalDayTime(value) => {
181                            let (days, ms) =
182                                IntervalDayTimeType::to_parts(value.unwrap_or_default());
183                            if days == i32::MIN || ms == i32::MIN {
184                                return Err(arithmetic_overflow_error("interval").into());
185                            }
186                        }
187                        ScalarValue::IntervalYearMonth(value) => {
188                            if value == Some(i32::MIN) {
189                                return Err(arithmetic_overflow_error("interval").into());
190                            }
191                        }
192                        _ => {
193                            // Overflow checks are not supported for other datatypes
194                        }
195                    }
196                }
197                Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?))
198            }
199        }
200    }
201
202    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
203        vec![&self.arg]
204    }
205
206    fn with_new_children(
207        self: Arc<Self>,
208        children: Vec<Arc<dyn PhysicalExpr>>,
209    ) -> Result<Arc<dyn PhysicalExpr>> {
210        Ok(Arc::new(NegativeExpr::new(
211            Arc::clone(&children[0]),
212            self.fail_on_error,
213        )))
214    }
215
216    /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval.
217    /// It replaces the upper and lower bounds after multiplying them with -1.
218    /// Ex: `(a, b]` => `[-b, -a)`
219    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
220        Interval::try_new(
221            children[0].upper().arithmetic_negate()?,
222            children[0].lower().arithmetic_negate()?,
223        )
224    }
225
226    /// Returns a new [`Interval`] of a NegativeExpr  that has the existing `interval` given that
227    /// given the input interval is known to be `children`.
228    fn propagate_constraints(
229        &self,
230        interval: &Interval,
231        children: &[&Interval],
232    ) -> Result<Option<Vec<Interval>>> {
233        let child_interval = children[0];
234
235        if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN))
236            || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN))
237            || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN))
238            || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN))
239        {
240            return Err(SparkError::ArithmeticOverflow {
241                from_type: "long".to_string(),
242            }
243            .into());
244        }
245
246        let negated_interval = Interval::try_new(
247            interval.upper().arithmetic_negate()?,
248            interval.lower().arithmetic_negate()?,
249        )?;
250
251        Ok(child_interval
252            .intersect(negated_interval)?
253            .map(|result| vec![result]))
254    }
255
256    /// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
257    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
258        let properties = children[0].clone().with_order(children[0].sort_properties);
259        Ok(properties)
260    }
261}