1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, AsArray, DurationMicrosecondBuilder, PrimitiveArray,
23};
24use arrow::datatypes::TimeUnit::Microsecond;
25use arrow::datatypes::{DataType, Field, FieldRef, Float64Type, Int32Type};
26use datafusion_common::types::{NativeType, logical_float64, logical_int32};
27use datafusion_common::{
28 DataFusionError, Result, ScalarValue, internal_err, plan_datafusion_err,
29};
30use datafusion_expr::{
31 Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
32 Signature, TypeSignature, TypeSignatureClass, Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35
36#[derive(Debug, PartialEq, Eq, Hash)]
37pub struct SparkMakeDtInterval {
38 signature: Signature,
39}
40
41impl Default for SparkMakeDtInterval {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl SparkMakeDtInterval {
48 pub fn new() -> Self {
49 let int32 = Coercion::new_implicit(
50 TypeSignatureClass::Native(logical_int32()),
51 vec![TypeSignatureClass::Integer],
52 NativeType::Int32,
53 );
54
55 let float64 = Coercion::new_implicit(
56 TypeSignatureClass::Native(logical_float64()),
57 vec![TypeSignatureClass::Numeric],
58 NativeType::Float64,
59 );
60
61 let variants = vec![
62 TypeSignature::Nullary,
63 TypeSignature::Coercible(vec![int32.clone()]),
65 TypeSignature::Coercible(vec![int32.clone(), int32.clone()]),
67 TypeSignature::Coercible(vec![int32.clone(), int32.clone(), int32.clone()]),
69 TypeSignature::Coercible(vec![
71 int32.clone(),
72 int32.clone(),
73 int32.clone(),
74 float64,
75 ]),
76 ];
77
78 Self {
79 signature: Signature::one_of(variants, Volatility::Immutable),
80 }
81 }
82}
83
84impl ScalarUDFImpl for SparkMakeDtInterval {
85 fn as_any(&self) -> &dyn Any {
86 self
87 }
88
89 fn name(&self) -> &str {
90 "make_dt_interval"
91 }
92
93 fn signature(&self) -> &Signature {
94 &self.signature
95 }
96
97 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
104 internal_err!("return_field_from_args should be used instead")
105 }
106
107 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
108 let has_non_finite_secs = args
109 .scalar_arguments
110 .get(3)
111 .and_then(|arg| {
112 arg.map(|scalar| match scalar {
113 ScalarValue::Float64(Some(v)) => !v.is_finite(),
114 ScalarValue::Float32(Some(v)) => !v.is_finite(),
115 _ => false,
116 })
117 })
118 .unwrap_or(false);
119 let nullable =
120 has_non_finite_secs || args.arg_fields.iter().any(|f| f.is_nullable());
121 Ok(Arc::new(Field::new(
122 self.name(),
123 DataType::Duration(Microsecond),
124 nullable,
125 )))
126 }
127
128 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
129 if args.args.is_empty() {
130 return Ok(ColumnarValue::Scalar(ScalarValue::DurationMicrosecond(
131 Some(0),
132 )));
133 }
134 if args.args.len() > 4 {
135 return Err(DataFusionError::Execution(format!(
136 "make_dt_interval expects between 0 and 4 arguments, got {}",
137 args.args.len()
138 )));
139 }
140 make_scalar_function(make_dt_interval_kernel, vec![])(&args.args)
141 }
142}
143
144fn make_dt_interval_kernel(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
145 let n_rows = args[0].len();
146 let days = args[0]
147 .as_primitive_opt::<Int32Type>()
148 .ok_or_else(|| plan_datafusion_err!("make_dt_interval arg[0] must be Int32"))?;
149 let hours: Option<&PrimitiveArray<Int32Type>> = args
150 .get(1)
151 .map(|a| {
152 a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
153 plan_datafusion_err!("make_dt_interval arg[1] must be Int32")
154 })
155 })
156 .transpose()?;
157 let mins: Option<&PrimitiveArray<Int32Type>> = args
158 .get(2)
159 .map(|a| {
160 a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
161 plan_datafusion_err!("make_dt_interval arg[2] must be Int32")
162 })
163 })
164 .transpose()?;
165 let secs: Option<&PrimitiveArray<Float64Type>> = args
166 .get(3)
167 .map(|a| {
168 a.as_primitive_opt::<Float64Type>().ok_or_else(|| {
169 plan_datafusion_err!("make_dt_interval arg[3] must be Float64")
170 })
171 })
172 .transpose()?;
173 let mut builder = DurationMicrosecondBuilder::with_capacity(n_rows);
174
175 for i in 0..n_rows {
176 let any_null_present = days.is_null(i)
178 || hours.as_ref().is_some_and(|a| a.is_null(i))
179 || mins.as_ref().is_some_and(|a| a.is_null(i))
180 || secs
181 .as_ref()
182 .is_some_and(|a| a.is_null(i) || !a.value(i).is_finite());
183
184 if any_null_present {
185 builder.append_null();
186 continue;
187 }
188
189 let d = days.value(i);
191 let h = hours.as_ref().map_or(0, |a| a.value(i));
192 let mi = mins.as_ref().map_or(0, |a| a.value(i));
193 let s = secs.as_ref().map_or(0.0, |a| a.value(i));
194
195 match make_interval_dt_nano(d, h, mi, s) {
196 Some(v) => builder.append_value(v),
197 None => {
198 builder.append_null();
199 continue;
200 }
201 }
202 }
203
204 Ok(Arc::new(builder.finish()))
205}
206fn make_interval_dt_nano(day: i32, hour: i32, min: i32, sec: f64) -> Option<i64> {
207 const HOURS_PER_DAY: i32 = 24;
208 const MINS_PER_HOUR: i32 = 60;
209 const SECS_PER_MINUTE: i64 = 60;
210 const MICROS_PER_SEC: i64 = 1_000_000;
211
212 let total_hours: i32 = day
213 .checked_mul(HOURS_PER_DAY)
214 .and_then(|v| v.checked_add(hour))?;
215
216 let total_mins: i32 = total_hours
217 .checked_mul(MINS_PER_HOUR)
218 .and_then(|v| v.checked_add(min))?;
219
220 let mut sec_whole: i64 = sec.trunc() as i64;
221 let sec_frac: f64 = sec - (sec_whole as f64);
222 let mut frac_us: i64 = (sec_frac * (MICROS_PER_SEC as f64)).round() as i64;
223
224 if frac_us.abs() >= MICROS_PER_SEC {
225 if frac_us > 0 {
226 frac_us -= MICROS_PER_SEC;
227 sec_whole = sec_whole.checked_add(1)?;
228 } else {
229 frac_us += MICROS_PER_SEC;
230 sec_whole = sec_whole.checked_sub(1)?;
231 }
232 }
233
234 let total_secs: i64 = (total_mins as i64)
235 .checked_mul(SECS_PER_MINUTE)
236 .and_then(|v| v.checked_add(sec_whole))?;
237
238 let total_us = total_secs
239 .checked_mul(MICROS_PER_SEC)
240 .and_then(|v| v.checked_add(frac_us))?;
241
242 Some(total_us)
243}
244
245#[cfg(test)]
246mod tests {
247 use std::sync::Arc;
248
249 use arrow::array::{DurationMicrosecondArray, Float64Array, Int32Array};
250 use arrow::datatypes::DataType::Duration;
251 use arrow::datatypes::{DataType, Field, TimeUnit::Microsecond};
252 use datafusion_common::{DataFusionError, Result, internal_datafusion_err};
253 use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs};
254
255 use super::*;
256
257 fn run_make_dt_interval(arrs: Vec<ArrayRef>) -> Result<ArrayRef> {
258 make_dt_interval_kernel(&arrs)
259 }
260
261 #[test]
262 fn nulls_propagate_per_row() -> Result<()> {
263 let days = Arc::new(Int32Array::from(vec![
264 None,
265 Some(2),
266 Some(3),
267 Some(4),
268 Some(5),
269 Some(6),
270 Some(7),
271 ])) as ArrayRef;
272
273 let hours = Arc::new(Int32Array::from(vec![
274 Some(1),
275 None,
276 Some(3),
277 Some(4),
278 Some(5),
279 Some(6),
280 Some(7),
281 ])) as ArrayRef;
282
283 let mins = Arc::new(Int32Array::from(vec![
284 Some(1),
285 Some(2),
286 None,
287 Some(4),
288 Some(5),
289 Some(6),
290 Some(7),
291 ])) as ArrayRef;
292
293 let secs = Arc::new(Float64Array::from(vec![
294 Some(1.0),
295 Some(2.0),
296 Some(3.0),
297 None,
298 Some(f64::NAN),
299 Some(f64::INFINITY),
300 Some(f64::NEG_INFINITY),
301 ])) as ArrayRef;
302
303 let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
304 let out = out
305 .as_any()
306 .downcast_ref::<DurationMicrosecondArray>()
307 .ok_or_else(|| {
308 internal_datafusion_err!("expected DurationMicrosecondArray")
309 })?;
310
311 for i in 0..out.len() {
312 assert!(out.is_null(i), "row {i} should be NULL");
313 }
314 Ok(())
315 }
316
317 #[test]
318 fn return_field_respects_nullability() -> Result<()> {
319 let udf = SparkMakeDtInterval::new();
320
321 let arg_fields = vec![
323 Arc::new(Field::new("days", DataType::Int32, true)),
324 Arc::new(Field::new("hours", DataType::Int32, true)),
325 Arc::new(Field::new("mins", DataType::Int32, true)),
326 Arc::new(Field::new("secs", DataType::Float64, true)),
327 ];
328
329 let out = udf.return_field_from_args(ReturnFieldArgs {
330 arg_fields: &arg_fields,
331 scalar_arguments: &[None, None, None, None],
332 })?;
333 assert!(out.is_nullable());
334 assert_eq!(out.data_type(), &Duration(Microsecond));
335
336 let non_nullable_arg_fields = vec![
338 Arc::new(Field::new("days", DataType::Int32, false)),
339 Arc::new(Field::new("hours", DataType::Int32, false)),
340 Arc::new(Field::new("mins", DataType::Int32, false)),
341 Arc::new(Field::new("secs", DataType::Float64, false)),
342 ];
343
344 let out = udf.return_field_from_args(ReturnFieldArgs {
345 arg_fields: &non_nullable_arg_fields,
346 scalar_arguments: &[None, None, None, None],
347 })?;
348 assert!(!out.is_nullable());
349
350 let scalar_values =
352 [None, None, None, Some(ScalarValue::Float64(Some(f64::NAN)))];
353 let scalar_refs = scalar_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
354 let out = udf.return_field_from_args(ReturnFieldArgs {
355 arg_fields: &non_nullable_arg_fields,
356 scalar_arguments: &scalar_refs,
357 })?;
358 assert!(out.is_nullable());
359
360 let out = udf.return_field_from_args(ReturnFieldArgs {
362 arg_fields: &[],
363 scalar_arguments: &[],
364 })?;
365 assert!(!out.is_nullable());
366
367 Ok(())
368 }
369
370 #[test]
371 fn error_months_overflow_should_be_null() -> Result<()> {
372 let days = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef;
375
376 let hours = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
377
378 let mins = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
379
380 let secs = Arc::new(Float64Array::from(vec![Some(1.0)])) as ArrayRef;
381
382 let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
383 let out = out
384 .as_any()
385 .downcast_ref::<DurationMicrosecondArray>()
386 .ok_or_else(|| {
387 internal_datafusion_err!("expected DurationMicrosecondArray")
388 })?;
389
390 for i in 0..out.len() {
391 assert!(out.is_null(i), "row {i} should be NULL");
392 }
393
394 Ok(())
395 }
396
397 fn invoke_make_dt_interval_with_args(
398 args: Vec<ColumnarValue>,
399 number_rows: usize,
400 ) -> Result<ColumnarValue, DataFusionError> {
401 let arg_fields = args
402 .iter()
403 .map(|arg| Field::new("a", arg.data_type(), true).into())
404 .collect::<Vec<_>>();
405 let args = ScalarFunctionArgs {
406 args,
407 arg_fields,
408 number_rows,
409 return_field: Field::new("f", Duration(Microsecond), true).into(),
410 config_options: Arc::new(Default::default()),
411 };
412 SparkMakeDtInterval::new().invoke_with_args(args)
413 }
414
415 #[test]
416 fn zero_args_returns_zero_duration() -> Result<()> {
417 let number_rows: usize = 3;
418
419 let res: ColumnarValue = invoke_make_dt_interval_with_args(vec![], number_rows)?;
420 let arr = res.into_array(number_rows)?;
421 let arr = arr
422 .as_any()
423 .downcast_ref::<DurationMicrosecondArray>()
424 .ok_or_else(|| {
425 internal_datafusion_err!("expected DurationMicrosecondArray")
426 })?;
427
428 assert_eq!(arr.len(), number_rows);
429 for i in 0..number_rows {
430 assert!(!arr.is_null(i));
431 assert_eq!(arr.value(i), 0_i64);
432 }
433 Ok(())
434 }
435
436 #[test]
437 fn one_day_minus_24_hours_equals_zero() -> Result<()> {
438 let arr_days = Arc::new(Int32Array::from(vec![Some(1), Some(-1)])) as ArrayRef;
439 let arr_hours = Arc::new(Int32Array::from(vec![Some(-24), Some(24)])) as ArrayRef;
440 let arr_mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
441 let arr_secs =
442 Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
443
444 let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
445 let out = out
446 .as_any()
447 .downcast_ref::<DurationMicrosecondArray>()
448 .ok_or_else(|| {
449 internal_datafusion_err!("expected DurationMicrosecondArray")
450 })?;
451
452 assert_eq!(out.len(), 2);
453 assert_eq!(out.null_count(), 0);
454 assert_eq!(out.value(0), 0_i64);
455 assert_eq!(out.value(1), 0_i64);
456 Ok(())
457 }
458
459 #[test]
460 fn one_hour_minus_60_mins_equals_zero() -> Result<()> {
461 let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
462 let arr_hours = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
463 let arr_mins = Arc::new(Int32Array::from(vec![Some(60), Some(-60)])) as ArrayRef;
464 let arr_secs =
465 Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
466
467 let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
468 let out = out
469 .as_any()
470 .downcast_ref::<DurationMicrosecondArray>()
471 .ok_or_else(|| {
472 internal_datafusion_err!("expected DurationMicrosecondArray")
473 })?;
474
475 assert_eq!(out.len(), 2);
476 assert_eq!(out.null_count(), 0);
477 assert_eq!(out.value(0), 0_i64);
478 assert_eq!(out.value(1), 0_i64);
479 Ok(())
480 }
481
482 #[test]
483 fn one_mins_minus_60_secs_equals_zero() -> Result<()> {
484 let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
485 let arr_hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
486 let arr_mins = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
487 let arr_secs =
488 Arc::new(Float64Array::from(vec![Some(60.0), Some(-60.0)])) as ArrayRef;
489
490 let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
491 let out = out
492 .as_any()
493 .downcast_ref::<DurationMicrosecondArray>()
494 .ok_or_else(|| {
495 internal_datafusion_err!("expected DurationMicrosecondArray")
496 })?;
497
498 assert_eq!(out.len(), 2);
499 assert_eq!(out.null_count(), 0);
500 assert_eq!(out.value(0), 0_i64);
501 assert_eq!(out.value(1), 0_i64);
502 Ok(())
503 }
504
505 #[test]
506 fn frac_carries_up_to_next_second_positive() -> Result<()> {
507 let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
509 let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
510 let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
511 let secs = Arc::new(Float64Array::from(vec![
512 Some(0.999_999_5),
513 Some(0.999_999_4),
514 ])) as ArrayRef;
515
516 let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
517 let out = out
518 .as_any()
519 .downcast_ref::<DurationMicrosecondArray>()
520 .ok_or_else(|| {
521 internal_datafusion_err!("expected DurationMicrosecondArray")
522 })?;
523
524 assert_eq!(out.len(), 2);
525 assert_eq!(out.value(0), 1_000_000);
526 assert_eq!(out.value(1), 999_999);
527 Ok(())
528 }
529
530 #[test]
531 fn frac_carries_down_to_prev_second_negative() -> Result<()> {
532 let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
534 let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
535 let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
536 let secs = Arc::new(Float64Array::from(vec![
537 Some(-0.999_999_5),
538 Some(-0.999_999_4),
539 ])) as ArrayRef;
540
541 let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
542 let out = out
543 .as_any()
544 .downcast_ref::<DurationMicrosecondArray>()
545 .ok_or_else(|| {
546 internal_datafusion_err!("expected DurationMicrosecondArray")
547 })?;
548
549 assert_eq!(out.len(), 2);
550 assert_eq!(out.value(0), -1_000_000);
551 assert_eq!(out.value(1), -999_999);
552 Ok(())
553 }
554
555 #[test]
556 fn no_more_than_4_params() -> Result<()> {
557 let udf = SparkMakeDtInterval::new();
558
559 let args = vec![
561 ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
562 ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
563 ColumnarValue::Scalar(ScalarValue::Int32(Some(3))),
564 ColumnarValue::Scalar(ScalarValue::Float64(Some(4.0))),
565 ColumnarValue::Scalar(ScalarValue::Int32(Some(5))),
566 ];
567
568 let arg_fields = args
569 .iter()
570 .map(|arg| Field::new("a", arg.data_type(), true).into())
571 .collect::<Vec<_>>();
572
573 let func_args = ScalarFunctionArgs {
574 args,
575 arg_fields,
576 number_rows: 1,
577 return_field: Field::new("f", Duration(Microsecond), true).into(),
578 config_options: Arc::new(Default::default()),
579 };
580
581 let res = udf.invoke_with_args(func_args);
582
583 assert!(
584 matches!(res, Err(DataFusionError::Execution(_))),
585 "make_dt_interval should return execution error for more than 4 arguments"
586 );
587
588 Ok(())
589 }
590}