datafusion_spark/function/datetime/
next_day.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{new_null_array, ArrayRef, AsArray, Date32Array, StringArrayType};
22use arrow::datatypes::{DataType, Date32Type};
23use chrono::{Datelike, Duration, Weekday};
24use datafusion_common::{exec_err, Result, ScalarValue};
25use datafusion_expr::{
26    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
27};
28
29/// <https://spark.apache.org/docs/latest/api/sql/index.html#next_day>
30#[derive(Debug, PartialEq, Eq, Hash)]
31pub struct SparkNextDay {
32    signature: Signature,
33}
34
35impl Default for SparkNextDay {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl SparkNextDay {
42    pub fn new() -> Self {
43        Self {
44            signature: Signature::exact(
45                vec![DataType::Date32, DataType::Utf8],
46                Volatility::Immutable,
47            ),
48        }
49    }
50}
51
52impl ScalarUDFImpl for SparkNextDay {
53    fn as_any(&self) -> &dyn Any {
54        self
55    }
56
57    fn name(&self) -> &str {
58        "next_day"
59    }
60
61    fn signature(&self) -> &Signature {
62        &self.signature
63    }
64
65    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
66        Ok(DataType::Date32)
67    }
68
69    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
70        let ScalarFunctionArgs { args, .. } = args;
71        let [date, day_of_week] = args.as_slice() else {
72            return exec_err!(
73                "Spark `next_day` function requires 2 arguments, got {}",
74                args.len()
75            );
76        };
77
78        match (date, day_of_week) {
79            (ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => {
80                match (date, day_of_week) {
81                    (ScalarValue::Date32(days), ScalarValue::Utf8(day_of_week) | ScalarValue::LargeUtf8(day_of_week) | ScalarValue::Utf8View(day_of_week)) => {
82                        if let Some(days) = days {
83                            if let Some(day_of_week) = day_of_week {
84                                Ok(ColumnarValue::Scalar(ScalarValue::Date32(
85                                    spark_next_day(*days, day_of_week.as_str()),
86                                )))
87                            } else {
88                                // TODO: if spark.sql.ansi.enabled is false,
89                                //  returns NULL instead of an error for a malformed dayOfWeek.
90                                Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
91                            }
92                        } else {
93                            Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
94                        }
95                    }
96                    _ => exec_err!("Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"),
97                }
98            }
99            (ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => {
100                match (date_array.data_type(), day_of_week) {
101                    (DataType::Date32, ScalarValue::Utf8(day_of_week) | ScalarValue::LargeUtf8(day_of_week) | ScalarValue::Utf8View(day_of_week)) => {
102                        if let Some(day_of_week) = day_of_week {
103                            let result: Date32Array = date_array
104                                .as_primitive::<Date32Type>()
105                                .unary_opt(|days| spark_next_day(days, day_of_week.as_str()))
106                                .with_data_type(DataType::Date32);
107                            Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
108                        } else {
109                            // TODO: if spark.sql.ansi.enabled is false,
110                            //  returns NULL instead of an error for a malformed dayOfWeek.
111                            Ok(ColumnarValue::Array(Arc::new(new_null_array(&DataType::Date32, date_array.len()))))
112                        }
113                    }
114                    _ => exec_err!("Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"),
115                }
116            }
117            (
118                ColumnarValue::Array(date_array),
119                ColumnarValue::Array(day_of_week_array),
120            ) => {
121                let result = match (date_array.data_type(), day_of_week_array.data_type())
122                {
123                    (
124                        DataType::Date32,
125                        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View,
126                    ) => {
127                        let date_array: &Date32Array =
128                            date_array.as_primitive::<Date32Type>();
129                        match day_of_week_array.data_type() {
130                            DataType::Utf8 => {
131                                let day_of_week_array =
132                                    day_of_week_array.as_string::<i32>();
133                                process_next_day_arrays(date_array, day_of_week_array)
134                            }
135                            DataType::LargeUtf8 => {
136                                let day_of_week_array =
137                                    day_of_week_array.as_string::<i64>();
138                                process_next_day_arrays(date_array, day_of_week_array)
139                            }
140                            DataType::Utf8View => {
141                                let day_of_week_array =
142                                    day_of_week_array.as_string_view();
143                                process_next_day_arrays(date_array, day_of_week_array)
144                            }
145                            other => {
146                                exec_err!("Spark `next_day` function: second arg must be string. Got {other:?}")
147                            }
148                        }
149                    }
150                    (left, right) => {
151                        exec_err!(
152                            "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
153                        )
154                    }
155                }?;
156                Ok(ColumnarValue::Array(result))
157            }
158            _ => exec_err!("Unsupported args {args:?} for Spark function `next_day`"),
159        }
160    }
161}
162
163fn process_next_day_arrays<'a, S>(
164    date_array: &Date32Array,
165    day_of_week_array: &'a S,
166) -> Result<ArrayRef>
167where
168    &'a S: StringArrayType<'a>,
169{
170    let result = date_array
171        .iter()
172        .zip(day_of_week_array.iter())
173        .map(|(days, day_of_week)| {
174            if let Some(days) = days {
175                if let Some(day_of_week) = day_of_week {
176                    spark_next_day(days, day_of_week)
177                } else {
178                    // TODO: if spark.sql.ansi.enabled is false,
179                    //  returns NULL instead of an error for a malformed dayOfWeek.
180                    None
181                }
182            } else {
183                None
184            }
185        })
186        .collect::<Date32Array>();
187    Ok(Arc::new(result) as ArrayRef)
188}
189
190fn spark_next_day(days: i32, day_of_week: &str) -> Option<i32> {
191    let date = Date32Type::to_naive_date(days);
192
193    let day_of_week = day_of_week.trim().to_uppercase();
194    let day_of_week = match day_of_week.as_str() {
195        "MO" | "MON" | "MONDAY" => Some("MONDAY"),
196        "TU" | "TUE" | "TUESDAY" => Some("TUESDAY"),
197        "WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"),
198        "TH" | "THU" | "THURSDAY" => Some("THURSDAY"),
199        "FR" | "FRI" | "FRIDAY" => Some("FRIDAY"),
200        "SA" | "SAT" | "SATURDAY" => Some("SATURDAY"),
201        "SU" | "SUN" | "SUNDAY" => Some("SUNDAY"),
202        _ => {
203            // TODO: if spark.sql.ansi.enabled is false,
204            //  returns NULL instead of an error for a malformed dayOfWeek.
205            None
206        }
207    };
208
209    if let Some(day_of_week) = day_of_week {
210        let day_of_week = day_of_week.parse::<Weekday>();
211        match day_of_week {
212            Ok(day_of_week) => Some(Date32Type::from_naive_date(
213                date + Duration::days(
214                    (7 - date.weekday().days_since(day_of_week)) as i64,
215                ),
216            )),
217            Err(_) => {
218                // TODO: if spark.sql.ansi.enabled is false,
219                //  returns NULL instead of an error for a malformed dayOfWeek.
220                None
221            }
222        }
223    } else {
224        None
225    }
226}