datafusion_comet_spark_expr/
comet_scalar_funcs.rs1use crate::hash_funcs::*;
19use crate::{
20 spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
21 spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
22 spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
23 SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc,
24};
25use arrow::datatypes::DataType;
26use datafusion::common::{DataFusionError, Result as DataFusionResult};
27use datafusion::execution::FunctionRegistry;
28use datafusion::logical_expr::{
29 ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature,
30 Volatility,
31};
32use datafusion::physical_plan::ColumnarValue;
33use std::any::Any;
34use std::fmt::Debug;
35use std::sync::Arc;
36
37macro_rules! make_comet_scalar_udf {
38 ($name:expr, $func:ident, $data_type:ident) => {{
39 let scalar_func = CometScalarFunction::new(
40 $name.to_string(),
41 Signature::variadic_any(Volatility::Immutable),
42 $data_type.clone(),
43 Arc::new(move |args| $func(args, &$data_type)),
44 );
45 Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
46 }};
47 ($name:expr, $func:expr, without $data_type:ident) => {{
48 let scalar_func = CometScalarFunction::new(
49 $name.to_string(),
50 Signature::variadic_any(Volatility::Immutable),
51 $data_type,
52 $func,
53 );
54 Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
55 }};
56}
57
58pub fn create_comet_physical_fun(
60 fun_name: &str,
61 data_type: DataType,
62 registry: &dyn FunctionRegistry,
63) -> Result<Arc<ScalarUDF>, DataFusionError> {
64 match fun_name {
65 "ceil" => {
66 make_comet_scalar_udf!("ceil", spark_ceil, data_type)
67 }
68 "floor" => {
69 make_comet_scalar_udf!("floor", spark_floor, data_type)
70 }
71 "read_side_padding" => {
72 let func = Arc::new(spark_read_side_padding);
73 make_comet_scalar_udf!("read_side_padding", func, without data_type)
74 }
75 "rpad" => {
76 let func = Arc::new(spark_rpad);
77 make_comet_scalar_udf!("rpad", func, without data_type)
78 }
79 "round" => {
80 make_comet_scalar_udf!("round", spark_round, data_type)
81 }
82 "unscaled_value" => {
83 let func = Arc::new(spark_unscaled_value);
84 make_comet_scalar_udf!("unscaled_value", func, without data_type)
85 }
86 "make_decimal" => {
87 make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
88 }
89 "hex" => {
90 let func = Arc::new(spark_hex);
91 make_comet_scalar_udf!("hex", func, without data_type)
92 }
93 "unhex" => {
94 let func = Arc::new(spark_unhex);
95 make_comet_scalar_udf!("unhex", func, without data_type)
96 }
97 "decimal_div" => {
98 make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
99 }
100 "decimal_integral_div" => {
101 make_comet_scalar_udf!(
102 "decimal_integral_div",
103 spark_decimal_integral_div,
104 data_type
105 )
106 }
107 "murmur3_hash" => {
108 let func = Arc::new(spark_murmur3_hash);
109 make_comet_scalar_udf!("murmur3_hash", func, without data_type)
110 }
111 "xxhash64" => {
112 let func = Arc::new(spark_xxhash64);
113 make_comet_scalar_udf!("xxhash64", func, without data_type)
114 }
115 "isnan" => {
116 let func = Arc::new(spark_isnan);
117 make_comet_scalar_udf!("isnan", func, without data_type)
118 }
119 "sha224" => {
120 let func = Arc::new(spark_sha224);
121 make_comet_scalar_udf!("sha224", func, without data_type)
122 }
123 "sha256" => {
124 let func = Arc::new(spark_sha256);
125 make_comet_scalar_udf!("sha256", func, without data_type)
126 }
127 "sha384" => {
128 let func = Arc::new(spark_sha384);
129 make_comet_scalar_udf!("sha384", func, without data_type)
130 }
131 "sha512" => {
132 let func = Arc::new(spark_sha512);
133 make_comet_scalar_udf!("sha512", func, without data_type)
134 }
135 "date_add" => {
136 let func = Arc::new(spark_date_add);
137 make_comet_scalar_udf!("date_add", func, without data_type)
138 }
139 "date_sub" => {
140 let func = Arc::new(spark_date_sub);
141 make_comet_scalar_udf!("date_sub", func, without data_type)
142 }
143 "array_repeat" => {
144 let func = Arc::new(spark_array_repeat);
145 make_comet_scalar_udf!("array_repeat", func, without data_type)
146 }
147 _ => registry.udf(fun_name).map_err(|e| {
148 DataFusionError::Execution(format!(
149 "Function {fun_name} not found in the registry: {e}",
150 ))
151 }),
152 }
153}
154
155fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
156 vec![
157 Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default())),
158 Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())),
159 Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
160 Arc::new(ScalarUDF::new_from_impl(SparkBitwiseGet::default())),
161 Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
162 ]
163}
164
165pub fn register_all_comet_functions(registry: &mut dyn FunctionRegistry) -> DataFusionResult<()> {
167 all_scalar_functions()
169 .into_iter()
170 .try_for_each(|udf| registry.register_udf(udf).map(|_| ()))?;
171
172 Ok(())
173}
174
175struct CometScalarFunction {
176 name: String,
177 signature: Signature,
178 data_type: DataType,
179 func: ScalarFunctionImplementation,
180}
181
182impl Debug for CometScalarFunction {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 f.debug_struct("CometScalarFunction")
185 .field("name", &self.name)
186 .field("signature", &self.signature)
187 .field("data_type", &self.data_type)
188 .finish()
189 }
190}
191
192impl CometScalarFunction {
193 fn new(
194 name: String,
195 signature: Signature,
196 data_type: DataType,
197 func: ScalarFunctionImplementation,
198 ) -> Self {
199 Self {
200 name,
201 signature,
202 data_type,
203 func,
204 }
205 }
206}
207
208impl ScalarUDFImpl for CometScalarFunction {
209 fn as_any(&self) -> &dyn Any {
210 self
211 }
212
213 fn name(&self) -> &str {
214 self.name.as_str()
215 }
216
217 fn signature(&self) -> &Signature {
218 &self.signature
219 }
220
221 fn return_type(&self, _: &[DataType]) -> DataFusionResult<DataType> {
222 Ok(self.data_type.clone())
223 }
224
225 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
226 (self.func)(&args.args)
227 }
228}