datafusion_comet_spark_expr/datetime_funcs/
date_arithmetic.rs1use arrow::array::{ArrayRef, AsArray};
19use arrow::compute::kernels::numeric::{add, sub};
20use arrow::datatypes::IntervalDayTime;
21use arrow_array::builder::IntervalDayTimeBuilder;
22use arrow_array::types::{Int16Type, Int32Type, Int8Type};
23use arrow_array::{Array, Datum};
24use arrow_schema::{ArrowError, DataType};
25use datafusion::physical_expr_common::datum;
26use datafusion::physical_plan::ColumnarValue;
27use datafusion_common::{DataFusionError, ScalarValue};
28use std::sync::Arc;
29
30macro_rules! scalar_date_arithmetic {
31 ($start:expr, $days:expr, $op:expr) => {{
32 let interval = IntervalDayTime::new(*$days as i32, 0);
33 let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
34 datum::apply($start, &interval_cv, $op)
35 }};
36}
37macro_rules! array_date_arithmetic {
38 ($days:expr, $interval_builder:expr, $intType:ty) => {{
39 for day in $days.as_primitive::<$intType>().into_iter() {
40 if let Some(non_null_day) = day {
41 $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
42 } else {
43 $interval_builder.append_null();
44 }
45 }
46 }};
47}
48
49fn spark_date_arithmetic(
53 args: &[ColumnarValue],
54 op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
55) -> Result<ColumnarValue, DataFusionError> {
56 let start = &args[0];
57 match &args[1] {
58 ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
59 scalar_date_arithmetic!(start, days, op)
60 }
61 ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
62 scalar_date_arithmetic!(start, days, op)
63 }
64 ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
65 scalar_date_arithmetic!(start, days, op)
66 }
67 ColumnarValue::Array(days) => {
68 let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len());
69 match days.data_type() {
70 DataType::Int8 => {
71 array_date_arithmetic!(days, interval_builder, Int8Type)
72 }
73 DataType::Int16 => {
74 array_date_arithmetic!(days, interval_builder, Int16Type)
75 }
76 DataType::Int32 => {
77 array_date_arithmetic!(days, interval_builder, Int32Type)
78 }
79 _ => {
80 return Err(DataFusionError::Internal(format!(
81 "Unsupported data types {:?} for date arithmetic.",
82 args,
83 )))
84 }
85 }
86 let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish()));
87 datum::apply(start, &interval_cv, op)
88 }
89 _ => Err(DataFusionError::Internal(format!(
90 "Unsupported data types {:?} for date arithmetic.",
91 args,
92 ))),
93 }
94}
95
96pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
97 spark_date_arithmetic(args, add)
98}
99
100pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
101 spark_date_arithmetic(args, sub)
102}