Skip to main content

datafusion_spark/function/math/
negative.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::types::*;
19use arrow::array::*;
20use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit};
21use bigdecimal::num_traits::WrappingNeg;
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
24use datafusion_expr::{
25    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
26    Volatility,
27};
28use std::any::Any;
29use std::sync::Arc;
30
31/// Spark-compatible `negative` expression
32/// <https://spark.apache.org/docs/latest/api/sql/index.html#negative>
33///
34/// Returns the negation of input (equivalent to unary minus)
35/// Returns NULL if input is NULL, returns NaN if input is NaN.
36///
37/// ANSI mode support:
38///  - When ANSI mode is disabled (`spark.sql.ansi.enabled=false`), negating the minimal
39///    value of a signed integer wraps around. For example: negative(i32::MIN) returns
40///    i32::MIN (wraps instead of error).
41///  - When ANSI mode is enabled (`spark.sql.ansi.enabled=true`), overflow conditions
42///    throw an ARITHMETIC_OVERFLOW error instead of wrapping.
43///
44#[derive(Debug, PartialEq, Eq, Hash)]
45pub struct SparkNegative {
46    signature: Signature,
47}
48
49impl Default for SparkNegative {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl SparkNegative {
56    pub fn new() -> Self {
57        Self {
58            signature: Signature {
59                type_signature: TypeSignature::OneOf(vec![
60                    // Numeric types: signed integers, float, decimals
61                    TypeSignature::Numeric(1),
62                    // Interval types: YearMonth, DayTime, MonthDayNano
63                    TypeSignature::Uniform(
64                        1,
65                        vec![
66                            DataType::Interval(IntervalUnit::YearMonth),
67                            DataType::Interval(IntervalUnit::DayTime),
68                            DataType::Interval(IntervalUnit::MonthDayNano),
69                        ],
70                    ),
71                ]),
72                volatility: Volatility::Immutable,
73                parameter_names: None,
74            },
75        }
76    }
77}
78
79impl ScalarUDFImpl for SparkNegative {
80    fn as_any(&self) -> &dyn Any {
81        self
82    }
83
84    fn name(&self) -> &str {
85        "negative"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93        Ok(arg_types[0].clone())
94    }
95
96    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
97        spark_negative(&args.args, args.config_options.execution.enable_ansi_mode)
98    }
99}
100
101/// Macro to implement negation for integer array types
102macro_rules! impl_integer_array_negative {
103    ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{
104        let array = $array.as_primitive::<$type>();
105        let result: PrimitiveArray<$type> = if $enable_ansi_mode {
106            array.try_unary(|x| {
107                x.checked_neg().ok_or_else(|| {
108                    (exec_err!("{} overflow on negative({x})", $type_name)
109                        as Result<(), _>)
110                        .unwrap_err()
111                })
112            })?
113        } else {
114            array.unary(|x| x.wrapping_neg())
115        };
116        Ok(ColumnarValue::Array(Arc::new(result)))
117    }};
118}
119
120/// Macro to implement negation for float array types
121macro_rules! impl_float_array_negative {
122    ($array:expr, $type:ty) => {{
123        let array = $array.as_primitive::<$type>();
124        let result: PrimitiveArray<$type> = array.unary(|x| -x);
125        Ok(ColumnarValue::Array(Arc::new(result)))
126    }};
127}
128
129/// Macro to implement negation for decimal array types
130macro_rules! impl_decimal_array_negative {
131    ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{
132        let array = $array.as_primitive::<$type>();
133        let result: PrimitiveArray<$type> = if $enable_ansi_mode {
134            array
135                .try_unary(|x| {
136                    x.checked_neg().ok_or_else(|| {
137                        (exec_err!("{} overflow on negative({x})", $type_name)
138                            as Result<(), _>)
139                            .unwrap_err()
140                    })
141                })?
142                .with_data_type(array.data_type().clone())
143        } else {
144            array.unary(|x| x.wrapping_neg())
145        };
146        Ok(ColumnarValue::Array(Arc::new(result)))
147    }};
148}
149
150/// Macro to implement negation for integer scalar types
151macro_rules! impl_integer_scalar_negative {
152    ($v:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{
153        let result = if $enable_ansi_mode {
154            $v.checked_neg().ok_or_else(|| {
155                (exec_err!("{} overflow on negative({})", $type_name, $v)
156                    as Result<(), _>)
157                    .unwrap_err()
158            })?
159        } else {
160            $v.wrapping_neg()
161        };
162        Ok(ColumnarValue::Scalar(ScalarValue::$variant(Some(result))))
163    }};
164}
165
166/// Macro to implement negation for decimal scalar types
167macro_rules! impl_decimal_scalar_negative {
168    ($v:expr, $precision:expr, $scale:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{
169        let result = if $enable_ansi_mode {
170            $v.checked_neg().ok_or_else(|| {
171                (exec_err!("{} overflow on negative({})", $type_name, $v)
172                    as Result<(), _>)
173                    .unwrap_err()
174            })?
175        } else {
176            $v.wrapping_neg()
177        };
178        Ok(ColumnarValue::Scalar(ScalarValue::$variant(
179            Some(result),
180            *$precision,
181            *$scale,
182        )))
183    }};
184}
185
186/// Core implementation of Spark's negative function
187fn spark_negative(
188    args: &[ColumnarValue],
189    enable_ansi_mode: bool,
190) -> Result<ColumnarValue> {
191    let [arg] = take_function_args("negative", args)?;
192
193    match arg {
194        ColumnarValue::Array(array) => match array.data_type() {
195            DataType::Null => Ok(arg.clone()),
196
197            // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode
198            DataType::Int8 => {
199                impl_integer_array_negative!(array, Int8Type, "Int8", enable_ansi_mode)
200            }
201            DataType::Int16 => {
202                impl_integer_array_negative!(array, Int16Type, "Int16", enable_ansi_mode)
203            }
204            DataType::Int32 => {
205                impl_integer_array_negative!(array, Int32Type, "Int32", enable_ansi_mode)
206            }
207            DataType::Int64 => {
208                impl_integer_array_negative!(array, Int64Type, "Int64", enable_ansi_mode)
209            }
210
211            // Floating point - simple negation (no overflow possible)
212            DataType::Float16 => impl_float_array_negative!(array, Float16Type),
213            DataType::Float32 => impl_float_array_negative!(array, Float32Type),
214            DataType::Float64 => impl_float_array_negative!(array, Float64Type),
215
216            // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode
217            DataType::Decimal32(_, _) => impl_decimal_array_negative!(
218                array,
219                Decimal32Type,
220                "Decimal32",
221                enable_ansi_mode
222            ),
223            DataType::Decimal64(_, _) => impl_decimal_array_negative!(
224                array,
225                Decimal64Type,
226                "Decimal64",
227                enable_ansi_mode
228            ),
229            DataType::Decimal128(_, _) => impl_decimal_array_negative!(
230                array,
231                Decimal128Type,
232                "Decimal128",
233                enable_ansi_mode
234            ),
235            DataType::Decimal256(_, _) => impl_decimal_array_negative!(
236                array,
237                Decimal256Type,
238                "Decimal256",
239                enable_ansi_mode
240            ),
241
242            // interval type - use checked negation in ANSI mode, wrapping in legacy mode
243            DataType::Interval(IntervalUnit::YearMonth) => {
244                impl_integer_array_negative!(
245                    array,
246                    IntervalYearMonthType,
247                    "IntervalYearMonth",
248                    enable_ansi_mode
249                )
250            }
251            DataType::Interval(IntervalUnit::DayTime) => {
252                let array = array.as_primitive::<IntervalDayTimeType>();
253                let result: PrimitiveArray<IntervalDayTimeType> = if enable_ansi_mode {
254                    array.try_unary(|x| {
255                        let days = x.days.checked_neg().ok_or_else(|| {
256                            (exec_err!(
257                                "IntervalDayTime overflow on negative (days: {})",
258                                x.days
259                            ) as Result<(), _>)
260                                .unwrap_err()
261                        })?;
262                        let milliseconds =
263                            x.milliseconds.checked_neg().ok_or_else(|| {
264                                (exec_err!(
265                                "IntervalDayTime overflow on negative (milliseconds: {})",
266                                x.milliseconds
267                            ) as Result<(), _>)
268                                .unwrap_err()
269                            })?;
270                        Ok::<_, arrow::error::ArrowError>(IntervalDayTime {
271                            days,
272                            milliseconds,
273                        })
274                    })?
275                } else {
276                    array.unary(|x| IntervalDayTime {
277                        days: x.days.wrapping_neg(),
278                        milliseconds: x.milliseconds.wrapping_neg(),
279                    })
280                };
281                Ok(ColumnarValue::Array(Arc::new(result)))
282            }
283            DataType::Interval(IntervalUnit::MonthDayNano) => {
284                let array = array.as_primitive::<IntervalMonthDayNanoType>();
285                let result: PrimitiveArray<IntervalMonthDayNanoType> = if enable_ansi_mode
286                {
287                    array.try_unary(|x| {
288                        let months = x.months.checked_neg().ok_or_else(|| {
289                            (exec_err!(
290                                "IntervalMonthDayNano overflow on negative (months: {})",
291                                x.months
292                            ) as Result<(), _>)
293                                .unwrap_err()
294                        })?;
295                        let days = x.days.checked_neg().ok_or_else(|| {
296                            (exec_err!(
297                                "IntervalMonthDayNano overflow on negative (days: {})",
298                                x.days
299                            ) as Result<(), _>)
300                                .unwrap_err()
301                        })?;
302                        let nanoseconds = x.nanoseconds.checked_neg().ok_or_else(|| {
303                            (exec_err!(
304                                "IntervalMonthDayNano overflow on negative (nanoseconds: {})",
305                                x.nanoseconds
306                            ) as Result<(), _>)
307                                .unwrap_err()
308                        })?;
309                        Ok::<_, arrow::error::ArrowError>(IntervalMonthDayNano {
310                            months,
311                            days,
312                            nanoseconds,
313                        })
314                    })?
315                } else {
316                    array.unary(|x| IntervalMonthDayNano {
317                        months: x.months.wrapping_neg(),
318                        days: x.days.wrapping_neg(),
319                        nanoseconds: x.nanoseconds.wrapping_neg(),
320                    })
321                };
322                Ok(ColumnarValue::Array(Arc::new(result)))
323            }
324
325            dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
326        },
327        ColumnarValue::Scalar(sv) => match sv {
328            ScalarValue::Null => Ok(arg.clone()),
329            _ if sv.is_null() => Ok(arg.clone()),
330
331            // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode
332            ScalarValue::Int8(Some(v)) => {
333                impl_integer_scalar_negative!(v, "Int8", Int8, enable_ansi_mode)
334            }
335            ScalarValue::Int16(Some(v)) => {
336                impl_integer_scalar_negative!(v, "Int16", Int16, enable_ansi_mode)
337            }
338            ScalarValue::Int32(Some(v)) => {
339                impl_integer_scalar_negative!(v, "Int32", Int32, enable_ansi_mode)
340            }
341            ScalarValue::Int64(Some(v)) => {
342                impl_integer_scalar_negative!(v, "Int64", Int64, enable_ansi_mode)
343            }
344
345            // Floating point - simple negation
346            ScalarValue::Float16(Some(v)) => {
347                Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v))))
348            }
349            ScalarValue::Float32(Some(v)) => {
350                Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v))))
351            }
352            ScalarValue::Float64(Some(v)) => {
353                Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v))))
354            }
355
356            // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode
357            ScalarValue::Decimal32(Some(v), precision, scale) => {
358                impl_decimal_scalar_negative!(
359                    v,
360                    precision,
361                    scale,
362                    "Decimal32",
363                    Decimal32,
364                    enable_ansi_mode
365                )
366            }
367            ScalarValue::Decimal64(Some(v), precision, scale) => {
368                impl_decimal_scalar_negative!(
369                    v,
370                    precision,
371                    scale,
372                    "Decimal64",
373                    Decimal64,
374                    enable_ansi_mode
375                )
376            }
377            ScalarValue::Decimal128(Some(v), precision, scale) => {
378                impl_decimal_scalar_negative!(
379                    v,
380                    precision,
381                    scale,
382                    "Decimal128",
383                    Decimal128,
384                    enable_ansi_mode
385                )
386            }
387            ScalarValue::Decimal256(Some(v), precision, scale) => {
388                impl_decimal_scalar_negative!(
389                    v,
390                    precision,
391                    scale,
392                    "Decimal256",
393                    Decimal256,
394                    enable_ansi_mode
395                )
396            }
397
398            //interval type - use checked negation in ANSI mode, wrapping in legacy mode
399            ScalarValue::IntervalYearMonth(Some(v)) => {
400                impl_integer_scalar_negative!(
401                    v,
402                    "IntervalYearMonth",
403                    IntervalYearMonth,
404                    enable_ansi_mode
405                )
406            }
407            ScalarValue::IntervalDayTime(Some(v)) => {
408                let result = if enable_ansi_mode {
409                    let days = v.days.checked_neg().ok_or_else(|| {
410                        (exec_err!(
411                            "IntervalDayTime overflow on negative (days: {})",
412                            v.days
413                        ) as Result<(), _>)
414                            .unwrap_err()
415                    })?;
416                    let milliseconds = v.milliseconds.checked_neg().ok_or_else(|| {
417                        (exec_err!(
418                            "IntervalDayTime overflow on negative (milliseconds: {})",
419                            v.milliseconds
420                        ) as Result<(), _>)
421                            .unwrap_err()
422                    })?;
423                    IntervalDayTime { days, milliseconds }
424                } else {
425                    IntervalDayTime {
426                        days: v.days.wrapping_neg(),
427                        milliseconds: v.milliseconds.wrapping_neg(),
428                    }
429                };
430                Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(
431                    result,
432                ))))
433            }
434            ScalarValue::IntervalMonthDayNano(Some(v)) => {
435                let result = if enable_ansi_mode {
436                    let months = v.months.checked_neg().ok_or_else(|| {
437                        (exec_err!(
438                            "IntervalMonthDayNano overflow on negative (months: {})",
439                            v.months
440                        ) as Result<(), _>)
441                            .unwrap_err()
442                    })?;
443                    let days = v.days.checked_neg().ok_or_else(|| {
444                        (exec_err!(
445                            "IntervalMonthDayNano overflow on negative (days: {})",
446                            v.days
447                        ) as Result<(), _>)
448                            .unwrap_err()
449                    })?;
450                    let nanoseconds = v.nanoseconds.checked_neg().ok_or_else(|| {
451                        (exec_err!(
452                            "IntervalMonthDayNano overflow on negative (nanoseconds: {})",
453                            v.nanoseconds
454                        ) as Result<(), _>)
455                            .unwrap_err()
456                    })?;
457                    IntervalMonthDayNano {
458                        months,
459                        days,
460                        nanoseconds,
461                    }
462                } else {
463                    IntervalMonthDayNano {
464                        months: v.months.wrapping_neg(),
465                        days: v.days.wrapping_neg(),
466                        nanoseconds: v.nanoseconds.wrapping_neg(),
467                    }
468                };
469                Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(
470                    Some(result),
471                )))
472            }
473
474            dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
475        },
476    }
477}