datafusion_comet_spark_expr/datetime_funcs/
date_trunc.rs1use arrow::datatypes::DataType;
19use datafusion::common::{utils::take_function_args, DataFusionError, Result, ScalarValue::Utf8};
20use datafusion::logical_expr::{
21    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
22};
23use std::any::Any;
24
25use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn};
26
27#[derive(Debug)]
28pub struct SparkDateTrunc {
29    signature: Signature,
30    aliases: Vec<String>,
31}
32
33impl SparkDateTrunc {
34    pub fn new() -> Self {
35        Self {
36            signature: Signature::exact(
37                vec![DataType::Date32, DataType::Utf8],
38                Volatility::Immutable,
39            ),
40            aliases: vec![],
41        }
42    }
43}
44
45impl Default for SparkDateTrunc {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl ScalarUDFImpl for SparkDateTrunc {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn name(&self) -> &str {
57        "date_trunc"
58    }
59
60    fn signature(&self) -> &Signature {
61        &self.signature
62    }
63
64    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
65        Ok(DataType::Date32)
66    }
67
68    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69        let [date, format] = take_function_args(self.name(), args.args)?;
70        match (date, format) {
71            (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => {
72                let result = date_trunc_dyn(&date, format)?;
73                Ok(ColumnarValue::Array(result))
74            }
75            (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => {
76                let result = date_trunc_array_fmt_dyn(&date, &formats)?;
77                Ok(ColumnarValue::Array(result))
78            }
79            _ => Err(DataFusionError::Execution(
80                "Invalid input to function DateTrunc. Expected (PrimitiveArray<Date32>, Scalar) or \
81                    (PrimitiveArray<Date32>, StringArray)".to_string(),
82            )),
83        }
84    }
85
86    fn aliases(&self) -> &[String] {
87        &self.aliases
88    }
89}