datafusion_spark/function/datetime/
next_day.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{new_null_array, ArrayRef, AsArray, Date32Array, StringArrayType};
22use arrow::datatypes::{DataType, Date32Type};
23use chrono::{Datelike, Duration, Weekday};
24use datafusion_common::{exec_err, Result, ScalarValue};
25use datafusion_expr::{
26 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
27};
28
29#[derive(Debug, PartialEq, Eq, Hash)]
31pub struct SparkNextDay {
32 signature: Signature,
33}
34
35impl Default for SparkNextDay {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl SparkNextDay {
42 pub fn new() -> Self {
43 Self {
44 signature: Signature::exact(
45 vec![DataType::Date32, DataType::Utf8],
46 Volatility::Immutable,
47 ),
48 }
49 }
50}
51
52impl ScalarUDFImpl for SparkNextDay {
53 fn as_any(&self) -> &dyn Any {
54 self
55 }
56
57 fn name(&self) -> &str {
58 "next_day"
59 }
60
61 fn signature(&self) -> &Signature {
62 &self.signature
63 }
64
65 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
66 Ok(DataType::Date32)
67 }
68
69 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
70 let ScalarFunctionArgs { args, .. } = args;
71 let [date, day_of_week] = args.as_slice() else {
72 return exec_err!(
73 "Spark `next_day` function requires 2 arguments, got {}",
74 args.len()
75 );
76 };
77
78 match (date, day_of_week) {
79 (ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => {
80 match (date, day_of_week) {
81 (ScalarValue::Date32(days), ScalarValue::Utf8(day_of_week) | ScalarValue::LargeUtf8(day_of_week) | ScalarValue::Utf8View(day_of_week)) => {
82 if let Some(days) = days {
83 if let Some(day_of_week) = day_of_week {
84 Ok(ColumnarValue::Scalar(ScalarValue::Date32(
85 spark_next_day(*days, day_of_week.as_str()),
86 )))
87 } else {
88 Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
91 }
92 } else {
93 Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
94 }
95 }
96 _ => exec_err!("Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"),
97 }
98 }
99 (ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => {
100 match (date_array.data_type(), day_of_week) {
101 (DataType::Date32, ScalarValue::Utf8(day_of_week) | ScalarValue::LargeUtf8(day_of_week) | ScalarValue::Utf8View(day_of_week)) => {
102 if let Some(day_of_week) = day_of_week {
103 let result: Date32Array = date_array
104 .as_primitive::<Date32Type>()
105 .unary_opt(|days| spark_next_day(days, day_of_week.as_str()))
106 .with_data_type(DataType::Date32);
107 Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
108 } else {
109 Ok(ColumnarValue::Array(Arc::new(new_null_array(&DataType::Date32, date_array.len()))))
112 }
113 }
114 _ => exec_err!("Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"),
115 }
116 }
117 (
118 ColumnarValue::Array(date_array),
119 ColumnarValue::Array(day_of_week_array),
120 ) => {
121 let result = match (date_array.data_type(), day_of_week_array.data_type())
122 {
123 (
124 DataType::Date32,
125 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View,
126 ) => {
127 let date_array: &Date32Array =
128 date_array.as_primitive::<Date32Type>();
129 match day_of_week_array.data_type() {
130 DataType::Utf8 => {
131 let day_of_week_array =
132 day_of_week_array.as_string::<i32>();
133 process_next_day_arrays(date_array, day_of_week_array)
134 }
135 DataType::LargeUtf8 => {
136 let day_of_week_array =
137 day_of_week_array.as_string::<i64>();
138 process_next_day_arrays(date_array, day_of_week_array)
139 }
140 DataType::Utf8View => {
141 let day_of_week_array =
142 day_of_week_array.as_string_view();
143 process_next_day_arrays(date_array, day_of_week_array)
144 }
145 other => {
146 exec_err!("Spark `next_day` function: second arg must be string. Got {other:?}")
147 }
148 }
149 }
150 (left, right) => {
151 exec_err!(
152 "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
153 )
154 }
155 }?;
156 Ok(ColumnarValue::Array(result))
157 }
158 _ => exec_err!("Unsupported args {args:?} for Spark function `next_day`"),
159 }
160 }
161}
162
163fn process_next_day_arrays<'a, S>(
164 date_array: &Date32Array,
165 day_of_week_array: &'a S,
166) -> Result<ArrayRef>
167where
168 &'a S: StringArrayType<'a>,
169{
170 let result = date_array
171 .iter()
172 .zip(day_of_week_array.iter())
173 .map(|(days, day_of_week)| {
174 if let Some(days) = days {
175 if let Some(day_of_week) = day_of_week {
176 spark_next_day(days, day_of_week)
177 } else {
178 None
181 }
182 } else {
183 None
184 }
185 })
186 .collect::<Date32Array>();
187 Ok(Arc::new(result) as ArrayRef)
188}
189
190fn spark_next_day(days: i32, day_of_week: &str) -> Option<i32> {
191 let date = Date32Type::to_naive_date(days);
192
193 let day_of_week = day_of_week.trim().to_uppercase();
194 let day_of_week = match day_of_week.as_str() {
195 "MO" | "MON" | "MONDAY" => Some("MONDAY"),
196 "TU" | "TUE" | "TUESDAY" => Some("TUESDAY"),
197 "WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"),
198 "TH" | "THU" | "THURSDAY" => Some("THURSDAY"),
199 "FR" | "FRI" | "FRIDAY" => Some("FRIDAY"),
200 "SA" | "SAT" | "SATURDAY" => Some("SATURDAY"),
201 "SU" | "SUN" | "SUNDAY" => Some("SUNDAY"),
202 _ => {
203 None
206 }
207 };
208
209 if let Some(day_of_week) = day_of_week {
210 let day_of_week = day_of_week.parse::<Weekday>();
211 match day_of_week {
212 Ok(day_of_week) => Some(Date32Type::from_naive_date(
213 date + Duration::days(
214 (7 - date.weekday().days_since(day_of_week)) as i64,
215 ),
216 )),
217 Err(_) => {
218 None
221 }
222 }
223 } else {
224 None
225 }
226}