datafusion_comet_spark_expr/
comet_scalar_funcs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::hash_funcs::*;
19use crate::map_funcs::spark_map_sort;
20use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
21use crate::math_funcs::modulo_expr::spark_modulo;
22use crate::{
23    spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
24    spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
25    spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
26    SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace,
27};
28use arrow::datatypes::DataType;
29use datafusion::common::{DataFusionError, Result as DataFusionResult};
30use datafusion::execution::FunctionRegistry;
31use datafusion::logical_expr::{
32    ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature,
33    Volatility,
34};
35use datafusion::physical_plan::ColumnarValue;
36use std::any::Any;
37use std::fmt::Debug;
38use std::sync::Arc;
39
40macro_rules! make_comet_scalar_udf {
41    ($name:expr, $func:ident, $data_type:ident) => {{
42        let scalar_func = CometScalarFunction::new(
43            $name.to_string(),
44            Signature::variadic_any(Volatility::Immutable),
45            $data_type.clone(),
46            Arc::new(move |args| $func(args, &$data_type)),
47        );
48        Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
49    }};
50    ($name:expr, $func:expr, without $data_type:ident) => {{
51        let scalar_func = CometScalarFunction::new(
52            $name.to_string(),
53            Signature::variadic_any(Volatility::Immutable),
54            $data_type,
55            $func,
56        );
57        Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
58    }};
59    ($name:expr, $func:ident, without $data_type:ident, $fail_on_error:ident) => {{
60        let scalar_func = CometScalarFunction::new(
61            $name.to_string(),
62            Signature::variadic_any(Volatility::Immutable),
63            $data_type,
64            Arc::new(move |args| $func(args, $fail_on_error)),
65        );
66        Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
67    }};
68}
69
70/// Create a physical scalar function.
71pub fn create_comet_physical_fun(
72    fun_name: &str,
73    data_type: DataType,
74    registry: &dyn FunctionRegistry,
75    fail_on_error: Option<bool>,
76) -> Result<Arc<ScalarUDF>, DataFusionError> {
77    match fun_name {
78        "ceil" => {
79            make_comet_scalar_udf!("ceil", spark_ceil, data_type)
80        }
81        "floor" => {
82            make_comet_scalar_udf!("floor", spark_floor, data_type)
83        }
84        "read_side_padding" => {
85            let func = Arc::new(spark_read_side_padding);
86            make_comet_scalar_udf!("read_side_padding", func, without data_type)
87        }
88        "rpad" => {
89            let func = Arc::new(spark_rpad);
90            make_comet_scalar_udf!("rpad", func, without data_type)
91        }
92        "round" => {
93            make_comet_scalar_udf!("round", spark_round, data_type)
94        }
95        "unscaled_value" => {
96            let func = Arc::new(spark_unscaled_value);
97            make_comet_scalar_udf!("unscaled_value", func, without data_type)
98        }
99        "make_decimal" => {
100            make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
101        }
102        "hex" => {
103            let func = Arc::new(spark_hex);
104            make_comet_scalar_udf!("hex", func, without data_type)
105        }
106        "unhex" => {
107            let func = Arc::new(spark_unhex);
108            make_comet_scalar_udf!("unhex", func, without data_type)
109        }
110        "decimal_div" => {
111            make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
112        }
113        "decimal_integral_div" => {
114            make_comet_scalar_udf!(
115                "decimal_integral_div",
116                spark_decimal_integral_div,
117                data_type
118            )
119        }
120        "checked_add" => {
121            make_comet_scalar_udf!("checked_add", checked_add, data_type)
122        }
123        "checked_sub" => {
124            make_comet_scalar_udf!("checked_sub", checked_sub, data_type)
125        }
126        "checked_mul" => {
127            make_comet_scalar_udf!("checked_mul", checked_mul, data_type)
128        }
129        "checked_div" => {
130            make_comet_scalar_udf!("checked_div", checked_div, data_type)
131        }
132        "murmur3_hash" => {
133            let func = Arc::new(spark_murmur3_hash);
134            make_comet_scalar_udf!("murmur3_hash", func, without data_type)
135        }
136        "xxhash64" => {
137            let func = Arc::new(spark_xxhash64);
138            make_comet_scalar_udf!("xxhash64", func, without data_type)
139        }
140        "isnan" => {
141            let func = Arc::new(spark_isnan);
142            make_comet_scalar_udf!("isnan", func, without data_type)
143        }
144        "date_add" => {
145            let func = Arc::new(spark_date_add);
146            make_comet_scalar_udf!("date_add", func, without data_type)
147        }
148        "date_sub" => {
149            let func = Arc::new(spark_date_sub);
150            make_comet_scalar_udf!("date_sub", func, without data_type)
151        }
152        "array_repeat" => {
153            let func = Arc::new(spark_array_repeat);
154            make_comet_scalar_udf!("array_repeat", func, without data_type)
155        }
156        "spark_modulo" => {
157            let func = Arc::new(spark_modulo);
158            let fail_on_error = fail_on_error.unwrap_or(false);
159            make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error)
160        }
161        "map_sort" => {
162            let func = Arc::new(spark_map_sort);
163            make_comet_scalar_udf!("spark_map_sort", func, without data_type)
164        }
165        _ => registry.udf(fun_name).map_err(|e| {
166            DataFusionError::Execution(format!(
167                "Function {fun_name} not found in the registry: {e}",
168            ))
169        }),
170    }
171}
172
173fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
174    vec![
175        Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())),
176        Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
177        Arc::new(ScalarUDF::new_from_impl(SparkBitwiseGet::default())),
178        Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
179        Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
180    ]
181}
182
183/// Registers all custom UDFs
184pub fn register_all_comet_functions(registry: &mut dyn FunctionRegistry) -> DataFusionResult<()> {
185    // This will override existing UDFs with the same name
186    all_scalar_functions()
187        .into_iter()
188        .try_for_each(|udf| registry.register_udf(udf).map(|_| ()))?;
189
190    Ok(())
191}
192
193struct CometScalarFunction {
194    name: String,
195    signature: Signature,
196    data_type: DataType,
197    func: ScalarFunctionImplementation,
198}
199
200impl Debug for CometScalarFunction {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        f.debug_struct("CometScalarFunction")
203            .field("name", &self.name)
204            .field("signature", &self.signature)
205            .field("data_type", &self.data_type)
206            .finish()
207    }
208}
209
210impl CometScalarFunction {
211    fn new(
212        name: String,
213        signature: Signature,
214        data_type: DataType,
215        func: ScalarFunctionImplementation,
216    ) -> Self {
217        Self {
218            name,
219            signature,
220            data_type,
221            func,
222        }
223    }
224}
225
226impl ScalarUDFImpl for CometScalarFunction {
227    fn as_any(&self) -> &dyn Any {
228        self
229    }
230
231    fn name(&self) -> &str {
232        self.name.as_str()
233    }
234
235    fn signature(&self) -> &Signature {
236        &self.signature
237    }
238
239    fn return_type(&self, _: &[DataType]) -> DataFusionResult<DataType> {
240        Ok(self.data_type.clone())
241    }
242
243    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
244        (self.func)(&args.args)
245    }
246}