datafusion_comet_spark_expr/string_funcs/
chr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{any::Any, sync::Arc};
19
20use arrow::{
21    array::{ArrayRef, StringArray},
22    datatypes::{
23        DataType,
24        DataType::{Int64, Utf8},
25    },
26};
27
28use datafusion::common::{cast::as_int64_array, exec_err, Result, ScalarValue};
29use datafusion::logical_expr::{
30    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
31};
32
33fn chr(args: &[ArrayRef]) -> Result<ArrayRef> {
34    let integer_array = as_int64_array(&args[0])?;
35
36    // first map is the iterator, second is for the `Option<_>`
37    let result = integer_array
38        .iter()
39        .map(|integer: Option<i64>| {
40            integer
41                .map(|integer| {
42                    if integer < 0 {
43                        return Ok("".to_string()); // Return empty string for negative integers
44                    }
45                    match core::char::from_u32((integer % 256) as u32) {
46                        Some(ch) => Ok(ch.to_string()),
47                        None => {
48                            exec_err!("requested character not compatible for encoding.")
49                        }
50                    }
51                })
52                .transpose()
53        })
54        .collect::<Result<StringArray>>()?;
55
56    Ok(Arc::new(result) as ArrayRef)
57}
58
59/// Spark-compatible `chr` expression
60#[derive(Debug)]
61pub struct SparkChrFunc {
62    signature: Signature,
63}
64
65impl Default for SparkChrFunc {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl SparkChrFunc {
72    pub fn new() -> Self {
73        Self {
74            signature: Signature::uniform(1, vec![Int64], Volatility::Immutable),
75        }
76    }
77}
78
79impl ScalarUDFImpl for SparkChrFunc {
80    fn as_any(&self) -> &dyn Any {
81        self
82    }
83
84    fn name(&self) -> &str {
85        "chr"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
93        Ok(Utf8)
94    }
95
96    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
97        spark_chr(&args.args)
98    }
99}
100
101/// Returns the ASCII character having the binary equivalent to the input expression.
102/// E.g., chr(65) = 'A'.
103/// Compatible with Apache Spark's Chr function
104fn spark_chr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
105    let array = args[0].clone();
106    match array {
107        ColumnarValue::Array(array) => {
108            let array = chr(&[array])?;
109            Ok(ColumnarValue::Array(array))
110        }
111        ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => {
112            if value < 0 {
113                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
114                    "".to_string(),
115                ))))
116            } else {
117                match core::char::from_u32((value % 256) as u32) {
118                    Some(ch) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
119                        ch.to_string(),
120                    )))),
121                    None => exec_err!("requested character was incompatible for encoding."),
122                }
123            }
124        }
125        _ => exec_err!("The argument must be an Int64 array or scalar."),
126    }
127}