datafusion_spark/function/datetime/
trunc.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit};
22use datafusion_common::types::{NativeType, logical_date, logical_string};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
25use datafusion_expr::expr::ScalarFunction;
26use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
27use datafusion_expr::{
28 Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs,
29 ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
30};
31
32#[derive(Debug, PartialEq, Eq, Hash)]
36pub struct SparkTrunc {
37 signature: Signature,
38}
39
40impl Default for SparkTrunc {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SparkTrunc {
47 pub fn new() -> Self {
48 Self {
49 signature: Signature::coercible(
50 vec![
51 Coercion::new_implicit(
52 TypeSignatureClass::Native(logical_date()),
53 vec![TypeSignatureClass::Native(logical_string())],
54 NativeType::Date,
55 ),
56 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
57 ],
58 Volatility::Immutable,
59 ),
60 }
61 }
62}
63
64impl ScalarUDFImpl for SparkTrunc {
65 fn as_any(&self) -> &dyn Any {
66 self
67 }
68
69 fn name(&self) -> &str {
70 "trunc"
71 }
72
73 fn signature(&self) -> &Signature {
74 &self.signature
75 }
76
77 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78 internal_err!("return_field_from_args should be used instead")
79 }
80
81 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
82 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
83
84 Ok(Arc::new(Field::new(
85 self.name(),
86 args.arg_fields[0].data_type().clone(),
87 nullable,
88 )))
89 }
90
91 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92 internal_err!("spark trunc should have been simplified to standard date_trunc")
93 }
94
95 fn simplify(
96 &self,
97 args: Vec<Expr>,
98 info: &SimplifyContext,
99 ) -> Result<ExprSimplifyResult> {
100 let [dt_expr, fmt_expr] = take_function_args(self.name(), args)?;
101
102 let fmt = match fmt_expr.as_literal() {
103 Some(ScalarValue::Utf8(Some(v)))
104 | Some(ScalarValue::Utf8View(Some(v)))
105 | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(),
106 _ => {
107 return plan_err!(
108 "Second argument of `TRUNC` must be non-null scalar Utf8"
109 );
110 }
111 };
112
113 let fmt = match fmt.as_str() {
115 "yy" | "yyyy" => "year",
116 "mm" | "mon" => "month",
117 "year" | "month" | "day" | "week" | "quarter" => fmt.as_str(),
118 _ => {
119 return plan_err!(
120 "The format argument of `TRUNC` must be one of: year, yy, yyyy, month, mm, mon, day, week, quarter."
121 );
122 }
123 };
124 let return_type = dt_expr.get_type(info.schema())?;
125
126 let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None);
127
128 Ok(ExprSimplifyResult::Simplified(
130 Expr::ScalarFunction(ScalarFunction::new_udf(
131 datafusion_functions::datetime::date_trunc(),
132 vec![
133 fmt_expr,
134 dt_expr.cast_to(
135 &DataType::Timestamp(TimeUnit::Nanosecond, None),
136 info.schema(),
137 )?,
138 ],
139 ))
140 .cast_to(&return_type, info.schema())?,
141 ))
142 }
143}