datafusion_spark/function/datetime/
time_trunc.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::datatypes::{DataType, Field, FieldRef};
22use datafusion_common::types::logical_string;
23use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
26use datafusion_expr::{
27 Coercion, ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
28 Signature, TypeSignatureClass, Volatility,
29};
30
31#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkTimeTrunc {
35 signature: Signature,
36}
37
38impl Default for SparkTimeTrunc {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl SparkTimeTrunc {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::coercible(
48 vec![
49 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
50 Coercion::new_exact(TypeSignatureClass::Time),
51 ],
52 Volatility::Immutable,
53 ),
54 }
55 }
56}
57
58impl ScalarUDFImpl for SparkTimeTrunc {
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62
63 fn name(&self) -> &str {
64 "time_trunc"
65 }
66
67 fn signature(&self) -> &Signature {
68 &self.signature
69 }
70
71 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
72 internal_err!("return_field_from_args should be used instead")
73 }
74
75 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
76 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
77
78 Ok(Arc::new(Field::new(
79 self.name(),
80 args.arg_fields[1].data_type().clone(),
81 nullable,
82 )))
83 }
84
85 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86 internal_err!(
87 "spark time_trunc should have been simplified to standard date_trunc"
88 )
89 }
90
91 fn simplify(
92 &self,
93 args: Vec<Expr>,
94 _info: &SimplifyContext,
95 ) -> Result<ExprSimplifyResult> {
96 let fmt_expr = &args[0];
97
98 let fmt = match fmt_expr.as_literal() {
99 Some(ScalarValue::Utf8(Some(v)))
100 | Some(ScalarValue::Utf8View(Some(v)))
101 | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(),
102 _ => {
103 return plan_err!(
104 "First argument of `TIME_TRUNC` must be non-null scalar Utf8"
105 );
106 }
107 };
108
109 if !matches!(
110 fmt.as_str(),
111 "hour" | "minute" | "second" | "millisecond" | "microsecond"
112 ) {
113 return plan_err!(
114 "The format argument of `TIME_TRUNC` must be one of: hour, minute, second, millisecond, microsecond"
115 );
116 }
117
118 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
119 ScalarFunction::new_udf(datafusion_functions::datetime::date_trunc(), args),
120 )))
121 }
122}