datafusion_spark/function/datetime/
date_trunc.rs1use std::sync::Arc;
19
20use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit};
21use datafusion_common::types::{NativeType, logical_string};
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
26use datafusion_expr::{
27 Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs,
28 ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
29};
30
31#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkDateTrunc {
36 signature: Signature,
37}
38
39impl Default for SparkDateTrunc {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkDateTrunc {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::coercible(
49 vec![
50 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
51 Coercion::new_implicit(
52 TypeSignatureClass::Timestamp,
53 vec![TypeSignatureClass::Native(logical_string())],
54 NativeType::Timestamp(TimeUnit::Microsecond, None),
55 ),
56 ],
57 Volatility::Immutable,
58 ),
59 }
60 }
61}
62
63impl ScalarUDFImpl for SparkDateTrunc {
64 fn name(&self) -> &str {
65 "date_trunc"
66 }
67
68 fn signature(&self) -> &Signature {
69 &self.signature
70 }
71
72 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
73 internal_err!("return_field_from_args should be used instead")
74 }
75
76 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
77 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
78
79 Ok(Arc::new(Field::new(
80 self.name(),
81 args.arg_fields[1].data_type().clone(),
82 nullable,
83 )))
84 }
85
86 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
87 internal_err!(
88 "spark date_trunc should have been simplified to standard date_trunc"
89 )
90 }
91
92 fn simplify(
93 &self,
94 args: Vec<Expr>,
95 info: &SimplifyContext,
96 ) -> Result<ExprSimplifyResult> {
97 let [fmt_expr, ts_expr] = take_function_args(self.name(), args)?;
98
99 let fmt = match fmt_expr.as_literal() {
100 Some(ScalarValue::Utf8(Some(v)))
101 | Some(ScalarValue::Utf8View(Some(v)))
102 | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(),
103 _ => {
104 return plan_err!(
105 "First argument of `DATE_TRUNC` must be non-null scalar Utf8"
106 );
107 }
108 };
109
110 let fmt = match fmt.as_str() {
112 "yy" | "yyyy" => "year",
113 "mm" | "mon" => "month",
114 "dd" => "day",
115 other => other,
116 };
117
118 let session_tz = info.config_options().execution.time_zone.clone();
119 let ts_type = ts_expr.get_type(info.schema())?;
120
121 let ts_expr = match (&ts_type, fmt) {
128 (_, "second" | "millisecond" | "microsecond") => ts_expr,
130
131 (DataType::Timestamp(unit, tz), _) => {
133 let ts_expr = match &session_tz {
134 Some(session_tz) => ts_expr.cast_to(
135 &DataType::Timestamp(
136 TimeUnit::Microsecond,
137 Some(Arc::from(session_tz.as_str())),
138 ),
139 info.schema(),
140 )?,
141 None => ts_expr,
142 };
143 Expr::ScalarFunction(ScalarFunction::new_udf(
144 datafusion_functions::datetime::to_local_time(),
145 vec![ts_expr],
146 ))
147 .cast_to(&DataType::Timestamp(*unit, tz.clone()), info.schema())?
148 }
149
150 _ => {
151 return plan_err!(
152 "Second argument of `DATE_TRUNC` must be Timestamp, got {}",
153 ts_type
154 );
155 }
156 };
157
158 let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None);
159
160 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
161 ScalarFunction::new_udf(
162 datafusion_functions::datetime::date_trunc(),
163 vec![fmt_expr, ts_expr],
164 ),
165 )))
166 }
167}