datafusion_spark/function/string/
ascii.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::datatypes::{DataType, Field, FieldRef};
21use datafusion_common::types::{NativeType, logical_string};
22use datafusion_common::{Result, internal_err};
23use datafusion_expr::{
24    Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
25    Signature, TypeSignatureClass, Volatility,
26};
27use datafusion_functions::string::ascii::ascii;
28use datafusion_functions::utils::make_scalar_function;
29use std::any::Any;
30
31/// Spark compatible version of the [ascii] function. Differs from the [default ascii function]
32/// in that it is more permissive of input types, for example casting numeric input to string
33/// before executing the function (default version doesn't allow numeric input).
34///
35/// [ascii]: https://spark.apache.org/docs/latest/api/sql/index.html#ascii
36/// [default ascii function]: datafusion_functions::string::ascii::AsciiFunc
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct SparkAscii {
39    signature: Signature,
40}
41
42impl Default for SparkAscii {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SparkAscii {
49    pub fn new() -> Self {
50        // Spark's ascii uses ImplicitCastInputTypes with StringType,
51        // which allows numeric types to be implicitly cast to String.
52        // See: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
53        let string_coercion = Coercion::new_implicit(
54            TypeSignatureClass::Native(logical_string()),
55            vec![TypeSignatureClass::Numeric],
56            NativeType::String,
57        );
58
59        Self {
60            signature: Signature::coercible(vec![string_coercion], Volatility::Immutable),
61        }
62    }
63}
64
65impl ScalarUDFImpl for SparkAscii {
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69
70    fn name(&self) -> &str {
71        "ascii"
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79        internal_err!("return_field_from_args should be used instead")
80    }
81
82    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
83        // ascii returns an Int32 value
84        // The result is nullable only if any of the input arguments is nullable
85        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
86        Ok(Arc::new(Field::new("ascii", DataType::Int32, nullable)))
87    }
88
89    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90        make_scalar_function(ascii, vec![])(&args.args)
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use datafusion_expr::ReturnFieldArgs;
98
99    #[test]
100    fn test_return_field_nullable_input() {
101        let ascii_func = SparkAscii::new();
102        let nullable_field = Arc::new(Field::new("input", DataType::Utf8, true));
103
104        let result = ascii_func
105            .return_field_from_args(ReturnFieldArgs {
106                arg_fields: &[nullable_field],
107                scalar_arguments: &[],
108            })
109            .unwrap();
110
111        assert_eq!(result.data_type(), &DataType::Int32);
112        assert!(
113            result.is_nullable(),
114            "Output should be nullable when input is nullable"
115        );
116    }
117
118    #[test]
119    fn test_return_field_non_nullable_input() {
120        let ascii_func = SparkAscii::new();
121        let non_nullable_field = Arc::new(Field::new("input", DataType::Utf8, false));
122
123        let result = ascii_func
124            .return_field_from_args(ReturnFieldArgs {
125                arg_fields: &[non_nullable_field],
126                scalar_arguments: &[],
127            })
128            .unwrap();
129
130        assert_eq!(result.data_type(), &DataType::Int32);
131        assert!(
132            !result.is_nullable(),
133            "Output should not be nullable when input is not nullable"
134        );
135    }
136}