datafusion_spark/function/bitwise/
bit_get.rs1use std::mem::size_of;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, ArrowPrimitiveType, AsArray, Int8Array, Int32Array, PrimitiveArray,
23 downcast_integer_array,
24};
25use arrow::compute::try_binary;
26use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Int8Type, Int32Type};
27use datafusion_common::types::{NativeType, logical_int32};
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{Result, internal_err};
30use datafusion_expr::{
31 Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
32 Signature, TypeSignatureClass, Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35
36#[derive(Debug, PartialEq, Eq, Hash)]
37pub struct SparkBitGet {
38 signature: Signature,
39 aliases: Vec<String>,
40}
41
42impl Default for SparkBitGet {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SparkBitGet {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::coercible(
52 vec![
53 Coercion::new_exact(TypeSignatureClass::Integer),
54 Coercion::new_implicit(
55 TypeSignatureClass::Native(logical_int32()),
56 vec![TypeSignatureClass::Integer],
57 NativeType::Int32,
58 ),
59 ],
60 Volatility::Immutable,
61 ),
62 aliases: vec!["getbit".to_string()],
63 }
64 }
65}
66
67impl ScalarUDFImpl for SparkBitGet {
68 fn name(&self) -> &str {
69 "bit_get"
70 }
71
72 fn aliases(&self) -> &[String] {
73 &self.aliases
74 }
75
76 fn signature(&self) -> &Signature {
77 &self.signature
78 }
79
80 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
81 internal_err!("return_field_from_args should be used instead")
82 }
83
84 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
85 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
87 Ok(Arc::new(Field::new(self.name(), DataType::Int8, nullable)))
88 }
89
90 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
91 make_scalar_function(spark_bit_get, vec![])(&args.args)
92 }
93}
94
95fn spark_bit_get_inner<T: ArrowPrimitiveType>(
96 value: &PrimitiveArray<T>,
97 pos: &Int32Array,
98) -> Result<Int8Array> {
99 let bit_length = (size_of::<T::Native>() * 8) as i32;
100
101 let result: PrimitiveArray<Int8Type> = try_binary(value, pos, |value, pos| {
102 if pos < 0 || pos >= bit_length {
103 return Err(arrow::error::ArrowError::ComputeError(format!(
104 "bit_get: position {pos} is out of bounds. Expected pos < {bit_length} and pos >= 0"
105 )));
106 }
107 Ok(((value.to_i64().unwrap() >> pos) & 1) as i8)
108 })?;
109 Ok(result)
110}
111
112fn spark_bit_get(args: &[ArrayRef]) -> Result<ArrayRef> {
113 let [value, position] = take_function_args("bit_get", args)?;
114 let pos_arg = position.as_primitive::<Int32Type>();
115 let ret = downcast_integer_array!(
116 value => spark_bit_get_inner(value, pos_arg),
117 DataType::Null => Ok(Int8Array::new_null(value.len())),
118 d => internal_err!("Unsupported datatype for bit_get: {d}"),
119 )?;
120 Ok(Arc::new(ret))
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_bit_get_nullability_non_nullable_inputs() {
129 let func = SparkBitGet::new();
130 let value_field = Arc::new(Field::new("value", DataType::Int32, false));
131 let pos_field = Arc::new(Field::new("pos", DataType::Int32, false));
132
133 let out_field = func
134 .return_field_from_args(ReturnFieldArgs {
135 arg_fields: &[value_field, pos_field],
136 scalar_arguments: &[None, None],
137 })
138 .unwrap();
139
140 assert_eq!(out_field.data_type(), &DataType::Int8);
141 assert!(!out_field.is_nullable());
142 }
143
144 #[test]
145 fn test_bit_get_nullability_nullable_inputs() {
146 let func = SparkBitGet::new();
147 let value_field = Arc::new(Field::new("value", DataType::Int32, true));
148 let pos_field = Arc::new(Field::new("pos", DataType::Int32, false));
149
150 let out_field = func
151 .return_field_from_args(ReturnFieldArgs {
152 arg_fields: &[value_field, pos_field],
153 scalar_arguments: &[None, None],
154 })
155 .unwrap();
156
157 assert_eq!(out_field.data_type(), &DataType::Int8);
158 assert!(out_field.is_nullable());
159 }
160}