datafusion_comet_spark_expr/
comet_scalar_funcs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::hash_funcs::*;
19use crate::{
20    spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
21    spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
22    spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
23    SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc,
24};
25use arrow::datatypes::DataType;
26use datafusion::common::{DataFusionError, Result as DataFusionResult};
27use datafusion::execution::FunctionRegistry;
28use datafusion::logical_expr::{
29    ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature,
30    Volatility,
31};
32use datafusion::physical_plan::ColumnarValue;
33use std::any::Any;
34use std::fmt::Debug;
35use std::sync::Arc;
36
37macro_rules! make_comet_scalar_udf {
38    ($name:expr, $func:ident, $data_type:ident) => {{
39        let scalar_func = CometScalarFunction::new(
40            $name.to_string(),
41            Signature::variadic_any(Volatility::Immutable),
42            $data_type.clone(),
43            Arc::new(move |args| $func(args, &$data_type)),
44        );
45        Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
46    }};
47    ($name:expr, $func:expr, without $data_type:ident) => {{
48        let scalar_func = CometScalarFunction::new(
49            $name.to_string(),
50            Signature::variadic_any(Volatility::Immutable),
51            $data_type,
52            $func,
53        );
54        Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
55    }};
56}
57
58/// Create a physical scalar function.
59pub fn create_comet_physical_fun(
60    fun_name: &str,
61    data_type: DataType,
62    registry: &dyn FunctionRegistry,
63) -> Result<Arc<ScalarUDF>, DataFusionError> {
64    match fun_name {
65        "ceil" => {
66            make_comet_scalar_udf!("ceil", spark_ceil, data_type)
67        }
68        "floor" => {
69            make_comet_scalar_udf!("floor", spark_floor, data_type)
70        }
71        "read_side_padding" => {
72            let func = Arc::new(spark_read_side_padding);
73            make_comet_scalar_udf!("read_side_padding", func, without data_type)
74        }
75        "rpad" => {
76            let func = Arc::new(spark_rpad);
77            make_comet_scalar_udf!("rpad", func, without data_type)
78        }
79        "round" => {
80            make_comet_scalar_udf!("round", spark_round, data_type)
81        }
82        "unscaled_value" => {
83            let func = Arc::new(spark_unscaled_value);
84            make_comet_scalar_udf!("unscaled_value", func, without data_type)
85        }
86        "make_decimal" => {
87            make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
88        }
89        "hex" => {
90            let func = Arc::new(spark_hex);
91            make_comet_scalar_udf!("hex", func, without data_type)
92        }
93        "unhex" => {
94            let func = Arc::new(spark_unhex);
95            make_comet_scalar_udf!("unhex", func, without data_type)
96        }
97        "decimal_div" => {
98            make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
99        }
100        "decimal_integral_div" => {
101            make_comet_scalar_udf!(
102                "decimal_integral_div",
103                spark_decimal_integral_div,
104                data_type
105            )
106        }
107        "murmur3_hash" => {
108            let func = Arc::new(spark_murmur3_hash);
109            make_comet_scalar_udf!("murmur3_hash", func, without data_type)
110        }
111        "xxhash64" => {
112            let func = Arc::new(spark_xxhash64);
113            make_comet_scalar_udf!("xxhash64", func, without data_type)
114        }
115        "isnan" => {
116            let func = Arc::new(spark_isnan);
117            make_comet_scalar_udf!("isnan", func, without data_type)
118        }
119        "sha224" => {
120            let func = Arc::new(spark_sha224);
121            make_comet_scalar_udf!("sha224", func, without data_type)
122        }
123        "sha256" => {
124            let func = Arc::new(spark_sha256);
125            make_comet_scalar_udf!("sha256", func, without data_type)
126        }
127        "sha384" => {
128            let func = Arc::new(spark_sha384);
129            make_comet_scalar_udf!("sha384", func, without data_type)
130        }
131        "sha512" => {
132            let func = Arc::new(spark_sha512);
133            make_comet_scalar_udf!("sha512", func, without data_type)
134        }
135        "date_add" => {
136            let func = Arc::new(spark_date_add);
137            make_comet_scalar_udf!("date_add", func, without data_type)
138        }
139        "date_sub" => {
140            let func = Arc::new(spark_date_sub);
141            make_comet_scalar_udf!("date_sub", func, without data_type)
142        }
143        "array_repeat" => {
144            let func = Arc::new(spark_array_repeat);
145            make_comet_scalar_udf!("array_repeat", func, without data_type)
146        }
147        _ => registry.udf(fun_name).map_err(|e| {
148            DataFusionError::Execution(format!(
149                "Function {fun_name} not found in the registry: {e}",
150            ))
151        }),
152    }
153}
154
155fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
156    vec![
157        Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default())),
158        Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())),
159        Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
160        Arc::new(ScalarUDF::new_from_impl(SparkBitwiseGet::default())),
161        Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
162    ]
163}
164
165/// Registers all custom UDFs
166pub fn register_all_comet_functions(registry: &mut dyn FunctionRegistry) -> DataFusionResult<()> {
167    // This will override existing UDFs with the same name
168    all_scalar_functions()
169        .into_iter()
170        .try_for_each(|udf| registry.register_udf(udf).map(|_| ()))?;
171
172    Ok(())
173}
174
175struct CometScalarFunction {
176    name: String,
177    signature: Signature,
178    data_type: DataType,
179    func: ScalarFunctionImplementation,
180}
181
182impl Debug for CometScalarFunction {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        f.debug_struct("CometScalarFunction")
185            .field("name", &self.name)
186            .field("signature", &self.signature)
187            .field("data_type", &self.data_type)
188            .finish()
189    }
190}
191
192impl CometScalarFunction {
193    fn new(
194        name: String,
195        signature: Signature,
196        data_type: DataType,
197        func: ScalarFunctionImplementation,
198    ) -> Self {
199        Self {
200            name,
201            signature,
202            data_type,
203            func,
204        }
205    }
206}
207
208impl ScalarUDFImpl for CometScalarFunction {
209    fn as_any(&self) -> &dyn Any {
210        self
211    }
212
213    fn name(&self) -> &str {
214        self.name.as_str()
215    }
216
217    fn signature(&self) -> &Signature {
218        &self.signature
219    }
220
221    fn return_type(&self, _: &[DataType]) -> DataFusionResult<DataType> {
222        Ok(self.data_type.clone())
223    }
224
225    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
226        (self.func)(&args.args)
227    }
228}