datafusion_comet_spark_expr/datetime_funcs/
date_arithmetic.rs1use arrow::array::builder::IntervalDayTimeBuilder;
19use arrow::array::types::{Int16Type, Int32Type, Int8Type};
20use arrow::array::{Array, Datum};
21use arrow::array::{ArrayRef, AsArray};
22use arrow::compute::kernels::numeric::{add, sub};
23use arrow::datatypes::DataType;
24use arrow::datatypes::IntervalDayTime;
25use arrow::error::ArrowError;
26use datafusion::common::{DataFusionError, ScalarValue};
27use datafusion::physical_expr_common::datum;
28use datafusion::physical_plan::ColumnarValue;
29use std::sync::Arc;
30
31macro_rules! scalar_date_arithmetic {
32 ($start:expr, $days:expr, $op:expr) => {{
33 let interval = IntervalDayTime::new(*$days as i32, 0);
34 let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
35 datum::apply($start, &interval_cv, $op)
36 }};
37}
38macro_rules! array_date_arithmetic {
39 ($days:expr, $interval_builder:expr, $intType:ty) => {{
40 for day in $days.as_primitive::<$intType>().into_iter() {
41 if let Some(non_null_day) = day {
42 $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
43 } else {
44 $interval_builder.append_null();
45 }
46 }
47 }};
48}
49
50fn spark_date_arithmetic(
54 args: &[ColumnarValue],
55 op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
56) -> Result<ColumnarValue, DataFusionError> {
57 let start = &args[0];
58 match &args[1] {
59 ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
60 scalar_date_arithmetic!(start, days, op)
61 }
62 ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
63 scalar_date_arithmetic!(start, days, op)
64 }
65 ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
66 scalar_date_arithmetic!(start, days, op)
67 }
68 ColumnarValue::Array(days) => {
69 let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len());
70 match days.data_type() {
71 DataType::Int8 => {
72 array_date_arithmetic!(days, interval_builder, Int8Type)
73 }
74 DataType::Int16 => {
75 array_date_arithmetic!(days, interval_builder, Int16Type)
76 }
77 DataType::Int32 => {
78 array_date_arithmetic!(days, interval_builder, Int32Type)
79 }
80 _ => {
81 return Err(DataFusionError::Internal(format!(
82 "Unsupported data types {args:?} for date arithmetic.",
83 )))
84 }
85 }
86 let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish()));
87 datum::apply(start, &interval_cv, op)
88 }
89 _ => Err(DataFusionError::Internal(format!(
90 "Unsupported data types {args:?} for date arithmetic.",
91 ))),
92 }
93}
94
95pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
96 spark_date_arithmetic(args, add)
97}
98
99pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
100 spark_date_arithmetic(args, sub)
101}