datafusion_comet_spark_expr/datetime_funcs/
date_trunc.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::record_batch::RecordBatch;
19use arrow_schema::{DataType, Schema};
20use datafusion::logical_expr::ColumnarValue;
21use datafusion_common::{DataFusionError, ScalarValue::Utf8};
22use datafusion_physical_expr::PhysicalExpr;
23use std::hash::Hash;
24use std::{
25    any::Any,
26    fmt::{Debug, Display, Formatter},
27    sync::Arc,
28};
29
30use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn};
31
32#[derive(Debug, Eq)]
33pub struct DateTruncExpr {
34    /// An array with DataType::Date32
35    child: Arc<dyn PhysicalExpr>,
36    /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc
37    format: Arc<dyn PhysicalExpr>,
38}
39
40impl Hash for DateTruncExpr {
41    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42        self.child.hash(state);
43        self.format.hash(state);
44    }
45}
46impl PartialEq for DateTruncExpr {
47    fn eq(&self, other: &Self) -> bool {
48        self.child.eq(&other.child) && self.format.eq(&other.format)
49    }
50}
51
52impl DateTruncExpr {
53    pub fn new(child: Arc<dyn PhysicalExpr>, format: Arc<dyn PhysicalExpr>) -> Self {
54        DateTruncExpr { child, format }
55    }
56}
57
58impl Display for DateTruncExpr {
59    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60        write!(
61            f,
62            "DateTrunc [child:{}, format: {}]",
63            self.child, self.format
64        )
65    }
66}
67
68impl PhysicalExpr for DateTruncExpr {
69    fn as_any(&self) -> &dyn Any {
70        self
71    }
72
73    fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result<DataType> {
74        self.child.data_type(input_schema)
75    }
76
77    fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
78        Ok(true)
79    }
80
81    fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
82        let date = self.child.evaluate(batch)?;
83        let format = self.format.evaluate(batch)?;
84        match (date, format) {
85            (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => {
86                let result = date_trunc_dyn(&date, format)?;
87                Ok(ColumnarValue::Array(result))
88            }
89            (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => {
90                let result = date_trunc_array_fmt_dyn(&date, &formats)?;
91                Ok(ColumnarValue::Array(result))
92            }
93            _ => Err(DataFusionError::Execution(
94                "Invalid input to function DateTrunc. Expected (PrimitiveArray<Date32>, Scalar) or \
95                    (PrimitiveArray<Date32>, StringArray)".to_string(),
96            )),
97        }
98    }
99
100    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
101        vec![&self.child]
102    }
103
104    fn with_new_children(
105        self: Arc<Self>,
106        children: Vec<Arc<dyn PhysicalExpr>>,
107    ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
108        Ok(Arc::new(DateTruncExpr::new(
109            Arc::clone(&children[0]),
110            Arc::clone(&self.format),
111        )))
112    }
113}