Skip to main content

datafusion_spark/function/string/
base64.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::datatypes::DataType;
21use datafusion_common::arrow::datatypes::{Field, FieldRef};
22use datafusion_common::types::{NativeType, logical_string};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, exec_err, internal_err};
25use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
26use datafusion_expr::{Coercion, Expr, ReturnFieldArgs, TypeSignatureClass, lit};
27use datafusion_expr::{
28    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
29};
30use datafusion_functions::expr_fn::{decode, encode};
31
32/// Apache Spark base64 uses padded base64 encoding.
33/// <https://spark.apache.org/docs/latest/api/sql/index.html#base64>
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkBase64 {
36    signature: Signature,
37}
38
39impl Default for SparkBase64 {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl SparkBase64 {
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::coercible(
49                vec![Coercion::new_implicit(
50                    TypeSignatureClass::Binary,
51                    vec![TypeSignatureClass::Native(logical_string())],
52                    NativeType::Binary,
53                )],
54                Volatility::Immutable,
55            ),
56        }
57    }
58}
59
60impl ScalarUDFImpl for SparkBase64 {
61    fn name(&self) -> &str {
62        "base64"
63    }
64
65    fn signature(&self) -> &Signature {
66        &self.signature
67    }
68
69    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
70        internal_err!("return_type should not be called for {}", self.name())
71    }
72
73    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
74        let [bin] = take_function_args(self.name(), args.arg_fields)?;
75        let return_type = match bin.data_type() {
76            DataType::LargeBinary => DataType::LargeUtf8,
77            _ => DataType::Utf8,
78        };
79        Ok(Arc::new(Field::new(
80            self.name(),
81            return_type,
82            bin.is_nullable(),
83        )))
84    }
85
86    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
87        exec_err!(
88            "invoke should not be called on a simplified {} function",
89            self.name()
90        )
91    }
92
93    fn simplify(
94        &self,
95        args: Vec<Expr>,
96        _info: &SimplifyContext,
97    ) -> Result<ExprSimplifyResult> {
98        let [bin] = take_function_args(self.name(), args)?;
99        Ok(ExprSimplifyResult::Simplified(encode(
100            bin,
101            lit("base64pad"),
102        )))
103    }
104}
105
106/// <https://spark.apache.org/docs/latest/api/sql/index.html#unbase64>
107#[derive(Debug, PartialEq, Eq, Hash)]
108pub struct SparkUnBase64 {
109    signature: Signature,
110}
111
112impl Default for SparkUnBase64 {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl SparkUnBase64 {
119    pub fn new() -> Self {
120        Self {
121            signature: Signature::coercible(
122                vec![Coercion::new_implicit(
123                    TypeSignatureClass::Binary,
124                    vec![TypeSignatureClass::Native(logical_string())],
125                    NativeType::Binary,
126                )],
127                Volatility::Immutable,
128            ),
129        }
130    }
131}
132
133impl ScalarUDFImpl for SparkUnBase64 {
134    fn name(&self) -> &str {
135        "unbase64"
136    }
137
138    fn signature(&self) -> &Signature {
139        &self.signature
140    }
141
142    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
143        internal_err!("return_type should not be called for {}", self.name())
144    }
145
146    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
147        let [str] = take_function_args(self.name(), args.arg_fields)?;
148        let return_type = match str.data_type() {
149            DataType::LargeBinary => DataType::LargeBinary,
150            _ => DataType::Binary,
151        };
152        Ok(Arc::new(Field::new(
153            self.name(),
154            return_type,
155            str.is_nullable(),
156        )))
157    }
158
159    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
160        exec_err!("{} should have been simplified", self.name())
161    }
162
163    fn simplify(
164        &self,
165        args: Vec<Expr>,
166        _info: &SimplifyContext,
167    ) -> Result<ExprSimplifyResult> {
168        let [bin] = take_function_args(self.name(), args)?;
169        Ok(ExprSimplifyResult::Simplified(decode(
170            bin,
171            lit("base64pad"),
172        )))
173    }
174}