1use std::convert::identity;
2
3use arrow::array::{Array, BooleanArray, PrimitiveArray};
4use arrow::bitmap::{binary_fold, intersects_with};
5use arrow::datatypes::ArrowDataType;
6use arrow::legacy::utils::CustomIterTools;
7use polars_utils::float16::pf16;
8
9pub trait BitwiseKernel {
10 type Scalar;
11
12 fn count_ones(&self) -> PrimitiveArray<u32>;
13 fn count_zeros(&self) -> PrimitiveArray<u32>;
14
15 fn leading_ones(&self) -> PrimitiveArray<u32>;
16 fn leading_zeros(&self) -> PrimitiveArray<u32>;
17
18 fn trailing_ones(&self) -> PrimitiveArray<u32>;
19 fn trailing_zeros(&self) -> PrimitiveArray<u32>;
20
21 fn reduce_and(&self) -> Option<Self::Scalar>;
22 fn reduce_or(&self) -> Option<Self::Scalar>;
23 fn reduce_xor(&self) -> Option<Self::Scalar>;
24
25 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
26 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
27 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
28}
29
30macro_rules! impl_bitwise_kernel {
31 ($(($T:ty, $to_bits:expr, $from_bits:expr)),+ $(,)?) => {
32 $(
33 impl BitwiseKernel for PrimitiveArray<$T> {
34 type Scalar = $T;
35
36 #[inline(never)]
37 fn count_ones(&self) -> PrimitiveArray<u32> {
38 PrimitiveArray::new(
39 ArrowDataType::UInt32,
40 self.values_iter()
41 .map(|&v| $to_bits(v).count_ones())
42 .collect_trusted::<Vec<_>>()
43 .into(),
44 self.validity().cloned(),
45 )
46 }
47
48 #[inline(never)]
49 fn count_zeros(&self) -> PrimitiveArray<u32> {
50 PrimitiveArray::new(
51 ArrowDataType::UInt32,
52 self.values_iter()
53 .map(|&v| $to_bits(v).count_zeros())
54 .collect_trusted::<Vec<_>>()
55 .into(),
56 self.validity().cloned(),
57 )
58 }
59
60 #[inline(never)]
61 fn leading_ones(&self) -> PrimitiveArray<u32> {
62 PrimitiveArray::new(
63 ArrowDataType::UInt32,
64 self.values_iter()
65 .map(|&v| $to_bits(v).leading_ones())
66 .collect_trusted::<Vec<_>>()
67 .into(),
68 self.validity().cloned(),
69 )
70 }
71
72 #[inline(never)]
73 fn leading_zeros(&self) -> PrimitiveArray<u32> {
74 PrimitiveArray::new(
75 ArrowDataType::UInt32,
76 self.values_iter()
77 .map(|&v| $to_bits(v).leading_zeros())
78 .collect_trusted::<Vec<_>>()
79 .into(),
80 self.validity().cloned(),
81 )
82 }
83
84 #[inline(never)]
85 fn trailing_ones(&self) -> PrimitiveArray<u32> {
86 PrimitiveArray::new(
87 ArrowDataType::UInt32,
88 self.values_iter()
89 .map(|&v| $to_bits(v).trailing_ones())
90 .collect_trusted::<Vec<_>>()
91 .into(),
92 self.validity().cloned(),
93 )
94 }
95
96 #[inline(never)]
97 fn trailing_zeros(&self) -> PrimitiveArray<u32> {
98 PrimitiveArray::new(
99 ArrowDataType::UInt32,
100 self.values().iter()
101 .map(|&v| $to_bits(v).trailing_zeros())
102 .collect_trusted::<Vec<_>>()
103 .into(),
104 self.validity().cloned(),
105 )
106 }
107
108 #[inline(never)]
109 fn reduce_and(&self) -> Option<Self::Scalar> {
110 if !self.has_nulls() {
111 self.values_iter().copied().map($to_bits).reduce(|a, b| a & b).map($from_bits)
112 } else {
113 self.non_null_values_iter().map($to_bits).reduce(|a, b| a & b).map($from_bits)
114 }
115 }
116
117 #[inline(never)]
118 fn reduce_or(&self) -> Option<Self::Scalar> {
119 if !self.has_nulls() {
120 self.values_iter().copied().map($to_bits).reduce(|a, b| a | b).map($from_bits)
121 } else {
122 self.non_null_values_iter().map($to_bits).reduce(|a, b| a | b).map($from_bits)
123 }
124 }
125
126 #[inline(never)]
127 fn reduce_xor(&self) -> Option<Self::Scalar> {
128 if !self.has_nulls() {
129 self.values_iter().copied().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
130 } else {
131 self.non_null_values_iter().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
132 }
133 }
134
135 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
136 $from_bits($to_bits(lhs) & $to_bits(rhs))
137 }
138 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
139 $from_bits($to_bits(lhs) | $to_bits(rhs))
140 }
141 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
142 $from_bits($to_bits(lhs) ^ $to_bits(rhs))
143 }
144 }
145 )+
146 };
147}
148
149impl_bitwise_kernel! {
150 (i8, identity, identity),
151 (i16, identity, identity),
152 (i32, identity, identity),
153 (i64, identity, identity),
154 (u8, identity, identity),
155 (u16, identity, identity),
156 (u32, identity, identity),
157 (u64, identity, identity),
158 (pf16, pf16::to_bits, pf16::from_bits),
159 (f32, f32::to_bits, f32::from_bits),
160 (f64, f64::to_bits, f64::from_bits),
161}
162
163#[cfg(feature = "dtype-u128")]
164impl_bitwise_kernel! {
165 (u128, identity, identity),
166}
167
168#[cfg(feature = "dtype-i128")]
169impl_bitwise_kernel! {
170 (i128, identity, identity),
171}
172
173impl BitwiseKernel for BooleanArray {
174 type Scalar = bool;
175
176 #[inline(never)]
177 fn count_ones(&self) -> PrimitiveArray<u32> {
178 PrimitiveArray::new(
179 ArrowDataType::UInt32,
180 self.values_iter()
181 .map(u32::from)
182 .collect_trusted::<Vec<_>>()
183 .into(),
184 self.validity().cloned(),
185 )
186 }
187
188 #[inline(never)]
189 fn count_zeros(&self) -> PrimitiveArray<u32> {
190 PrimitiveArray::new(
191 ArrowDataType::UInt32,
192 self.values_iter()
193 .map(|v| u32::from(!v))
194 .collect_trusted::<Vec<_>>()
195 .into(),
196 self.validity().cloned(),
197 )
198 }
199
200 #[inline(always)]
201 fn leading_ones(&self) -> PrimitiveArray<u32> {
202 self.count_ones()
203 }
204
205 #[inline(always)]
206 fn leading_zeros(&self) -> PrimitiveArray<u32> {
207 self.count_zeros()
208 }
209
210 #[inline(always)]
211 fn trailing_ones(&self) -> PrimitiveArray<u32> {
212 self.count_ones()
213 }
214
215 #[inline(always)]
216 fn trailing_zeros(&self) -> PrimitiveArray<u32> {
217 self.count_zeros()
218 }
219
220 fn reduce_and(&self) -> Option<Self::Scalar> {
221 if self.len() == self.null_count() {
222 None
223 } else if !self.has_nulls() {
224 Some(self.values().unset_bits() == 0)
225 } else {
226 let false_found = binary_fold(
227 self.values(),
228 self.validity().unwrap(),
229 |lhs, rhs| (!lhs & rhs) != 0,
230 false,
231 |a, b| a || b,
232 );
233 Some(!false_found)
234 }
235 }
236
237 fn reduce_or(&self) -> Option<Self::Scalar> {
238 if self.len() == self.null_count() {
239 None
240 } else if !self.has_nulls() {
241 Some(self.values().set_bits() > 0)
242 } else {
243 Some(intersects_with(self.values(), self.validity().unwrap()))
244 }
245 }
246
247 fn reduce_xor(&self) -> Option<Self::Scalar> {
248 if self.len() == self.null_count() {
249 None
250 } else if !self.has_nulls() {
251 Some(self.values().set_bits() % 2 == 1)
252 } else {
253 let nonnull_parity = binary_fold(
254 self.values(),
255 self.validity().unwrap(),
256 |lhs, rhs| lhs & rhs,
257 0,
258 |a, b| a ^ b,
259 );
260 Some(nonnull_parity.count_ones() % 2 == 1)
261 }
262 }
263
264 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
265 lhs & rhs
266 }
267 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
268 lhs | rhs
269 }
270 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
271 lhs ^ rhs
272 }
273}