datafusion_spark/function/datetime/
trunc.rs1use std::sync::Arc;
19
20use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit};
21use datafusion_common::types::{NativeType, logical_date, logical_string};
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
26use datafusion_expr::{
27 Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs,
28 ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
29};
30
31#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkTrunc {
36 signature: Signature,
37}
38
39impl Default for SparkTrunc {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkTrunc {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::coercible(
49 vec![
50 Coercion::new_implicit(
51 TypeSignatureClass::Native(logical_date()),
52 vec![TypeSignatureClass::Native(logical_string())],
53 NativeType::Date,
54 ),
55 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
56 ],
57 Volatility::Immutable,
58 ),
59 }
60 }
61}
62
63impl ScalarUDFImpl for SparkTrunc {
64 fn name(&self) -> &str {
65 "trunc"
66 }
67
68 fn signature(&self) -> &Signature {
69 &self.signature
70 }
71
72 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
73 internal_err!("return_field_from_args should be used instead")
74 }
75
76 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
77 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
78
79 Ok(Arc::new(Field::new(
80 self.name(),
81 args.arg_fields[0].data_type().clone(),
82 nullable,
83 )))
84 }
85
86 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
87 internal_err!("spark trunc should have been simplified to standard date_trunc")
88 }
89
90 fn simplify(
91 &self,
92 args: Vec<Expr>,
93 info: &SimplifyContext,
94 ) -> Result<ExprSimplifyResult> {
95 let [dt_expr, fmt_expr] = take_function_args(self.name(), args)?;
96
97 let fmt = match fmt_expr.as_literal() {
98 Some(ScalarValue::Utf8(Some(v)))
99 | Some(ScalarValue::Utf8View(Some(v)))
100 | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(),
101 _ => {
102 return plan_err!(
103 "Second argument of `TRUNC` must be non-null scalar Utf8"
104 );
105 }
106 };
107
108 let fmt = match fmt.as_str() {
110 "yy" | "yyyy" => "year",
111 "mm" | "mon" => "month",
112 "year" | "month" | "day" | "week" | "quarter" => fmt.as_str(),
113 _ => {
114 return plan_err!(
115 "The format argument of `TRUNC` must be one of: year, yy, yyyy, month, mm, mon, day, week, quarter."
116 );
117 }
118 };
119 let return_type = dt_expr.get_type(info.schema())?;
120
121 let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None);
122
123 Ok(ExprSimplifyResult::Simplified(
125 Expr::ScalarFunction(ScalarFunction::new_udf(
126 datafusion_functions::datetime::date_trunc(),
127 vec![
128 fmt_expr,
129 dt_expr.cast_to(
130 &DataType::Timestamp(TimeUnit::Nanosecond, None),
131 info.schema(),
132 )?,
133 ],
134 ))
135 .cast_to(&return_type, info.schema())?,
136 ))
137 }
138}