1use arrow::array::types::*;
19use arrow::array::*;
20use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit};
21use bigdecimal::num_traits::WrappingNeg;
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
24use datafusion_expr::{
25 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
26 Volatility,
27};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, PartialEq, Eq, Hash)]
45pub struct SparkNegative {
46 signature: Signature,
47}
48
49impl Default for SparkNegative {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl SparkNegative {
56 pub fn new() -> Self {
57 Self {
58 signature: Signature {
59 type_signature: TypeSignature::OneOf(vec![
60 TypeSignature::Numeric(1),
62 TypeSignature::Uniform(
64 1,
65 vec![
66 DataType::Interval(IntervalUnit::YearMonth),
67 DataType::Interval(IntervalUnit::DayTime),
68 DataType::Interval(IntervalUnit::MonthDayNano),
69 ],
70 ),
71 ]),
72 volatility: Volatility::Immutable,
73 parameter_names: None,
74 },
75 }
76 }
77}
78
79impl ScalarUDFImpl for SparkNegative {
80 fn as_any(&self) -> &dyn Any {
81 self
82 }
83
84 fn name(&self) -> &str {
85 "negative"
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93 Ok(arg_types[0].clone())
94 }
95
96 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
97 spark_negative(&args.args, args.config_options.execution.enable_ansi_mode)
98 }
99}
100
101macro_rules! impl_integer_array_negative {
103 ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{
104 let array = $array.as_primitive::<$type>();
105 let result: PrimitiveArray<$type> = if $enable_ansi_mode {
106 array.try_unary(|x| {
107 x.checked_neg().ok_or_else(|| {
108 (exec_err!("{} overflow on negative({x})", $type_name)
109 as Result<(), _>)
110 .unwrap_err()
111 })
112 })?
113 } else {
114 array.unary(|x| x.wrapping_neg())
115 };
116 Ok(ColumnarValue::Array(Arc::new(result)))
117 }};
118}
119
120macro_rules! impl_float_array_negative {
122 ($array:expr, $type:ty) => {{
123 let array = $array.as_primitive::<$type>();
124 let result: PrimitiveArray<$type> = array.unary(|x| -x);
125 Ok(ColumnarValue::Array(Arc::new(result)))
126 }};
127}
128
129macro_rules! impl_decimal_array_negative {
131 ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{
132 let array = $array.as_primitive::<$type>();
133 let result: PrimitiveArray<$type> = if $enable_ansi_mode {
134 array
135 .try_unary(|x| {
136 x.checked_neg().ok_or_else(|| {
137 (exec_err!("{} overflow on negative({x})", $type_name)
138 as Result<(), _>)
139 .unwrap_err()
140 })
141 })?
142 .with_data_type(array.data_type().clone())
143 } else {
144 array.unary(|x| x.wrapping_neg())
145 };
146 Ok(ColumnarValue::Array(Arc::new(result)))
147 }};
148}
149
150macro_rules! impl_integer_scalar_negative {
152 ($v:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{
153 let result = if $enable_ansi_mode {
154 $v.checked_neg().ok_or_else(|| {
155 (exec_err!("{} overflow on negative({})", $type_name, $v)
156 as Result<(), _>)
157 .unwrap_err()
158 })?
159 } else {
160 $v.wrapping_neg()
161 };
162 Ok(ColumnarValue::Scalar(ScalarValue::$variant(Some(result))))
163 }};
164}
165
166macro_rules! impl_decimal_scalar_negative {
168 ($v:expr, $precision:expr, $scale:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{
169 let result = if $enable_ansi_mode {
170 $v.checked_neg().ok_or_else(|| {
171 (exec_err!("{} overflow on negative({})", $type_name, $v)
172 as Result<(), _>)
173 .unwrap_err()
174 })?
175 } else {
176 $v.wrapping_neg()
177 };
178 Ok(ColumnarValue::Scalar(ScalarValue::$variant(
179 Some(result),
180 *$precision,
181 *$scale,
182 )))
183 }};
184}
185
186fn spark_negative(
188 args: &[ColumnarValue],
189 enable_ansi_mode: bool,
190) -> Result<ColumnarValue> {
191 let [arg] = take_function_args("negative", args)?;
192
193 match arg {
194 ColumnarValue::Array(array) => match array.data_type() {
195 DataType::Null => Ok(arg.clone()),
196
197 DataType::Int8 => {
199 impl_integer_array_negative!(array, Int8Type, "Int8", enable_ansi_mode)
200 }
201 DataType::Int16 => {
202 impl_integer_array_negative!(array, Int16Type, "Int16", enable_ansi_mode)
203 }
204 DataType::Int32 => {
205 impl_integer_array_negative!(array, Int32Type, "Int32", enable_ansi_mode)
206 }
207 DataType::Int64 => {
208 impl_integer_array_negative!(array, Int64Type, "Int64", enable_ansi_mode)
209 }
210
211 DataType::Float16 => impl_float_array_negative!(array, Float16Type),
213 DataType::Float32 => impl_float_array_negative!(array, Float32Type),
214 DataType::Float64 => impl_float_array_negative!(array, Float64Type),
215
216 DataType::Decimal32(_, _) => impl_decimal_array_negative!(
218 array,
219 Decimal32Type,
220 "Decimal32",
221 enable_ansi_mode
222 ),
223 DataType::Decimal64(_, _) => impl_decimal_array_negative!(
224 array,
225 Decimal64Type,
226 "Decimal64",
227 enable_ansi_mode
228 ),
229 DataType::Decimal128(_, _) => impl_decimal_array_negative!(
230 array,
231 Decimal128Type,
232 "Decimal128",
233 enable_ansi_mode
234 ),
235 DataType::Decimal256(_, _) => impl_decimal_array_negative!(
236 array,
237 Decimal256Type,
238 "Decimal256",
239 enable_ansi_mode
240 ),
241
242 DataType::Interval(IntervalUnit::YearMonth) => {
244 impl_integer_array_negative!(
245 array,
246 IntervalYearMonthType,
247 "IntervalYearMonth",
248 enable_ansi_mode
249 )
250 }
251 DataType::Interval(IntervalUnit::DayTime) => {
252 let array = array.as_primitive::<IntervalDayTimeType>();
253 let result: PrimitiveArray<IntervalDayTimeType> = if enable_ansi_mode {
254 array.try_unary(|x| {
255 let days = x.days.checked_neg().ok_or_else(|| {
256 (exec_err!(
257 "IntervalDayTime overflow on negative (days: {})",
258 x.days
259 ) as Result<(), _>)
260 .unwrap_err()
261 })?;
262 let milliseconds =
263 x.milliseconds.checked_neg().ok_or_else(|| {
264 (exec_err!(
265 "IntervalDayTime overflow on negative (milliseconds: {})",
266 x.milliseconds
267 ) as Result<(), _>)
268 .unwrap_err()
269 })?;
270 Ok::<_, arrow::error::ArrowError>(IntervalDayTime {
271 days,
272 milliseconds,
273 })
274 })?
275 } else {
276 array.unary(|x| IntervalDayTime {
277 days: x.days.wrapping_neg(),
278 milliseconds: x.milliseconds.wrapping_neg(),
279 })
280 };
281 Ok(ColumnarValue::Array(Arc::new(result)))
282 }
283 DataType::Interval(IntervalUnit::MonthDayNano) => {
284 let array = array.as_primitive::<IntervalMonthDayNanoType>();
285 let result: PrimitiveArray<IntervalMonthDayNanoType> = if enable_ansi_mode
286 {
287 array.try_unary(|x| {
288 let months = x.months.checked_neg().ok_or_else(|| {
289 (exec_err!(
290 "IntervalMonthDayNano overflow on negative (months: {})",
291 x.months
292 ) as Result<(), _>)
293 .unwrap_err()
294 })?;
295 let days = x.days.checked_neg().ok_or_else(|| {
296 (exec_err!(
297 "IntervalMonthDayNano overflow on negative (days: {})",
298 x.days
299 ) as Result<(), _>)
300 .unwrap_err()
301 })?;
302 let nanoseconds = x.nanoseconds.checked_neg().ok_or_else(|| {
303 (exec_err!(
304 "IntervalMonthDayNano overflow on negative (nanoseconds: {})",
305 x.nanoseconds
306 ) as Result<(), _>)
307 .unwrap_err()
308 })?;
309 Ok::<_, arrow::error::ArrowError>(IntervalMonthDayNano {
310 months,
311 days,
312 nanoseconds,
313 })
314 })?
315 } else {
316 array.unary(|x| IntervalMonthDayNano {
317 months: x.months.wrapping_neg(),
318 days: x.days.wrapping_neg(),
319 nanoseconds: x.nanoseconds.wrapping_neg(),
320 })
321 };
322 Ok(ColumnarValue::Array(Arc::new(result)))
323 }
324
325 dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
326 },
327 ColumnarValue::Scalar(sv) => match sv {
328 ScalarValue::Null => Ok(arg.clone()),
329 _ if sv.is_null() => Ok(arg.clone()),
330
331 ScalarValue::Int8(Some(v)) => {
333 impl_integer_scalar_negative!(v, "Int8", Int8, enable_ansi_mode)
334 }
335 ScalarValue::Int16(Some(v)) => {
336 impl_integer_scalar_negative!(v, "Int16", Int16, enable_ansi_mode)
337 }
338 ScalarValue::Int32(Some(v)) => {
339 impl_integer_scalar_negative!(v, "Int32", Int32, enable_ansi_mode)
340 }
341 ScalarValue::Int64(Some(v)) => {
342 impl_integer_scalar_negative!(v, "Int64", Int64, enable_ansi_mode)
343 }
344
345 ScalarValue::Float16(Some(v)) => {
347 Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v))))
348 }
349 ScalarValue::Float32(Some(v)) => {
350 Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v))))
351 }
352 ScalarValue::Float64(Some(v)) => {
353 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v))))
354 }
355
356 ScalarValue::Decimal32(Some(v), precision, scale) => {
358 impl_decimal_scalar_negative!(
359 v,
360 precision,
361 scale,
362 "Decimal32",
363 Decimal32,
364 enable_ansi_mode
365 )
366 }
367 ScalarValue::Decimal64(Some(v), precision, scale) => {
368 impl_decimal_scalar_negative!(
369 v,
370 precision,
371 scale,
372 "Decimal64",
373 Decimal64,
374 enable_ansi_mode
375 )
376 }
377 ScalarValue::Decimal128(Some(v), precision, scale) => {
378 impl_decimal_scalar_negative!(
379 v,
380 precision,
381 scale,
382 "Decimal128",
383 Decimal128,
384 enable_ansi_mode
385 )
386 }
387 ScalarValue::Decimal256(Some(v), precision, scale) => {
388 impl_decimal_scalar_negative!(
389 v,
390 precision,
391 scale,
392 "Decimal256",
393 Decimal256,
394 enable_ansi_mode
395 )
396 }
397
398 ScalarValue::IntervalYearMonth(Some(v)) => {
400 impl_integer_scalar_negative!(
401 v,
402 "IntervalYearMonth",
403 IntervalYearMonth,
404 enable_ansi_mode
405 )
406 }
407 ScalarValue::IntervalDayTime(Some(v)) => {
408 let result = if enable_ansi_mode {
409 let days = v.days.checked_neg().ok_or_else(|| {
410 (exec_err!(
411 "IntervalDayTime overflow on negative (days: {})",
412 v.days
413 ) as Result<(), _>)
414 .unwrap_err()
415 })?;
416 let milliseconds = v.milliseconds.checked_neg().ok_or_else(|| {
417 (exec_err!(
418 "IntervalDayTime overflow on negative (milliseconds: {})",
419 v.milliseconds
420 ) as Result<(), _>)
421 .unwrap_err()
422 })?;
423 IntervalDayTime { days, milliseconds }
424 } else {
425 IntervalDayTime {
426 days: v.days.wrapping_neg(),
427 milliseconds: v.milliseconds.wrapping_neg(),
428 }
429 };
430 Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(
431 result,
432 ))))
433 }
434 ScalarValue::IntervalMonthDayNano(Some(v)) => {
435 let result = if enable_ansi_mode {
436 let months = v.months.checked_neg().ok_or_else(|| {
437 (exec_err!(
438 "IntervalMonthDayNano overflow on negative (months: {})",
439 v.months
440 ) as Result<(), _>)
441 .unwrap_err()
442 })?;
443 let days = v.days.checked_neg().ok_or_else(|| {
444 (exec_err!(
445 "IntervalMonthDayNano overflow on negative (days: {})",
446 v.days
447 ) as Result<(), _>)
448 .unwrap_err()
449 })?;
450 let nanoseconds = v.nanoseconds.checked_neg().ok_or_else(|| {
451 (exec_err!(
452 "IntervalMonthDayNano overflow on negative (nanoseconds: {})",
453 v.nanoseconds
454 ) as Result<(), _>)
455 .unwrap_err()
456 })?;
457 IntervalMonthDayNano {
458 months,
459 days,
460 nanoseconds,
461 }
462 } else {
463 IntervalMonthDayNano {
464 months: v.months.wrapping_neg(),
465 days: v.days.wrapping_neg(),
466 nanoseconds: v.nanoseconds.wrapping_neg(),
467 }
468 };
469 Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(
470 Some(result),
471 )))
472 }
473
474 dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
475 },
476 }
477}