Skip to main content

datafusion_spark/function/bitwise/
bit_get.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::mem::size_of;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, ArrowPrimitiveType, AsArray, Int8Array, Int32Array, PrimitiveArray,
23    downcast_integer_array,
24};
25use arrow::compute::try_binary;
26use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Int8Type, Int32Type};
27use datafusion_common::types::{NativeType, logical_int32};
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{Result, internal_err};
30use datafusion_expr::{
31    Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
32    Signature, TypeSignatureClass, Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35
36#[derive(Debug, PartialEq, Eq, Hash)]
37pub struct SparkBitGet {
38    signature: Signature,
39    aliases: Vec<String>,
40}
41
42impl Default for SparkBitGet {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SparkBitGet {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::coercible(
52                vec![
53                    Coercion::new_exact(TypeSignatureClass::Integer),
54                    Coercion::new_implicit(
55                        TypeSignatureClass::Native(logical_int32()),
56                        vec![TypeSignatureClass::Integer],
57                        NativeType::Int32,
58                    ),
59                ],
60                Volatility::Immutable,
61            ),
62            aliases: vec!["getbit".to_string()],
63        }
64    }
65}
66
67impl ScalarUDFImpl for SparkBitGet {
68    fn name(&self) -> &str {
69        "bit_get"
70    }
71
72    fn aliases(&self) -> &[String] {
73        &self.aliases
74    }
75
76    fn signature(&self) -> &Signature {
77        &self.signature
78    }
79
80    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
81        internal_err!("return_field_from_args should be used instead")
82    }
83
84    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
85        // Spark derives nullability for BinaryExpression from its children
86        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
87        Ok(Arc::new(Field::new(self.name(), DataType::Int8, nullable)))
88    }
89
90    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
91        make_scalar_function(spark_bit_get, vec![])(&args.args)
92    }
93}
94
95fn spark_bit_get_inner<T: ArrowPrimitiveType>(
96    value: &PrimitiveArray<T>,
97    pos: &Int32Array,
98) -> Result<Int8Array> {
99    let bit_length = (size_of::<T::Native>() * 8) as i32;
100
101    let result: PrimitiveArray<Int8Type> = try_binary(value, pos, |value, pos| {
102        if pos < 0 || pos >= bit_length {
103            return Err(arrow::error::ArrowError::ComputeError(format!(
104                "bit_get: position {pos} is out of bounds. Expected pos < {bit_length} and pos >= 0"
105            )));
106        }
107        Ok(((value.to_i64().unwrap() >> pos) & 1) as i8)
108    })?;
109    Ok(result)
110}
111
112fn spark_bit_get(args: &[ArrayRef]) -> Result<ArrayRef> {
113    let [value, position] = take_function_args("bit_get", args)?;
114    let pos_arg = position.as_primitive::<Int32Type>();
115    let ret = downcast_integer_array!(
116        value => spark_bit_get_inner(value, pos_arg),
117        DataType::Null => Ok(Int8Array::new_null(value.len())),
118        d => internal_err!("Unsupported datatype for bit_get: {d}"),
119    )?;
120    Ok(Arc::new(ret))
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_bit_get_nullability_non_nullable_inputs() {
129        let func = SparkBitGet::new();
130        let value_field = Arc::new(Field::new("value", DataType::Int32, false));
131        let pos_field = Arc::new(Field::new("pos", DataType::Int32, false));
132
133        let out_field = func
134            .return_field_from_args(ReturnFieldArgs {
135                arg_fields: &[value_field, pos_field],
136                scalar_arguments: &[None, None],
137            })
138            .unwrap();
139
140        assert_eq!(out_field.data_type(), &DataType::Int8);
141        assert!(!out_field.is_nullable());
142    }
143
144    #[test]
145    fn test_bit_get_nullability_nullable_inputs() {
146        let func = SparkBitGet::new();
147        let value_field = Arc::new(Field::new("value", DataType::Int32, true));
148        let pos_field = Arc::new(Field::new("pos", DataType::Int32, false));
149
150        let out_field = func
151            .return_field_from_args(ReturnFieldArgs {
152                arg_fields: &[value_field, pos_field],
153                scalar_arguments: &[None, None],
154            })
155            .unwrap();
156
157        assert_eq!(out_field.data_type(), &DataType::Int8);
158        assert!(out_field.is_nullable());
159    }
160}