datafusion_spark/function/datetime/
next_day.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType, new_null_array};
22use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
23use chrono::{Datelike, Duration, Weekday};
24use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
25use datafusion_expr::{
26 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27 Volatility,
28};
29
30#[derive(Debug, PartialEq, Eq, Hash)]
32pub struct SparkNextDay {
33 signature: Signature,
34}
35
36impl Default for SparkNextDay {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SparkNextDay {
43 pub fn new() -> Self {
44 Self {
45 signature: Signature::exact(
46 vec![DataType::Date32, DataType::Utf8],
47 Volatility::Immutable,
48 ),
49 }
50 }
51}
52
53impl ScalarUDFImpl for SparkNextDay {
54 fn as_any(&self) -> &dyn Any {
55 self
56 }
57
58 fn name(&self) -> &str {
59 "next_day"
60 }
61
62 fn signature(&self) -> &Signature {
63 &self.signature
64 }
65
66 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
67 internal_err!("return_field_from_args should be used instead")
68 }
69
70 fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
71 Ok(Arc::new(Field::new(self.name(), DataType::Date32, true)))
74 }
75
76 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
77 let ScalarFunctionArgs { args, .. } = args;
78 let [date, day_of_week] = args.as_slice() else {
79 return exec_err!(
80 "Spark `next_day` function requires 2 arguments, got {}",
81 args.len()
82 );
83 };
84
85 match (date, day_of_week) {
86 (ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => {
87 match (date, day_of_week) {
88 (
89 ScalarValue::Date32(days),
90 ScalarValue::Utf8(day_of_week)
91 | ScalarValue::LargeUtf8(day_of_week)
92 | ScalarValue::Utf8View(day_of_week),
93 ) => {
94 if let Some(days) = days {
95 if let Some(day_of_week) = day_of_week {
96 Ok(ColumnarValue::Scalar(ScalarValue::Date32(
97 spark_next_day(*days, day_of_week.as_str()),
98 )))
99 } else {
100 Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
103 }
104 } else {
105 Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
106 }
107 }
108 _ => exec_err!(
109 "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
110 ),
111 }
112 }
113 (ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => {
114 match (date_array.data_type(), day_of_week) {
115 (
116 DataType::Date32,
117 ScalarValue::Utf8(day_of_week)
118 | ScalarValue::LargeUtf8(day_of_week)
119 | ScalarValue::Utf8View(day_of_week),
120 ) => {
121 if let Some(day_of_week) = day_of_week {
122 let result: Date32Array = date_array
123 .as_primitive::<Date32Type>()
124 .unary_opt(|days| {
125 spark_next_day(days, day_of_week.as_str())
126 })
127 .with_data_type(DataType::Date32);
128 Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
129 } else {
130 Ok(ColumnarValue::Array(Arc::new(new_null_array(
133 &DataType::Date32,
134 date_array.len(),
135 ))))
136 }
137 }
138 _ => exec_err!(
139 "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
140 ),
141 }
142 }
143 (
144 ColumnarValue::Array(date_array),
145 ColumnarValue::Array(day_of_week_array),
146 ) => {
147 let result = match (date_array.data_type(), day_of_week_array.data_type())
148 {
149 (
150 DataType::Date32,
151 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View,
152 ) => {
153 let date_array: &Date32Array =
154 date_array.as_primitive::<Date32Type>();
155 match day_of_week_array.data_type() {
156 DataType::Utf8 => {
157 let day_of_week_array =
158 day_of_week_array.as_string::<i32>();
159 process_next_day_arrays(date_array, day_of_week_array)
160 }
161 DataType::LargeUtf8 => {
162 let day_of_week_array =
163 day_of_week_array.as_string::<i64>();
164 process_next_day_arrays(date_array, day_of_week_array)
165 }
166 DataType::Utf8View => {
167 let day_of_week_array =
168 day_of_week_array.as_string_view();
169 process_next_day_arrays(date_array, day_of_week_array)
170 }
171 other => {
172 exec_err!(
173 "Spark `next_day` function: second arg must be string. Got {other:?}"
174 )
175 }
176 }
177 }
178 (left, right) => {
179 exec_err!(
180 "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
181 )
182 }
183 }?;
184 Ok(ColumnarValue::Array(result))
185 }
186 _ => exec_err!("Unsupported args {args:?} for Spark function `next_day`"),
187 }
188 }
189}
190
191fn process_next_day_arrays<'a, S>(
192 date_array: &Date32Array,
193 day_of_week_array: &'a S,
194) -> Result<ArrayRef>
195where
196 &'a S: StringArrayType<'a>,
197{
198 let result = date_array
199 .iter()
200 .zip(day_of_week_array.iter())
201 .map(|(days, day_of_week)| {
202 if let Some(days) = days {
203 if let Some(day_of_week) = day_of_week {
204 spark_next_day(days, day_of_week)
205 } else {
206 None
209 }
210 } else {
211 None
212 }
213 })
214 .collect::<Date32Array>();
215 Ok(Arc::new(result) as ArrayRef)
216}
217
218fn spark_next_day(days: i32, day_of_week: &str) -> Option<i32> {
219 let date = Date32Type::to_naive_date(days);
220
221 let day_of_week = day_of_week.trim().to_uppercase();
222 let day_of_week = match day_of_week.as_str() {
223 "MO" | "MON" | "MONDAY" => Some("MONDAY"),
224 "TU" | "TUE" | "TUESDAY" => Some("TUESDAY"),
225 "WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"),
226 "TH" | "THU" | "THURSDAY" => Some("THURSDAY"),
227 "FR" | "FRI" | "FRIDAY" => Some("FRIDAY"),
228 "SA" | "SAT" | "SATURDAY" => Some("SATURDAY"),
229 "SU" | "SUN" | "SUNDAY" => Some("SUNDAY"),
230 _ => {
231 None
234 }
235 };
236
237 if let Some(day_of_week) = day_of_week {
238 let day_of_week = day_of_week.parse::<Weekday>();
239 match day_of_week {
240 Ok(day_of_week) => Some(Date32Type::from_naive_date(
241 date + Duration::days(
242 (7 - date.weekday().days_since(day_of_week)) as i64,
243 ),
244 )),
245 Err(_) => {
246 None
249 }
250 }
251 } else {
252 None
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use datafusion_expr::ReturnFieldArgs;
260
261 #[test]
262 fn return_type_is_not_used() {
263 let func = SparkNextDay::new();
264 let err = func
265 .return_type(&[DataType::Date32, DataType::Utf8])
266 .unwrap_err();
267 assert!(
268 err.to_string()
269 .contains("return_field_from_args should be used instead")
270 );
271 }
272
273 #[test]
274 fn next_day_is_always_nullable() {
275 let func = SparkNextDay::new();
276 let date_field: FieldRef =
277 Arc::new(Field::new("start_date", DataType::Date32, false));
278 let day_field: FieldRef =
279 Arc::new(Field::new("day_of_week", DataType::Utf8, false));
280
281 let field = func
282 .return_field_from_args(ReturnFieldArgs {
283 arg_fields: &[Arc::clone(&date_field), Arc::clone(&day_field)],
284 scalar_arguments: &[None, None],
285 })
286 .unwrap();
287
288 assert_eq!(field.data_type(), &DataType::Date32);
289 assert!(field.is_nullable());
290 }
291}