datafusion_spark/function/datetime/
make_dt_interval.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, AsArray, DurationMicrosecondBuilder, PrimitiveArray,
23};
24use arrow::datatypes::TimeUnit::Microsecond;
25use arrow::datatypes::{DataType, Float64Type, Int32Type};
26use datafusion_common::{
27    exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue,
28};
29use datafusion_expr::{
30    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkMakeDtInterval {
36    signature: Signature,
37}
38
39impl Default for SparkMakeDtInterval {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl SparkMakeDtInterval {
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::user_defined(Volatility::Immutable),
49        }
50    }
51}
52
53impl ScalarUDFImpl for SparkMakeDtInterval {
54    fn as_any(&self) -> &dyn Any {
55        self
56    }
57
58    fn name(&self) -> &str {
59        "make_dt_interval"
60    }
61
62    fn signature(&self) -> &Signature {
63        &self.signature
64    }
65
66    /// Note the return type is `DataType::Duration(TimeUnit::Microsecond)` and not `DataType::Interval(DayTime)` as you might expect.
67    /// This is because `DataType::Interval(DayTime)` has precision only to the millisecond, whilst Spark's `DayTimeIntervalType` has
68    /// precision to the microsecond. We use `DataType::Duration(TimeUnit::Microsecond)` in order to not lose any precision. See the
69    /// [Sail compatibility doc] for reference.
70    ///
71    /// [Sail compatibility doc]: https://github.com/lakehq/sail/blob/dc5368daa24d40a7758a299e1ba8fc985cb29108/docs/guide/dataframe/data-types/compatibility.md?plain=1#L260
72    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
73        Ok(DataType::Duration(Microsecond))
74    }
75
76    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
77        if args.args.is_empty() {
78            return Ok(ColumnarValue::Scalar(ScalarValue::DurationMicrosecond(
79                Some(0),
80            )));
81        }
82        make_scalar_function(make_dt_interval_kernel, vec![])(&args.args)
83    }
84
85    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
86        if arg_types.len() > 4 {
87            return exec_err!(
88                "make_dt_interval expects between 0 and 4 arguments, got {}",
89                arg_types.len()
90            );
91        }
92
93        Ok((0..arg_types.len())
94            .map(|i| {
95                if i == 3 {
96                    DataType::Float64
97                } else {
98                    DataType::Int32
99                }
100            })
101            .collect())
102    }
103}
104
105fn make_dt_interval_kernel(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
106    let n_rows = args[0].len();
107    let days = args[0]
108        .as_primitive_opt::<Int32Type>()
109        .ok_or_else(|| plan_datafusion_err!("make_dt_interval arg[0] must be Int32"))?;
110    let hours: Option<&PrimitiveArray<Int32Type>> = args
111        .get(1)
112        .map(|a| {
113            a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
114                plan_datafusion_err!("make_dt_interval arg[1] must be Int32")
115            })
116        })
117        .transpose()?;
118    let mins: Option<&PrimitiveArray<Int32Type>> = args
119        .get(2)
120        .map(|a| {
121            a.as_primitive_opt::<Int32Type>().ok_or_else(|| {
122                plan_datafusion_err!("make_dt_interval arg[2] must be Int32")
123            })
124        })
125        .transpose()?;
126    let secs: Option<&PrimitiveArray<Float64Type>> = args
127        .get(3)
128        .map(|a| {
129            a.as_primitive_opt::<Float64Type>().ok_or_else(|| {
130                plan_datafusion_err!("make_dt_interval arg[3] must be Float64")
131            })
132        })
133        .transpose()?;
134    let mut builder = DurationMicrosecondBuilder::with_capacity(n_rows);
135
136    for i in 0..n_rows {
137        // if one column is NULL → result NULL
138        let any_null_present = days.is_null(i)
139            || hours.as_ref().is_some_and(|a| a.is_null(i))
140            || mins.as_ref().is_some_and(|a| a.is_null(i))
141            || secs
142                .as_ref()
143                .is_some_and(|a| a.is_null(i) || !a.value(i).is_finite());
144
145        if any_null_present {
146            builder.append_null();
147            continue;
148        }
149
150        // default values 0 or 0.0
151        let d = days.value(i);
152        let h = hours.as_ref().map_or(0, |a| a.value(i));
153        let mi = mins.as_ref().map_or(0, |a| a.value(i));
154        let s = secs.as_ref().map_or(0.0, |a| a.value(i));
155
156        match make_interval_dt_nano(d, h, mi, s) {
157            Some(v) => builder.append_value(v),
158            None => {
159                builder.append_null();
160                continue;
161            }
162        }
163    }
164
165    Ok(Arc::new(builder.finish()))
166}
167fn make_interval_dt_nano(day: i32, hour: i32, min: i32, sec: f64) -> Option<i64> {
168    const HOURS_PER_DAY: i32 = 24;
169    const MINS_PER_HOUR: i32 = 60;
170    const SECS_PER_MINUTE: i64 = 60;
171    const MICROS_PER_SEC: i64 = 1_000_000;
172
173    let total_hours: i32 = day
174        .checked_mul(HOURS_PER_DAY)
175        .and_then(|v| v.checked_add(hour))?;
176
177    let total_mins: i32 = total_hours
178        .checked_mul(MINS_PER_HOUR)
179        .and_then(|v| v.checked_add(min))?;
180
181    let mut sec_whole: i64 = sec.trunc() as i64;
182    let sec_frac: f64 = sec - (sec_whole as f64);
183    let mut frac_us: i64 = (sec_frac * (MICROS_PER_SEC as f64)).round() as i64;
184
185    if frac_us.abs() >= MICROS_PER_SEC {
186        if frac_us > 0 {
187            frac_us -= MICROS_PER_SEC;
188            sec_whole = sec_whole.checked_add(1)?;
189        } else {
190            frac_us += MICROS_PER_SEC;
191            sec_whole = sec_whole.checked_sub(1)?;
192        }
193    }
194
195    let total_secs: i64 = (total_mins as i64)
196        .checked_mul(SECS_PER_MINUTE)
197        .and_then(|v| v.checked_add(sec_whole))?;
198
199    let total_us = total_secs
200        .checked_mul(MICROS_PER_SEC)
201        .and_then(|v| v.checked_add(frac_us))?;
202
203    Some(total_us)
204}
205
206#[cfg(test)]
207mod tests {
208    use std::sync::Arc;
209
210    use arrow::array::{DurationMicrosecondArray, Float64Array, Int32Array};
211    use arrow::datatypes::DataType::Duration;
212    use arrow::datatypes::Field;
213    use arrow::datatypes::TimeUnit::Microsecond;
214    use datafusion_common::{internal_datafusion_err, DataFusionError, Result};
215    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
216
217    use super::*;
218
219    fn run_make_dt_interval(arrs: Vec<ArrayRef>) -> Result<ArrayRef> {
220        make_dt_interval_kernel(&arrs)
221    }
222
223    #[test]
224    fn nulls_propagate_per_row() -> Result<()> {
225        let days = Arc::new(Int32Array::from(vec![
226            None,
227            Some(2),
228            Some(3),
229            Some(4),
230            Some(5),
231            Some(6),
232            Some(7),
233        ])) as ArrayRef;
234
235        let hours = Arc::new(Int32Array::from(vec![
236            Some(1),
237            None,
238            Some(3),
239            Some(4),
240            Some(5),
241            Some(6),
242            Some(7),
243        ])) as ArrayRef;
244
245        let mins = Arc::new(Int32Array::from(vec![
246            Some(1),
247            Some(2),
248            None,
249            Some(4),
250            Some(5),
251            Some(6),
252            Some(7),
253        ])) as ArrayRef;
254
255        let secs = Arc::new(Float64Array::from(vec![
256            Some(1.0),
257            Some(2.0),
258            Some(3.0),
259            None,
260            Some(f64::NAN),
261            Some(f64::INFINITY),
262            Some(f64::NEG_INFINITY),
263        ])) as ArrayRef;
264
265        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
266        let out = out
267            .as_any()
268            .downcast_ref::<DurationMicrosecondArray>()
269            .ok_or_else(|| {
270                internal_datafusion_err!("expected DurationMicrosecondArray")
271            })?;
272
273        for i in 0..out.len() {
274            assert!(out.is_null(i), "row {i} should be NULL");
275        }
276        Ok(())
277    }
278
279    #[test]
280    fn error_months_overflow_should_be_null() -> Result<()> {
281        // months = year*12 + month → NULL
282
283        let days = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef;
284
285        let hours = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
286
287        let mins = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
288
289        let secs = Arc::new(Float64Array::from(vec![Some(1.0)])) as ArrayRef;
290
291        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
292        let out = out
293            .as_any()
294            .downcast_ref::<DurationMicrosecondArray>()
295            .ok_or_else(|| {
296                internal_datafusion_err!("expected DurationMicrosecondArray")
297            })?;
298
299        for i in 0..out.len() {
300            assert!(out.is_null(i), "row {i} should be NULL");
301        }
302
303        Ok(())
304    }
305
306    fn invoke_make_dt_interval_with_args(
307        args: Vec<ColumnarValue>,
308        number_rows: usize,
309    ) -> Result<ColumnarValue, DataFusionError> {
310        let arg_fields = args
311            .iter()
312            .map(|arg| Field::new("a", arg.data_type(), true).into())
313            .collect::<Vec<_>>();
314        let args = ScalarFunctionArgs {
315            args,
316            arg_fields,
317            number_rows,
318            return_field: Field::new("f", Duration(Microsecond), true).into(),
319            config_options: Arc::new(Default::default()),
320        };
321        SparkMakeDtInterval::new().invoke_with_args(args)
322    }
323
324    #[test]
325    fn zero_args_returns_zero_duration() -> Result<()> {
326        let number_rows: usize = 3;
327
328        let res: ColumnarValue = invoke_make_dt_interval_with_args(vec![], number_rows)?;
329        let arr = res.into_array(number_rows)?;
330        let arr = arr
331            .as_any()
332            .downcast_ref::<DurationMicrosecondArray>()
333            .ok_or_else(|| {
334                internal_datafusion_err!("expected DurationMicrosecondArray")
335            })?;
336
337        assert_eq!(arr.len(), number_rows);
338        for i in 0..number_rows {
339            assert!(!arr.is_null(i));
340            assert_eq!(arr.value(i), 0_i64);
341        }
342        Ok(())
343    }
344
345    #[test]
346    fn one_day_minus_24_hours_equals_zero() -> Result<()> {
347        let arr_days = Arc::new(Int32Array::from(vec![Some(1), Some(-1)])) as ArrayRef;
348        let arr_hours = Arc::new(Int32Array::from(vec![Some(-24), Some(24)])) as ArrayRef;
349        let arr_mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
350        let arr_secs =
351            Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
352
353        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
354        let out = out
355            .as_any()
356            .downcast_ref::<DurationMicrosecondArray>()
357            .ok_or_else(|| {
358                internal_datafusion_err!("expected DurationMicrosecondArray")
359            })?;
360
361        assert_eq!(out.len(), 2);
362        assert_eq!(out.null_count(), 0);
363        assert_eq!(out.value(0), 0_i64);
364        assert_eq!(out.value(1), 0_i64);
365        Ok(())
366    }
367
368    #[test]
369    fn one_hour_minus_60_mins_equals_zero() -> Result<()> {
370        let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
371        let arr_hours = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
372        let arr_mins = Arc::new(Int32Array::from(vec![Some(60), Some(-60)])) as ArrayRef;
373        let arr_secs =
374            Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef;
375
376        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
377        let out = out
378            .as_any()
379            .downcast_ref::<DurationMicrosecondArray>()
380            .ok_or_else(|| {
381                internal_datafusion_err!("expected DurationMicrosecondArray")
382            })?;
383
384        assert_eq!(out.len(), 2);
385        assert_eq!(out.null_count(), 0);
386        assert_eq!(out.value(0), 0_i64);
387        assert_eq!(out.value(1), 0_i64);
388        Ok(())
389    }
390
391    #[test]
392    fn one_mins_minus_60_secs_equals_zero() -> Result<()> {
393        let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
394        let arr_hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
395        let arr_mins = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef;
396        let arr_secs =
397            Arc::new(Float64Array::from(vec![Some(60.0), Some(-60.0)])) as ArrayRef;
398
399        let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?;
400        let out = out
401            .as_any()
402            .downcast_ref::<DurationMicrosecondArray>()
403            .ok_or_else(|| {
404                internal_datafusion_err!("expected DurationMicrosecondArray")
405            })?;
406
407        assert_eq!(out.len(), 2);
408        assert_eq!(out.null_count(), 0);
409        assert_eq!(out.value(0), 0_i64);
410        assert_eq!(out.value(1), 0_i64);
411        Ok(())
412    }
413
414    #[test]
415    fn frac_carries_up_to_next_second_positive() -> Result<()> {
416        // 0.9999995s → 1_000_000 µs (carry a +1s)
417        let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
418        let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
419        let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
420        let secs = Arc::new(Float64Array::from(vec![
421            Some(0.999_999_5),
422            Some(0.999_999_4),
423        ])) as ArrayRef;
424
425        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
426        let out = out
427            .as_any()
428            .downcast_ref::<DurationMicrosecondArray>()
429            .ok_or_else(|| {
430                internal_datafusion_err!("expected DurationMicrosecondArray")
431            })?;
432
433        assert_eq!(out.len(), 2);
434        assert_eq!(out.value(0), 1_000_000);
435        assert_eq!(out.value(1), 999_999);
436        Ok(())
437    }
438
439    #[test]
440    fn frac_carries_down_to_prev_second_negative() -> Result<()> {
441        // -0.9999995s → -1_000_000 µs (carry a −1s)
442        let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
443        let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
444        let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef;
445        let secs = Arc::new(Float64Array::from(vec![
446            Some(-0.999_999_5),
447            Some(-0.999_999_4),
448        ])) as ArrayRef;
449
450        let out = run_make_dt_interval(vec![days, hours, mins, secs])?;
451        let out = out
452            .as_any()
453            .downcast_ref::<DurationMicrosecondArray>()
454            .ok_or_else(|| {
455                internal_datafusion_err!("expected DurationMicrosecondArray")
456            })?;
457
458        assert_eq!(out.len(), 2);
459        assert_eq!(out.value(0), -1_000_000);
460        assert_eq!(out.value(1), -999_999);
461        Ok(())
462    }
463
464    #[test]
465    fn no_more_than_4_params() -> Result<()> {
466        let udf = SparkMakeDtInterval::new();
467
468        let arg_types = vec![
469            DataType::Int32,
470            DataType::Int32,
471            DataType::Int32,
472            DataType::Float64,
473            DataType::Int32,
474        ];
475
476        let res = udf.coerce_types(&arg_types);
477
478        assert!(
479            matches!(res, Err(DataFusionError::Execution(_))),
480            "make_dt_interval should return execution error for too many arguments"
481        );
482
483        Ok(())
484    }
485}