Skip to main content

datafusion_spark/function/string/
is_valid_utf8.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field, FieldRef};
19use datafusion_common::{Result, internal_err};
20use datafusion_expr::{
21    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
22    Volatility,
23};
24
25use arrow::array::{Array, ArrayRef, BooleanArray};
26use arrow::buffer::BooleanBuffer;
27use datafusion_common::cast::{
28    as_binary_array, as_binary_view_array, as_large_binary_array,
29};
30use datafusion_common::utils::take_function_args;
31use datafusion_functions::utils::make_scalar_function;
32
33use std::sync::Arc;
34
35/// Spark-compatible `is_valid_utf8` expression
36/// <https://spark.apache.org/docs/latest/api/sql/index.html#is_valid_utf8>
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct SparkIsValidUtf8 {
39    signature: Signature,
40}
41
42impl Default for SparkIsValidUtf8 {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SparkIsValidUtf8 {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::uniform(
52                1,
53                vec![
54                    DataType::Utf8,
55                    DataType::LargeUtf8,
56                    DataType::Utf8View,
57                    DataType::Binary,
58                    DataType::BinaryView,
59                    DataType::LargeBinary,
60                ],
61                Volatility::Immutable,
62            ),
63        }
64    }
65}
66
67impl ScalarUDFImpl for SparkIsValidUtf8 {
68    fn name(&self) -> &str {
69        "is_valid_utf8"
70    }
71
72    fn signature(&self) -> &Signature {
73        &self.signature
74    }
75
76    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
77        internal_err!("return_field_from_args should be used instead")
78    }
79
80    fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
81        Ok(Arc::new(Field::new(self.name(), DataType::Boolean, true)))
82    }
83
84    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
85        make_scalar_function(spark_is_valid_utf8_inner, vec![])(&args.args)
86    }
87}
88
89fn spark_is_valid_utf8_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
90    let [array] = take_function_args("is_valid_utf8", args)?;
91    match array.data_type() {
92        DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => {
93            Ok(Arc::new(BooleanArray::new(
94                BooleanBuffer::new_set(array.len()),
95                array.nulls().cloned(),
96            )))
97        }
98        DataType::Binary => Ok(Arc::new(
99            as_binary_array(array)?
100                .iter()
101                .map(|x| x.map(|y| str::from_utf8(y).is_ok()))
102                .collect::<BooleanArray>(),
103        )),
104        DataType::LargeBinary => Ok(Arc::new(
105            as_large_binary_array(array)?
106                .iter()
107                .map(|x| x.map(|y| str::from_utf8(y).is_ok()))
108                .collect::<BooleanArray>(),
109        )),
110        DataType::BinaryView => Ok(Arc::new(
111            as_binary_view_array(array)?
112                .iter()
113                .map(|x| x.map(|y| str::from_utf8(y).is_ok()))
114                .collect::<BooleanArray>(),
115        )),
116        data_type => {
117            internal_err!("is_valid_utf8 does not support: {data_type}")
118        }
119    }
120}