datafusion_functions/encoding/
inner.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Encoding expressions
19
20use arrow::{
21    array::{
22        Array, ArrayRef, BinaryArray, GenericByteArray, OffsetSizeTrait, StringArray,
23    },
24    datatypes::{ByteArrayType, DataType},
25};
26use arrow_buffer::{Buffer, OffsetBufferBuilder};
27use base64::{
28    engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
29    Engine as _,
30};
31use datafusion_common::{
32    cast::{as_generic_binary_array, as_generic_string_array},
33    not_impl_err, plan_err,
34    utils::take_function_args,
35};
36use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue};
37use datafusion_common::{DataFusionError, Result};
38use datafusion_expr::{ColumnarValue, Documentation};
39use std::sync::Arc;
40use std::{fmt, str::FromStr};
41
42use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
43use datafusion_macros::user_doc;
44use std::any::Any;
45
46// Allow padding characters, but don't require them, and don't generate them.
47const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new(
48    &base64::alphabet::STANDARD,
49    GeneralPurposeConfig::new()
50        .with_encode_padding(false)
51        .with_decode_padding_mode(DecodePaddingMode::Indifferent),
52);
53
54#[user_doc(
55    doc_section(label = "Binary String Functions"),
56    description = "Encode binary data into a textual representation.",
57    syntax_example = "encode(expression, format)",
58    argument(
59        name = "expression",
60        description = "Expression containing string or binary data"
61    ),
62    argument(
63        name = "format",
64        description = "Supported formats are: `base64`, `hex`"
65    ),
66    related_udf(name = "decode")
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct EncodeFunc {
70    signature: Signature,
71}
72
73impl Default for EncodeFunc {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl EncodeFunc {
80    pub fn new() -> Self {
81        Self {
82            signature: Signature::user_defined(Volatility::Immutable),
83        }
84    }
85}
86
87impl ScalarUDFImpl for EncodeFunc {
88    fn as_any(&self) -> &dyn Any {
89        self
90    }
91    fn name(&self) -> &str {
92        "encode"
93    }
94
95    fn signature(&self) -> &Signature {
96        &self.signature
97    }
98
99    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
100        use DataType::*;
101
102        Ok(match arg_types[0] {
103            Utf8 => Utf8,
104            LargeUtf8 => LargeUtf8,
105            Utf8View => Utf8,
106            Binary => Utf8,
107            LargeBinary => LargeUtf8,
108            Null => Null,
109            _ => {
110                return plan_err!(
111                    "The encode function can only accept Utf8 or Binary or Null."
112                );
113            }
114        })
115    }
116
117    fn invoke_with_args(
118        &self,
119        args: datafusion_expr::ScalarFunctionArgs,
120    ) -> Result<ColumnarValue> {
121        encode(&args.args)
122    }
123
124    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
125        let [expression, format] = take_function_args(self.name(), arg_types)?;
126
127        if format != &DataType::Utf8 {
128            return Err(DataFusionError::Plan("2nd argument should be Utf8".into()));
129        }
130
131        match expression {
132            DataType::Utf8 | DataType::Utf8View | DataType::Null => {
133                Ok(vec![DataType::Utf8; 2])
134            }
135            DataType::LargeUtf8 => Ok(vec![DataType::LargeUtf8, DataType::Utf8]),
136            DataType::Binary => Ok(vec![DataType::Binary, DataType::Utf8]),
137            DataType::LargeBinary => Ok(vec![DataType::LargeBinary, DataType::Utf8]),
138            _ => plan_err!(
139                "1st argument should be Utf8 or Binary or Null, got {:?}",
140                arg_types[0]
141            ),
142        }
143    }
144
145    fn documentation(&self) -> Option<&Documentation> {
146        self.doc()
147    }
148}
149
150#[user_doc(
151    doc_section(label = "Binary String Functions"),
152    description = "Decode binary data from textual representation in string.",
153    syntax_example = "decode(expression, format)",
154    argument(
155        name = "expression",
156        description = "Expression containing encoded string data"
157    ),
158    argument(name = "format", description = "Same arguments as [encode](#encode)"),
159    related_udf(name = "encode")
160)]
161#[derive(Debug, PartialEq, Eq, Hash)]
162pub struct DecodeFunc {
163    signature: Signature,
164}
165
166impl Default for DecodeFunc {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl DecodeFunc {
173    pub fn new() -> Self {
174        Self {
175            signature: Signature::user_defined(Volatility::Immutable),
176        }
177    }
178}
179
180impl ScalarUDFImpl for DecodeFunc {
181    fn as_any(&self) -> &dyn Any {
182        self
183    }
184    fn name(&self) -> &str {
185        "decode"
186    }
187
188    fn signature(&self) -> &Signature {
189        &self.signature
190    }
191
192    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
193        Ok(arg_types[0].to_owned())
194    }
195
196    fn invoke_with_args(
197        &self,
198        args: datafusion_expr::ScalarFunctionArgs,
199    ) -> Result<ColumnarValue> {
200        decode(&args.args)
201    }
202
203    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
204        if arg_types.len() != 2 {
205            return plan_err!(
206                "{} expects to get 2 arguments, but got {}",
207                self.name(),
208                arg_types.len()
209            );
210        }
211
212        if arg_types[1] != DataType::Utf8 {
213            return plan_err!("2nd argument should be Utf8");
214        }
215
216        match arg_types[0] {
217            DataType::Utf8 | DataType::Utf8View | DataType::Binary | DataType::Null => {
218                Ok(vec![DataType::Binary, DataType::Utf8])
219            }
220            DataType::LargeUtf8 | DataType::LargeBinary => {
221                Ok(vec![DataType::LargeBinary, DataType::Utf8])
222            }
223            _ => plan_err!(
224                "1st argument should be Utf8 or Binary or Null, got {:?}",
225                arg_types[0]
226            ),
227        }
228    }
229
230    fn documentation(&self) -> Option<&Documentation> {
231        self.doc()
232    }
233}
234
235#[derive(Debug, Copy, Clone)]
236enum Encoding {
237    Base64,
238    Hex,
239}
240
241fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result<ColumnarValue> {
242    match value {
243        ColumnarValue::Array(a) => match a.data_type() {
244            DataType::Utf8 => encoding.encode_utf8_array::<i32>(a.as_ref()),
245            DataType::LargeUtf8 => encoding.encode_utf8_array::<i64>(a.as_ref()),
246            DataType::Utf8View => encoding.encode_utf8_array::<i32>(a.as_ref()),
247            DataType::Binary => encoding.encode_binary_array::<i32>(a.as_ref()),
248            DataType::LargeBinary => encoding.encode_binary_array::<i64>(a.as_ref()),
249            other => exec_err!(
250                "Unsupported data type {other:?} for function encode({encoding})"
251            ),
252        },
253        ColumnarValue::Scalar(scalar) => {
254            match scalar {
255                ScalarValue::Utf8(a) => {
256                    Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes())))
257                }
258                ScalarValue::LargeUtf8(a) => Ok(encoding
259                    .encode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes()))),
260                ScalarValue::Utf8View(a) => {
261                    Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes())))
262                }
263                ScalarValue::Binary(a) => Ok(
264                    encoding.encode_scalar(a.as_ref().map(|v: &Vec<u8>| v.as_slice()))
265                ),
266                ScalarValue::LargeBinary(a) => Ok(encoding
267                    .encode_large_scalar(a.as_ref().map(|v: &Vec<u8>| v.as_slice()))),
268                other => exec_err!(
269                    "Unsupported data type {other:?} for function encode({encoding})"
270                ),
271            }
272        }
273    }
274}
275
276fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result<ColumnarValue> {
277    match value {
278        ColumnarValue::Array(a) => match a.data_type() {
279            DataType::Utf8 => encoding.decode_utf8_array::<i32>(a.as_ref()),
280            DataType::LargeUtf8 => encoding.decode_utf8_array::<i64>(a.as_ref()),
281            DataType::Utf8View => encoding.decode_utf8_array::<i32>(a.as_ref()),
282            DataType::Binary => encoding.decode_binary_array::<i32>(a.as_ref()),
283            DataType::LargeBinary => encoding.decode_binary_array::<i64>(a.as_ref()),
284            other => exec_err!(
285                "Unsupported data type {other:?} for function decode({encoding})"
286            ),
287        },
288        ColumnarValue::Scalar(scalar) => {
289            match scalar {
290                ScalarValue::Utf8(a) => {
291                    encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))
292                }
293                ScalarValue::LargeUtf8(a) => encoding
294                    .decode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes())),
295                ScalarValue::Utf8View(a) => {
296                    encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))
297                }
298                ScalarValue::Binary(a) => {
299                    encoding.decode_scalar(a.as_ref().map(|v: &Vec<u8>| v.as_slice()))
300                }
301                ScalarValue::LargeBinary(a) => encoding
302                    .decode_large_scalar(a.as_ref().map(|v: &Vec<u8>| v.as_slice())),
303                other => exec_err!(
304                    "Unsupported data type {other:?} for function decode({encoding})"
305                ),
306            }
307        }
308    }
309}
310
311fn hex_encode(input: &[u8]) -> String {
312    hex::encode(input)
313}
314
315fn base64_encode(input: &[u8]) -> String {
316    BASE64_ENGINE.encode(input)
317}
318
319fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result<usize> {
320    // only write input / 2 bytes to buf
321    let out_len = input.len() / 2;
322    let buf = &mut buf[..out_len];
323    hex::decode_to_slice(input, buf)
324        .map_err(|e| internal_datafusion_err!("Failed to decode from hex: {e}"))?;
325    Ok(out_len)
326}
327
328fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result<usize> {
329    BASE64_ENGINE
330        .decode_slice(input, buf)
331        .map_err(|e| internal_datafusion_err!("Failed to decode from base64: {e}"))
332}
333
334macro_rules! encode_to_array {
335    ($METHOD: ident, $INPUT:expr) => {{
336        let utf8_array: StringArray = $INPUT
337            .iter()
338            .map(|x| x.map(|x| $METHOD(x.as_ref())))
339            .collect();
340        Arc::new(utf8_array)
341    }};
342}
343
344fn decode_to_array<F, T: ByteArrayType>(
345    method: F,
346    input: &GenericByteArray<T>,
347    conservative_upper_bound_size: usize,
348) -> Result<ArrayRef>
349where
350    F: Fn(&[u8], &mut [u8]) -> Result<usize>,
351{
352    let mut values = vec![0; conservative_upper_bound_size];
353    let mut offsets = OffsetBufferBuilder::new(input.len());
354    let mut total_bytes_decoded = 0;
355    for v in input {
356        if let Some(v) = v {
357            let cursor = &mut values[total_bytes_decoded..];
358            let decoded = method(v.as_ref(), cursor)?;
359            total_bytes_decoded += decoded;
360            offsets.push_length(decoded);
361        } else {
362            offsets.push_length(0);
363        }
364    }
365    // We reserved an upper bound size for the values buffer, but we only use the actual size
366    values.truncate(total_bytes_decoded);
367    let binary_array = BinaryArray::try_new(
368        offsets.finish(),
369        Buffer::from_vec(values),
370        input.nulls().cloned(),
371    )?;
372    Ok(Arc::new(binary_array))
373}
374
375impl Encoding {
376    fn encode_scalar(self, value: Option<&[u8]>) -> ColumnarValue {
377        ColumnarValue::Scalar(match self {
378            Self::Base64 => ScalarValue::Utf8(value.map(|v| BASE64_ENGINE.encode(v))),
379            Self::Hex => ScalarValue::Utf8(value.map(hex::encode)),
380        })
381    }
382
383    fn encode_large_scalar(self, value: Option<&[u8]>) -> ColumnarValue {
384        ColumnarValue::Scalar(match self {
385            Self::Base64 => {
386                ScalarValue::LargeUtf8(value.map(|v| BASE64_ENGINE.encode(v)))
387            }
388            Self::Hex => ScalarValue::LargeUtf8(value.map(hex::encode)),
389        })
390    }
391
392    fn encode_binary_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
393    where
394        T: OffsetSizeTrait,
395    {
396        let input_value = as_generic_binary_array::<T>(value)?;
397        let array: ArrayRef = match self {
398            Self::Base64 => encode_to_array!(base64_encode, input_value),
399            Self::Hex => encode_to_array!(hex_encode, input_value),
400        };
401        Ok(ColumnarValue::Array(array))
402    }
403
404    fn encode_utf8_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
405    where
406        T: OffsetSizeTrait,
407    {
408        let input_value = as_generic_string_array::<T>(value)?;
409        let array: ArrayRef = match self {
410            Self::Base64 => encode_to_array!(base64_encode, input_value),
411            Self::Hex => encode_to_array!(hex_encode, input_value),
412        };
413        Ok(ColumnarValue::Array(array))
414    }
415
416    fn decode_scalar(self, value: Option<&[u8]>) -> Result<ColumnarValue> {
417        let value = match value {
418            Some(value) => value,
419            None => return Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))),
420        };
421
422        let out = match self {
423            Self::Base64 => BASE64_ENGINE.decode(value).map_err(|e| {
424                internal_datafusion_err!("Failed to decode value using base64: {e}")
425            })?,
426            Self::Hex => hex::decode(value).map_err(|e| {
427                internal_datafusion_err!("Failed to decode value using hex: {e}")
428            })?,
429        };
430
431        Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(out))))
432    }
433
434    fn decode_large_scalar(self, value: Option<&[u8]>) -> Result<ColumnarValue> {
435        let value = match value {
436            Some(value) => value,
437            None => return Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(None))),
438        };
439
440        let out = match self {
441            Self::Base64 => BASE64_ENGINE.decode(value).map_err(|e| {
442                internal_datafusion_err!("Failed to decode value using base64: {e}")
443            })?,
444            Self::Hex => hex::decode(value).map_err(|e| {
445                internal_datafusion_err!("Failed to decode value using hex: {e}")
446            })?,
447        };
448
449        Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(out))))
450    }
451
452    fn decode_binary_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
453    where
454        T: OffsetSizeTrait,
455    {
456        let input_value = as_generic_binary_array::<T>(value)?;
457        let array = self.decode_byte_array(input_value)?;
458        Ok(ColumnarValue::Array(array))
459    }
460
461    fn decode_utf8_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
462    where
463        T: OffsetSizeTrait,
464    {
465        let input_value = as_generic_string_array::<T>(value)?;
466        let array = self.decode_byte_array(input_value)?;
467        Ok(ColumnarValue::Array(array))
468    }
469
470    fn decode_byte_array<T: ByteArrayType>(
471        &self,
472        input_value: &GenericByteArray<T>,
473    ) -> Result<ArrayRef> {
474        match self {
475            Self::Base64 => {
476                let upper_bound =
477                    base64::decoded_len_estimate(input_value.values().len());
478                decode_to_array(base64_decode, input_value, upper_bound)
479            }
480            Self::Hex => {
481                // Calculate the upper bound for decoded byte size
482                // For hex encoding, each pair of hex characters (2 bytes) represents 1 byte when decoded
483                // So the upper bound is half the length of the input values.
484                let upper_bound = input_value.values().len() / 2;
485                decode_to_array(hex_decode, input_value, upper_bound)
486            }
487        }
488    }
489}
490
491impl fmt::Display for Encoding {
492    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
493        write!(f, "{}", format!("{self:?}").to_lowercase())
494    }
495}
496
497impl FromStr for Encoding {
498    type Err = DataFusionError;
499    fn from_str(name: &str) -> Result<Encoding> {
500        Ok(match name {
501            "base64" => Self::Base64,
502            "hex" => Self::Hex,
503            _ => {
504                let options = [Self::Base64, Self::Hex]
505                    .iter()
506                    .map(|i| i.to_string())
507                    .collect::<Vec<_>>()
508                    .join(", ");
509                return plan_err!(
510                    "There is no built-in encoding named '{name}', currently supported encodings are: {options}"
511                );
512            }
513        })
514    }
515}
516
517/// Encodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`].
518/// Second argument is the encoding to use.
519/// Standard encodings are base64 and hex.
520fn encode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
521    let [expression, format] = take_function_args("encode", args)?;
522
523    let encoding = match format {
524        ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
525            Some(Some(method)) => method.parse::<Encoding>(),
526            _ => not_impl_err!(
527                "Second argument to encode must be non null constant string: Encode using dynamically decided method is not yet supported. Got {scalar:?}"
528            ),
529        },
530        ColumnarValue::Array(_) => not_impl_err!(
531            "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported"
532        ),
533    }?;
534    encode_process(expression, encoding)
535}
536
537/// Decodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`].
538/// Second argument is the encoding to use.
539/// Standard encodings are base64 and hex.
540fn decode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
541    let [expression, format] = take_function_args("decode", args)?;
542
543    let encoding = match format {
544        ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
545            Some(Some(method))=> method.parse::<Encoding>(),
546            _ => not_impl_err!(
547                "Second argument to decode must be a non null constant string: Decode using dynamically decided method is not yet supported. Got {scalar:?}"
548            ),
549        },
550        ColumnarValue::Array(_) => not_impl_err!(
551            "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported"
552        ),
553    }?;
554    decode_process(expression, encoding)
555}