use arrow::array::builder::IntervalDayTimeBuilder;
use arrow::array::types::{Int16Type, Int32Type, Int8Type};
use arrow::array::{Array, Datum};
use arrow::array::{ArrayRef, AsArray};
use arrow::compute::kernels::numeric::{add, sub};
use arrow::datatypes::DataType;
use arrow::datatypes::IntervalDayTime;
use arrow::error::ArrowError;
use datafusion::common::{DataFusionError, ScalarValue};
use datafusion::physical_expr_common::datum;
use datafusion::physical_plan::ColumnarValue;
use std::sync::Arc;
macro_rules! scalar_date_arithmetic {
($start:expr, $days:expr, $op:expr) => {{
let interval = IntervalDayTime::new(*$days as i32, 0);
let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
datum::apply($start, &interval_cv, $op)
}};
}
macro_rules! array_date_arithmetic {
($days:expr, $interval_builder:expr, $intType:ty) => {{
for day in $days.as_primitive::<$intType>().into_iter() {
if let Some(non_null_day) = day {
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
} else {
$interval_builder.append_null();
}
}
}};
}
fn spark_date_arithmetic(
args: &[ColumnarValue],
op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
) -> Result<ColumnarValue, DataFusionError> {
let start = &args[0];
match &args[1] {
ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Array(days) => {
let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len());
match days.data_type() {
DataType::Int8 => {
array_date_arithmetic!(days, interval_builder, Int8Type)
}
DataType::Int16 => {
array_date_arithmetic!(days, interval_builder, Int16Type)
}
DataType::Int32 => {
array_date_arithmetic!(days, interval_builder, Int32Type)
}
_ => {
return Err(DataFusionError::Internal(format!(
"Unsupported data types {args:?} for date arithmetic.",
)))
}
}
let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish()));
datum::apply(start, &interval_cv, op)
}
_ => Err(DataFusionError::Internal(format!(
"Unsupported data types {args:?} for date arithmetic.",
))),
}
}
pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_date_arithmetic(args, add)
}
pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_date_arithmetic(args, sub)
}