datafusion_comet_spark_expr/math_funcs/
negative.rs1use crate::arithmetic_overflow_error;
19use crate::SparkError;
20use arrow::array::RecordBatch;
21use arrow::datatypes::IntervalDayTime;
22use arrow::datatypes::{DataType, Schema};
23use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType};
24use datafusion::common::{DataFusionError, Result, ScalarValue};
25use datafusion::logical_expr::sort_properties::ExprProperties;
26use datafusion::{
27 logical_expr::{interval_arithmetic::Interval, ColumnarValue},
28 physical_expr::PhysicalExpr,
29};
30use std::fmt::Formatter;
31use std::hash::Hash;
32use std::{any::Any, sync::Arc};
33
34pub fn create_negate_expr(
35 expr: Arc<dyn PhysicalExpr>,
36 fail_on_error: bool,
37) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
38 Ok(Arc::new(NegativeExpr::new(expr, fail_on_error)))
39}
40
41#[derive(Debug, Eq)]
43pub struct NegativeExpr {
44 arg: Arc<dyn PhysicalExpr>,
46 fail_on_error: bool,
47}
48
49impl Hash for NegativeExpr {
50 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
51 self.arg.hash(state);
52 self.fail_on_error.hash(state);
53 }
54}
55
56impl PartialEq for NegativeExpr {
57 fn eq(&self, other: &Self) -> bool {
58 self.arg.eq(&other.arg) && self.fail_on_error.eq(&other.fail_on_error)
59 }
60}
61
62macro_rules! check_overflow {
63 ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
64 let typed_array = $array
65 .as_any()
66 .downcast_ref::<$array_type>()
67 .expect(concat!(stringify!($array_type), " expected"));
68 for i in 0..typed_array.len() {
69 if typed_array.value(i) == $min_val {
70 if $type_name == "byte" || $type_name == "short" {
71 let value = format!("{:?} caused", typed_array.value(i));
72 return Err(arithmetic_overflow_error(value.as_str()).into());
73 }
74 return Err(arithmetic_overflow_error($type_name).into());
75 }
76 }
77 }};
78}
79
80impl NegativeExpr {
81 pub fn new(arg: Arc<dyn PhysicalExpr>, fail_on_error: bool) -> Self {
83 Self { arg, fail_on_error }
84 }
85
86 pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
88 &self.arg
89 }
90}
91
92impl std::fmt::Display for NegativeExpr {
93 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
94 write!(f, "(- {})", self.arg)
95 }
96}
97
98impl PhysicalExpr for NegativeExpr {
99 fn as_any(&self) -> &dyn Any {
101 self
102 }
103
104 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
105 self.arg.data_type(input_schema)
106 }
107
108 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
109 self.arg.nullable(input_schema)
110 }
111
112 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
113 let arg = self.arg.evaluate(batch)?;
114
115 match arg {
118 ColumnarValue::Array(array) => {
119 if self.fail_on_error {
120 match array.data_type() {
121 DataType::Int8 => {
122 check_overflow!(array, arrow::array::Int8Array, i8::MIN, "byte")
123 }
124 DataType::Int16 => {
125 check_overflow!(array, arrow::array::Int16Array, i16::MIN, "short")
126 }
127 DataType::Int32 => {
128 check_overflow!(array, arrow::array::Int32Array, i32::MIN, "integer")
129 }
130 DataType::Int64 => {
131 check_overflow!(array, arrow::array::Int64Array, i64::MIN, "long")
132 }
133 DataType::Interval(value) => match value {
134 arrow::datatypes::IntervalUnit::YearMonth => check_overflow!(
135 array,
136 arrow::array::IntervalYearMonthArray,
137 i32::MIN,
138 "interval"
139 ),
140 arrow::datatypes::IntervalUnit::DayTime => check_overflow!(
141 array,
142 arrow::array::IntervalDayTimeArray,
143 IntervalDayTime::MIN,
144 "interval"
145 ),
146 arrow::datatypes::IntervalUnit::MonthDayNano => {
147 }
149 },
150 _ => {
151 }
153 }
154 }
155 let result = neg_wrapping(array.as_ref())?;
156 Ok(ColumnarValue::Array(result))
157 }
158 ColumnarValue::Scalar(scalar) => {
159 if self.fail_on_error {
160 match scalar {
161 ScalarValue::Int8(value) => {
162 if value == Some(i8::MIN) {
163 return Err(arithmetic_overflow_error(" caused").into());
164 }
165 }
166 ScalarValue::Int16(value) => {
167 if value == Some(i16::MIN) {
168 return Err(arithmetic_overflow_error(" caused").into());
169 }
170 }
171 ScalarValue::Int32(value) => {
172 if value == Some(i32::MIN) {
173 return Err(arithmetic_overflow_error("integer").into());
174 }
175 }
176 ScalarValue::Int64(value) => {
177 if value == Some(i64::MIN) {
178 return Err(arithmetic_overflow_error("long").into());
179 }
180 }
181 ScalarValue::IntervalDayTime(value) => {
182 let (days, ms) =
183 IntervalDayTimeType::to_parts(value.unwrap_or_default());
184 if days == i32::MIN || ms == i32::MIN {
185 return Err(arithmetic_overflow_error("interval").into());
186 }
187 }
188 ScalarValue::IntervalYearMonth(value) => {
189 if value == Some(i32::MIN) {
190 return Err(arithmetic_overflow_error("interval").into());
191 }
192 }
193 _ => {
194 }
196 }
197 }
198 Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?))
199 }
200 }
201 }
202
203 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
204 vec![&self.arg]
205 }
206
207 fn with_new_children(
208 self: Arc<Self>,
209 children: Vec<Arc<dyn PhysicalExpr>>,
210 ) -> Result<Arc<dyn PhysicalExpr>> {
211 Ok(Arc::new(NegativeExpr::new(
212 Arc::clone(&children[0]),
213 self.fail_on_error,
214 )))
215 }
216
217 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
221 Interval::try_new(
222 children[0].upper().arithmetic_negate()?,
223 children[0].lower().arithmetic_negate()?,
224 )
225 }
226
227 fn propagate_constraints(
230 &self,
231 interval: &Interval,
232 children: &[&Interval],
233 ) -> Result<Option<Vec<Interval>>> {
234 let child_interval = children[0];
235
236 if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN))
237 || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN))
238 || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN))
239 || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN))
240 {
241 return Err(SparkError::ArithmeticOverflow {
242 from_type: "long".to_string(),
243 }
244 .into());
245 }
246
247 let negated_interval = Interval::try_new(
248 interval.upper().arithmetic_negate()?,
249 interval.lower().arithmetic_negate()?,
250 )?;
251
252 Ok(child_interval
253 .intersect(negated_interval)?
254 .map(|result| vec![result]))
255 }
256
257 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
259 let properties = children[0].clone().with_order(children[0].sort_properties);
260 Ok(properties)
261 }
262
263 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
264 unimplemented!()
265 }
266}