datafusion_spark/function/datetime/
date_add.rs1use std::sync::Arc;
19
20use arrow::array::ArrayRef;
21use arrow::compute;
22use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
23use datafusion_common::cast::{
24 as_date32_array, as_int8_array, as_int16_array, as_int32_array,
25};
26use datafusion_common::utils::take_function_args;
27use datafusion_common::{Result, internal_err};
28use datafusion_expr::{
29 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30 TypeSignature, Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkDateAdd {
36 signature: Signature,
37 aliases: Vec<String>,
38}
39
40impl Default for SparkDateAdd {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SparkDateAdd {
47 pub fn new() -> Self {
48 Self {
49 signature: Signature::one_of(
50 vec![
51 TypeSignature::Exact(vec![DataType::Date32, DataType::Int8]),
52 TypeSignature::Exact(vec![DataType::Date32, DataType::Int16]),
53 TypeSignature::Exact(vec![DataType::Date32, DataType::Int32]),
54 ],
55 Volatility::Immutable,
56 ),
57 aliases: vec!["dateadd".to_string()],
58 }
59 }
60}
61
62impl ScalarUDFImpl for SparkDateAdd {
63 fn name(&self) -> &str {
64 "date_add"
65 }
66
67 fn aliases(&self) -> &[String] {
68 &self.aliases
69 }
70
71 fn signature(&self) -> &Signature {
72 &self.signature
73 }
74
75 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
76 internal_err!("Use return_field_from_args in this case instead.")
77 }
78
79 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
80 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
81 Ok(Arc::new(Field::new(
82 self.name(),
83 DataType::Date32,
84 nullable,
85 )))
86 }
87
88 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
89 make_scalar_function(spark_date_add, vec![])(&args.args)
90 }
91}
92
93fn spark_date_add(args: &[ArrayRef]) -> Result<ArrayRef> {
94 let [date_arg, days_arg] = take_function_args("date_add", args)?;
95 let date_array = as_date32_array(date_arg)?;
96 let result = match days_arg.data_type() {
97 DataType::Int8 => {
98 let days_array = as_int8_array(days_arg)?;
99 compute::binary::<_, _, _, Date32Type>(
100 date_array,
101 days_array,
102 |date, days| date.wrapping_add(days as i32),
103 )?
104 }
105 DataType::Int16 => {
106 let days_array = as_int16_array(days_arg)?;
107 compute::binary::<_, _, _, Date32Type>(
108 date_array,
109 days_array,
110 |date, days| date.wrapping_add(days as i32),
111 )?
112 }
113 DataType::Int32 => {
114 let days_array = as_int32_array(days_arg)?;
115 compute::binary::<_, _, _, Date32Type>(
116 date_array,
117 days_array,
118 |date, days| date.wrapping_add(days),
119 )?
120 }
121 _ => {
122 return internal_err!(
123 "Spark `date_add` function: argument must be int8, int16, int32, got {:?}",
124 days_arg.data_type()
125 );
126 }
127 };
128 Ok(Arc::new(result))
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_date_add_non_nullable_inputs() {
137 let func = SparkDateAdd::new();
138 let args = &[
139 Arc::new(Field::new("date", DataType::Date32, false)),
140 Arc::new(Field::new("num", DataType::Int8, false)),
141 ];
142
143 let ret_field = func
144 .return_field_from_args(ReturnFieldArgs {
145 arg_fields: args,
146 scalar_arguments: &[None, None],
147 })
148 .unwrap();
149
150 assert_eq!(ret_field.data_type(), &DataType::Date32);
151 assert!(!ret_field.is_nullable());
152 }
153
154 #[test]
155 fn test_date_add_nullable_inputs() {
156 let func = SparkDateAdd::new();
157 let args = &[
158 Arc::new(Field::new("date", DataType::Date32, false)),
159 Arc::new(Field::new("num", DataType::Int16, true)),
160 ];
161
162 let ret_field = func
163 .return_field_from_args(ReturnFieldArgs {
164 arg_fields: args,
165 scalar_arguments: &[None, None],
166 })
167 .unwrap();
168
169 assert_eq!(ret_field.data_type(), &DataType::Date32);
170 assert!(ret_field.is_nullable());
171 }
172}