datafusion_spark/function/datetime/
date_trunc.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit};
22use datafusion_common::types::{NativeType, logical_string};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
25use datafusion_expr::expr::ScalarFunction;
26use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
27use datafusion_expr::{
28 Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs,
29 ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
30};
31
32#[derive(Debug, PartialEq, Eq, Hash)]
36pub struct SparkDateTrunc {
37 signature: Signature,
38}
39
40impl Default for SparkDateTrunc {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SparkDateTrunc {
47 pub fn new() -> Self {
48 Self {
49 signature: Signature::coercible(
50 vec![
51 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
52 Coercion::new_implicit(
53 TypeSignatureClass::Timestamp,
54 vec![TypeSignatureClass::Native(logical_string())],
55 NativeType::Timestamp(TimeUnit::Microsecond, None),
56 ),
57 ],
58 Volatility::Immutable,
59 ),
60 }
61 }
62}
63
64impl ScalarUDFImpl for SparkDateTrunc {
65 fn as_any(&self) -> &dyn Any {
66 self
67 }
68
69 fn name(&self) -> &str {
70 "date_trunc"
71 }
72
73 fn signature(&self) -> &Signature {
74 &self.signature
75 }
76
77 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78 internal_err!("return_field_from_args should be used instead")
79 }
80
81 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
82 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
83
84 Ok(Arc::new(Field::new(
85 self.name(),
86 args.arg_fields[1].data_type().clone(),
87 nullable,
88 )))
89 }
90
91 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92 internal_err!(
93 "spark date_trunc should have been simplified to standard date_trunc"
94 )
95 }
96
97 fn simplify(
98 &self,
99 args: Vec<Expr>,
100 info: &SimplifyContext,
101 ) -> Result<ExprSimplifyResult> {
102 let [fmt_expr, ts_expr] = take_function_args(self.name(), args)?;
103
104 let fmt = match fmt_expr.as_literal() {
105 Some(ScalarValue::Utf8(Some(v)))
106 | Some(ScalarValue::Utf8View(Some(v)))
107 | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(),
108 _ => {
109 return plan_err!(
110 "First argument of `DATE_TRUNC` must be non-null scalar Utf8"
111 );
112 }
113 };
114
115 let fmt = match fmt.as_str() {
117 "yy" | "yyyy" => "year",
118 "mm" | "mon" => "month",
119 "dd" => "day",
120 other => other,
121 };
122
123 let session_tz = info.config_options().execution.time_zone.clone();
124 let ts_type = ts_expr.get_type(info.schema())?;
125
126 let ts_expr = match (&ts_type, fmt) {
133 (_, "second" | "millisecond" | "microsecond") => ts_expr,
135
136 (DataType::Timestamp(unit, tz), _) => {
138 let ts_expr = match &session_tz {
139 Some(session_tz) => ts_expr.cast_to(
140 &DataType::Timestamp(
141 TimeUnit::Microsecond,
142 Some(Arc::from(session_tz.as_str())),
143 ),
144 info.schema(),
145 )?,
146 None => ts_expr,
147 };
148 Expr::ScalarFunction(ScalarFunction::new_udf(
149 datafusion_functions::datetime::to_local_time(),
150 vec![ts_expr],
151 ))
152 .cast_to(&DataType::Timestamp(*unit, tz.clone()), info.schema())?
153 }
154
155 _ => {
156 return plan_err!(
157 "Second argument of `DATE_TRUNC` must be Timestamp, got {}",
158 ts_type
159 );
160 }
161 };
162
163 let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None);
164
165 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
166 ScalarFunction::new_udf(
167 datafusion_functions::datetime::date_trunc(),
168 vec![fmt_expr, ts_expr],
169 ),
170 )))
171 }
172}