1use std::sync::Arc;
19
20use arrow::array::{
21 Array, ArrayRef, AsArray, DurationMicrosecondBuilder, PrimitiveArray,
22};
23use arrow::datatypes::TimeUnit::Microsecond;
24use arrow::datatypes::{DataType, Field, FieldRef, Float64Type, Int32Type};
25use datafusion_common::types::{NativeType, logical_float64, logical_int32};
26use datafusion_common::{
27 DataFusionError, Result, ScalarValue, internal_err, plan_datafusion_err,
28};
29use datafusion_expr::{
30 Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
31 Signature, TypeSignature, TypeSignatureClass, Volatility,
32};
33use datafusion_functions::utils::make_scalar_function;
34
35#[derive(Debug, PartialEq, Eq, Hash)]
36pub struct SparkMakeDtInterval {
37 signature: Signature,
38}
39
40impl Default for SparkMakeDtInterval {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SparkMakeDtInterval {
47 pub fn new() -> Self {
48 let int32 = Coercion::new_implicit(
49 TypeSignatureClass::Native(logical_int32()),
50 vec![TypeSignatureClass::Integer],
51 NativeType::Int32,
52 );
53
54 let float64 = Coercion::new_implicit(
55 TypeSignatureClass::Native(logical_float64()),
56 vec![TypeSignatureClass::Numeric],
57 NativeType::Float64,
58 );
59
60 let variants = vec![
61 TypeSignature::Nullary,
62 TypeSignature::Coercible(vec![int32.clone()]),
64 TypeSignature::Coercible(vec![int32.clone(), int32.clone()]),
66 TypeSignature::Coercible(vec![int32.clone(), int32.clone(), int32.clone()]),
68 TypeSignature::Coercible(vec![
70 int32.clone(),
71 int32.clone(),
72 int32.clone(),
73 float64,
74 ]),
75 ];
76
77 Self {
78 signature: Signature::one_of(variants, Volatility::Immutable),
79 }
80 }
81}
82
83impl ScalarUDFImpl for SparkMakeDtInterval {
84 fn name(&self) -> &str {
85 "make_dt_interval"
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99 internal_err!("return_field_from_args should be used instead")
100 }
101
102 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
103 let has_non_finite_secs = args
104 .scalar_arguments
105 .get(3)
106 .and_then(|arg| {
107 arg.map(|scalar| match scalar {
108 ScalarValue::Float64(Some(v)) => !v.is_finite(),
109 ScalarValue::Float32(Some(v)) => !v.is_finite(),
110 _ => false,
111 })
112 })
113 .unwrap_or(false);
114 let nullable =
115 has_non_finite_secs || args.arg_fields.iter().any(|f| f.is_nullable());
116 Ok(Arc::new(Field::new(
117 self.name(),
118 DataType::Duration(Microsecond),
119 nullable,
120 )))
121 }
122
123 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
124 if args.args.is_empty() {
125 return Ok(ColumnarValue::Scalar(ScalarValue::DurationMicrosecond(
126 Some(0),
127 )));
128 }
129 if args.args.len() > 4 {
130 return Err(DataFusionError::Execution(format!(
131 "make_dt_interval expects between 0 and 4 arguments, got {}",
132 args.args.len()
133 )));
134 }
135 make_scalar_function(make_dt_interval_kernel, vec![])(&args.args)
136 }
137}
138
139fn make_dt_interval_kernel(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
140 let n_rows = args[0].len();
141 let days = args[0]
142 .as_primitive_opt::<Int32Type>()
143 .ok_or_else(|| plan_datafusion_err!("make_dt_interval arg[0] must be Int32"))?;
144 let hours: Option<&PrimitiveArray<Int32Type>> = args
145 .get(1)
146 .map(|a| {
147 a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
148 plan_datafusion_err!("make_dt_interval arg[1] must be Int32")
149 })
150 })
151 .transpose()?;
152 let mins: Option<&PrimitiveArray<Int32Type>> = args
153 .get(2)
154 .map(|a| {
155 a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
156 plan_datafusion_err!("make_dt_interval arg[2] must be Int32")
157 })
158 })
159 .transpose()?;
160 let secs: Option<&PrimitiveArray<Float64Type>> = args
161 .get(3)
162 .map(|a| {
163 a.as_primitive_opt::<Float64Type>().ok_or_else(|| {
164 plan_datafusion_err!("make_dt_interval arg[3] must be Float64")
165 })
166 })
167 .transpose()?;
168 let mut builder = DurationMicrosecondBuilder::with_capacity(n_rows);
169
170 for i in 0..n_rows {
171 let any_null_present = days.is_null(i)
173 || hours.as_ref().is_some_and(|a| a.is_null(i))
174 || mins.as_ref().is_some_and(|a| a.is_null(i))
175 || secs
176 .as_ref()
177 .is_some_and(|a| a.is_null(i) || !a.value(i).is_finite());
178
179 if any_null_present {
180 builder.append_null();
181 continue;
182 }
183
184 let d = days.value(i);
186 let h = hours.as_ref().map_or(0, |a| a.value(i));
187 let mi = mins.as_ref().map_or(0, |a| a.value(i));
188 let s = secs.as_ref().map_or(0.0, |a| a.value(i));
189
190 match make_interval_dt_nano(d, h, mi, s) {
191 Some(v) => builder.append_value(v),
192 None => {
193 builder.append_null();
194 continue;
195 }
196 }
197 }
198
199 Ok(Arc::new(builder.finish()))
200}
201fn make_interval_dt_nano(day: i32, hour: i32, min: i32, sec: f64) -> Option<i64> {
202 const HOURS_PER_DAY: i32 = 24;
203 const MINS_PER_HOUR: i32 = 60;
204 const SECS_PER_MINUTE: i64 = 60;
205 const MICROS_PER_SEC: i64 = 1_000_000;
206
207 let total_hours: i32 = day
208 .checked_mul(HOURS_PER_DAY)
209 .and_then(|v| v.checked_add(hour))?;
210
211 let total_mins: i32 = total_hours
212 .checked_mul(MINS_PER_HOUR)
213 .and_then(|v| v.checked_add(min))?;
214
215 let mut sec_whole: i64 = sec.trunc() as i64;
216 let sec_frac: f64 = sec - (sec_whole as f64);
217 let mut frac_us: i64 = (sec_frac * (MICROS_PER_SEC as f64)).round() as i64;
218
219 if frac_us.abs() >= MICROS_PER_SEC {
220 if frac_us > 0 {
221 frac_us -= MICROS_PER_SEC;
222 sec_whole = sec_whole.checked_add(1)?;
223 } else {
224 frac_us += MICROS_PER_SEC;
225 sec_whole = sec_whole.checked_sub(1)?;
226 }
227 }
228
229 let total_secs: i64 = (total_mins as i64)
230 .checked_mul(SECS_PER_MINUTE)
231 .and_then(|v| v.checked_add(sec_whole))?;
232
233 let total_us = total_secs
234 .checked_mul(MICROS_PER_SEC)
235 .and_then(|v| v.checked_add(frac_us))?;
236
237 Some(total_us)
238}
239
240#[cfg(test)]
241mod tests {
242
243 use arrow::array::{DurationMicrosecondArray, Float64Array, Int32Array};
244 use arrow::datatypes::DataType::Duration;
245 use arrow::datatypes::TimeUnit::Microsecond;
246 use datafusion_common::internal_datafusion_err;
247
248 use super::*;
249
250 fn run_make_dt_interval(arrs: Vec<ArrayRef>) -> Result<ArrayRef> {
251 make_dt_interval_kernel(&arrs)
252 }
253
254 #[test]
255 fn nulls_propagate_per_row() -> Result<()> {
256 let days = Arc::new(Int32Array::from(vec![
257 None,
258 Some(2),
259 Some(3),
260 Some(4),
261 Some(5),
262 Some(6),
263 Some(7),
264 ])) as ArrayRef;
265
266 let hours = Arc::new(Int32Array::from(vec![
267 Some(1),
268 None,
269 Some(3),
270 Some(4),
271 Some(5),
272 Some(6),
273 Some(7),
274 ])) as ArrayRef;
275
276 let mins = Arc::new(Int32Array::from(vec![
277 Some(1),
278 Some(2),
279 None,
280 Some(4),
281 Some(5),
282 Some(6),
283 Some(7),
284 ])) as ArrayRef;
285
286 let secs = Arc::new(Float64Array::from(vec![
287 Some(1.0),
288 Some(2.0),
289 Some(3.0),
290 None,
291 Some(f64::NAN),
292 Some(f64::INFINITY),
293 Some(f64::NEG_INFINITY),
294 ])) as ArrayRef;
295
296 let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
297 let out = out
298 .as_any()
299 .downcast_ref::<DurationMicrosecondArray>()
300 .ok_or_else(|| {
301 internal_datafusion_err!("expected DurationMicrosecondArray")
302 })?;
303
304 for i in 0..out.len() {
305 assert!(out.is_null(i), "row {i} should be NULL");
306 }
307 Ok(())
308 }
309
310 #[test]
311 fn return_field_respects_nullability() -> Result<()> {
312 let udf = SparkMakeDtInterval::new();
313
314 let arg_fields = vec![
316 Arc::new(Field::new("days", DataType::Int32, true)),
317 Arc::new(Field::new("hours", DataType::Int32, true)),
318 Arc::new(Field::new("mins", DataType::Int32, true)),
319 Arc::new(Field::new("secs", DataType::Float64, true)),
320 ];
321
322 let out = udf.return_field_from_args(ReturnFieldArgs {
323 arg_fields: &arg_fields,
324 scalar_arguments: &[None, None, None, None],
325 })?;
326 assert!(out.is_nullable());
327 assert_eq!(out.data_type(), &Duration(Microsecond));
328
329 let non_nullable_arg_fields = vec![
331 Arc::new(Field::new("days", DataType::Int32, false)),
332 Arc::new(Field::new("hours", DataType::Int32, false)),
333 Arc::new(Field::new("mins", DataType::Int32, false)),
334 Arc::new(Field::new("secs", DataType::Float64, false)),
335 ];
336
337 let out = udf.return_field_from_args(ReturnFieldArgs {
338 arg_fields: &non_nullable_arg_fields,
339 scalar_arguments: &[None, None, None, None],
340 })?;
341 assert!(!out.is_nullable());
342
343 let scalar_values =
345 [None, None, None, Some(ScalarValue::Float64(Some(f64::NAN)))];
346 let scalar_refs = scalar_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
347 let out = udf.return_field_from_args(ReturnFieldArgs {
348 arg_fields: &non_nullable_arg_fields,
349 scalar_arguments: &scalar_refs,
350 })?;
351 assert!(out.is_nullable());
352
353 let out = udf.return_field_from_args(ReturnFieldArgs {
355 arg_fields: &[],
356 scalar_arguments: &[],
357 })?;
358 assert!(!out.is_nullable());
359
360 Ok(())
361 }
362
363 #[test]
364 fn error_months_overflow_should_be_null() -> Result<()> {
365 let days = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef;
368
369 let hours = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
370
371 let mins = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
372
373 let secs = Arc::new(Float64Array::from(vec![Some(1.0)])) as ArrayRef;
374
375 let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
376 let out = out
377 .as_any()
378 .downcast_ref::<DurationMicrosecondArray>()
379 .ok_or_else(|| {
380 internal_datafusion_err!("expected DurationMicrosecondArray")
381 })?;
382
383 for i in 0..out.len() {
384 assert!(out.is_null(i), "row {i} should be NULL");
385 }
386
387 Ok(())
388 }
389
390 fn invoke_make_dt_interval_with_args(
391 args: Vec<ColumnarValue>,
392 number_rows: usize,
393 ) -> Result<ColumnarValue, DataFusionError> {
394 let arg_fields = args
395 .iter()
396 .map(|arg| Field::new("a", arg.data_type(), true).into())
397 .collect::<Vec<_>>();
398 let args = ScalarFunctionArgs {
399 args,
400 arg_fields,
401 number_rows,
402 return_field: Field::new("f", Duration(Microsecond), true).into(),
403 config_options: Arc::new(Default::default()),
404 };
405 SparkMakeDtInterval::new().invoke_with_args(args)
406 }
407
408 #[test]
409 fn zero_args_returns_zero_duration() -> Result<()> {
410 let number_rows: usize = 3;
411
412 let res: ColumnarValue = invoke_make_dt_interval_with_args(vec![], number_rows)?;
413 let arr = res.into_array(number_rows)?;
414 let arr = arr
415 .as_any()
416 .downcast_ref::<DurationMicrosecondArray>()
417 .ok_or_else(|| {
418 internal_datafusion_err!("expected DurationMicrosecondArray")
419 })?;
420
421 assert_eq!(arr.len(), number_rows);
422 for i in 0..number_rows {
423 assert!(!arr.is_null(i));
424 assert_eq!(arr.value(i), 0_i64);
425 }
426 Ok(())
427 }
428
429 #[test]
430 fn one_day_minus_24_hours_equals_zero() -> Result<()> {
431 let arr_days = Arc::new(Int32Array::from(vec![Some(1), Some(-1)])) as ArrayRef;
432 let arr_hours = Arc::new(Int32Array::from(vec![Some(-24), Some(24)])) as ArrayRef;
433 let arr_mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
434 let arr_secs =
435 Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
436
437 let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
438 let out = out
439 .as_any()
440 .downcast_ref::<DurationMicrosecondArray>()
441 .ok_or_else(|| {
442 internal_datafusion_err!("expected DurationMicrosecondArray")
443 })?;
444
445 assert_eq!(out.len(), 2);
446 assert_eq!(out.null_count(), 0);
447 assert_eq!(out.value(0), 0_i64);
448 assert_eq!(out.value(1), 0_i64);
449 Ok(())
450 }
451
452 #[test]
453 fn one_hour_minus_60_mins_equals_zero() -> Result<()> {
454 let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
455 let arr_hours = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
456 let arr_mins = Arc::new(Int32Array::from(vec![Some(60), Some(-60)])) as ArrayRef;
457 let arr_secs =
458 Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
459
460 let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
461 let out = out
462 .as_any()
463 .downcast_ref::<DurationMicrosecondArray>()
464 .ok_or_else(|| {
465 internal_datafusion_err!("expected DurationMicrosecondArray")
466 })?;
467
468 assert_eq!(out.len(), 2);
469 assert_eq!(out.null_count(), 0);
470 assert_eq!(out.value(0), 0_i64);
471 assert_eq!(out.value(1), 0_i64);
472 Ok(())
473 }
474
475 #[test]
476 fn one_mins_minus_60_secs_equals_zero() -> Result<()> {
477 let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
478 let arr_hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
479 let arr_mins = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
480 let arr_secs =
481 Arc::new(Float64Array::from(vec![Some(60.0), Some(-60.0)])) as ArrayRef;
482
483 let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
484 let out = out
485 .as_any()
486 .downcast_ref::<DurationMicrosecondArray>()
487 .ok_or_else(|| {
488 internal_datafusion_err!("expected DurationMicrosecondArray")
489 })?;
490
491 assert_eq!(out.len(), 2);
492 assert_eq!(out.null_count(), 0);
493 assert_eq!(out.value(0), 0_i64);
494 assert_eq!(out.value(1), 0_i64);
495 Ok(())
496 }
497
498 #[test]
499 fn frac_carries_up_to_next_second_positive() -> Result<()> {
500 let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
502 let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
503 let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
504 let secs = Arc::new(Float64Array::from(vec![
505 Some(0.999_999_5),
506 Some(0.999_999_4),
507 ])) as ArrayRef;
508
509 let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
510 let out = out
511 .as_any()
512 .downcast_ref::<DurationMicrosecondArray>()
513 .ok_or_else(|| {
514 internal_datafusion_err!("expected DurationMicrosecondArray")
515 })?;
516
517 assert_eq!(out.len(), 2);
518 assert_eq!(out.value(0), 1_000_000);
519 assert_eq!(out.value(1), 999_999);
520 Ok(())
521 }
522
523 #[test]
524 fn frac_carries_down_to_prev_second_negative() -> Result<()> {
525 let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
527 let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
528 let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
529 let secs = Arc::new(Float64Array::from(vec![
530 Some(-0.999_999_5),
531 Some(-0.999_999_4),
532 ])) as ArrayRef;
533
534 let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
535 let out = out
536 .as_any()
537 .downcast_ref::<DurationMicrosecondArray>()
538 .ok_or_else(|| {
539 internal_datafusion_err!("expected DurationMicrosecondArray")
540 })?;
541
542 assert_eq!(out.len(), 2);
543 assert_eq!(out.value(0), -1_000_000);
544 assert_eq!(out.value(1), -999_999);
545 Ok(())
546 }
547
548 #[test]
549 fn no_more_than_4_params() -> Result<()> {
550 let udf = SparkMakeDtInterval::new();
551
552 let args = vec![
554 ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
555 ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
556 ColumnarValue::Scalar(ScalarValue::Int32(Some(3))),
557 ColumnarValue::Scalar(ScalarValue::Float64(Some(4.0))),
558 ColumnarValue::Scalar(ScalarValue::Int32(Some(5))),
559 ];
560
561 let arg_fields = args
562 .iter()
563 .map(|arg| Field::new("a", arg.data_type(), true).into())
564 .collect::<Vec<_>>();
565
566 let func_args = ScalarFunctionArgs {
567 args,
568 arg_fields,
569 number_rows: 1,
570 return_field: Field::new("f", Duration(Microsecond), true).into(),
571 config_options: Arc::new(Default::default()),
572 };
573
574 let res = udf.invoke_with_args(func_args);
575
576 assert!(
577 matches!(res, Err(DataFusionError::Execution(_))),
578 "make_dt_interval should return execution error for more than 4 arguments"
579 );
580
581 Ok(())
582 }
583}