datafusion_comet_spark_expr/datetime_funcs/
date_trunc.rs1use arrow::datatypes::DataType;
19use datafusion::common::{utils::take_function_args, DataFusionError, Result, ScalarValue::Utf8};
20use datafusion::logical_expr::{
21 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
22};
23use std::any::Any;
24
25use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn};
26
27#[derive(Debug)]
28pub struct SparkDateTrunc {
29 signature: Signature,
30 aliases: Vec<String>,
31}
32
33impl SparkDateTrunc {
34 pub fn new() -> Self {
35 Self {
36 signature: Signature::exact(
37 vec![DataType::Date32, DataType::Utf8],
38 Volatility::Immutable,
39 ),
40 aliases: vec![],
41 }
42 }
43}
44
45impl Default for SparkDateTrunc {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl ScalarUDFImpl for SparkDateTrunc {
52 fn as_any(&self) -> &dyn Any {
53 self
54 }
55
56 fn name(&self) -> &str {
57 "date_trunc"
58 }
59
60 fn signature(&self) -> &Signature {
61 &self.signature
62 }
63
64 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
65 Ok(DataType::Date32)
66 }
67
68 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69 let [date, format] = take_function_args(self.name(), args.args)?;
70 match (date, format) {
71 (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => {
72 let result = date_trunc_dyn(&date, format)?;
73 Ok(ColumnarValue::Array(result))
74 }
75 (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => {
76 let result = date_trunc_array_fmt_dyn(&date, &formats)?;
77 Ok(ColumnarValue::Array(result))
78 }
79 _ => Err(DataFusionError::Execution(
80 "Invalid input to function DateTrunc. Expected (PrimitiveArray<Date32>, Scalar) or \
81 (PrimitiveArray<Date32>, StringArray)".to_string(),
82 )),
83 }
84 }
85
86 fn aliases(&self) -> &[String] {
87 &self.aliases
88 }
89}