datafusion_spark/function/datetime/
next_day.rs1use std::sync::Arc;
19
20use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType};
21use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
22use chrono::{Datelike, Duration, Weekday};
23use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
24use datafusion_expr::{
25 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
26 Volatility,
27};
28
29#[derive(Debug, PartialEq, Eq, Hash)]
31pub struct SparkNextDay {
32 signature: Signature,
33}
34
35impl Default for SparkNextDay {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl SparkNextDay {
42 pub fn new() -> Self {
43 Self {
44 signature: Signature::exact(
45 vec![DataType::Date32, DataType::Utf8],
46 Volatility::Immutable,
47 ),
48 }
49 }
50}
51
52impl ScalarUDFImpl for SparkNextDay {
53 fn name(&self) -> &str {
54 "next_day"
55 }
56
57 fn signature(&self) -> &Signature {
58 &self.signature
59 }
60
61 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
62 internal_err!("return_field_from_args should be used instead")
63 }
64
65 fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
66 Ok(Arc::new(Field::new(self.name(), DataType::Date32, true)))
69 }
70
71 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
72 let ScalarFunctionArgs { args, .. } = args;
73 let [date, day_of_week] = args.as_slice() else {
74 return exec_err!(
75 "Spark `next_day` function requires 2 arguments, got {}",
76 args.len()
77 );
78 };
79
80 match (date, day_of_week) {
81 (ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => {
82 match (date, day_of_week) {
83 (
84 ScalarValue::Date32(days),
85 ScalarValue::Utf8(day_of_week)
86 | ScalarValue::LargeUtf8(day_of_week)
87 | ScalarValue::Utf8View(day_of_week),
88 ) => {
89 if let Some(days) = days {
90 if let Some(day_of_week) = day_of_week {
91 Ok(ColumnarValue::Scalar(ScalarValue::Date32(
92 spark_next_day(*days, day_of_week.as_str()),
93 )))
94 } else {
95 Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
98 }
99 } else {
100 Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
101 }
102 }
103 _ => exec_err!(
104 "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
105 ),
106 }
107 }
108 (ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => {
109 match (date_array.data_type(), day_of_week) {
110 (
111 DataType::Date32,
112 ScalarValue::Utf8(day_of_week)
113 | ScalarValue::LargeUtf8(day_of_week)
114 | ScalarValue::Utf8View(day_of_week),
115 ) => {
116 if let Some(day_of_week) = day_of_week {
117 let result: Date32Array = date_array
118 .as_primitive::<Date32Type>()
119 .unary_opt(|days| {
120 spark_next_day(days, day_of_week.as_str())
121 })
122 .with_data_type(DataType::Date32);
123 Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
124 } else {
125 Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
128 }
129 }
130 _ => exec_err!(
131 "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
132 ),
133 }
134 }
135 (
136 ColumnarValue::Array(date_array),
137 ColumnarValue::Array(day_of_week_array),
138 ) => {
139 let result = match (date_array.data_type(), day_of_week_array.data_type())
140 {
141 (
142 DataType::Date32,
143 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View,
144 ) => {
145 let date_array: &Date32Array =
146 date_array.as_primitive::<Date32Type>();
147 match day_of_week_array.data_type() {
148 DataType::Utf8 => {
149 let day_of_week_array =
150 day_of_week_array.as_string::<i32>();
151 process_next_day_arrays(date_array, day_of_week_array)
152 }
153 DataType::LargeUtf8 => {
154 let day_of_week_array =
155 day_of_week_array.as_string::<i64>();
156 process_next_day_arrays(date_array, day_of_week_array)
157 }
158 DataType::Utf8View => {
159 let day_of_week_array =
160 day_of_week_array.as_string_view();
161 process_next_day_arrays(date_array, day_of_week_array)
162 }
163 other => {
164 exec_err!(
165 "Spark `next_day` function: second arg must be string. Got {other:?}"
166 )
167 }
168 }
169 }
170 (left, right) => {
171 exec_err!(
172 "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
173 )
174 }
175 }?;
176 Ok(ColumnarValue::Array(result))
177 }
178 _ => exec_err!("Unsupported args {args:?} for Spark function `next_day`"),
179 }
180 }
181}
182
183fn process_next_day_arrays<'a, S>(
184 date_array: &Date32Array,
185 day_of_week_array: &'a S,
186) -> Result<ArrayRef>
187where
188 &'a S: StringArrayType<'a>,
189{
190 let result = date_array
191 .iter()
192 .zip(day_of_week_array.iter())
193 .map(|(days, day_of_week)| {
194 if let Some(days) = days {
195 if let Some(day_of_week) = day_of_week {
196 spark_next_day(days, day_of_week)
197 } else {
198 None
201 }
202 } else {
203 None
204 }
205 })
206 .collect::<Date32Array>();
207 Ok(Arc::new(result) as ArrayRef)
208}
209
210fn spark_next_day(days: i32, day_of_week: &str) -> Option<i32> {
211 let date = Date32Type::to_naive_date_opt(days)?;
212
213 let day_of_week = day_of_week.trim().to_uppercase();
214 let day_of_week = match day_of_week.as_str() {
215 "MO" | "MON" | "MONDAY" => Some("MONDAY"),
216 "TU" | "TUE" | "TUESDAY" => Some("TUESDAY"),
217 "WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"),
218 "TH" | "THU" | "THURSDAY" => Some("THURSDAY"),
219 "FR" | "FRI" | "FRIDAY" => Some("FRIDAY"),
220 "SA" | "SAT" | "SATURDAY" => Some("SATURDAY"),
221 "SU" | "SUN" | "SUNDAY" => Some("SUNDAY"),
222 _ => {
223 None
226 }
227 };
228
229 if let Some(day_of_week) = day_of_week {
230 let day_of_week = day_of_week.parse::<Weekday>();
231 match day_of_week {
232 Ok(day_of_week) => Some(Date32Type::from_naive_date(
233 date + Duration::days(
234 (7 - date.weekday().days_since(day_of_week)) as i64,
235 ),
236 )),
237 Err(_) => {
238 None
241 }
242 }
243 } else {
244 None
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn return_type_is_not_used() {
254 let func = SparkNextDay::new();
255 let err = func
256 .return_type(&[DataType::Date32, DataType::Utf8])
257 .unwrap_err();
258 assert!(
259 err.to_string()
260 .contains("return_field_from_args should be used instead")
261 );
262 }
263
264 #[test]
265 fn next_day_is_always_nullable() {
266 let func = SparkNextDay::new();
267 let date_field: FieldRef =
268 Arc::new(Field::new("start_date", DataType::Date32, false));
269 let day_field: FieldRef =
270 Arc::new(Field::new("day_of_week", DataType::Utf8, false));
271
272 let field = func
273 .return_field_from_args(ReturnFieldArgs {
274 arg_fields: &[Arc::clone(&date_field), Arc::clone(&day_field)],
275 scalar_arguments: &[None, None],
276 })
277 .unwrap();
278
279 assert_eq!(field.data_type(), &DataType::Date32);
280 assert!(field.is_nullable());
281 }
282}