datafusion_comet_spark_expr/
comet_scalar_funcs.rs1use crate::hash_funcs::*;
19use crate::map_funcs::spark_map_sort;
20use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
21use crate::math_funcs::modulo_expr::spark_modulo;
22use crate::{
23 spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
24 spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
25 spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
26 SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace,
27};
28use arrow::datatypes::DataType;
29use datafusion::common::{DataFusionError, Result as DataFusionResult};
30use datafusion::execution::FunctionRegistry;
31use datafusion::logical_expr::{
32 ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature,
33 Volatility,
34};
35use datafusion::physical_plan::ColumnarValue;
36use std::any::Any;
37use std::fmt::Debug;
38use std::sync::Arc;
39
40macro_rules! make_comet_scalar_udf {
41 ($name:expr, $func:ident, $data_type:ident) => {{
42 let scalar_func = CometScalarFunction::new(
43 $name.to_string(),
44 Signature::variadic_any(Volatility::Immutable),
45 $data_type.clone(),
46 Arc::new(move |args| $func(args, &$data_type)),
47 );
48 Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
49 }};
50 ($name:expr, $func:expr, without $data_type:ident) => {{
51 let scalar_func = CometScalarFunction::new(
52 $name.to_string(),
53 Signature::variadic_any(Volatility::Immutable),
54 $data_type,
55 $func,
56 );
57 Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
58 }};
59 ($name:expr, $func:ident, without $data_type:ident, $fail_on_error:ident) => {{
60 let scalar_func = CometScalarFunction::new(
61 $name.to_string(),
62 Signature::variadic_any(Volatility::Immutable),
63 $data_type,
64 Arc::new(move |args| $func(args, $fail_on_error)),
65 );
66 Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
67 }};
68}
69
70pub fn create_comet_physical_fun(
72 fun_name: &str,
73 data_type: DataType,
74 registry: &dyn FunctionRegistry,
75 fail_on_error: Option<bool>,
76) -> Result<Arc<ScalarUDF>, DataFusionError> {
77 match fun_name {
78 "ceil" => {
79 make_comet_scalar_udf!("ceil", spark_ceil, data_type)
80 }
81 "floor" => {
82 make_comet_scalar_udf!("floor", spark_floor, data_type)
83 }
84 "read_side_padding" => {
85 let func = Arc::new(spark_read_side_padding);
86 make_comet_scalar_udf!("read_side_padding", func, without data_type)
87 }
88 "rpad" => {
89 let func = Arc::new(spark_rpad);
90 make_comet_scalar_udf!("rpad", func, without data_type)
91 }
92 "round" => {
93 make_comet_scalar_udf!("round", spark_round, data_type)
94 }
95 "unscaled_value" => {
96 let func = Arc::new(spark_unscaled_value);
97 make_comet_scalar_udf!("unscaled_value", func, without data_type)
98 }
99 "make_decimal" => {
100 make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
101 }
102 "hex" => {
103 let func = Arc::new(spark_hex);
104 make_comet_scalar_udf!("hex", func, without data_type)
105 }
106 "unhex" => {
107 let func = Arc::new(spark_unhex);
108 make_comet_scalar_udf!("unhex", func, without data_type)
109 }
110 "decimal_div" => {
111 make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
112 }
113 "decimal_integral_div" => {
114 make_comet_scalar_udf!(
115 "decimal_integral_div",
116 spark_decimal_integral_div,
117 data_type
118 )
119 }
120 "checked_add" => {
121 make_comet_scalar_udf!("checked_add", checked_add, data_type)
122 }
123 "checked_sub" => {
124 make_comet_scalar_udf!("checked_sub", checked_sub, data_type)
125 }
126 "checked_mul" => {
127 make_comet_scalar_udf!("checked_mul", checked_mul, data_type)
128 }
129 "checked_div" => {
130 make_comet_scalar_udf!("checked_div", checked_div, data_type)
131 }
132 "murmur3_hash" => {
133 let func = Arc::new(spark_murmur3_hash);
134 make_comet_scalar_udf!("murmur3_hash", func, without data_type)
135 }
136 "xxhash64" => {
137 let func = Arc::new(spark_xxhash64);
138 make_comet_scalar_udf!("xxhash64", func, without data_type)
139 }
140 "isnan" => {
141 let func = Arc::new(spark_isnan);
142 make_comet_scalar_udf!("isnan", func, without data_type)
143 }
144 "date_add" => {
145 let func = Arc::new(spark_date_add);
146 make_comet_scalar_udf!("date_add", func, without data_type)
147 }
148 "date_sub" => {
149 let func = Arc::new(spark_date_sub);
150 make_comet_scalar_udf!("date_sub", func, without data_type)
151 }
152 "array_repeat" => {
153 let func = Arc::new(spark_array_repeat);
154 make_comet_scalar_udf!("array_repeat", func, without data_type)
155 }
156 "spark_modulo" => {
157 let func = Arc::new(spark_modulo);
158 let fail_on_error = fail_on_error.unwrap_or(false);
159 make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error)
160 }
161 "map_sort" => {
162 let func = Arc::new(spark_map_sort);
163 make_comet_scalar_udf!("spark_map_sort", func, without data_type)
164 }
165 _ => registry.udf(fun_name).map_err(|e| {
166 DataFusionError::Execution(format!(
167 "Function {fun_name} not found in the registry: {e}",
168 ))
169 }),
170 }
171}
172
173fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
174 vec![
175 Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())),
176 Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
177 Arc::new(ScalarUDF::new_from_impl(SparkBitwiseGet::default())),
178 Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
179 Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
180 ]
181}
182
183pub fn register_all_comet_functions(registry: &mut dyn FunctionRegistry) -> DataFusionResult<()> {
185 all_scalar_functions()
187 .into_iter()
188 .try_for_each(|udf| registry.register_udf(udf).map(|_| ()))?;
189
190 Ok(())
191}
192
193struct CometScalarFunction {
194 name: String,
195 signature: Signature,
196 data_type: DataType,
197 func: ScalarFunctionImplementation,
198}
199
200impl Debug for CometScalarFunction {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("CometScalarFunction")
203 .field("name", &self.name)
204 .field("signature", &self.signature)
205 .field("data_type", &self.data_type)
206 .finish()
207 }
208}
209
210impl CometScalarFunction {
211 fn new(
212 name: String,
213 signature: Signature,
214 data_type: DataType,
215 func: ScalarFunctionImplementation,
216 ) -> Self {
217 Self {
218 name,
219 signature,
220 data_type,
221 func,
222 }
223 }
224}
225
226impl ScalarUDFImpl for CometScalarFunction {
227 fn as_any(&self) -> &dyn Any {
228 self
229 }
230
231 fn name(&self) -> &str {
232 self.name.as_str()
233 }
234
235 fn signature(&self) -> &Signature {
236 &self.signature
237 }
238
239 fn return_type(&self, _: &[DataType]) -> DataFusionResult<DataType> {
240 Ok(self.data_type.clone())
241 }
242
243 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
244 (self.func)(&args.args)
245 }
246}