Skip to main content

datafusion_spark/function/string/
ascii.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::datatypes::{DataType, Field, FieldRef};
21use datafusion_common::types::{NativeType, logical_string};
22use datafusion_common::{Result, internal_err};
23use datafusion_expr::{
24    Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
25    Signature, TypeSignatureClass, Volatility,
26};
27use datafusion_functions::string::ascii::ascii;
28use datafusion_functions::utils::make_scalar_function;
29
30/// Spark compatible version of the [ascii] function. Differs from the [default ascii function]
31/// in that it is more permissive of input types, for example casting numeric input to string
32/// before executing the function (default version doesn't allow numeric input).
33///
34/// [ascii]: https://spark.apache.org/docs/latest/api/sql/index.html#ascii
35/// [default ascii function]: datafusion_functions::string::ascii::AsciiFunc
36#[derive(Debug, PartialEq, Eq, Hash)]
37pub struct SparkAscii {
38    signature: Signature,
39}
40
41impl Default for SparkAscii {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl SparkAscii {
48    pub fn new() -> Self {
49        // Spark's ascii uses ImplicitCastInputTypes with StringType,
50        // which allows numeric types to be implicitly cast to String.
51        // See: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
52        let string_coercion = Coercion::new_implicit(
53            TypeSignatureClass::Native(logical_string()),
54            vec![TypeSignatureClass::Numeric],
55            NativeType::String,
56        );
57
58        Self {
59            signature: Signature::coercible(vec![string_coercion], Volatility::Immutable),
60        }
61    }
62}
63
64impl ScalarUDFImpl for SparkAscii {
65    fn name(&self) -> &str {
66        "ascii"
67    }
68
69    fn signature(&self) -> &Signature {
70        &self.signature
71    }
72
73    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
74        internal_err!("return_field_from_args should be used instead")
75    }
76
77    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
78        // ascii returns an Int32 value
79        // The result is nullable only if any of the input arguments is nullable
80        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
81        Ok(Arc::new(Field::new("ascii", DataType::Int32, nullable)))
82    }
83
84    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
85        make_scalar_function(ascii, vec![])(&args.args)
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_return_field_nullable_input() {
95        let ascii_func = SparkAscii::new();
96        let nullable_field = Arc::new(Field::new("input", DataType::Utf8, true));
97
98        let result = ascii_func
99            .return_field_from_args(ReturnFieldArgs {
100                arg_fields: &[nullable_field],
101                scalar_arguments: &[],
102            })
103            .unwrap();
104
105        assert_eq!(result.data_type(), &DataType::Int32);
106        assert!(
107            result.is_nullable(),
108            "Output should be nullable when input is nullable"
109        );
110    }
111
112    #[test]
113    fn test_return_field_non_nullable_input() {
114        let ascii_func = SparkAscii::new();
115        let non_nullable_field = Arc::new(Field::new("input", DataType::Utf8, false));
116
117        let result = ascii_func
118            .return_field_from_args(ReturnFieldArgs {
119                arg_fields: &[non_nullable_field],
120                scalar_arguments: &[],
121            })
122            .unwrap();
123
124        assert_eq!(result.data_type(), &DataType::Int32);
125        assert!(
126            !result.is_nullable(),
127            "Output should not be nullable when input is not nullable"
128        );
129    }
130}