Skip to main content

datafusion_functions/string/
ascii.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, AsArray, Int32Array, StringArrayType};
19use arrow::datatypes::DataType;
20use arrow::error::ArrowError;
21use datafusion_common::types::logical_string;
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, internal_err};
24use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass};
25use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
26use datafusion_expr_common::signature::Coercion;
27use datafusion_macros::user_doc;
28use std::sync::Arc;
29
30#[user_doc(
31    doc_section(label = "String Functions"),
32    description = "Returns the first Unicode scalar value of a string.",
33    syntax_example = "ascii(str)",
34    sql_example = r#"```sql
35> select ascii('abc');
36+--------------------+
37| ascii(Utf8("abc")) |
38+--------------------+
39| 97                 |
40+--------------------+
41> select ascii('🚀');
42+-------------------+
43| ascii(Utf8("🚀")) |
44+-------------------+
45| 128640            |
46+-------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    related_udf(name = "chr")
50)]
51#[derive(Debug, PartialEq, Eq, Hash)]
52pub struct AsciiFunc {
53    signature: Signature,
54}
55
56impl Default for AsciiFunc {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl AsciiFunc {
63    pub fn new() -> Self {
64        Self {
65            signature: Signature::coercible(
66                vec![Coercion::new_exact(TypeSignatureClass::Native(
67                    logical_string(),
68                ))],
69                Volatility::Immutable,
70            ),
71        }
72    }
73}
74
75impl ScalarUDFImpl for AsciiFunc {
76    fn name(&self) -> &str {
77        "ascii"
78    }
79
80    fn signature(&self) -> &Signature {
81        &self.signature
82    }
83
84    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85        Ok(DataType::Int32)
86    }
87
88    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
89        let [arg] = take_function_args(self.name(), args.args)?;
90
91        match arg {
92            ColumnarValue::Scalar(scalar) => {
93                if scalar.is_null() {
94                    return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None)));
95                }
96
97                match scalar {
98                    ScalarValue::Utf8(Some(s))
99                    | ScalarValue::LargeUtf8(Some(s))
100                    | ScalarValue::Utf8View(Some(s)) => {
101                        let result = s.chars().next().map_or(0, |c| c as i32);
102                        Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result))))
103                    }
104                    _ => {
105                        internal_err!(
106                            "Unexpected data type {:?} for function ascii",
107                            scalar.data_type()
108                        )
109                    }
110                }
111            }
112            ColumnarValue::Array(array) => Ok(ColumnarValue::Array(ascii(&[array])?)),
113        }
114    }
115
116    fn documentation(&self) -> Option<&Documentation> {
117        self.doc()
118    }
119}
120
121fn calculate_ascii<'a, V>(array: &V) -> Result<ArrayRef, ArrowError>
122where
123    V: StringArrayType<'a, Item = &'a str>,
124{
125    let values: Vec<_> = (0..array.len())
126        .map(|i| {
127            if array.is_null(i) {
128                0
129            } else {
130                let s = array.value(i);
131                s.chars().next().map_or(0, |c| c as i32)
132            }
133        })
134        .collect();
135
136    let array = Int32Array::new(values.into(), array.nulls().cloned());
137
138    Ok(Arc::new(array))
139}
140
141/// Returns the numeric code of the first character of the argument.
142pub fn ascii(args: &[ArrayRef]) -> Result<ArrayRef> {
143    match args[0].data_type() {
144        DataType::Utf8 => {
145            let string_array = args[0].as_string::<i32>();
146            Ok(calculate_ascii(&string_array)?)
147        }
148        DataType::LargeUtf8 => {
149            let string_array = args[0].as_string::<i64>();
150            Ok(calculate_ascii(&string_array)?)
151        }
152        DataType::Utf8View => {
153            let string_array = args[0].as_string_view();
154            Ok(calculate_ascii(&string_array)?)
155        }
156        _ => internal_err!("Unsupported data type"),
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use crate::string::ascii::AsciiFunc;
163    use crate::utils::test::test_function;
164    use arrow::array::{Array, Int32Array};
165    use arrow::datatypes::DataType::Int32;
166    use datafusion_common::{Result, ScalarValue};
167    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
168
169    macro_rules! test_ascii {
170        ($INPUT:expr, $EXPECTED:expr) => {
171            test_function!(
172                AsciiFunc::new(),
173                vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
174                $EXPECTED,
175                i32,
176                Int32,
177                Int32Array
178            );
179
180            test_function!(
181                AsciiFunc::new(),
182                vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
183                $EXPECTED,
184                i32,
185                Int32,
186                Int32Array
187            );
188
189            test_function!(
190                AsciiFunc::new(),
191                vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
192                $EXPECTED,
193                i32,
194                Int32,
195                Int32Array
196            );
197        };
198    }
199
200    #[test]
201    fn test_functions() -> Result<()> {
202        test_ascii!(Some(String::from("x")), Ok(Some(120)));
203        test_ascii!(Some(String::from("a")), Ok(Some(97)));
204        test_ascii!(Some(String::from("")), Ok(Some(0)));
205        test_ascii!(Some(String::from("🚀")), Ok(Some(128640)));
206        test_ascii!(Some(String::from("\n")), Ok(Some(10)));
207        test_ascii!(Some(String::from("\t")), Ok(Some(9)));
208        test_ascii!(None, Ok(None));
209        Ok(())
210    }
211}