datafusion_comet_spark_expr/datetime_funcs/
timestamp_trunc.rs1use crate::utils::array_with_timezone;
19use arrow::datatypes::{DataType, Schema, TimeUnit::Microsecond};
20use arrow::record_batch::RecordBatch;
21use datafusion::common::{DataFusionError, ScalarValue::Utf8};
22use datafusion::logical_expr::ColumnarValue;
23use datafusion::physical_expr::PhysicalExpr;
24use std::hash::Hash;
25use std::{
26 any::Any,
27 fmt::{Debug, Display, Formatter},
28 sync::Arc,
29};
30
31use crate::kernels::temporal::{timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn};
32
33#[derive(Debug, Eq)]
34pub struct TimestampTruncExpr {
35 child: Arc<dyn PhysicalExpr>,
37 format: Arc<dyn PhysicalExpr>,
39 timezone: String,
46}
47
48impl Hash for TimestampTruncExpr {
49 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50 self.child.hash(state);
51 self.format.hash(state);
52 self.timezone.hash(state);
53 }
54}
55impl PartialEq for TimestampTruncExpr {
56 fn eq(&self, other: &Self) -> bool {
57 self.child.eq(&other.child)
58 && self.format.eq(&other.format)
59 && self.timezone.eq(&other.timezone)
60 }
61}
62
63impl TimestampTruncExpr {
64 pub fn new(
65 child: Arc<dyn PhysicalExpr>,
66 format: Arc<dyn PhysicalExpr>,
67 timezone: String,
68 ) -> Self {
69 TimestampTruncExpr {
70 child,
71 format,
72 timezone,
73 }
74 }
75}
76
77impl Display for TimestampTruncExpr {
78 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
79 write!(
80 f,
81 "TimestampTrunc [child:{}, format:{}, timezone: {}]",
82 self.child, self.format, self.timezone
83 )
84 }
85}
86
87impl PhysicalExpr for TimestampTruncExpr {
88 fn as_any(&self) -> &dyn Any {
89 self
90 }
91
92 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
93 unimplemented!()
94 }
95
96 fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result<DataType> {
97 match self.child.data_type(input_schema)? {
98 DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary(
99 key_type,
100 Box::new(DataType::Timestamp(Microsecond, None)),
101 )),
102 _ => Ok(DataType::Timestamp(Microsecond, None)),
103 }
104 }
105
106 fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
107 Ok(true)
108 }
109
110 fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
111 let timestamp = self.child.evaluate(batch)?;
112 let format = self.format.evaluate(batch)?;
113 let tz = self.timezone.clone();
114 match (timestamp, format) {
115 (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => {
116 let ts = array_with_timezone(
117 ts,
118 tz.clone(),
119 Some(&DataType::Timestamp(Microsecond, Some(tz.into()))),
120 )?;
121 let result = timestamp_trunc_dyn(&ts, format)?;
122 Ok(ColumnarValue::Array(result))
123 }
124 (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => {
125 let ts = array_with_timezone(
126 ts,
127 tz.clone(),
128 Some(&DataType::Timestamp(Microsecond, Some(tz.into()))),
129 )?;
130 let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?;
131 Ok(ColumnarValue::Array(result))
132 }
133 _ => Err(DataFusionError::Execution(
134 "Invalid input to function TimestampTrunc. \
135 Expected (PrimitiveArray<TimestampMicrosecondType>, Scalar, String) or \
136 (PrimitiveArray<TimestampMicrosecondType>, StringArray, String)"
137 .to_string(),
138 )),
139 }
140 }
141
142 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
143 vec![&self.child]
144 }
145
146 fn with_new_children(
147 self: Arc<Self>,
148 children: Vec<Arc<dyn PhysicalExpr>>,
149 ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
150 Ok(Arc::new(TimestampTruncExpr::new(
151 Arc::clone(&children[0]),
152 Arc::clone(&self.format),
153 self.timezone.clone(),
154 )))
155 }
156}