datafusion_comet_spark_expr/datetime_funcs/
timestamp_trunc.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::array_with_timezone;
19use arrow::record_batch::RecordBatch;
20use arrow_schema::{DataType, Schema, TimeUnit::Microsecond};
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::{DataFusionError, ScalarValue::Utf8};
23use datafusion_physical_expr::PhysicalExpr;
24use std::hash::Hash;
25use std::{
26    any::Any,
27    fmt::{Debug, Display, Formatter},
28    sync::Arc,
29};
30
31use crate::kernels::temporal::{timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn};
32
33#[derive(Debug, Eq)]
34pub struct TimestampTruncExpr {
35    /// An array with DataType::Timestamp(TimeUnit::Microsecond, None)
36    child: Arc<dyn PhysicalExpr>,
37    /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc
38    format: Arc<dyn PhysicalExpr>,
39    /// String containing a timezone name. The name must be found in the standard timezone
40    /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is
41    /// later parsed into a chrono::TimeZone.
42    /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros)
43    /// along with a single value for the associated TimeZone. The timezone offset is applied
44    /// just before any operations on the timestamp
45    timezone: String,
46}
47
48impl Hash for TimestampTruncExpr {
49    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50        self.child.hash(state);
51        self.format.hash(state);
52        self.timezone.hash(state);
53    }
54}
55impl PartialEq for TimestampTruncExpr {
56    fn eq(&self, other: &Self) -> bool {
57        self.child.eq(&other.child)
58            && self.format.eq(&other.format)
59            && self.timezone.eq(&other.timezone)
60    }
61}
62
63impl TimestampTruncExpr {
64    pub fn new(
65        child: Arc<dyn PhysicalExpr>,
66        format: Arc<dyn PhysicalExpr>,
67        timezone: String,
68    ) -> Self {
69        TimestampTruncExpr {
70            child,
71            format,
72            timezone,
73        }
74    }
75}
76
77impl Display for TimestampTruncExpr {
78    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
79        write!(
80            f,
81            "TimestampTrunc [child:{}, format:{}, timezone: {}]",
82            self.child, self.format, self.timezone
83        )
84    }
85}
86
87impl PhysicalExpr for TimestampTruncExpr {
88    fn as_any(&self) -> &dyn Any {
89        self
90    }
91
92    fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result<DataType> {
93        match self.child.data_type(input_schema)? {
94            DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary(
95                key_type,
96                Box::new(DataType::Timestamp(Microsecond, None)),
97            )),
98            _ => Ok(DataType::Timestamp(Microsecond, None)),
99        }
100    }
101
102    fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
103        Ok(true)
104    }
105
106    fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
107        let timestamp = self.child.evaluate(batch)?;
108        let format = self.format.evaluate(batch)?;
109        let tz = self.timezone.clone();
110        match (timestamp, format) {
111            (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => {
112                let ts = array_with_timezone(
113                    ts,
114                    tz.clone(),
115                    Some(&DataType::Timestamp(Microsecond, Some(tz.into()))),
116                )?;
117                let result = timestamp_trunc_dyn(&ts, format)?;
118                Ok(ColumnarValue::Array(result))
119            }
120            (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => {
121                let ts = array_with_timezone(
122                    ts,
123                    tz.clone(),
124                    Some(&DataType::Timestamp(Microsecond, Some(tz.into()))),
125                )?;
126                let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?;
127                Ok(ColumnarValue::Array(result))
128            }
129            _ => Err(DataFusionError::Execution(
130                "Invalid input to function TimestampTrunc. \
131                    Expected (PrimitiveArray<TimestampMicrosecondType>, Scalar, String) or \
132                    (PrimitiveArray<TimestampMicrosecondType>, StringArray, String)"
133                    .to_string(),
134            )),
135        }
136    }
137
138    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
139        vec![&self.child]
140    }
141
142    fn with_new_children(
143        self: Arc<Self>,
144        children: Vec<Arc<dyn PhysicalExpr>>,
145    ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
146        Ok(Arc::new(TimestampTruncExpr::new(
147            Arc::clone(&children[0]),
148            Arc::clone(&self.format),
149            self.timezone.clone(),
150        )))
151    }
152}