Skip to main content

datafusion_spark/function/datetime/
next_day.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType, new_null_array};
22use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
23use chrono::{Datelike, Duration, Weekday};
24use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
25use datafusion_expr::{
26    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27    Volatility,
28};
29
30/// <https://spark.apache.org/docs/latest/api/sql/index.html#next_day>
31#[derive(Debug, PartialEq, Eq, Hash)]
32pub struct SparkNextDay {
33    signature: Signature,
34}
35
36impl Default for SparkNextDay {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl SparkNextDay {
43    pub fn new() -> Self {
44        Self {
45            signature: Signature::exact(
46                vec![DataType::Date32, DataType::Utf8],
47                Volatility::Immutable,
48            ),
49        }
50    }
51}
52
53impl ScalarUDFImpl for SparkNextDay {
54    fn as_any(&self) -> &dyn Any {
55        self
56    }
57
58    fn name(&self) -> &str {
59        "next_day"
60    }
61
62    fn signature(&self) -> &Signature {
63        &self.signature
64    }
65
66    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
67        internal_err!("return_field_from_args should be used instead")
68    }
69
70    fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
71        // Spark marks next_day as always nullable because invalid day_of_week values
72        // can yield NULL even when inputs are non-null.
73        Ok(Arc::new(Field::new(self.name(), DataType::Date32, true)))
74    }
75
76    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
77        let ScalarFunctionArgs { args, .. } = args;
78        let [date, day_of_week] = args.as_slice() else {
79            return exec_err!(
80                "Spark `next_day` function requires 2 arguments, got {}",
81                args.len()
82            );
83        };
84
85        match (date, day_of_week) {
86            (ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => {
87                match (date, day_of_week) {
88                    (
89                        ScalarValue::Date32(days),
90                        ScalarValue::Utf8(day_of_week)
91                        | ScalarValue::LargeUtf8(day_of_week)
92                        | ScalarValue::Utf8View(day_of_week),
93                    ) => {
94                        if let Some(days) = days {
95                            if let Some(day_of_week) = day_of_week {
96                                Ok(ColumnarValue::Scalar(ScalarValue::Date32(
97                                    spark_next_day(*days, day_of_week.as_str()),
98                                )))
99                            } else {
100                                // TODO: if spark.sql.ansi.enabled is false,
101                                //  returns NULL instead of an error for a malformed dayOfWeek.
102                                Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
103                            }
104                        } else {
105                            Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
106                        }
107                    }
108                    _ => exec_err!(
109                        "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
110                    ),
111                }
112            }
113            (ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => {
114                match (date_array.data_type(), day_of_week) {
115                    (
116                        DataType::Date32,
117                        ScalarValue::Utf8(day_of_week)
118                        | ScalarValue::LargeUtf8(day_of_week)
119                        | ScalarValue::Utf8View(day_of_week),
120                    ) => {
121                        if let Some(day_of_week) = day_of_week {
122                            let result: Date32Array = date_array
123                                .as_primitive::<Date32Type>()
124                                .unary_opt(|days| {
125                                    spark_next_day(days, day_of_week.as_str())
126                                })
127                                .with_data_type(DataType::Date32);
128                            Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
129                        } else {
130                            // TODO: if spark.sql.ansi.enabled is false,
131                            //  returns NULL instead of an error for a malformed dayOfWeek.
132                            Ok(ColumnarValue::Array(Arc::new(new_null_array(
133                                &DataType::Date32,
134                                date_array.len(),
135                            ))))
136                        }
137                    }
138                    _ => exec_err!(
139                        "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
140                    ),
141                }
142            }
143            (
144                ColumnarValue::Array(date_array),
145                ColumnarValue::Array(day_of_week_array),
146            ) => {
147                let result = match (date_array.data_type(), day_of_week_array.data_type())
148                {
149                    (
150                        DataType::Date32,
151                        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View,
152                    ) => {
153                        let date_array: &Date32Array =
154                            date_array.as_primitive::<Date32Type>();
155                        match day_of_week_array.data_type() {
156                            DataType::Utf8 => {
157                                let day_of_week_array =
158                                    day_of_week_array.as_string::<i32>();
159                                process_next_day_arrays(date_array, day_of_week_array)
160                            }
161                            DataType::LargeUtf8 => {
162                                let day_of_week_array =
163                                    day_of_week_array.as_string::<i64>();
164                                process_next_day_arrays(date_array, day_of_week_array)
165                            }
166                            DataType::Utf8View => {
167                                let day_of_week_array =
168                                    day_of_week_array.as_string_view();
169                                process_next_day_arrays(date_array, day_of_week_array)
170                            }
171                            other => {
172                                exec_err!(
173                                    "Spark `next_day` function: second arg must be string. Got {other:?}"
174                                )
175                            }
176                        }
177                    }
178                    (left, right) => {
179                        exec_err!(
180                            "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
181                        )
182                    }
183                }?;
184                Ok(ColumnarValue::Array(result))
185            }
186            _ => exec_err!("Unsupported args {args:?} for Spark function `next_day`"),
187        }
188    }
189}
190
191fn process_next_day_arrays<'a, S>(
192    date_array: &Date32Array,
193    day_of_week_array: &'a S,
194) -> Result<ArrayRef>
195where
196    &'a S: StringArrayType<'a>,
197{
198    let result = date_array
199        .iter()
200        .zip(day_of_week_array.iter())
201        .map(|(days, day_of_week)| {
202            if let Some(days) = days {
203                if let Some(day_of_week) = day_of_week {
204                    spark_next_day(days, day_of_week)
205                } else {
206                    // TODO: if spark.sql.ansi.enabled is false,
207                    //  returns NULL instead of an error for a malformed dayOfWeek.
208                    None
209                }
210            } else {
211                None
212            }
213        })
214        .collect::<Date32Array>();
215    Ok(Arc::new(result) as ArrayRef)
216}
217
218fn spark_next_day(days: i32, day_of_week: &str) -> Option<i32> {
219    let date = Date32Type::to_naive_date(days);
220
221    let day_of_week = day_of_week.trim().to_uppercase();
222    let day_of_week = match day_of_week.as_str() {
223        "MO" | "MON" | "MONDAY" => Some("MONDAY"),
224        "TU" | "TUE" | "TUESDAY" => Some("TUESDAY"),
225        "WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"),
226        "TH" | "THU" | "THURSDAY" => Some("THURSDAY"),
227        "FR" | "FRI" | "FRIDAY" => Some("FRIDAY"),
228        "SA" | "SAT" | "SATURDAY" => Some("SATURDAY"),
229        "SU" | "SUN" | "SUNDAY" => Some("SUNDAY"),
230        _ => {
231            // TODO: if spark.sql.ansi.enabled is false,
232            //  returns NULL instead of an error for a malformed dayOfWeek.
233            None
234        }
235    };
236
237    if let Some(day_of_week) = day_of_week {
238        let day_of_week = day_of_week.parse::<Weekday>();
239        match day_of_week {
240            Ok(day_of_week) => Some(Date32Type::from_naive_date(
241                date + Duration::days(
242                    (7 - date.weekday().days_since(day_of_week)) as i64,
243                ),
244            )),
245            Err(_) => {
246                // TODO: if spark.sql.ansi.enabled is false,
247                //  returns NULL instead of an error for a malformed dayOfWeek.
248                None
249            }
250        }
251    } else {
252        None
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use datafusion_expr::ReturnFieldArgs;
260
261    #[test]
262    fn return_type_is_not_used() {
263        let func = SparkNextDay::new();
264        let err = func
265            .return_type(&[DataType::Date32, DataType::Utf8])
266            .unwrap_err();
267        assert!(
268            err.to_string()
269                .contains("return_field_from_args should be used instead")
270        );
271    }
272
273    #[test]
274    fn next_day_is_always_nullable() {
275        let func = SparkNextDay::new();
276        let date_field: FieldRef =
277            Arc::new(Field::new("start_date", DataType::Date32, false));
278        let day_field: FieldRef =
279            Arc::new(Field::new("day_of_week", DataType::Utf8, false));
280
281        let field = func
282            .return_field_from_args(ReturnFieldArgs {
283                arg_fields: &[Arc::clone(&date_field), Arc::clone(&day_field)],
284                scalar_arguments: &[None, None],
285            })
286            .unwrap();
287
288        assert_eq!(field.data_type(), &DataType::Date32);
289        assert!(field.is_nullable());
290    }
291}