datafusion_comet_spark_expr/
comet_scalar_funcs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::hash_funcs::*;
19use crate::{
20    spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div,
21    spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round,
22    spark_rpad, spark_unhex, spark_unscaled_value, SparkChrFunc,
23};
24use arrow_schema::DataType;
25use datafusion_common::{DataFusionError, Result as DataFusionResult};
26use datafusion_expr::registry::FunctionRegistry;
27use datafusion_expr::{
28    ColumnarValue, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
29};
30use std::any::Any;
31use std::fmt::Debug;
32use std::sync::Arc;
33
34macro_rules! make_comet_scalar_udf {
35    ($name:expr, $func:ident, $data_type:ident) => {{
36        let scalar_func = CometScalarFunction::new(
37            $name.to_string(),
38            Signature::variadic_any(Volatility::Immutable),
39            $data_type.clone(),
40            Arc::new(move |args| $func(args, &$data_type)),
41        );
42        Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
43    }};
44    ($name:expr, $func:expr, without $data_type:ident) => {{
45        let scalar_func = CometScalarFunction::new(
46            $name.to_string(),
47            Signature::variadic_any(Volatility::Immutable),
48            $data_type,
49            $func,
50        );
51        Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
52    }};
53}
54
55/// Create a physical scalar function.
56pub fn create_comet_physical_fun(
57    fun_name: &str,
58    data_type: DataType,
59    registry: &dyn FunctionRegistry,
60) -> Result<Arc<ScalarUDF>, DataFusionError> {
61    match fun_name {
62        "ceil" => {
63            make_comet_scalar_udf!("ceil", spark_ceil, data_type)
64        }
65        "floor" => {
66            make_comet_scalar_udf!("floor", spark_floor, data_type)
67        }
68        "read_side_padding" => {
69            let func = Arc::new(spark_read_side_padding);
70            make_comet_scalar_udf!("read_side_padding", func, without data_type)
71        }
72        "rpad" => {
73            let func = Arc::new(spark_rpad);
74            make_comet_scalar_udf!("rpad", func, without data_type)
75        }
76        "round" => {
77            make_comet_scalar_udf!("round", spark_round, data_type)
78        }
79        "unscaled_value" => {
80            let func = Arc::new(spark_unscaled_value);
81            make_comet_scalar_udf!("unscaled_value", func, without data_type)
82        }
83        "make_decimal" => {
84            make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
85        }
86        "hex" => {
87            let func = Arc::new(spark_hex);
88            make_comet_scalar_udf!("hex", func, without data_type)
89        }
90        "unhex" => {
91            let func = Arc::new(spark_unhex);
92            make_comet_scalar_udf!("unhex", func, without data_type)
93        }
94        "decimal_div" => {
95            make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
96        }
97        "decimal_integral_div" => {
98            make_comet_scalar_udf!(
99                "decimal_integral_div",
100                spark_decimal_integral_div,
101                data_type
102            )
103        }
104        "murmur3_hash" => {
105            let func = Arc::new(spark_murmur3_hash);
106            make_comet_scalar_udf!("murmur3_hash", func, without data_type)
107        }
108        "xxhash64" => {
109            let func = Arc::new(spark_xxhash64);
110            make_comet_scalar_udf!("xxhash64", func, without data_type)
111        }
112        "chr" => Ok(Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default()))),
113        "isnan" => {
114            let func = Arc::new(spark_isnan);
115            make_comet_scalar_udf!("isnan", func, without data_type)
116        }
117        "sha224" => {
118            let func = Arc::new(spark_sha224);
119            make_comet_scalar_udf!("sha224", func, without data_type)
120        }
121        "sha256" => {
122            let func = Arc::new(spark_sha256);
123            make_comet_scalar_udf!("sha256", func, without data_type)
124        }
125        "sha384" => {
126            let func = Arc::new(spark_sha384);
127            make_comet_scalar_udf!("sha384", func, without data_type)
128        }
129        "sha512" => {
130            let func = Arc::new(spark_sha512);
131            make_comet_scalar_udf!("sha512", func, without data_type)
132        }
133        "date_add" => {
134            let func = Arc::new(spark_date_add);
135            make_comet_scalar_udf!("date_add", func, without data_type)
136        }
137        "date_sub" => {
138            let func = Arc::new(spark_date_sub);
139            make_comet_scalar_udf!("date_sub", func, without data_type)
140        }
141        _ => registry.udf(fun_name).map_err(|e| {
142            DataFusionError::Execution(format!(
143                "Function {fun_name} not found in the registry: {e}",
144            ))
145        }),
146    }
147}
148
149struct CometScalarFunction {
150    name: String,
151    signature: Signature,
152    data_type: DataType,
153    func: ScalarFunctionImplementation,
154}
155
156impl Debug for CometScalarFunction {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("CometScalarFunction")
159            .field("name", &self.name)
160            .field("signature", &self.signature)
161            .field("data_type", &self.data_type)
162            .finish()
163    }
164}
165
166impl CometScalarFunction {
167    fn new(
168        name: String,
169        signature: Signature,
170        data_type: DataType,
171        func: ScalarFunctionImplementation,
172    ) -> Self {
173        Self {
174            name,
175            signature,
176            data_type,
177            func,
178        }
179    }
180}
181
182impl ScalarUDFImpl for CometScalarFunction {
183    fn as_any(&self) -> &dyn Any {
184        self
185    }
186
187    fn name(&self) -> &str {
188        self.name.as_str()
189    }
190
191    fn signature(&self) -> &Signature {
192        &self.signature
193    }
194
195    fn return_type(&self, _: &[DataType]) -> DataFusionResult<DataType> {
196        Ok(self.data_type.clone())
197    }
198
199    fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
200        (self.func)(args)
201    }
202}