1use std::convert::identity;
2
3use arrow::array::{Array, BooleanArray, PrimitiveArray};
4use arrow::bitmap::{binary_fold, intersects_with};
5use arrow::datatypes::ArrowDataType;
6use arrow::legacy::utils::CustomIterTools;
7
8pub trait BitwiseKernel {
9 type Scalar;
10
11 fn count_ones(&self) -> PrimitiveArray<u32>;
12 fn count_zeros(&self) -> PrimitiveArray<u32>;
13
14 fn leading_ones(&self) -> PrimitiveArray<u32>;
15 fn leading_zeros(&self) -> PrimitiveArray<u32>;
16
17 fn trailing_ones(&self) -> PrimitiveArray<u32>;
18 fn trailing_zeros(&self) -> PrimitiveArray<u32>;
19
20 fn reduce_and(&self) -> Option<Self::Scalar>;
21 fn reduce_or(&self) -> Option<Self::Scalar>;
22 fn reduce_xor(&self) -> Option<Self::Scalar>;
23
24 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
25 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
26 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
27}
28
29macro_rules! impl_bitwise_kernel {
30 ($(($T:ty, $to_bits:expr, $from_bits:expr)),+ $(,)?) => {
31 $(
32 impl BitwiseKernel for PrimitiveArray<$T> {
33 type Scalar = $T;
34
35 #[inline(never)]
36 fn count_ones(&self) -> PrimitiveArray<u32> {
37 PrimitiveArray::new(
38 ArrowDataType::UInt32,
39 self.values_iter()
40 .map(|&v| $to_bits(v).count_ones())
41 .collect_trusted::<Vec<_>>()
42 .into(),
43 self.validity().cloned(),
44 )
45 }
46
47 #[inline(never)]
48 fn count_zeros(&self) -> PrimitiveArray<u32> {
49 PrimitiveArray::new(
50 ArrowDataType::UInt32,
51 self.values_iter()
52 .map(|&v| $to_bits(v).count_zeros())
53 .collect_trusted::<Vec<_>>()
54 .into(),
55 self.validity().cloned(),
56 )
57 }
58
59 #[inline(never)]
60 fn leading_ones(&self) -> PrimitiveArray<u32> {
61 PrimitiveArray::new(
62 ArrowDataType::UInt32,
63 self.values_iter()
64 .map(|&v| $to_bits(v).leading_ones())
65 .collect_trusted::<Vec<_>>()
66 .into(),
67 self.validity().cloned(),
68 )
69 }
70
71 #[inline(never)]
72 fn leading_zeros(&self) -> PrimitiveArray<u32> {
73 PrimitiveArray::new(
74 ArrowDataType::UInt32,
75 self.values_iter()
76 .map(|&v| $to_bits(v).leading_zeros())
77 .collect_trusted::<Vec<_>>()
78 .into(),
79 self.validity().cloned(),
80 )
81 }
82
83 #[inline(never)]
84 fn trailing_ones(&self) -> PrimitiveArray<u32> {
85 PrimitiveArray::new(
86 ArrowDataType::UInt32,
87 self.values_iter()
88 .map(|&v| $to_bits(v).trailing_ones())
89 .collect_trusted::<Vec<_>>()
90 .into(),
91 self.validity().cloned(),
92 )
93 }
94
95 #[inline(never)]
96 fn trailing_zeros(&self) -> PrimitiveArray<u32> {
97 PrimitiveArray::new(
98 ArrowDataType::UInt32,
99 self.values().iter()
100 .map(|&v| $to_bits(v).trailing_zeros())
101 .collect_trusted::<Vec<_>>()
102 .into(),
103 self.validity().cloned(),
104 )
105 }
106
107 #[inline(never)]
108 fn reduce_and(&self) -> Option<Self::Scalar> {
109 if !self.has_nulls() {
110 self.values_iter().copied().map($to_bits).reduce(|a, b| a & b).map($from_bits)
111 } else {
112 self.non_null_values_iter().map($to_bits).reduce(|a, b| a & b).map($from_bits)
113 }
114 }
115
116 #[inline(never)]
117 fn reduce_or(&self) -> Option<Self::Scalar> {
118 if !self.has_nulls() {
119 self.values_iter().copied().map($to_bits).reduce(|a, b| a | b).map($from_bits)
120 } else {
121 self.non_null_values_iter().map($to_bits).reduce(|a, b| a | b).map($from_bits)
122 }
123 }
124
125 #[inline(never)]
126 fn reduce_xor(&self) -> Option<Self::Scalar> {
127 if !self.has_nulls() {
128 self.values_iter().copied().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
129 } else {
130 self.non_null_values_iter().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
131 }
132 }
133
134 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
135 $from_bits($to_bits(lhs) & $to_bits(rhs))
136 }
137 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
138 $from_bits($to_bits(lhs) | $to_bits(rhs))
139 }
140 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
141 $from_bits($to_bits(lhs) ^ $to_bits(rhs))
142 }
143 }
144 )+
145 };
146}
147
148impl_bitwise_kernel! {
149 (i8, identity, identity),
150 (i16, identity, identity),
151 (i32, identity, identity),
152 (i64, identity, identity),
153 (u8, identity, identity),
154 (u16, identity, identity),
155 (u32, identity, identity),
156 (u64, identity, identity),
157 (f32, f32::to_bits, f32::from_bits),
158 (f64, f64::to_bits, f64::from_bits),
159}
160
161#[cfg(feature = "dtype-u128")]
162impl_bitwise_kernel! {
163 (u128, identity, identity),
164}
165
166#[cfg(feature = "dtype-i128")]
167impl_bitwise_kernel! {
168 (i128, identity, identity),
169}
170
171impl BitwiseKernel for BooleanArray {
172 type Scalar = bool;
173
174 #[inline(never)]
175 fn count_ones(&self) -> PrimitiveArray<u32> {
176 PrimitiveArray::new(
177 ArrowDataType::UInt32,
178 self.values_iter()
179 .map(u32::from)
180 .collect_trusted::<Vec<_>>()
181 .into(),
182 self.validity().cloned(),
183 )
184 }
185
186 #[inline(never)]
187 fn count_zeros(&self) -> PrimitiveArray<u32> {
188 PrimitiveArray::new(
189 ArrowDataType::UInt32,
190 self.values_iter()
191 .map(|v| u32::from(!v))
192 .collect_trusted::<Vec<_>>()
193 .into(),
194 self.validity().cloned(),
195 )
196 }
197
198 #[inline(always)]
199 fn leading_ones(&self) -> PrimitiveArray<u32> {
200 self.count_ones()
201 }
202
203 #[inline(always)]
204 fn leading_zeros(&self) -> PrimitiveArray<u32> {
205 self.count_zeros()
206 }
207
208 #[inline(always)]
209 fn trailing_ones(&self) -> PrimitiveArray<u32> {
210 self.count_ones()
211 }
212
213 #[inline(always)]
214 fn trailing_zeros(&self) -> PrimitiveArray<u32> {
215 self.count_zeros()
216 }
217
218 fn reduce_and(&self) -> Option<Self::Scalar> {
219 if self.len() == self.null_count() {
220 None
221 } else if !self.has_nulls() {
222 Some(self.values().unset_bits() == 0)
223 } else {
224 let false_found = binary_fold(
225 self.values(),
226 self.validity().unwrap(),
227 |lhs, rhs| (!lhs & rhs) != 0,
228 false,
229 |a, b| a || b,
230 );
231 Some(!false_found)
232 }
233 }
234
235 fn reduce_or(&self) -> Option<Self::Scalar> {
236 if self.len() == self.null_count() {
237 None
238 } else if !self.has_nulls() {
239 Some(self.values().set_bits() > 0)
240 } else {
241 Some(intersects_with(self.values(), self.validity().unwrap()))
242 }
243 }
244
245 fn reduce_xor(&self) -> Option<Self::Scalar> {
246 if self.len() == self.null_count() {
247 None
248 } else if !self.has_nulls() {
249 Some(self.values().set_bits() % 2 == 1)
250 } else {
251 let nonnull_parity = binary_fold(
252 self.values(),
253 self.validity().unwrap(),
254 |lhs, rhs| lhs & rhs,
255 0,
256 |a, b| a ^ b,
257 );
258 Some(nonnull_parity.count_ones() % 2 == 1)
259 }
260 }
261
262 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
263 lhs & rhs
264 }
265 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
266 lhs | rhs
267 }
268 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
269 lhs ^ rhs
270 }
271}