Skip to main content

datafusion_spark/function/math/
negative.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::types::*;
19use arrow::array::*;
20use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit};
21use bigdecimal::num_traits::WrappingNeg;
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
24use datafusion_expr::{
25    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
26    Volatility,
27};
28use std::sync::Arc;
29
30/// Spark-compatible `negative` expression
31/// <https://spark.apache.org/docs/latest/api/sql/index.html#negative>
32///
33/// Returns the negation of input (equivalent to unary minus)
34/// Returns NULL if input is NULL, returns NaN if input is NaN.
35///
36/// ANSI mode support:
37///  - When ANSI mode is disabled (`spark.sql.ansi.enabled=false`), negating the minimal
38///    value of a signed integer wraps around. For example: negative(i32::MIN) returns
39///    i32::MIN (wraps instead of error).
40///  - When ANSI mode is enabled (`spark.sql.ansi.enabled=true`), overflow conditions
41///    throw an ARITHMETIC_OVERFLOW error instead of wrapping.
42///
43#[derive(Debug, PartialEq, Eq, Hash)]
44pub struct SparkNegative {
45    signature: Signature,
46}
47
48impl Default for SparkNegative {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl SparkNegative {
55    pub fn new() -> Self {
56        Self {
57            signature: Signature {
58                type_signature: TypeSignature::OneOf(vec![
59                    // Numeric types: signed integers, float, decimals
60                    TypeSignature::Numeric(1),
61                    // Interval types: YearMonth, DayTime, MonthDayNano
62                    TypeSignature::Uniform(
63                        1,
64                        vec![
65                            DataType::Interval(IntervalUnit::YearMonth),
66                            DataType::Interval(IntervalUnit::DayTime),
67                            DataType::Interval(IntervalUnit::MonthDayNano),
68                        ],
69                    ),
70                ]),
71                volatility: Volatility::Immutable,
72                parameter_names: None,
73            },
74        }
75    }
76}
77
78impl ScalarUDFImpl for SparkNegative {
79    fn name(&self) -> &str {
80        "negative"
81    }
82
83    fn signature(&self) -> &Signature {
84        &self.signature
85    }
86
87    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
88        Ok(arg_types[0].clone())
89    }
90
91    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92        spark_negative(&args.args, args.config_options.execution.enable_ansi_mode)
93    }
94}
95
96/// Macro to implement negation for integer array types
97macro_rules! impl_integer_array_negative {
98    ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{
99        let array = $array.as_primitive::<$type>();
100        let result: PrimitiveArray<$type> = if $enable_ansi_mode {
101            array.try_unary(|x| {
102                x.checked_neg().ok_or_else(|| {
103                    (exec_err!("{} overflow on negative({x})", $type_name)
104                        as Result<(), _>)
105                        .unwrap_err()
106                })
107            })?
108        } else {
109            array.unary(|x| x.wrapping_neg())
110        };
111        Ok(ColumnarValue::Array(Arc::new(result)))
112    }};
113}
114
115/// Macro to implement negation for float array types
116macro_rules! impl_float_array_negative {
117    ($array:expr, $type:ty) => {{
118        let array = $array.as_primitive::<$type>();
119        let result: PrimitiveArray<$type> = array.unary(|x| -x);
120        Ok(ColumnarValue::Array(Arc::new(result)))
121    }};
122}
123
124/// Macro to implement negation for decimal array types
125macro_rules! impl_decimal_array_negative {
126    ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{
127        let array = $array.as_primitive::<$type>();
128        let result: PrimitiveArray<$type> = if $enable_ansi_mode {
129            array
130                .try_unary(|x| {
131                    x.checked_neg().ok_or_else(|| {
132                        (exec_err!("{} overflow on negative({x})", $type_name)
133                            as Result<(), _>)
134                            .unwrap_err()
135                    })
136                })?
137                .with_data_type(array.data_type().clone())
138        } else {
139            array.unary(|x| x.wrapping_neg())
140        };
141        Ok(ColumnarValue::Array(Arc::new(result)))
142    }};
143}
144
145/// Macro to implement negation for integer scalar types
146macro_rules! impl_integer_scalar_negative {
147    ($v:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{
148        let result = if $enable_ansi_mode {
149            $v.checked_neg().ok_or_else(|| {
150                (exec_err!("{} overflow on negative({})", $type_name, $v)
151                    as Result<(), _>)
152                    .unwrap_err()
153            })?
154        } else {
155            $v.wrapping_neg()
156        };
157        Ok(ColumnarValue::Scalar(ScalarValue::$variant(Some(result))))
158    }};
159}
160
161/// Macro to implement negation for decimal scalar types
162macro_rules! impl_decimal_scalar_negative {
163    ($v:expr, $precision:expr, $scale:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{
164        let result = if $enable_ansi_mode {
165            $v.checked_neg().ok_or_else(|| {
166                (exec_err!("{} overflow on negative({})", $type_name, $v)
167                    as Result<(), _>)
168                    .unwrap_err()
169            })?
170        } else {
171            $v.wrapping_neg()
172        };
173        Ok(ColumnarValue::Scalar(ScalarValue::$variant(
174            Some(result),
175            *$precision,
176            *$scale,
177        )))
178    }};
179}
180
181/// Core implementation of Spark's negative function
182fn spark_negative(
183    args: &[ColumnarValue],
184    enable_ansi_mode: bool,
185) -> Result<ColumnarValue> {
186    let [arg] = take_function_args("negative", args)?;
187
188    match arg {
189        ColumnarValue::Array(array) => match array.data_type() {
190            DataType::Null => Ok(arg.clone()),
191
192            // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode
193            DataType::Int8 => {
194                impl_integer_array_negative!(array, Int8Type, "Int8", enable_ansi_mode)
195            }
196            DataType::Int16 => {
197                impl_integer_array_negative!(array, Int16Type, "Int16", enable_ansi_mode)
198            }
199            DataType::Int32 => {
200                impl_integer_array_negative!(array, Int32Type, "Int32", enable_ansi_mode)
201            }
202            DataType::Int64 => {
203                impl_integer_array_negative!(array, Int64Type, "Int64", enable_ansi_mode)
204            }
205
206            // Floating point - simple negation (no overflow possible)
207            DataType::Float16 => impl_float_array_negative!(array, Float16Type),
208            DataType::Float32 => impl_float_array_negative!(array, Float32Type),
209            DataType::Float64 => impl_float_array_negative!(array, Float64Type),
210
211            // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode
212            DataType::Decimal32(_, _) => impl_decimal_array_negative!(
213                array,
214                Decimal32Type,
215                "Decimal32",
216                enable_ansi_mode
217            ),
218            DataType::Decimal64(_, _) => impl_decimal_array_negative!(
219                array,
220                Decimal64Type,
221                "Decimal64",
222                enable_ansi_mode
223            ),
224            DataType::Decimal128(_, _) => impl_decimal_array_negative!(
225                array,
226                Decimal128Type,
227                "Decimal128",
228                enable_ansi_mode
229            ),
230            DataType::Decimal256(_, _) => impl_decimal_array_negative!(
231                array,
232                Decimal256Type,
233                "Decimal256",
234                enable_ansi_mode
235            ),
236
237            // interval type - use checked negation in ANSI mode, wrapping in legacy mode
238            DataType::Interval(IntervalUnit::YearMonth) => {
239                impl_integer_array_negative!(
240                    array,
241                    IntervalYearMonthType,
242                    "IntervalYearMonth",
243                    enable_ansi_mode
244                )
245            }
246            DataType::Interval(IntervalUnit::DayTime) => {
247                let array = array.as_primitive::<IntervalDayTimeType>();
248                let result: PrimitiveArray<IntervalDayTimeType> = if enable_ansi_mode {
249                    array.try_unary(|x| {
250                        let days = x.days.checked_neg().ok_or_else(|| {
251                            (exec_err!(
252                                "IntervalDayTime overflow on negative (days: {})",
253                                x.days
254                            ) as Result<(), _>)
255                                .unwrap_err()
256                        })?;
257                        let milliseconds =
258                            x.milliseconds.checked_neg().ok_or_else(|| {
259                                (exec_err!(
260                                "IntervalDayTime overflow on negative (milliseconds: {})",
261                                x.milliseconds
262                            ) as Result<(), _>)
263                                .unwrap_err()
264                            })?;
265                        Ok::<_, arrow::error::ArrowError>(IntervalDayTime {
266                            days,
267                            milliseconds,
268                        })
269                    })?
270                } else {
271                    array.unary(|x| IntervalDayTime {
272                        days: x.days.wrapping_neg(),
273                        milliseconds: x.milliseconds.wrapping_neg(),
274                    })
275                };
276                Ok(ColumnarValue::Array(Arc::new(result)))
277            }
278            DataType::Interval(IntervalUnit::MonthDayNano) => {
279                let array = array.as_primitive::<IntervalMonthDayNanoType>();
280                let result: PrimitiveArray<IntervalMonthDayNanoType> = if enable_ansi_mode
281                {
282                    array.try_unary(|x| {
283                        let months = x.months.checked_neg().ok_or_else(|| {
284                            (exec_err!(
285                                "IntervalMonthDayNano overflow on negative (months: {})",
286                                x.months
287                            ) as Result<(), _>)
288                                .unwrap_err()
289                        })?;
290                        let days = x.days.checked_neg().ok_or_else(|| {
291                            (exec_err!(
292                                "IntervalMonthDayNano overflow on negative (days: {})",
293                                x.days
294                            ) as Result<(), _>)
295                                .unwrap_err()
296                        })?;
297                        let nanoseconds = x.nanoseconds.checked_neg().ok_or_else(|| {
298                            (exec_err!(
299                                "IntervalMonthDayNano overflow on negative (nanoseconds: {})",
300                                x.nanoseconds
301                            ) as Result<(), _>)
302                                .unwrap_err()
303                        })?;
304                        Ok::<_, arrow::error::ArrowError>(IntervalMonthDayNano {
305                            months,
306                            days,
307                            nanoseconds,
308                        })
309                    })?
310                } else {
311                    array.unary(|x| IntervalMonthDayNano {
312                        months: x.months.wrapping_neg(),
313                        days: x.days.wrapping_neg(),
314                        nanoseconds: x.nanoseconds.wrapping_neg(),
315                    })
316                };
317                Ok(ColumnarValue::Array(Arc::new(result)))
318            }
319
320            dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
321        },
322        ColumnarValue::Scalar(sv) => match sv {
323            ScalarValue::Null => Ok(arg.clone()),
324            _ if sv.is_null() => Ok(arg.clone()),
325
326            // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode
327            ScalarValue::Int8(Some(v)) => {
328                impl_integer_scalar_negative!(v, "Int8", Int8, enable_ansi_mode)
329            }
330            ScalarValue::Int16(Some(v)) => {
331                impl_integer_scalar_negative!(v, "Int16", Int16, enable_ansi_mode)
332            }
333            ScalarValue::Int32(Some(v)) => {
334                impl_integer_scalar_negative!(v, "Int32", Int32, enable_ansi_mode)
335            }
336            ScalarValue::Int64(Some(v)) => {
337                impl_integer_scalar_negative!(v, "Int64", Int64, enable_ansi_mode)
338            }
339
340            // Floating point - simple negation
341            ScalarValue::Float16(Some(v)) => {
342                Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v))))
343            }
344            ScalarValue::Float32(Some(v)) => {
345                Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v))))
346            }
347            ScalarValue::Float64(Some(v)) => {
348                Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v))))
349            }
350
351            // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode
352            ScalarValue::Decimal32(Some(v), precision, scale) => {
353                impl_decimal_scalar_negative!(
354                    v,
355                    precision,
356                    scale,
357                    "Decimal32",
358                    Decimal32,
359                    enable_ansi_mode
360                )
361            }
362            ScalarValue::Decimal64(Some(v), precision, scale) => {
363                impl_decimal_scalar_negative!(
364                    v,
365                    precision,
366                    scale,
367                    "Decimal64",
368                    Decimal64,
369                    enable_ansi_mode
370                )
371            }
372            ScalarValue::Decimal128(Some(v), precision, scale) => {
373                impl_decimal_scalar_negative!(
374                    v,
375                    precision,
376                    scale,
377                    "Decimal128",
378                    Decimal128,
379                    enable_ansi_mode
380                )
381            }
382            ScalarValue::Decimal256(Some(v), precision, scale) => {
383                impl_decimal_scalar_negative!(
384                    v,
385                    precision,
386                    scale,
387                    "Decimal256",
388                    Decimal256,
389                    enable_ansi_mode
390                )
391            }
392
393            //interval type - use checked negation in ANSI mode, wrapping in legacy mode
394            ScalarValue::IntervalYearMonth(Some(v)) => {
395                impl_integer_scalar_negative!(
396                    v,
397                    "IntervalYearMonth",
398                    IntervalYearMonth,
399                    enable_ansi_mode
400                )
401            }
402            ScalarValue::IntervalDayTime(Some(v)) => {
403                let result = if enable_ansi_mode {
404                    let days = v.days.checked_neg().ok_or_else(|| {
405                        (exec_err!(
406                            "IntervalDayTime overflow on negative (days: {})",
407                            v.days
408                        ) as Result<(), _>)
409                            .unwrap_err()
410                    })?;
411                    let milliseconds = v.milliseconds.checked_neg().ok_or_else(|| {
412                        (exec_err!(
413                            "IntervalDayTime overflow on negative (milliseconds: {})",
414                            v.milliseconds
415                        ) as Result<(), _>)
416                            .unwrap_err()
417                    })?;
418                    IntervalDayTime { days, milliseconds }
419                } else {
420                    IntervalDayTime {
421                        days: v.days.wrapping_neg(),
422                        milliseconds: v.milliseconds.wrapping_neg(),
423                    }
424                };
425                Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(
426                    result,
427                ))))
428            }
429            ScalarValue::IntervalMonthDayNano(Some(v)) => {
430                let result = if enable_ansi_mode {
431                    let months = v.months.checked_neg().ok_or_else(|| {
432                        (exec_err!(
433                            "IntervalMonthDayNano overflow on negative (months: {})",
434                            v.months
435                        ) as Result<(), _>)
436                            .unwrap_err()
437                    })?;
438                    let days = v.days.checked_neg().ok_or_else(|| {
439                        (exec_err!(
440                            "IntervalMonthDayNano overflow on negative (days: {})",
441                            v.days
442                        ) as Result<(), _>)
443                            .unwrap_err()
444                    })?;
445                    let nanoseconds = v.nanoseconds.checked_neg().ok_or_else(|| {
446                        (exec_err!(
447                            "IntervalMonthDayNano overflow on negative (nanoseconds: {})",
448                            v.nanoseconds
449                        ) as Result<(), _>)
450                            .unwrap_err()
451                    })?;
452                    IntervalMonthDayNano {
453                        months,
454                        days,
455                        nanoseconds,
456                    }
457                } else {
458                    IntervalMonthDayNano {
459                        months: v.months.wrapping_neg(),
460                        days: v.days.wrapping_neg(),
461                        nanoseconds: v.nanoseconds.wrapping_neg(),
462                    }
463                };
464                Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(
465                    Some(result),
466                )))
467            }
468
469            dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
470        },
471    }
472}