datafusion_comet_spark_expr/datetime_funcs/
date_arithmetic.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, AsArray};
19use arrow::compute::kernels::numeric::{add, sub};
20use arrow::datatypes::IntervalDayTime;
21use arrow_array::builder::IntervalDayTimeBuilder;
22use arrow_array::types::{Int16Type, Int32Type, Int8Type};
23use arrow_array::{Array, Datum};
24use arrow_schema::{ArrowError, DataType};
25use datafusion::physical_expr_common::datum;
26use datafusion::physical_plan::ColumnarValue;
27use datafusion_common::{DataFusionError, ScalarValue};
28use std::sync::Arc;
29
30macro_rules! scalar_date_arithmetic {
31    ($start:expr, $days:expr, $op:expr) => {{
32        let interval = IntervalDayTime::new(*$days as i32, 0);
33        let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
34        datum::apply($start, &interval_cv, $op)
35    }};
36}
37macro_rules! array_date_arithmetic {
38    ($days:expr, $interval_builder:expr, $intType:ty) => {{
39        for day in $days.as_primitive::<$intType>().into_iter() {
40            if let Some(non_null_day) = day {
41                $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
42            } else {
43                $interval_builder.append_null();
44            }
45        }
46    }};
47}
48
49/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second
50/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the
51/// second argument and use DataFusion's interface to apply Arrow's operators.
52fn spark_date_arithmetic(
53    args: &[ColumnarValue],
54    op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
55) -> Result<ColumnarValue, DataFusionError> {
56    let start = &args[0];
57    match &args[1] {
58        ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
59            scalar_date_arithmetic!(start, days, op)
60        }
61        ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
62            scalar_date_arithmetic!(start, days, op)
63        }
64        ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
65            scalar_date_arithmetic!(start, days, op)
66        }
67        ColumnarValue::Array(days) => {
68            let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len());
69            match days.data_type() {
70                DataType::Int8 => {
71                    array_date_arithmetic!(days, interval_builder, Int8Type)
72                }
73                DataType::Int16 => {
74                    array_date_arithmetic!(days, interval_builder, Int16Type)
75                }
76                DataType::Int32 => {
77                    array_date_arithmetic!(days, interval_builder, Int32Type)
78                }
79                _ => {
80                    return Err(DataFusionError::Internal(format!(
81                        "Unsupported data types {:?} for date arithmetic.",
82                        args,
83                    )))
84                }
85            }
86            let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish()));
87            datum::apply(start, &interval_cv, op)
88        }
89        _ => Err(DataFusionError::Internal(format!(
90            "Unsupported data types {:?} for date arithmetic.",
91            args,
92        ))),
93    }
94}
95
96pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
97    spark_date_arithmetic(args, add)
98}
99
100pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
101    spark_date_arithmetic(args, sub)
102}