datafusion_comet_spark_expr/datetime_funcs/
extract_date_part.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::array_with_timezone;
19use arrow::compute::{date_part, DatePart};
20use arrow::datatypes::{DataType, TimeUnit::Microsecond};
21use datafusion::common::{internal_datafusion_err, DataFusionError};
22use datafusion::logical_expr::{
23    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use std::{any::Any, fmt::Debug};
26
27macro_rules! extract_date_part {
28    ($struct_name:ident, $fn_name:expr, $date_part_variant:ident) => {
29        #[derive(Debug)]
30        pub struct $struct_name {
31            signature: Signature,
32            aliases: Vec<String>,
33            timezone: String,
34        }
35
36        impl $struct_name {
37            pub fn new(timezone: String) -> Self {
38                Self {
39                    signature: Signature::user_defined(Volatility::Immutable),
40                    aliases: vec![],
41                    timezone,
42                }
43            }
44        }
45
46        impl ScalarUDFImpl for $struct_name {
47            fn as_any(&self) -> &dyn Any {
48                self
49            }
50
51            fn name(&self) -> &str {
52                $fn_name
53            }
54
55            fn signature(&self) -> &Signature {
56                &self.signature
57            }
58
59            fn return_type(&self, arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
60                Ok(match &arg_types[0] {
61                    DataType::Dictionary(_, _) => {
62                        DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32))
63                    }
64                    _ => DataType::Int32,
65                })
66            }
67
68            fn invoke_with_args(
69                &self,
70                args: ScalarFunctionArgs,
71            ) -> datafusion::common::Result<ColumnarValue> {
72                let args: [ColumnarValue; 1] = args.args.try_into().map_err(|_| {
73                    internal_datafusion_err!(concat!($fn_name, " expects exactly one argument"))
74                })?;
75
76                match args {
77                    [ColumnarValue::Array(array)] => {
78                        let array = array_with_timezone(
79                            array,
80                            self.timezone.clone(),
81                            Some(&DataType::Timestamp(
82                                Microsecond,
83                                Some(self.timezone.clone().into()),
84                            )),
85                        )?;
86                        let result = date_part(&array, DatePart::$date_part_variant)?;
87                        Ok(ColumnarValue::Array(result))
88                    }
89                    _ => Err(DataFusionError::Execution(
90                        concat!($fn_name, "(scalar) should be fold in Spark JVM side.").to_string(),
91                    )),
92                }
93            }
94
95            fn aliases(&self) -> &[String] {
96                &self.aliases
97            }
98        }
99    };
100}
101
102extract_date_part!(SparkHour, "hour", Hour);
103extract_date_part!(SparkMinute, "minute", Minute);
104extract_date_part!(SparkSecond, "second", Second);