1use arrow::array::types::*;
19use arrow::array::*;
20use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit};
21use bigdecimal::num_traits::WrappingNeg;
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
24use datafusion_expr::{
25 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
26 Volatility,
27};
28use std::sync::Arc;
29
30#[derive(Debug, PartialEq, Eq, Hash)]
44pub struct SparkNegative {
45 signature: Signature,
46}
47
48impl Default for SparkNegative {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl SparkNegative {
55 pub fn new() -> Self {
56 Self {
57 signature: Signature {
58 type_signature: TypeSignature::OneOf(vec![
59 TypeSignature::Numeric(1),
61 TypeSignature::Uniform(
63 1,
64 vec![
65 DataType::Interval(IntervalUnit::YearMonth),
66 DataType::Interval(IntervalUnit::DayTime),
67 DataType::Interval(IntervalUnit::MonthDayNano),
68 ],
69 ),
70 ]),
71 volatility: Volatility::Immutable,
72 parameter_names: None,
73 },
74 }
75 }
76}
77
78impl ScalarUDFImpl for SparkNegative {
79 fn name(&self) -> &str {
80 "negative"
81 }
82
83 fn signature(&self) -> &Signature {
84 &self.signature
85 }
86
87 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
88 Ok(arg_types[0].clone())
89 }
90
91 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92 spark_negative(&args.args, args.config_options.execution.enable_ansi_mode)
93 }
94}
95
96macro_rules! impl_integer_array_negative {
98 ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{
99 let array = $array.as_primitive::<$type>();
100 let result: PrimitiveArray<$type> = if $enable_ansi_mode {
101 array.try_unary(|x| {
102 x.checked_neg().ok_or_else(|| {
103 (exec_err!("{} overflow on negative({x})", $type_name)
104 as Result<(), _>)
105 .unwrap_err()
106 })
107 })?
108 } else {
109 array.unary(|x| x.wrapping_neg())
110 };
111 Ok(ColumnarValue::Array(Arc::new(result)))
112 }};
113}
114
115macro_rules! impl_float_array_negative {
117 ($array:expr, $type:ty) => {{
118 let array = $array.as_primitive::<$type>();
119 let result: PrimitiveArray<$type> = array.unary(|x| -x);
120 Ok(ColumnarValue::Array(Arc::new(result)))
121 }};
122}
123
124macro_rules! impl_decimal_array_negative {
126 ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{
127 let array = $array.as_primitive::<$type>();
128 let result: PrimitiveArray<$type> = if $enable_ansi_mode {
129 array
130 .try_unary(|x| {
131 x.checked_neg().ok_or_else(|| {
132 (exec_err!("{} overflow on negative({x})", $type_name)
133 as Result<(), _>)
134 .unwrap_err()
135 })
136 })?
137 .with_data_type(array.data_type().clone())
138 } else {
139 array.unary(|x| x.wrapping_neg())
140 };
141 Ok(ColumnarValue::Array(Arc::new(result)))
142 }};
143}
144
145macro_rules! impl_integer_scalar_negative {
147 ($v:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{
148 let result = if $enable_ansi_mode {
149 $v.checked_neg().ok_or_else(|| {
150 (exec_err!("{} overflow on negative({})", $type_name, $v)
151 as Result<(), _>)
152 .unwrap_err()
153 })?
154 } else {
155 $v.wrapping_neg()
156 };
157 Ok(ColumnarValue::Scalar(ScalarValue::$variant(Some(result))))
158 }};
159}
160
161macro_rules! impl_decimal_scalar_negative {
163 ($v:expr, $precision:expr, $scale:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{
164 let result = if $enable_ansi_mode {
165 $v.checked_neg().ok_or_else(|| {
166 (exec_err!("{} overflow on negative({})", $type_name, $v)
167 as Result<(), _>)
168 .unwrap_err()
169 })?
170 } else {
171 $v.wrapping_neg()
172 };
173 Ok(ColumnarValue::Scalar(ScalarValue::$variant(
174 Some(result),
175 *$precision,
176 *$scale,
177 )))
178 }};
179}
180
181fn spark_negative(
183 args: &[ColumnarValue],
184 enable_ansi_mode: bool,
185) -> Result<ColumnarValue> {
186 let [arg] = take_function_args("negative", args)?;
187
188 match arg {
189 ColumnarValue::Array(array) => match array.data_type() {
190 DataType::Null => Ok(arg.clone()),
191
192 DataType::Int8 => {
194 impl_integer_array_negative!(array, Int8Type, "Int8", enable_ansi_mode)
195 }
196 DataType::Int16 => {
197 impl_integer_array_negative!(array, Int16Type, "Int16", enable_ansi_mode)
198 }
199 DataType::Int32 => {
200 impl_integer_array_negative!(array, Int32Type, "Int32", enable_ansi_mode)
201 }
202 DataType::Int64 => {
203 impl_integer_array_negative!(array, Int64Type, "Int64", enable_ansi_mode)
204 }
205
206 DataType::Float16 => impl_float_array_negative!(array, Float16Type),
208 DataType::Float32 => impl_float_array_negative!(array, Float32Type),
209 DataType::Float64 => impl_float_array_negative!(array, Float64Type),
210
211 DataType::Decimal32(_, _) => impl_decimal_array_negative!(
213 array,
214 Decimal32Type,
215 "Decimal32",
216 enable_ansi_mode
217 ),
218 DataType::Decimal64(_, _) => impl_decimal_array_negative!(
219 array,
220 Decimal64Type,
221 "Decimal64",
222 enable_ansi_mode
223 ),
224 DataType::Decimal128(_, _) => impl_decimal_array_negative!(
225 array,
226 Decimal128Type,
227 "Decimal128",
228 enable_ansi_mode
229 ),
230 DataType::Decimal256(_, _) => impl_decimal_array_negative!(
231 array,
232 Decimal256Type,
233 "Decimal256",
234 enable_ansi_mode
235 ),
236
237 DataType::Interval(IntervalUnit::YearMonth) => {
239 impl_integer_array_negative!(
240 array,
241 IntervalYearMonthType,
242 "IntervalYearMonth",
243 enable_ansi_mode
244 )
245 }
246 DataType::Interval(IntervalUnit::DayTime) => {
247 let array = array.as_primitive::<IntervalDayTimeType>();
248 let result: PrimitiveArray<IntervalDayTimeType> = if enable_ansi_mode {
249 array.try_unary(|x| {
250 let days = x.days.checked_neg().ok_or_else(|| {
251 (exec_err!(
252 "IntervalDayTime overflow on negative (days: {})",
253 x.days
254 ) as Result<(), _>)
255 .unwrap_err()
256 })?;
257 let milliseconds =
258 x.milliseconds.checked_neg().ok_or_else(|| {
259 (exec_err!(
260 "IntervalDayTime overflow on negative (milliseconds: {})",
261 x.milliseconds
262 ) as Result<(), _>)
263 .unwrap_err()
264 })?;
265 Ok::<_, arrow::error::ArrowError>(IntervalDayTime {
266 days,
267 milliseconds,
268 })
269 })?
270 } else {
271 array.unary(|x| IntervalDayTime {
272 days: x.days.wrapping_neg(),
273 milliseconds: x.milliseconds.wrapping_neg(),
274 })
275 };
276 Ok(ColumnarValue::Array(Arc::new(result)))
277 }
278 DataType::Interval(IntervalUnit::MonthDayNano) => {
279 let array = array.as_primitive::<IntervalMonthDayNanoType>();
280 let result: PrimitiveArray<IntervalMonthDayNanoType> = if enable_ansi_mode
281 {
282 array.try_unary(|x| {
283 let months = x.months.checked_neg().ok_or_else(|| {
284 (exec_err!(
285 "IntervalMonthDayNano overflow on negative (months: {})",
286 x.months
287 ) as Result<(), _>)
288 .unwrap_err()
289 })?;
290 let days = x.days.checked_neg().ok_or_else(|| {
291 (exec_err!(
292 "IntervalMonthDayNano overflow on negative (days: {})",
293 x.days
294 ) as Result<(), _>)
295 .unwrap_err()
296 })?;
297 let nanoseconds = x.nanoseconds.checked_neg().ok_or_else(|| {
298 (exec_err!(
299 "IntervalMonthDayNano overflow on negative (nanoseconds: {})",
300 x.nanoseconds
301 ) as Result<(), _>)
302 .unwrap_err()
303 })?;
304 Ok::<_, arrow::error::ArrowError>(IntervalMonthDayNano {
305 months,
306 days,
307 nanoseconds,
308 })
309 })?
310 } else {
311 array.unary(|x| IntervalMonthDayNano {
312 months: x.months.wrapping_neg(),
313 days: x.days.wrapping_neg(),
314 nanoseconds: x.nanoseconds.wrapping_neg(),
315 })
316 };
317 Ok(ColumnarValue::Array(Arc::new(result)))
318 }
319
320 dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
321 },
322 ColumnarValue::Scalar(sv) => match sv {
323 ScalarValue::Null => Ok(arg.clone()),
324 _ if sv.is_null() => Ok(arg.clone()),
325
326 ScalarValue::Int8(Some(v)) => {
328 impl_integer_scalar_negative!(v, "Int8", Int8, enable_ansi_mode)
329 }
330 ScalarValue::Int16(Some(v)) => {
331 impl_integer_scalar_negative!(v, "Int16", Int16, enable_ansi_mode)
332 }
333 ScalarValue::Int32(Some(v)) => {
334 impl_integer_scalar_negative!(v, "Int32", Int32, enable_ansi_mode)
335 }
336 ScalarValue::Int64(Some(v)) => {
337 impl_integer_scalar_negative!(v, "Int64", Int64, enable_ansi_mode)
338 }
339
340 ScalarValue::Float16(Some(v)) => {
342 Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v))))
343 }
344 ScalarValue::Float32(Some(v)) => {
345 Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v))))
346 }
347 ScalarValue::Float64(Some(v)) => {
348 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v))))
349 }
350
351 ScalarValue::Decimal32(Some(v), precision, scale) => {
353 impl_decimal_scalar_negative!(
354 v,
355 precision,
356 scale,
357 "Decimal32",
358 Decimal32,
359 enable_ansi_mode
360 )
361 }
362 ScalarValue::Decimal64(Some(v), precision, scale) => {
363 impl_decimal_scalar_negative!(
364 v,
365 precision,
366 scale,
367 "Decimal64",
368 Decimal64,
369 enable_ansi_mode
370 )
371 }
372 ScalarValue::Decimal128(Some(v), precision, scale) => {
373 impl_decimal_scalar_negative!(
374 v,
375 precision,
376 scale,
377 "Decimal128",
378 Decimal128,
379 enable_ansi_mode
380 )
381 }
382 ScalarValue::Decimal256(Some(v), precision, scale) => {
383 impl_decimal_scalar_negative!(
384 v,
385 precision,
386 scale,
387 "Decimal256",
388 Decimal256,
389 enable_ansi_mode
390 )
391 }
392
393 ScalarValue::IntervalYearMonth(Some(v)) => {
395 impl_integer_scalar_negative!(
396 v,
397 "IntervalYearMonth",
398 IntervalYearMonth,
399 enable_ansi_mode
400 )
401 }
402 ScalarValue::IntervalDayTime(Some(v)) => {
403 let result = if enable_ansi_mode {
404 let days = v.days.checked_neg().ok_or_else(|| {
405 (exec_err!(
406 "IntervalDayTime overflow on negative (days: {})",
407 v.days
408 ) as Result<(), _>)
409 .unwrap_err()
410 })?;
411 let milliseconds = v.milliseconds.checked_neg().ok_or_else(|| {
412 (exec_err!(
413 "IntervalDayTime overflow on negative (milliseconds: {})",
414 v.milliseconds
415 ) as Result<(), _>)
416 .unwrap_err()
417 })?;
418 IntervalDayTime { days, milliseconds }
419 } else {
420 IntervalDayTime {
421 days: v.days.wrapping_neg(),
422 milliseconds: v.milliseconds.wrapping_neg(),
423 }
424 };
425 Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(
426 result,
427 ))))
428 }
429 ScalarValue::IntervalMonthDayNano(Some(v)) => {
430 let result = if enable_ansi_mode {
431 let months = v.months.checked_neg().ok_or_else(|| {
432 (exec_err!(
433 "IntervalMonthDayNano overflow on negative (months: {})",
434 v.months
435 ) as Result<(), _>)
436 .unwrap_err()
437 })?;
438 let days = v.days.checked_neg().ok_or_else(|| {
439 (exec_err!(
440 "IntervalMonthDayNano overflow on negative (days: {})",
441 v.days
442 ) as Result<(), _>)
443 .unwrap_err()
444 })?;
445 let nanoseconds = v.nanoseconds.checked_neg().ok_or_else(|| {
446 (exec_err!(
447 "IntervalMonthDayNano overflow on negative (nanoseconds: {})",
448 v.nanoseconds
449 ) as Result<(), _>)
450 .unwrap_err()
451 })?;
452 IntervalMonthDayNano {
453 months,
454 days,
455 nanoseconds,
456 }
457 } else {
458 IntervalMonthDayNano {
459 months: v.months.wrapping_neg(),
460 days: v.days.wrapping_neg(),
461 nanoseconds: v.nanoseconds.wrapping_neg(),
462 }
463 };
464 Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(
465 Some(result),
466 )))
467 }
468
469 dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
470 },
471 }
472}