datafusion_comet_spark_expr/datetime_funcs/
timestamp_trunc.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::array_with_timezone;
19use arrow::datatypes::{DataType, Schema, TimeUnit::Microsecond};
20use arrow::record_batch::RecordBatch;
21use datafusion::common::{DataFusionError, ScalarValue::Utf8};
22use datafusion::logical_expr::ColumnarValue;
23use datafusion::physical_expr::PhysicalExpr;
24use std::hash::Hash;
25use std::{
26    any::Any,
27    fmt::{Debug, Display, Formatter},
28    sync::Arc,
29};
30
31use crate::kernels::temporal::{timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn};
32
33#[derive(Debug, Eq)]
34pub struct TimestampTruncExpr {
35    /// An array with DataType::Timestamp(TimeUnit::Microsecond, None)
36    child: Arc<dyn PhysicalExpr>,
37    /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc
38    format: Arc<dyn PhysicalExpr>,
39    /// String containing a timezone name. The name must be found in the standard timezone
40    /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is
41    /// later parsed into a chrono::TimeZone.
42    /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros)
43    /// along with a single value for the associated TimeZone. The timezone offset is applied
44    /// just before any operations on the timestamp
45    timezone: String,
46}
47
48impl Hash for TimestampTruncExpr {
49    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50        self.child.hash(state);
51        self.format.hash(state);
52        self.timezone.hash(state);
53    }
54}
55impl PartialEq for TimestampTruncExpr {
56    fn eq(&self, other: &Self) -> bool {
57        self.child.eq(&other.child)
58            && self.format.eq(&other.format)
59            && self.timezone.eq(&other.timezone)
60    }
61}
62
63impl TimestampTruncExpr {
64    pub fn new(
65        child: Arc<dyn PhysicalExpr>,
66        format: Arc<dyn PhysicalExpr>,
67        timezone: String,
68    ) -> Self {
69        TimestampTruncExpr {
70            child,
71            format,
72            timezone,
73        }
74    }
75}
76
77impl Display for TimestampTruncExpr {
78    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
79        write!(
80            f,
81            "TimestampTrunc [child:{}, format:{}, timezone: {}]",
82            self.child, self.format, self.timezone
83        )
84    }
85}
86
87impl PhysicalExpr for TimestampTruncExpr {
88    fn as_any(&self) -> &dyn Any {
89        self
90    }
91
92    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
93        unimplemented!()
94    }
95
96    fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result<DataType> {
97        match self.child.data_type(input_schema)? {
98            DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary(
99                key_type,
100                Box::new(DataType::Timestamp(Microsecond, None)),
101            )),
102            _ => Ok(DataType::Timestamp(Microsecond, None)),
103        }
104    }
105
106    fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
107        Ok(true)
108    }
109
110    fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
111        let timestamp = self.child.evaluate(batch)?;
112        let format = self.format.evaluate(batch)?;
113        let tz = self.timezone.clone();
114        match (timestamp, format) {
115            (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => {
116                let ts = array_with_timezone(
117                    ts,
118                    tz.clone(),
119                    Some(&DataType::Timestamp(Microsecond, Some(tz.into()))),
120                )?;
121                let result = timestamp_trunc_dyn(&ts, format)?;
122                Ok(ColumnarValue::Array(result))
123            }
124            (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => {
125                let ts = array_with_timezone(
126                    ts,
127                    tz.clone(),
128                    Some(&DataType::Timestamp(Microsecond, Some(tz.into()))),
129                )?;
130                let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?;
131                Ok(ColumnarValue::Array(result))
132            }
133            _ => Err(DataFusionError::Execution(
134                "Invalid input to function TimestampTrunc. \
135                    Expected (PrimitiveArray<TimestampMicrosecondType>, Scalar, String) or \
136                    (PrimitiveArray<TimestampMicrosecondType>, StringArray, String)"
137                    .to_string(),
138            )),
139        }
140    }
141
142    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
143        vec![&self.child]
144    }
145
146    fn with_new_children(
147        self: Arc<Self>,
148        children: Vec<Arc<dyn PhysicalExpr>>,
149    ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
150        Ok(Arc::new(TimestampTruncExpr::new(
151            Arc::clone(&children[0]),
152            Arc::clone(&self.format),
153            self.timezone.clone(),
154        )))
155    }
156}