datafusion_comet_spark_expr/datetime_funcs/
timestamp_trunc.rs1use crate::utils::array_with_timezone;
19use arrow::record_batch::RecordBatch;
20use arrow_schema::{DataType, Schema, TimeUnit::Microsecond};
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::{DataFusionError, ScalarValue::Utf8};
23use datafusion_physical_expr::PhysicalExpr;
24use std::hash::Hash;
25use std::{
26 any::Any,
27 fmt::{Debug, Display, Formatter},
28 sync::Arc,
29};
30
31use crate::kernels::temporal::{timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn};
32
33#[derive(Debug, Eq)]
34pub struct TimestampTruncExpr {
35 child: Arc<dyn PhysicalExpr>,
37 format: Arc<dyn PhysicalExpr>,
39 timezone: String,
46}
47
48impl Hash for TimestampTruncExpr {
49 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50 self.child.hash(state);
51 self.format.hash(state);
52 self.timezone.hash(state);
53 }
54}
55impl PartialEq for TimestampTruncExpr {
56 fn eq(&self, other: &Self) -> bool {
57 self.child.eq(&other.child)
58 && self.format.eq(&other.format)
59 && self.timezone.eq(&other.timezone)
60 }
61}
62
63impl TimestampTruncExpr {
64 pub fn new(
65 child: Arc<dyn PhysicalExpr>,
66 format: Arc<dyn PhysicalExpr>,
67 timezone: String,
68 ) -> Self {
69 TimestampTruncExpr {
70 child,
71 format,
72 timezone,
73 }
74 }
75}
76
77impl Display for TimestampTruncExpr {
78 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
79 write!(
80 f,
81 "TimestampTrunc [child:{}, format:{}, timezone: {}]",
82 self.child, self.format, self.timezone
83 )
84 }
85}
86
87impl PhysicalExpr for TimestampTruncExpr {
88 fn as_any(&self) -> &dyn Any {
89 self
90 }
91
92 fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result<DataType> {
93 match self.child.data_type(input_schema)? {
94 DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary(
95 key_type,
96 Box::new(DataType::Timestamp(Microsecond, None)),
97 )),
98 _ => Ok(DataType::Timestamp(Microsecond, None)),
99 }
100 }
101
102 fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
103 Ok(true)
104 }
105
106 fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
107 let timestamp = self.child.evaluate(batch)?;
108 let format = self.format.evaluate(batch)?;
109 let tz = self.timezone.clone();
110 match (timestamp, format) {
111 (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => {
112 let ts = array_with_timezone(
113 ts,
114 tz.clone(),
115 Some(&DataType::Timestamp(Microsecond, Some(tz.into()))),
116 )?;
117 let result = timestamp_trunc_dyn(&ts, format)?;
118 Ok(ColumnarValue::Array(result))
119 }
120 (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => {
121 let ts = array_with_timezone(
122 ts,
123 tz.clone(),
124 Some(&DataType::Timestamp(Microsecond, Some(tz.into()))),
125 )?;
126 let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?;
127 Ok(ColumnarValue::Array(result))
128 }
129 _ => Err(DataFusionError::Execution(
130 "Invalid input to function TimestampTrunc. \
131 Expected (PrimitiveArray<TimestampMicrosecondType>, Scalar, String) or \
132 (PrimitiveArray<TimestampMicrosecondType>, StringArray, String)"
133 .to_string(),
134 )),
135 }
136 }
137
138 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
139 vec![&self.child]
140 }
141
142 fn with_new_children(
143 self: Arc<Self>,
144 children: Vec<Arc<dyn PhysicalExpr>>,
145 ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
146 Ok(Arc::new(TimestampTruncExpr::new(
147 Arc::clone(&children[0]),
148 Arc::clone(&self.format),
149 self.timezone.clone(),
150 )))
151 }
152}