Skip to main content

datafusion_spark/function/datetime/
make_dt_interval.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{
21    Array, ArrayRef, AsArray, DurationMicrosecondBuilder, PrimitiveArray,
22};
23use arrow::datatypes::TimeUnit::Microsecond;
24use arrow::datatypes::{DataType, Field, FieldRef, Float64Type, Int32Type};
25use datafusion_common::types::{NativeType, logical_float64, logical_int32};
26use datafusion_common::{
27    DataFusionError, Result, ScalarValue, internal_err, plan_datafusion_err,
28};
29use datafusion_expr::{
30    Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
31    Signature, TypeSignature, TypeSignatureClass, Volatility,
32};
33use datafusion_functions::utils::make_scalar_function;
34
35#[derive(Debug, PartialEq, Eq, Hash)]
36pub struct SparkMakeDtInterval {
37    signature: Signature,
38}
39
40impl Default for SparkMakeDtInterval {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl SparkMakeDtInterval {
47    pub fn new() -> Self {
48        let int32 = Coercion::new_implicit(
49            TypeSignatureClass::Native(logical_int32()),
50            vec![TypeSignatureClass::Integer],
51            NativeType::Int32,
52        );
53
54        let float64 = Coercion::new_implicit(
55            TypeSignatureClass::Native(logical_float64()),
56            vec![TypeSignatureClass::Numeric],
57            NativeType::Float64,
58        );
59
60        let variants = vec![
61            TypeSignature::Nullary,
62            // (days)
63            TypeSignature::Coercible(vec![int32.clone()]),
64            // (days, hours)
65            TypeSignature::Coercible(vec![int32.clone(), int32.clone()]),
66            // (days, hours, minutes)
67            TypeSignature::Coercible(vec![int32.clone(), int32.clone(), int32.clone()]),
68            // (days, hours, minutes, seconds)
69            TypeSignature::Coercible(vec![
70                int32.clone(),
71                int32.clone(),
72                int32.clone(),
73                float64,
74            ]),
75        ];
76
77        Self {
78            signature: Signature::one_of(variants, Volatility::Immutable),
79        }
80    }
81}
82
83impl ScalarUDFImpl for SparkMakeDtInterval {
84    fn name(&self) -> &str {
85        "make_dt_interval"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    /// Note the return type is `DataType::Duration(TimeUnit::Microsecond)` and not `DataType::Interval(DayTime)` as you might expect.
93    /// This is because `DataType::Interval(DayTime)` has precision only to the millisecond, whilst Spark's `DayTimeIntervalType` has
94    /// precision to the microsecond. We use `DataType::Duration(TimeUnit::Microsecond)` in order to not lose any precision. See the
95    /// [Sail compatibility doc] for reference.
96    ///
97    /// [Sail compatibility doc]: https://github.com/lakehq/sail/blob/dc5368daa24d40a7758a299e1ba8fc985cb29108/docs/guide/dataframe/data-types/compatibility.md?plain=1#L260
98    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99        internal_err!("return_field_from_args should be used instead")
100    }
101
102    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
103        let has_non_finite_secs = args
104            .scalar_arguments
105            .get(3)
106            .and_then(|arg| {
107                arg.map(|scalar| match scalar {
108                    ScalarValue::Float64(Some(v)) => !v.is_finite(),
109                    ScalarValue::Float32(Some(v)) => !v.is_finite(),
110                    _ => false,
111                })
112            })
113            .unwrap_or(false);
114        let nullable =
115            has_non_finite_secs || args.arg_fields.iter().any(|f| f.is_nullable());
116        Ok(Arc::new(Field::new(
117            self.name(),
118            DataType::Duration(Microsecond),
119            nullable,
120        )))
121    }
122
123    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
124        if args.args.is_empty() {
125            return Ok(ColumnarValue::Scalar(ScalarValue::DurationMicrosecond(
126                Some(0),
127            )));
128        }
129        if args.args.len() > 4 {
130            return Err(DataFusionError::Execution(format!(
131                "make_dt_interval expects between 0 and 4 arguments, got {}",
132                args.args.len()
133            )));
134        }
135        make_scalar_function(make_dt_interval_kernel, vec![])(&args.args)
136    }
137}
138
139fn make_dt_interval_kernel(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
140    let n_rows = args[0].len();
141    let days = args[0]
142        .as_primitive_opt::<Int32Type>()
143        .ok_or_else(|| plan_datafusion_err!("make_dt_interval arg[0] must be Int32"))?;
144    let hours: Option<&PrimitiveArray<Int32Type>> = args
145        .get(1)
146        .map(|a| {
147            a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
148                plan_datafusion_err!("make_dt_interval arg[1] must be Int32")
149            })
150        })
151        .transpose()?;
152    let mins: Option<&PrimitiveArray<Int32Type>> = args
153        .get(2)
154        .map(|a| {
155            a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
156                plan_datafusion_err!("make_dt_interval arg[2] must be Int32")
157            })
158        })
159        .transpose()?;
160    let secs: Option<&PrimitiveArray<Float64Type>> = args
161        .get(3)
162        .map(|a| {
163            a.as_primitive_opt::<Float64Type>().ok_or_else(|| {
164                plan_datafusion_err!("make_dt_interval arg[3] must be Float64")
165            })
166        })
167        .transpose()?;
168    let mut builder = DurationMicrosecondBuilder::with_capacity(n_rows);
169
170    for i in 0..n_rows {
171        // if one column is NULL → result NULL
172        let any_null_present = days.is_null(i)
173            || hours.as_ref().is_some_and(|a| a.is_null(i))
174            || mins.as_ref().is_some_and(|a| a.is_null(i))
175            || secs
176                .as_ref()
177                .is_some_and(|a| a.is_null(i) || !a.value(i).is_finite());
178
179        if any_null_present {
180            builder.append_null();
181            continue;
182        }
183
184        // default values 0 or 0.0
185        let d = days.value(i);
186        let h = hours.as_ref().map_or(0, |a| a.value(i));
187        let mi = mins.as_ref().map_or(0, |a| a.value(i));
188        let s = secs.as_ref().map_or(0.0, |a| a.value(i));
189
190        match make_interval_dt_nano(d, h, mi, s) {
191            Some(v) => builder.append_value(v),
192            None => {
193                builder.append_null();
194                continue;
195            }
196        }
197    }
198
199    Ok(Arc::new(builder.finish()))
200}
201fn make_interval_dt_nano(day: i32, hour: i32, min: i32, sec: f64) -> Option<i64> {
202    const HOURS_PER_DAY: i32 = 24;
203    const MINS_PER_HOUR: i32 = 60;
204    const SECS_PER_MINUTE: i64 = 60;
205    const MICROS_PER_SEC: i64 = 1_000_000;
206
207    let total_hours: i32 = day
208        .checked_mul(HOURS_PER_DAY)
209        .and_then(|v| v.checked_add(hour))?;
210
211    let total_mins: i32 = total_hours
212        .checked_mul(MINS_PER_HOUR)
213        .and_then(|v| v.checked_add(min))?;
214
215    let mut sec_whole: i64 = sec.trunc() as i64;
216    let sec_frac: f64 = sec - (sec_whole as f64);
217    let mut frac_us: i64 = (sec_frac * (MICROS_PER_SEC as f64)).round() as i64;
218
219    if frac_us.abs() >= MICROS_PER_SEC {
220        if frac_us > 0 {
221            frac_us -= MICROS_PER_SEC;
222            sec_whole = sec_whole.checked_add(1)?;
223        } else {
224            frac_us += MICROS_PER_SEC;
225            sec_whole = sec_whole.checked_sub(1)?;
226        }
227    }
228
229    let total_secs: i64 = (total_mins as i64)
230        .checked_mul(SECS_PER_MINUTE)
231        .and_then(|v| v.checked_add(sec_whole))?;
232
233    let total_us = total_secs
234        .checked_mul(MICROS_PER_SEC)
235        .and_then(|v| v.checked_add(frac_us))?;
236
237    Some(total_us)
238}
239
240#[cfg(test)]
241mod tests {
242
243    use arrow::array::{DurationMicrosecondArray, Float64Array, Int32Array};
244    use arrow::datatypes::DataType::Duration;
245    use arrow::datatypes::TimeUnit::Microsecond;
246    use datafusion_common::internal_datafusion_err;
247
248    use super::*;
249
250    fn run_make_dt_interval(arrs: Vec<ArrayRef>) -> Result<ArrayRef> {
251        make_dt_interval_kernel(&arrs)
252    }
253
254    #[test]
255    fn nulls_propagate_per_row() -> Result<()> {
256        let days = Arc::new(Int32Array::from(vec![
257            None,
258            Some(2),
259            Some(3),
260            Some(4),
261            Some(5),
262            Some(6),
263            Some(7),
264        ])) as ArrayRef;
265
266        let hours = Arc::new(Int32Array::from(vec![
267            Some(1),
268            None,
269            Some(3),
270            Some(4),
271            Some(5),
272            Some(6),
273            Some(7),
274        ])) as ArrayRef;
275
276        let mins = Arc::new(Int32Array::from(vec![
277            Some(1),
278            Some(2),
279            None,
280            Some(4),
281            Some(5),
282            Some(6),
283            Some(7),
284        ])) as ArrayRef;
285
286        let secs = Arc::new(Float64Array::from(vec![
287            Some(1.0),
288            Some(2.0),
289            Some(3.0),
290            None,
291            Some(f64::NAN),
292            Some(f64::INFINITY),
293            Some(f64::NEG_INFINITY),
294        ])) as ArrayRef;
295
296        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
297        let out = out
298            .as_any()
299            .downcast_ref::<DurationMicrosecondArray>()
300            .ok_or_else(|| {
301                internal_datafusion_err!("expected DurationMicrosecondArray")
302            })?;
303
304        for i in 0..out.len() {
305            assert!(out.is_null(i), "row {i} should be NULL");
306        }
307        Ok(())
308    }
309
310    #[test]
311    fn return_field_respects_nullability() -> Result<()> {
312        let udf = SparkMakeDtInterval::new();
313
314        // All nullable inputs -> nullable output
315        let arg_fields = vec![
316            Arc::new(Field::new("days", DataType::Int32, true)),
317            Arc::new(Field::new("hours", DataType::Int32, true)),
318            Arc::new(Field::new("mins", DataType::Int32, true)),
319            Arc::new(Field::new("secs", DataType::Float64, true)),
320        ];
321
322        let out = udf.return_field_from_args(ReturnFieldArgs {
323            arg_fields: &arg_fields,
324            scalar_arguments: &[None, None, None, None],
325        })?;
326        assert!(out.is_nullable());
327        assert_eq!(out.data_type(), &Duration(Microsecond));
328
329        // Non-nullable inputs -> non-nullable output
330        let non_nullable_arg_fields = vec![
331            Arc::new(Field::new("days", DataType::Int32, false)),
332            Arc::new(Field::new("hours", DataType::Int32, false)),
333            Arc::new(Field::new("mins", DataType::Int32, false)),
334            Arc::new(Field::new("secs", DataType::Float64, false)),
335        ];
336
337        let out = udf.return_field_from_args(ReturnFieldArgs {
338            arg_fields: &non_nullable_arg_fields,
339            scalar_arguments: &[None, None, None, None],
340        })?;
341        assert!(!out.is_nullable());
342
343        // Non-finite secs scalar should force nullable even if fields are non-nullable
344        let scalar_values =
345            [None, None, None, Some(ScalarValue::Float64(Some(f64::NAN)))];
346        let scalar_refs = scalar_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
347        let out = udf.return_field_from_args(ReturnFieldArgs {
348            arg_fields: &non_nullable_arg_fields,
349            scalar_arguments: &scalar_refs,
350        })?;
351        assert!(out.is_nullable());
352
353        // Zero-arg call (defaults) should also be non-nullable
354        let out = udf.return_field_from_args(ReturnFieldArgs {
355            arg_fields: &[],
356            scalar_arguments: &[],
357        })?;
358        assert!(!out.is_nullable());
359
360        Ok(())
361    }
362
363    #[test]
364    fn error_months_overflow_should_be_null() -> Result<()> {
365        // months = year*12 + month → NULL
366
367        let days = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef;
368
369        let hours = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
370
371        let mins = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
372
373        let secs = Arc::new(Float64Array::from(vec![Some(1.0)])) as ArrayRef;
374
375        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
376        let out = out
377            .as_any()
378            .downcast_ref::<DurationMicrosecondArray>()
379            .ok_or_else(|| {
380                internal_datafusion_err!("expected DurationMicrosecondArray")
381            })?;
382
383        for i in 0..out.len() {
384            assert!(out.is_null(i), "row {i} should be NULL");
385        }
386
387        Ok(())
388    }
389
390    fn invoke_make_dt_interval_with_args(
391        args: Vec<ColumnarValue>,
392        number_rows: usize,
393    ) -> Result<ColumnarValue, DataFusionError> {
394        let arg_fields = args
395            .iter()
396            .map(|arg| Field::new("a", arg.data_type(), true).into())
397            .collect::<Vec<_>>();
398        let args = ScalarFunctionArgs {
399            args,
400            arg_fields,
401            number_rows,
402            return_field: Field::new("f", Duration(Microsecond), true).into(),
403            config_options: Arc::new(Default::default()),
404        };
405        SparkMakeDtInterval::new().invoke_with_args(args)
406    }
407
408    #[test]
409    fn zero_args_returns_zero_duration() -> Result<()> {
410        let number_rows: usize = 3;
411
412        let res: ColumnarValue = invoke_make_dt_interval_with_args(vec![], number_rows)?;
413        let arr = res.into_array(number_rows)?;
414        let arr = arr
415            .as_any()
416            .downcast_ref::<DurationMicrosecondArray>()
417            .ok_or_else(|| {
418                internal_datafusion_err!("expected DurationMicrosecondArray")
419            })?;
420
421        assert_eq!(arr.len(), number_rows);
422        for i in 0..number_rows {
423            assert!(!arr.is_null(i));
424            assert_eq!(arr.value(i), 0_i64);
425        }
426        Ok(())
427    }
428
429    #[test]
430    fn one_day_minus_24_hours_equals_zero() -> Result<()> {
431        let arr_days = Arc::new(Int32Array::from(vec![Some(1), Some(-1)])) as ArrayRef;
432        let arr_hours = Arc::new(Int32Array::from(vec![Some(-24), Some(24)])) as ArrayRef;
433        let arr_mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
434        let arr_secs =
435            Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
436
437        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
438        let out = out
439            .as_any()
440            .downcast_ref::<DurationMicrosecondArray>()
441            .ok_or_else(|| {
442                internal_datafusion_err!("expected DurationMicrosecondArray")
443            })?;
444
445        assert_eq!(out.len(), 2);
446        assert_eq!(out.null_count(), 0);
447        assert_eq!(out.value(0), 0_i64);
448        assert_eq!(out.value(1), 0_i64);
449        Ok(())
450    }
451
452    #[test]
453    fn one_hour_minus_60_mins_equals_zero() -> Result<()> {
454        let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
455        let arr_hours = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
456        let arr_mins = Arc::new(Int32Array::from(vec![Some(60), Some(-60)])) as ArrayRef;
457        let arr_secs =
458            Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
459
460        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
461        let out = out
462            .as_any()
463            .downcast_ref::<DurationMicrosecondArray>()
464            .ok_or_else(|| {
465                internal_datafusion_err!("expected DurationMicrosecondArray")
466            })?;
467
468        assert_eq!(out.len(), 2);
469        assert_eq!(out.null_count(), 0);
470        assert_eq!(out.value(0), 0_i64);
471        assert_eq!(out.value(1), 0_i64);
472        Ok(())
473    }
474
475    #[test]
476    fn one_mins_minus_60_secs_equals_zero() -> Result<()> {
477        let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
478        let arr_hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
479        let arr_mins = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
480        let arr_secs =
481            Arc::new(Float64Array::from(vec![Some(60.0), Some(-60.0)])) as ArrayRef;
482
483        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
484        let out = out
485            .as_any()
486            .downcast_ref::<DurationMicrosecondArray>()
487            .ok_or_else(|| {
488                internal_datafusion_err!("expected DurationMicrosecondArray")
489            })?;
490
491        assert_eq!(out.len(), 2);
492        assert_eq!(out.null_count(), 0);
493        assert_eq!(out.value(0), 0_i64);
494        assert_eq!(out.value(1), 0_i64);
495        Ok(())
496    }
497
498    #[test]
499    fn frac_carries_up_to_next_second_positive() -> Result<()> {
500        // 0.9999995s → 1_000_000 µs (carry a +1s)
501        let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
502        let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
503        let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
504        let secs = Arc::new(Float64Array::from(vec![
505            Some(0.999_999_5),
506            Some(0.999_999_4),
507        ])) as ArrayRef;
508
509        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
510        let out = out
511            .as_any()
512            .downcast_ref::<DurationMicrosecondArray>()
513            .ok_or_else(|| {
514                internal_datafusion_err!("expected DurationMicrosecondArray")
515            })?;
516
517        assert_eq!(out.len(), 2);
518        assert_eq!(out.value(0), 1_000_000);
519        assert_eq!(out.value(1), 999_999);
520        Ok(())
521    }
522
523    #[test]
524    fn frac_carries_down_to_prev_second_negative() -> Result<()> {
525        // -0.9999995s → -1_000_000 µs (carry a −1s)
526        let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
527        let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
528        let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
529        let secs = Arc::new(Float64Array::from(vec![
530            Some(-0.999_999_5),
531            Some(-0.999_999_4),
532        ])) as ArrayRef;
533
534        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
535        let out = out
536            .as_any()
537            .downcast_ref::<DurationMicrosecondArray>()
538            .ok_or_else(|| {
539                internal_datafusion_err!("expected DurationMicrosecondArray")
540            })?;
541
542        assert_eq!(out.len(), 2);
543        assert_eq!(out.value(0), -1_000_000);
544        assert_eq!(out.value(1), -999_999);
545        Ok(())
546    }
547
548    #[test]
549    fn no_more_than_4_params() -> Result<()> {
550        let udf = SparkMakeDtInterval::new();
551
552        // Create args with 5 parameters (exceeds the limit of 4)
553        let args = vec![
554            ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
555            ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
556            ColumnarValue::Scalar(ScalarValue::Int32(Some(3))),
557            ColumnarValue::Scalar(ScalarValue::Float64(Some(4.0))),
558            ColumnarValue::Scalar(ScalarValue::Int32(Some(5))),
559        ];
560
561        let arg_fields = args
562            .iter()
563            .map(|arg| Field::new("a", arg.data_type(), true).into())
564            .collect::<Vec<_>>();
565
566        let func_args = ScalarFunctionArgs {
567            args,
568            arg_fields,
569            number_rows: 1,
570            return_field: Field::new("f", Duration(Microsecond), true).into(),
571            config_options: Arc::new(Default::default()),
572        };
573
574        let res = udf.invoke_with_args(func_args);
575
576        assert!(
577            matches!(res, Err(DataFusionError::Execution(_))),
578            "make_dt_interval should return execution error for more than 4 arguments"
579        );
580
581        Ok(())
582    }
583}