Skip to main content

datafusion_spark/function/datetime/
make_dt_interval.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, AsArray, DurationMicrosecondBuilder, PrimitiveArray,
23};
24use arrow::datatypes::TimeUnit::Microsecond;
25use arrow::datatypes::{DataType, Field, FieldRef, Float64Type, Int32Type};
26use datafusion_common::types::{NativeType, logical_float64, logical_int32};
27use datafusion_common::{
28    DataFusionError, Result, ScalarValue, internal_err, plan_datafusion_err,
29};
30use datafusion_expr::{
31    Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
32    Signature, TypeSignature, TypeSignatureClass, Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35
36#[derive(Debug, PartialEq, Eq, Hash)]
37pub struct SparkMakeDtInterval {
38    signature: Signature,
39}
40
41impl Default for SparkMakeDtInterval {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl SparkMakeDtInterval {
48    pub fn new() -> Self {
49        let int32 = Coercion::new_implicit(
50            TypeSignatureClass::Native(logical_int32()),
51            vec![TypeSignatureClass::Integer],
52            NativeType::Int32,
53        );
54
55        let float64 = Coercion::new_implicit(
56            TypeSignatureClass::Native(logical_float64()),
57            vec![TypeSignatureClass::Numeric],
58            NativeType::Float64,
59        );
60
61        let variants = vec![
62            TypeSignature::Nullary,
63            // (days)
64            TypeSignature::Coercible(vec![int32.clone()]),
65            // (days, hours)
66            TypeSignature::Coercible(vec![int32.clone(), int32.clone()]),
67            // (days, hours, minutes)
68            TypeSignature::Coercible(vec![int32.clone(), int32.clone(), int32.clone()]),
69            // (days, hours, minutes, seconds)
70            TypeSignature::Coercible(vec![
71                int32.clone(),
72                int32.clone(),
73                int32.clone(),
74                float64,
75            ]),
76        ];
77
78        Self {
79            signature: Signature::one_of(variants, Volatility::Immutable),
80        }
81    }
82}
83
84impl ScalarUDFImpl for SparkMakeDtInterval {
85    fn as_any(&self) -> &dyn Any {
86        self
87    }
88
89    fn name(&self) -> &str {
90        "make_dt_interval"
91    }
92
93    fn signature(&self) -> &Signature {
94        &self.signature
95    }
96
97    /// Note the return type is `DataType::Duration(TimeUnit::Microsecond)` and not `DataType::Interval(DayTime)` as you might expect.
98    /// This is because `DataType::Interval(DayTime)` has precision only to the millisecond, whilst Spark's `DayTimeIntervalType` has
99    /// precision to the microsecond. We use `DataType::Duration(TimeUnit::Microsecond)` in order to not lose any precision. See the
100    /// [Sail compatibility doc] for reference.
101    ///
102    /// [Sail compatibility doc]: https://github.com/lakehq/sail/blob/dc5368daa24d40a7758a299e1ba8fc985cb29108/docs/guide/dataframe/data-types/compatibility.md?plain=1#L260
103    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
104        internal_err!("return_field_from_args should be used instead")
105    }
106
107    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
108        let has_non_finite_secs = args
109            .scalar_arguments
110            .get(3)
111            .and_then(|arg| {
112                arg.map(|scalar| match scalar {
113                    ScalarValue::Float64(Some(v)) => !v.is_finite(),
114                    ScalarValue::Float32(Some(v)) => !v.is_finite(),
115                    _ => false,
116                })
117            })
118            .unwrap_or(false);
119        let nullable =
120            has_non_finite_secs || args.arg_fields.iter().any(|f| f.is_nullable());
121        Ok(Arc::new(Field::new(
122            self.name(),
123            DataType::Duration(Microsecond),
124            nullable,
125        )))
126    }
127
128    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
129        if args.args.is_empty() {
130            return Ok(ColumnarValue::Scalar(ScalarValue::DurationMicrosecond(
131                Some(0),
132            )));
133        }
134        if args.args.len() > 4 {
135            return Err(DataFusionError::Execution(format!(
136                "make_dt_interval expects between 0 and 4 arguments, got {}",
137                args.args.len()
138            )));
139        }
140        make_scalar_function(make_dt_interval_kernel, vec![])(&args.args)
141    }
142}
143
144fn make_dt_interval_kernel(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
145    let n_rows = args[0].len();
146    let days = args[0]
147        .as_primitive_opt::<Int32Type>()
148        .ok_or_else(|| plan_datafusion_err!("make_dt_interval arg[0] must be Int32"))?;
149    let hours: Option<&PrimitiveArray<Int32Type>> = args
150        .get(1)
151        .map(|a| {
152            a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
153                plan_datafusion_err!("make_dt_interval arg[1] must be Int32")
154            })
155        })
156        .transpose()?;
157    let mins: Option<&PrimitiveArray<Int32Type>> = args
158        .get(2)
159        .map(|a| {
160            a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
161                plan_datafusion_err!("make_dt_interval arg[2] must be Int32")
162            })
163        })
164        .transpose()?;
165    let secs: Option<&PrimitiveArray<Float64Type>> = args
166        .get(3)
167        .map(|a| {
168            a.as_primitive_opt::<Float64Type>().ok_or_else(|| {
169                plan_datafusion_err!("make_dt_interval arg[3] must be Float64")
170            })
171        })
172        .transpose()?;
173    let mut builder = DurationMicrosecondBuilder::with_capacity(n_rows);
174
175    for i in 0..n_rows {
176        // if one column is NULL → result NULL
177        let any_null_present = days.is_null(i)
178            || hours.as_ref().is_some_and(|a| a.is_null(i))
179            || mins.as_ref().is_some_and(|a| a.is_null(i))
180            || secs
181                .as_ref()
182                .is_some_and(|a| a.is_null(i) || !a.value(i).is_finite());
183
184        if any_null_present {
185            builder.append_null();
186            continue;
187        }
188
189        // default values 0 or 0.0
190        let d = days.value(i);
191        let h = hours.as_ref().map_or(0, |a| a.value(i));
192        let mi = mins.as_ref().map_or(0, |a| a.value(i));
193        let s = secs.as_ref().map_or(0.0, |a| a.value(i));
194
195        match make_interval_dt_nano(d, h, mi, s) {
196            Some(v) => builder.append_value(v),
197            None => {
198                builder.append_null();
199                continue;
200            }
201        }
202    }
203
204    Ok(Arc::new(builder.finish()))
205}
206fn make_interval_dt_nano(day: i32, hour: i32, min: i32, sec: f64) -> Option<i64> {
207    const HOURS_PER_DAY: i32 = 24;
208    const MINS_PER_HOUR: i32 = 60;
209    const SECS_PER_MINUTE: i64 = 60;
210    const MICROS_PER_SEC: i64 = 1_000_000;
211
212    let total_hours: i32 = day
213        .checked_mul(HOURS_PER_DAY)
214        .and_then(|v| v.checked_add(hour))?;
215
216    let total_mins: i32 = total_hours
217        .checked_mul(MINS_PER_HOUR)
218        .and_then(|v| v.checked_add(min))?;
219
220    let mut sec_whole: i64 = sec.trunc() as i64;
221    let sec_frac: f64 = sec - (sec_whole as f64);
222    let mut frac_us: i64 = (sec_frac * (MICROS_PER_SEC as f64)).round() as i64;
223
224    if frac_us.abs() >= MICROS_PER_SEC {
225        if frac_us > 0 {
226            frac_us -= MICROS_PER_SEC;
227            sec_whole = sec_whole.checked_add(1)?;
228        } else {
229            frac_us += MICROS_PER_SEC;
230            sec_whole = sec_whole.checked_sub(1)?;
231        }
232    }
233
234    let total_secs: i64 = (total_mins as i64)
235        .checked_mul(SECS_PER_MINUTE)
236        .and_then(|v| v.checked_add(sec_whole))?;
237
238    let total_us = total_secs
239        .checked_mul(MICROS_PER_SEC)
240        .and_then(|v| v.checked_add(frac_us))?;
241
242    Some(total_us)
243}
244
245#[cfg(test)]
246mod tests {
247    use std::sync::Arc;
248
249    use arrow::array::{DurationMicrosecondArray, Float64Array, Int32Array};
250    use arrow::datatypes::DataType::Duration;
251    use arrow::datatypes::{DataType, Field, TimeUnit::Microsecond};
252    use datafusion_common::{DataFusionError, Result, internal_datafusion_err};
253    use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs};
254
255    use super::*;
256
257    fn run_make_dt_interval(arrs: Vec<ArrayRef>) -> Result<ArrayRef> {
258        make_dt_interval_kernel(&arrs)
259    }
260
261    #[test]
262    fn nulls_propagate_per_row() -> Result<()> {
263        let days = Arc::new(Int32Array::from(vec![
264            None,
265            Some(2),
266            Some(3),
267            Some(4),
268            Some(5),
269            Some(6),
270            Some(7),
271        ])) as ArrayRef;
272
273        let hours = Arc::new(Int32Array::from(vec![
274            Some(1),
275            None,
276            Some(3),
277            Some(4),
278            Some(5),
279            Some(6),
280            Some(7),
281        ])) as ArrayRef;
282
283        let mins = Arc::new(Int32Array::from(vec![
284            Some(1),
285            Some(2),
286            None,
287            Some(4),
288            Some(5),
289            Some(6),
290            Some(7),
291        ])) as ArrayRef;
292
293        let secs = Arc::new(Float64Array::from(vec![
294            Some(1.0),
295            Some(2.0),
296            Some(3.0),
297            None,
298            Some(f64::NAN),
299            Some(f64::INFINITY),
300            Some(f64::NEG_INFINITY),
301        ])) as ArrayRef;
302
303        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
304        let out = out
305            .as_any()
306            .downcast_ref::<DurationMicrosecondArray>()
307            .ok_or_else(|| {
308                internal_datafusion_err!("expected DurationMicrosecondArray")
309            })?;
310
311        for i in 0..out.len() {
312            assert!(out.is_null(i), "row {i} should be NULL");
313        }
314        Ok(())
315    }
316
317    #[test]
318    fn return_field_respects_nullability() -> Result<()> {
319        let udf = SparkMakeDtInterval::new();
320
321        // All nullable inputs -> nullable output
322        let arg_fields = vec![
323            Arc::new(Field::new("days", DataType::Int32, true)),
324            Arc::new(Field::new("hours", DataType::Int32, true)),
325            Arc::new(Field::new("mins", DataType::Int32, true)),
326            Arc::new(Field::new("secs", DataType::Float64, true)),
327        ];
328
329        let out = udf.return_field_from_args(ReturnFieldArgs {
330            arg_fields: &arg_fields,
331            scalar_arguments: &[None, None, None, None],
332        })?;
333        assert!(out.is_nullable());
334        assert_eq!(out.data_type(), &Duration(Microsecond));
335
336        // Non-nullable inputs -> non-nullable output
337        let non_nullable_arg_fields = vec![
338            Arc::new(Field::new("days", DataType::Int32, false)),
339            Arc::new(Field::new("hours", DataType::Int32, false)),
340            Arc::new(Field::new("mins", DataType::Int32, false)),
341            Arc::new(Field::new("secs", DataType::Float64, false)),
342        ];
343
344        let out = udf.return_field_from_args(ReturnFieldArgs {
345            arg_fields: &non_nullable_arg_fields,
346            scalar_arguments: &[None, None, None, None],
347        })?;
348        assert!(!out.is_nullable());
349
350        // Non-finite secs scalar should force nullable even if fields are non-nullable
351        let scalar_values =
352            [None, None, None, Some(ScalarValue::Float64(Some(f64::NAN)))];
353        let scalar_refs = scalar_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
354        let out = udf.return_field_from_args(ReturnFieldArgs {
355            arg_fields: &non_nullable_arg_fields,
356            scalar_arguments: &scalar_refs,
357        })?;
358        assert!(out.is_nullable());
359
360        // Zero-arg call (defaults) should also be non-nullable
361        let out = udf.return_field_from_args(ReturnFieldArgs {
362            arg_fields: &[],
363            scalar_arguments: &[],
364        })?;
365        assert!(!out.is_nullable());
366
367        Ok(())
368    }
369
370    #[test]
371    fn error_months_overflow_should_be_null() -> Result<()> {
372        // months = year*12 + month → NULL
373
374        let days = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef;
375
376        let hours = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
377
378        let mins = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
379
380        let secs = Arc::new(Float64Array::from(vec![Some(1.0)])) as ArrayRef;
381
382        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
383        let out = out
384            .as_any()
385            .downcast_ref::<DurationMicrosecondArray>()
386            .ok_or_else(|| {
387                internal_datafusion_err!("expected DurationMicrosecondArray")
388            })?;
389
390        for i in 0..out.len() {
391            assert!(out.is_null(i), "row {i} should be NULL");
392        }
393
394        Ok(())
395    }
396
397    fn invoke_make_dt_interval_with_args(
398        args: Vec<ColumnarValue>,
399        number_rows: usize,
400    ) -> Result<ColumnarValue, DataFusionError> {
401        let arg_fields = args
402            .iter()
403            .map(|arg| Field::new("a", arg.data_type(), true).into())
404            .collect::<Vec<_>>();
405        let args = ScalarFunctionArgs {
406            args,
407            arg_fields,
408            number_rows,
409            return_field: Field::new("f", Duration(Microsecond), true).into(),
410            config_options: Arc::new(Default::default()),
411        };
412        SparkMakeDtInterval::new().invoke_with_args(args)
413    }
414
415    #[test]
416    fn zero_args_returns_zero_duration() -> Result<()> {
417        let number_rows: usize = 3;
418
419        let res: ColumnarValue = invoke_make_dt_interval_with_args(vec![], number_rows)?;
420        let arr = res.into_array(number_rows)?;
421        let arr = arr
422            .as_any()
423            .downcast_ref::<DurationMicrosecondArray>()
424            .ok_or_else(|| {
425                internal_datafusion_err!("expected DurationMicrosecondArray")
426            })?;
427
428        assert_eq!(arr.len(), number_rows);
429        for i in 0..number_rows {
430            assert!(!arr.is_null(i));
431            assert_eq!(arr.value(i), 0_i64);
432        }
433        Ok(())
434    }
435
436    #[test]
437    fn one_day_minus_24_hours_equals_zero() -> Result<()> {
438        let arr_days = Arc::new(Int32Array::from(vec![Some(1), Some(-1)])) as ArrayRef;
439        let arr_hours = Arc::new(Int32Array::from(vec![Some(-24), Some(24)])) as ArrayRef;
440        let arr_mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
441        let arr_secs =
442            Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
443
444        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
445        let out = out
446            .as_any()
447            .downcast_ref::<DurationMicrosecondArray>()
448            .ok_or_else(|| {
449                internal_datafusion_err!("expected DurationMicrosecondArray")
450            })?;
451
452        assert_eq!(out.len(), 2);
453        assert_eq!(out.null_count(), 0);
454        assert_eq!(out.value(0), 0_i64);
455        assert_eq!(out.value(1), 0_i64);
456        Ok(())
457    }
458
459    #[test]
460    fn one_hour_minus_60_mins_equals_zero() -> Result<()> {
461        let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
462        let arr_hours = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
463        let arr_mins = Arc::new(Int32Array::from(vec![Some(60), Some(-60)])) as ArrayRef;
464        let arr_secs =
465            Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
466
467        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
468        let out = out
469            .as_any()
470            .downcast_ref::<DurationMicrosecondArray>()
471            .ok_or_else(|| {
472                internal_datafusion_err!("expected DurationMicrosecondArray")
473            })?;
474
475        assert_eq!(out.len(), 2);
476        assert_eq!(out.null_count(), 0);
477        assert_eq!(out.value(0), 0_i64);
478        assert_eq!(out.value(1), 0_i64);
479        Ok(())
480    }
481
482    #[test]
483    fn one_mins_minus_60_secs_equals_zero() -> Result<()> {
484        let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
485        let arr_hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
486        let arr_mins = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
487        let arr_secs =
488            Arc::new(Float64Array::from(vec![Some(60.0), Some(-60.0)])) as ArrayRef;
489
490        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
491        let out = out
492            .as_any()
493            .downcast_ref::<DurationMicrosecondArray>()
494            .ok_or_else(|| {
495                internal_datafusion_err!("expected DurationMicrosecondArray")
496            })?;
497
498        assert_eq!(out.len(), 2);
499        assert_eq!(out.null_count(), 0);
500        assert_eq!(out.value(0), 0_i64);
501        assert_eq!(out.value(1), 0_i64);
502        Ok(())
503    }
504
505    #[test]
506    fn frac_carries_up_to_next_second_positive() -> Result<()> {
507        // 0.9999995s → 1_000_000 µs (carry a +1s)
508        let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
509        let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
510        let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
511        let secs = Arc::new(Float64Array::from(vec![
512            Some(0.999_999_5),
513            Some(0.999_999_4),
514        ])) as ArrayRef;
515
516        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
517        let out = out
518            .as_any()
519            .downcast_ref::<DurationMicrosecondArray>()
520            .ok_or_else(|| {
521                internal_datafusion_err!("expected DurationMicrosecondArray")
522            })?;
523
524        assert_eq!(out.len(), 2);
525        assert_eq!(out.value(0), 1_000_000);
526        assert_eq!(out.value(1), 999_999);
527        Ok(())
528    }
529
530    #[test]
531    fn frac_carries_down_to_prev_second_negative() -> Result<()> {
532        // -0.9999995s → -1_000_000 µs (carry a −1s)
533        let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
534        let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
535        let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
536        let secs = Arc::new(Float64Array::from(vec![
537            Some(-0.999_999_5),
538            Some(-0.999_999_4),
539        ])) as ArrayRef;
540
541        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
542        let out = out
543            .as_any()
544            .downcast_ref::<DurationMicrosecondArray>()
545            .ok_or_else(|| {
546                internal_datafusion_err!("expected DurationMicrosecondArray")
547            })?;
548
549        assert_eq!(out.len(), 2);
550        assert_eq!(out.value(0), -1_000_000);
551        assert_eq!(out.value(1), -999_999);
552        Ok(())
553    }
554
555    #[test]
556    fn no_more_than_4_params() -> Result<()> {
557        let udf = SparkMakeDtInterval::new();
558
559        // Create args with 5 parameters (exceeds the limit of 4)
560        let args = vec![
561            ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
562            ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
563            ColumnarValue::Scalar(ScalarValue::Int32(Some(3))),
564            ColumnarValue::Scalar(ScalarValue::Float64(Some(4.0))),
565            ColumnarValue::Scalar(ScalarValue::Int32(Some(5))),
566        ];
567
568        let arg_fields = args
569            .iter()
570            .map(|arg| Field::new("a", arg.data_type(), true).into())
571            .collect::<Vec<_>>();
572
573        let func_args = ScalarFunctionArgs {
574            args,
575            arg_fields,
576            number_rows: 1,
577            return_field: Field::new("f", Duration(Microsecond), true).into(),
578            config_options: Arc::new(Default::default()),
579        };
580
581        let res = udf.invoke_with_args(func_args);
582
583        assert!(
584            matches!(res, Err(DataFusionError::Execution(_))),
585            "make_dt_interval should return execution error for more than 4 arguments"
586        );
587
588        Ok(())
589    }
590}