datafusion_comet_spark_expr/math_funcs/
negative.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::arithmetic_overflow_error;
19use crate::SparkError;
20use arrow::array::RecordBatch;
21use arrow::datatypes::IntervalDayTime;
22use arrow::datatypes::{DataType, Schema};
23use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType};
24use datafusion::common::{DataFusionError, Result, ScalarValue};
25use datafusion::logical_expr::sort_properties::ExprProperties;
26use datafusion::{
27    logical_expr::{interval_arithmetic::Interval, ColumnarValue},
28    physical_expr::PhysicalExpr,
29};
30use std::fmt::Formatter;
31use std::hash::Hash;
32use std::{any::Any, sync::Arc};
33
34pub fn create_negate_expr(
35    expr: Arc<dyn PhysicalExpr>,
36    fail_on_error: bool,
37) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
38    Ok(Arc::new(NegativeExpr::new(expr, fail_on_error)))
39}
40
41/// Negative expression
42#[derive(Debug, Eq)]
43pub struct NegativeExpr {
44    /// Input expression
45    arg: Arc<dyn PhysicalExpr>,
46    fail_on_error: bool,
47}
48
49impl Hash for NegativeExpr {
50    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
51        self.arg.hash(state);
52        self.fail_on_error.hash(state);
53    }
54}
55
56impl PartialEq for NegativeExpr {
57    fn eq(&self, other: &Self) -> bool {
58        self.arg.eq(&other.arg) && self.fail_on_error.eq(&other.fail_on_error)
59    }
60}
61
62macro_rules! check_overflow {
63    ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
64        let typed_array = $array
65            .as_any()
66            .downcast_ref::<$array_type>()
67            .expect(concat!(stringify!($array_type), " expected"));
68        for i in 0..typed_array.len() {
69            if typed_array.value(i) == $min_val {
70                if $type_name == "byte" || $type_name == "short" {
71                    let value = format!("{:?} caused", typed_array.value(i));
72                    return Err(arithmetic_overflow_error(value.as_str()).into());
73                }
74                return Err(arithmetic_overflow_error($type_name).into());
75            }
76        }
77    }};
78}
79
80impl NegativeExpr {
81    /// Create new not expression
82    pub fn new(arg: Arc<dyn PhysicalExpr>, fail_on_error: bool) -> Self {
83        Self { arg, fail_on_error }
84    }
85
86    /// Get the input expression
87    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
88        &self.arg
89    }
90}
91
92impl std::fmt::Display for NegativeExpr {
93    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
94        write!(f, "(- {})", self.arg)
95    }
96}
97
98impl PhysicalExpr for NegativeExpr {
99    /// Return a reference to Any that can be used for downcasting
100    fn as_any(&self) -> &dyn Any {
101        self
102    }
103
104    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
105        self.arg.data_type(input_schema)
106    }
107
108    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
109        self.arg.nullable(input_schema)
110    }
111
112    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
113        let arg = self.arg.evaluate(batch)?;
114
115        // overflow checks only apply in ANSI mode
116        // datatypes supported are byte, short, integer, long, float, interval
117        match arg {
118            ColumnarValue::Array(array) => {
119                if self.fail_on_error {
120                    match array.data_type() {
121                        DataType::Int8 => {
122                            check_overflow!(array, arrow::array::Int8Array, i8::MIN, "byte")
123                        }
124                        DataType::Int16 => {
125                            check_overflow!(array, arrow::array::Int16Array, i16::MIN, "short")
126                        }
127                        DataType::Int32 => {
128                            check_overflow!(array, arrow::array::Int32Array, i32::MIN, "integer")
129                        }
130                        DataType::Int64 => {
131                            check_overflow!(array, arrow::array::Int64Array, i64::MIN, "long")
132                        }
133                        DataType::Interval(value) => match value {
134                            arrow::datatypes::IntervalUnit::YearMonth => check_overflow!(
135                                array,
136                                arrow::array::IntervalYearMonthArray,
137                                i32::MIN,
138                                "interval"
139                            ),
140                            arrow::datatypes::IntervalUnit::DayTime => check_overflow!(
141                                array,
142                                arrow::array::IntervalDayTimeArray,
143                                IntervalDayTime::MIN,
144                                "interval"
145                            ),
146                            arrow::datatypes::IntervalUnit::MonthDayNano => {
147                                // Overflow checks are not supported
148                            }
149                        },
150                        _ => {
151                            // Overflow checks are not supported for other datatypes
152                        }
153                    }
154                }
155                let result = neg_wrapping(array.as_ref())?;
156                Ok(ColumnarValue::Array(result))
157            }
158            ColumnarValue::Scalar(scalar) => {
159                if self.fail_on_error {
160                    match scalar {
161                        ScalarValue::Int8(value) => {
162                            if value == Some(i8::MIN) {
163                                return Err(arithmetic_overflow_error(" caused").into());
164                            }
165                        }
166                        ScalarValue::Int16(value) => {
167                            if value == Some(i16::MIN) {
168                                return Err(arithmetic_overflow_error(" caused").into());
169                            }
170                        }
171                        ScalarValue::Int32(value) => {
172                            if value == Some(i32::MIN) {
173                                return Err(arithmetic_overflow_error("integer").into());
174                            }
175                        }
176                        ScalarValue::Int64(value) => {
177                            if value == Some(i64::MIN) {
178                                return Err(arithmetic_overflow_error("long").into());
179                            }
180                        }
181                        ScalarValue::IntervalDayTime(value) => {
182                            let (days, ms) =
183                                IntervalDayTimeType::to_parts(value.unwrap_or_default());
184                            if days == i32::MIN || ms == i32::MIN {
185                                return Err(arithmetic_overflow_error("interval").into());
186                            }
187                        }
188                        ScalarValue::IntervalYearMonth(value) => {
189                            if value == Some(i32::MIN) {
190                                return Err(arithmetic_overflow_error("interval").into());
191                            }
192                        }
193                        _ => {
194                            // Overflow checks are not supported for other datatypes
195                        }
196                    }
197                }
198                Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?))
199            }
200        }
201    }
202
203    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
204        vec![&self.arg]
205    }
206
207    fn with_new_children(
208        self: Arc<Self>,
209        children: Vec<Arc<dyn PhysicalExpr>>,
210    ) -> Result<Arc<dyn PhysicalExpr>> {
211        Ok(Arc::new(NegativeExpr::new(
212            Arc::clone(&children[0]),
213            self.fail_on_error,
214        )))
215    }
216
217    /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval.
218    /// It replaces the upper and lower bounds after multiplying them with -1.
219    /// Ex: `(a, b]` => `[-b, -a)`
220    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
221        Interval::try_new(
222            children[0].upper().arithmetic_negate()?,
223            children[0].lower().arithmetic_negate()?,
224        )
225    }
226
227    /// Returns a new [`Interval`] of a NegativeExpr  that has the existing `interval` given that
228    /// given the input interval is known to be `children`.
229    fn propagate_constraints(
230        &self,
231        interval: &Interval,
232        children: &[&Interval],
233    ) -> Result<Option<Vec<Interval>>> {
234        let child_interval = children[0];
235
236        if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN))
237            || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN))
238            || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN))
239            || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN))
240        {
241            return Err(SparkError::ArithmeticOverflow {
242                from_type: "long".to_string(),
243            }
244            .into());
245        }
246
247        let negated_interval = Interval::try_new(
248            interval.upper().arithmetic_negate()?,
249            interval.lower().arithmetic_negate()?,
250        )?;
251
252        Ok(child_interval
253            .intersect(negated_interval)?
254            .map(|result| vec![result]))
255    }
256
257    /// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
258    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
259        let properties = children[0].clone().with_order(children[0].sort_properties);
260        Ok(properties)
261    }
262
263    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
264        unimplemented!()
265    }
266}