Skip to main content

datafusion_spark/function/datetime/
next_day.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType};
21use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
22use chrono::{Datelike, Duration, Weekday};
23use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
24use datafusion_expr::{
25    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
26    Volatility,
27};
28
29/// <https://spark.apache.org/docs/latest/api/sql/index.html#next_day>
30#[derive(Debug, PartialEq, Eq, Hash)]
31pub struct SparkNextDay {
32    signature: Signature,
33}
34
35impl Default for SparkNextDay {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl SparkNextDay {
42    pub fn new() -> Self {
43        Self {
44            signature: Signature::exact(
45                vec![DataType::Date32, DataType::Utf8],
46                Volatility::Immutable,
47            ),
48        }
49    }
50}
51
52impl ScalarUDFImpl for SparkNextDay {
53    fn name(&self) -> &str {
54        "next_day"
55    }
56
57    fn signature(&self) -> &Signature {
58        &self.signature
59    }
60
61    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
62        internal_err!("return_field_from_args should be used instead")
63    }
64
65    fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
66        // Spark marks next_day as always nullable because invalid day_of_week values
67        // can yield NULL even when inputs are non-null.
68        Ok(Arc::new(Field::new(self.name(), DataType::Date32, true)))
69    }
70
71    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
72        let ScalarFunctionArgs { args, .. } = args;
73        let [date, day_of_week] = args.as_slice() else {
74            return exec_err!(
75                "Spark `next_day` function requires 2 arguments, got {}",
76                args.len()
77            );
78        };
79
80        match (date, day_of_week) {
81            (ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => {
82                match (date, day_of_week) {
83                    (
84                        ScalarValue::Date32(days),
85                        ScalarValue::Utf8(day_of_week)
86                        | ScalarValue::LargeUtf8(day_of_week)
87                        | ScalarValue::Utf8View(day_of_week),
88                    ) => {
89                        if let Some(days) = days {
90                            if let Some(day_of_week) = day_of_week {
91                                Ok(ColumnarValue::Scalar(ScalarValue::Date32(
92                                    spark_next_day(*days, day_of_week.as_str()),
93                                )))
94                            } else {
95                                // TODO: if spark.sql.ansi.enabled is false,
96                                //  returns NULL instead of an error for a malformed dayOfWeek.
97                                Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
98                            }
99                        } else {
100                            Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
101                        }
102                    }
103                    _ => exec_err!(
104                        "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
105                    ),
106                }
107            }
108            (ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => {
109                match (date_array.data_type(), day_of_week) {
110                    (
111                        DataType::Date32,
112                        ScalarValue::Utf8(day_of_week)
113                        | ScalarValue::LargeUtf8(day_of_week)
114                        | ScalarValue::Utf8View(day_of_week),
115                    ) => {
116                        if let Some(day_of_week) = day_of_week {
117                            let result: Date32Array = date_array
118                                .as_primitive::<Date32Type>()
119                                .unary_opt(|days| {
120                                    spark_next_day(days, day_of_week.as_str())
121                                })
122                                .with_data_type(DataType::Date32);
123                            Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
124                        } else {
125                            // TODO: if spark.sql.ansi.enabled is false,
126                            //  returns NULL instead of an error for a malformed dayOfWeek.
127                            Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
128                        }
129                    }
130                    _ => exec_err!(
131                        "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"
132                    ),
133                }
134            }
135            (
136                ColumnarValue::Array(date_array),
137                ColumnarValue::Array(day_of_week_array),
138            ) => {
139                let result = match (date_array.data_type(), day_of_week_array.data_type())
140                {
141                    (
142                        DataType::Date32,
143                        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View,
144                    ) => {
145                        let date_array: &Date32Array =
146                            date_array.as_primitive::<Date32Type>();
147                        match day_of_week_array.data_type() {
148                            DataType::Utf8 => {
149                                let day_of_week_array =
150                                    day_of_week_array.as_string::<i32>();
151                                process_next_day_arrays(date_array, day_of_week_array)
152                            }
153                            DataType::LargeUtf8 => {
154                                let day_of_week_array =
155                                    day_of_week_array.as_string::<i64>();
156                                process_next_day_arrays(date_array, day_of_week_array)
157                            }
158                            DataType::Utf8View => {
159                                let day_of_week_array =
160                                    day_of_week_array.as_string_view();
161                                process_next_day_arrays(date_array, day_of_week_array)
162                            }
163                            other => {
164                                exec_err!(
165                                    "Spark `next_day` function: second arg must be string. Got {other:?}"
166                                )
167                            }
168                        }
169                    }
170                    (left, right) => {
171                        exec_err!(
172                            "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
173                        )
174                    }
175                }?;
176                Ok(ColumnarValue::Array(result))
177            }
178            _ => exec_err!("Unsupported args {args:?} for Spark function `next_day`"),
179        }
180    }
181}
182
183fn process_next_day_arrays<'a, S>(
184    date_array: &Date32Array,
185    day_of_week_array: &'a S,
186) -> Result<ArrayRef>
187where
188    &'a S: StringArrayType<'a>,
189{
190    let result = date_array
191        .iter()
192        .zip(day_of_week_array.iter())
193        .map(|(days, day_of_week)| {
194            if let Some(days) = days {
195                if let Some(day_of_week) = day_of_week {
196                    spark_next_day(days, day_of_week)
197                } else {
198                    // TODO: if spark.sql.ansi.enabled is false,
199                    //  returns NULL instead of an error for a malformed dayOfWeek.
200                    None
201                }
202            } else {
203                None
204            }
205        })
206        .collect::<Date32Array>();
207    Ok(Arc::new(result) as ArrayRef)
208}
209
210fn spark_next_day(days: i32, day_of_week: &str) -> Option<i32> {
211    let date = Date32Type::to_naive_date_opt(days)?;
212
213    let day_of_week = day_of_week.trim().to_uppercase();
214    let day_of_week = match day_of_week.as_str() {
215        "MO" | "MON" | "MONDAY" => Some("MONDAY"),
216        "TU" | "TUE" | "TUESDAY" => Some("TUESDAY"),
217        "WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"),
218        "TH" | "THU" | "THURSDAY" => Some("THURSDAY"),
219        "FR" | "FRI" | "FRIDAY" => Some("FRIDAY"),
220        "SA" | "SAT" | "SATURDAY" => Some("SATURDAY"),
221        "SU" | "SUN" | "SUNDAY" => Some("SUNDAY"),
222        _ => {
223            // TODO: if spark.sql.ansi.enabled is false,
224            //  returns NULL instead of an error for a malformed dayOfWeek.
225            None
226        }
227    };
228
229    if let Some(day_of_week) = day_of_week {
230        let day_of_week = day_of_week.parse::<Weekday>();
231        match day_of_week {
232            Ok(day_of_week) => Some(Date32Type::from_naive_date(
233                date + Duration::days(
234                    (7 - date.weekday().days_since(day_of_week)) as i64,
235                ),
236            )),
237            Err(_) => {
238                // TODO: if spark.sql.ansi.enabled is false,
239                //  returns NULL instead of an error for a malformed dayOfWeek.
240                None
241            }
242        }
243    } else {
244        None
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn return_type_is_not_used() {
254        let func = SparkNextDay::new();
255        let err = func
256            .return_type(&[DataType::Date32, DataType::Utf8])
257            .unwrap_err();
258        assert!(
259            err.to_string()
260                .contains("return_field_from_args should be used instead")
261        );
262    }
263
264    #[test]
265    fn next_day_is_always_nullable() {
266        let func = SparkNextDay::new();
267        let date_field: FieldRef =
268            Arc::new(Field::new("start_date", DataType::Date32, false));
269        let day_field: FieldRef =
270            Arc::new(Field::new("day_of_week", DataType::Utf8, false));
271
272        let field = func
273            .return_field_from_args(ReturnFieldArgs {
274                arg_fields: &[Arc::clone(&date_field), Arc::clone(&day_field)],
275                scalar_arguments: &[None, None],
276            })
277            .unwrap();
278
279        assert_eq!(field.data_type(), &DataType::Date32);
280        assert!(field.is_nullable());
281    }
282}