datafusion_comet_spark_expr/math_funcs/
negative.rs1use crate::arithmetic_overflow_error;
19use crate::SparkError;
20use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType};
21use arrow_array::RecordBatch;
22use arrow_buffer::IntervalDayTime;
23use arrow_schema::{DataType, Schema};
24use datafusion::{
25 logical_expr::{interval_arithmetic::Interval, ColumnarValue},
26 physical_expr::PhysicalExpr,
27};
28use datafusion_common::{DataFusionError, Result, ScalarValue};
29use datafusion_expr::sort_properties::ExprProperties;
30use std::hash::Hash;
31use std::{any::Any, sync::Arc};
32
33pub fn create_negate_expr(
34 expr: Arc<dyn PhysicalExpr>,
35 fail_on_error: bool,
36) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
37 Ok(Arc::new(NegativeExpr::new(expr, fail_on_error)))
38}
39
40#[derive(Debug, Eq)]
42pub struct NegativeExpr {
43 arg: Arc<dyn PhysicalExpr>,
45 fail_on_error: bool,
46}
47
48impl Hash for NegativeExpr {
49 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50 self.arg.hash(state);
51 self.fail_on_error.hash(state);
52 }
53}
54
55impl PartialEq for NegativeExpr {
56 fn eq(&self, other: &Self) -> bool {
57 self.arg.eq(&other.arg) && self.fail_on_error.eq(&other.fail_on_error)
58 }
59}
60
61macro_rules! check_overflow {
62 ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
63 let typed_array = $array
64 .as_any()
65 .downcast_ref::<$array_type>()
66 .expect(concat!(stringify!($array_type), " expected"));
67 for i in 0..typed_array.len() {
68 if typed_array.value(i) == $min_val {
69 if $type_name == "byte" || $type_name == "short" {
70 let value = format!("{:?} caused", typed_array.value(i));
71 return Err(arithmetic_overflow_error(value.as_str()).into());
72 }
73 return Err(arithmetic_overflow_error($type_name).into());
74 }
75 }
76 }};
77}
78
79impl NegativeExpr {
80 pub fn new(arg: Arc<dyn PhysicalExpr>, fail_on_error: bool) -> Self {
82 Self { arg, fail_on_error }
83 }
84
85 pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
87 &self.arg
88 }
89}
90
91impl std::fmt::Display for NegativeExpr {
92 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
93 write!(f, "(- {})", self.arg)
94 }
95}
96
97impl PhysicalExpr for NegativeExpr {
98 fn as_any(&self) -> &dyn Any {
100 self
101 }
102
103 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
104 self.arg.data_type(input_schema)
105 }
106
107 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
108 self.arg.nullable(input_schema)
109 }
110
111 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
112 let arg = self.arg.evaluate(batch)?;
113
114 match arg {
117 ColumnarValue::Array(array) => {
118 if self.fail_on_error {
119 match array.data_type() {
120 DataType::Int8 => {
121 check_overflow!(array, arrow::array::Int8Array, i8::MIN, "byte")
122 }
123 DataType::Int16 => {
124 check_overflow!(array, arrow::array::Int16Array, i16::MIN, "short")
125 }
126 DataType::Int32 => {
127 check_overflow!(array, arrow::array::Int32Array, i32::MIN, "integer")
128 }
129 DataType::Int64 => {
130 check_overflow!(array, arrow::array::Int64Array, i64::MIN, "long")
131 }
132 DataType::Interval(value) => match value {
133 arrow::datatypes::IntervalUnit::YearMonth => check_overflow!(
134 array,
135 arrow::array::IntervalYearMonthArray,
136 i32::MIN,
137 "interval"
138 ),
139 arrow::datatypes::IntervalUnit::DayTime => check_overflow!(
140 array,
141 arrow::array::IntervalDayTimeArray,
142 IntervalDayTime::MIN,
143 "interval"
144 ),
145 arrow::datatypes::IntervalUnit::MonthDayNano => {
146 }
148 },
149 _ => {
150 }
152 }
153 }
154 let result = neg_wrapping(array.as_ref())?;
155 Ok(ColumnarValue::Array(result))
156 }
157 ColumnarValue::Scalar(scalar) => {
158 if self.fail_on_error {
159 match scalar {
160 ScalarValue::Int8(value) => {
161 if value == Some(i8::MIN) {
162 return Err(arithmetic_overflow_error(" caused").into());
163 }
164 }
165 ScalarValue::Int16(value) => {
166 if value == Some(i16::MIN) {
167 return Err(arithmetic_overflow_error(" caused").into());
168 }
169 }
170 ScalarValue::Int32(value) => {
171 if value == Some(i32::MIN) {
172 return Err(arithmetic_overflow_error("integer").into());
173 }
174 }
175 ScalarValue::Int64(value) => {
176 if value == Some(i64::MIN) {
177 return Err(arithmetic_overflow_error("long").into());
178 }
179 }
180 ScalarValue::IntervalDayTime(value) => {
181 let (days, ms) =
182 IntervalDayTimeType::to_parts(value.unwrap_or_default());
183 if days == i32::MIN || ms == i32::MIN {
184 return Err(arithmetic_overflow_error("interval").into());
185 }
186 }
187 ScalarValue::IntervalYearMonth(value) => {
188 if value == Some(i32::MIN) {
189 return Err(arithmetic_overflow_error("interval").into());
190 }
191 }
192 _ => {
193 }
195 }
196 }
197 Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?))
198 }
199 }
200 }
201
202 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
203 vec![&self.arg]
204 }
205
206 fn with_new_children(
207 self: Arc<Self>,
208 children: Vec<Arc<dyn PhysicalExpr>>,
209 ) -> Result<Arc<dyn PhysicalExpr>> {
210 Ok(Arc::new(NegativeExpr::new(
211 Arc::clone(&children[0]),
212 self.fail_on_error,
213 )))
214 }
215
216 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
220 Interval::try_new(
221 children[0].upper().arithmetic_negate()?,
222 children[0].lower().arithmetic_negate()?,
223 )
224 }
225
226 fn propagate_constraints(
229 &self,
230 interval: &Interval,
231 children: &[&Interval],
232 ) -> Result<Option<Vec<Interval>>> {
233 let child_interval = children[0];
234
235 if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN))
236 || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN))
237 || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN))
238 || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN))
239 {
240 return Err(SparkError::ArithmeticOverflow {
241 from_type: "long".to_string(),
242 }
243 .into());
244 }
245
246 let negated_interval = Interval::try_new(
247 interval.upper().arithmetic_negate()?,
248 interval.lower().arithmetic_negate()?,
249 )?;
250
251 Ok(child_interval
252 .intersect(negated_interval)?
253 .map(|result| vec![result]))
254 }
255
256 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
258 let properties = children[0].clone().with_order(children[0].sort_properties);
259 Ok(properties)
260 }
261}