use std::any::Any;
use std::sync::Arc;
use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType};
use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
use chrono::{Datelike, Duration, Weekday};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkNextDay {
signature: Signature,
}
impl Default for SparkNextDay {
fn default() -> Self {
Self::new()
}
}
impl SparkNextDay {
pub fn new() -> Self {
Self {
signature: Signature::exact(
vec![DataType::Date32, DataType::Utf8],
Volatility::Immutable,
),
}
}
}
impl ScalarUDFImpl for SparkNextDay {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"next_day"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
Ok(Arc::new(Field::new(self.name(), DataType::Date32, true)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;
let [date, day_of_week] = args.as_slice() else {
return exec_err!(
"Spark `next_day` function requires 2 arguments, got {}",
args.len()
);
};
match (date, day_of_week) {
(ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => {
match (date, day_of_week) {
(
ScalarValue::Date32(days),
ScalarValue::Utf8(day_of_week)
| ScalarValue::LargeUtf8(day_of_week)
| ScalarValue::Utf8View(day_of_week),
) => {
if let Some(days) = days {
if let Some(day_of_week) = day_of_week {
Ok(ColumnarValue::Scalar(ScalarValue::Date32(
spark_next_day(*days, day_of_week.as_str()),
)))
} else {
Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
}
} else {
Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
}
}
_ => exec_err!(
"Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
),
}
}
(ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => {
match (date_array.data_type(), day_of_week) {
(
DataType::Date32,
ScalarValue::Utf8(day_of_week)
| ScalarValue::LargeUtf8(day_of_week)
| ScalarValue::Utf8View(day_of_week),
) => {
if let Some(day_of_week) = day_of_week {
let result: Date32Array = date_array
.as_primitive::<Date32Type>()
.unary_opt(|days| {
spark_next_day(days, day_of_week.as_str())
})
.with_data_type(DataType::Date32);
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
} else {
Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
}
}
_ => exec_err!(
"Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
),
}
}
(
ColumnarValue::Array(date_array),
ColumnarValue::Array(day_of_week_array),
) => {
let result = match (date_array.data_type(), day_of_week_array.data_type())
{
(
DataType::Date32,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View,
) => {
let date_array: &Date32Array =
date_array.as_primitive::<Date32Type>();
match day_of_week_array.data_type() {
DataType::Utf8 => {
let day_of_week_array =
day_of_week_array.as_string::<i32>();
process_next_day_arrays(date_array, day_of_week_array)
}
DataType::LargeUtf8 => {
let day_of_week_array =
day_of_week_array.as_string::<i64>();
process_next_day_arrays(date_array, day_of_week_array)
}
DataType::Utf8View => {
let day_of_week_array =
day_of_week_array.as_string_view();
process_next_day_arrays(date_array, day_of_week_array)
}
other => {
exec_err!(
"Spark `next_day` function: second arg must be string. Got {other:?}"
)
}
}
}
(left, right) => {
exec_err!(
"Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
)
}
}?;
Ok(ColumnarValue::Array(result))
}
_ => exec_err!("Unsupported args {args:?} for Spark function `next_day`"),
}
}
}
fn process_next_day_arrays<'a, S>(
date_array: &Date32Array,
day_of_week_array: &'a S,
) -> Result<ArrayRef>
where
&'a S: StringArrayType<'a>,
{
let result = date_array
.iter()
.zip(day_of_week_array.iter())
.map(|(days, day_of_week)| {
if let Some(days) = days {
if let Some(day_of_week) = day_of_week {
spark_next_day(days, day_of_week)
} else {
None
}
} else {
None
}
})
.collect::<Date32Array>();
Ok(Arc::new(result) as ArrayRef)
}
fn spark_next_day(days: i32, day_of_week: &str) -> Option<i32> {
let date = Date32Type::to_naive_date_opt(days)?;
let day_of_week = day_of_week.trim().to_uppercase();
let day_of_week = match day_of_week.as_str() {
"MO" | "MON" | "MONDAY" => Some("MONDAY"),
"TU" | "TUE" | "TUESDAY" => Some("TUESDAY"),
"WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"),
"TH" | "THU" | "THURSDAY" => Some("THURSDAY"),
"FR" | "FRI" | "FRIDAY" => Some("FRIDAY"),
"SA" | "SAT" | "SATURDAY" => Some("SATURDAY"),
"SU" | "SUN" | "SUNDAY" => Some("SUNDAY"),
_ => {
None
}
};
if let Some(day_of_week) = day_of_week {
let day_of_week = day_of_week.parse::<Weekday>();
match day_of_week {
Ok(day_of_week) => Some(Date32Type::from_naive_date(
date + Duration::days(
(7 - date.weekday().days_since(day_of_week)) as i64,
),
)),
Err(_) => {
None
}
}
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr::ReturnFieldArgs;
#[test]
fn return_type_is_not_used() {
let func = SparkNextDay::new();
let err = func
.return_type(&[DataType::Date32, DataType::Utf8])
.unwrap_err();
assert!(
err.to_string()
.contains("return_field_from_args should be used instead")
);
}
#[test]
fn next_day_is_always_nullable() {
let func = SparkNextDay::new();
let date_field: FieldRef =
Arc::new(Field::new("start_date", DataType::Date32, false));
let day_field: FieldRef =
Arc::new(Field::new("day_of_week", DataType::Utf8, false));
let field = func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&date_field), Arc::clone(&day_field)],
scalar_arguments: &[None, None],
})
.unwrap();
assert_eq!(field.data_type(), &DataType::Date32);
assert!(field.is_nullable());
}
}