datafusion_comet_spark_expr/
comet_scalar_funcs.rs1use crate::hash_funcs::*;
19use crate::{
20 spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div,
21 spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round,
22 spark_rpad, spark_unhex, spark_unscaled_value, SparkChrFunc,
23};
24use arrow_schema::DataType;
25use datafusion_common::{DataFusionError, Result as DataFusionResult};
26use datafusion_expr::registry::FunctionRegistry;
27use datafusion_expr::{
28 ColumnarValue, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
29};
30use std::any::Any;
31use std::fmt::Debug;
32use std::sync::Arc;
33
34macro_rules! make_comet_scalar_udf {
35 ($name:expr, $func:ident, $data_type:ident) => {{
36 let scalar_func = CometScalarFunction::new(
37 $name.to_string(),
38 Signature::variadic_any(Volatility::Immutable),
39 $data_type.clone(),
40 Arc::new(move |args| $func(args, &$data_type)),
41 );
42 Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
43 }};
44 ($name:expr, $func:expr, without $data_type:ident) => {{
45 let scalar_func = CometScalarFunction::new(
46 $name.to_string(),
47 Signature::variadic_any(Volatility::Immutable),
48 $data_type,
49 $func,
50 );
51 Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
52 }};
53}
54
55pub fn create_comet_physical_fun(
57 fun_name: &str,
58 data_type: DataType,
59 registry: &dyn FunctionRegistry,
60) -> Result<Arc<ScalarUDF>, DataFusionError> {
61 match fun_name {
62 "ceil" => {
63 make_comet_scalar_udf!("ceil", spark_ceil, data_type)
64 }
65 "floor" => {
66 make_comet_scalar_udf!("floor", spark_floor, data_type)
67 }
68 "read_side_padding" => {
69 let func = Arc::new(spark_read_side_padding);
70 make_comet_scalar_udf!("read_side_padding", func, without data_type)
71 }
72 "rpad" => {
73 let func = Arc::new(spark_rpad);
74 make_comet_scalar_udf!("rpad", func, without data_type)
75 }
76 "round" => {
77 make_comet_scalar_udf!("round", spark_round, data_type)
78 }
79 "unscaled_value" => {
80 let func = Arc::new(spark_unscaled_value);
81 make_comet_scalar_udf!("unscaled_value", func, without data_type)
82 }
83 "make_decimal" => {
84 make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
85 }
86 "hex" => {
87 let func = Arc::new(spark_hex);
88 make_comet_scalar_udf!("hex", func, without data_type)
89 }
90 "unhex" => {
91 let func = Arc::new(spark_unhex);
92 make_comet_scalar_udf!("unhex", func, without data_type)
93 }
94 "decimal_div" => {
95 make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
96 }
97 "decimal_integral_div" => {
98 make_comet_scalar_udf!(
99 "decimal_integral_div",
100 spark_decimal_integral_div,
101 data_type
102 )
103 }
104 "murmur3_hash" => {
105 let func = Arc::new(spark_murmur3_hash);
106 make_comet_scalar_udf!("murmur3_hash", func, without data_type)
107 }
108 "xxhash64" => {
109 let func = Arc::new(spark_xxhash64);
110 make_comet_scalar_udf!("xxhash64", func, without data_type)
111 }
112 "chr" => Ok(Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default()))),
113 "isnan" => {
114 let func = Arc::new(spark_isnan);
115 make_comet_scalar_udf!("isnan", func, without data_type)
116 }
117 "sha224" => {
118 let func = Arc::new(spark_sha224);
119 make_comet_scalar_udf!("sha224", func, without data_type)
120 }
121 "sha256" => {
122 let func = Arc::new(spark_sha256);
123 make_comet_scalar_udf!("sha256", func, without data_type)
124 }
125 "sha384" => {
126 let func = Arc::new(spark_sha384);
127 make_comet_scalar_udf!("sha384", func, without data_type)
128 }
129 "sha512" => {
130 let func = Arc::new(spark_sha512);
131 make_comet_scalar_udf!("sha512", func, without data_type)
132 }
133 "date_add" => {
134 let func = Arc::new(spark_date_add);
135 make_comet_scalar_udf!("date_add", func, without data_type)
136 }
137 "date_sub" => {
138 let func = Arc::new(spark_date_sub);
139 make_comet_scalar_udf!("date_sub", func, without data_type)
140 }
141 _ => registry.udf(fun_name).map_err(|e| {
142 DataFusionError::Execution(format!(
143 "Function {fun_name} not found in the registry: {e}",
144 ))
145 }),
146 }
147}
148
149struct CometScalarFunction {
150 name: String,
151 signature: Signature,
152 data_type: DataType,
153 func: ScalarFunctionImplementation,
154}
155
156impl Debug for CometScalarFunction {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("CometScalarFunction")
159 .field("name", &self.name)
160 .field("signature", &self.signature)
161 .field("data_type", &self.data_type)
162 .finish()
163 }
164}
165
166impl CometScalarFunction {
167 fn new(
168 name: String,
169 signature: Signature,
170 data_type: DataType,
171 func: ScalarFunctionImplementation,
172 ) -> Self {
173 Self {
174 name,
175 signature,
176 data_type,
177 func,
178 }
179 }
180}
181
182impl ScalarUDFImpl for CometScalarFunction {
183 fn as_any(&self) -> &dyn Any {
184 self
185 }
186
187 fn name(&self) -> &str {
188 self.name.as_str()
189 }
190
191 fn signature(&self) -> &Signature {
192 &self.signature
193 }
194
195 fn return_type(&self, _: &[DataType]) -> DataFusionResult<DataType> {
196 Ok(self.data_type.clone())
197 }
198
199 fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
200 (self.func)(args)
201 }
202}