polars_compute/bitwise/
mod.rs

1use std::convert::identity;
2
3use arrow::array::{Array, BooleanArray, PrimitiveArray};
4use arrow::bitmap::{binary_fold, intersects_with};
5use arrow::datatypes::ArrowDataType;
6use arrow::legacy::utils::CustomIterTools;
7
8pub trait BitwiseKernel {
9    type Scalar;
10
11    fn count_ones(&self) -> PrimitiveArray<u32>;
12    fn count_zeros(&self) -> PrimitiveArray<u32>;
13
14    fn leading_ones(&self) -> PrimitiveArray<u32>;
15    fn leading_zeros(&self) -> PrimitiveArray<u32>;
16
17    fn trailing_ones(&self) -> PrimitiveArray<u32>;
18    fn trailing_zeros(&self) -> PrimitiveArray<u32>;
19
20    fn reduce_and(&self) -> Option<Self::Scalar>;
21    fn reduce_or(&self) -> Option<Self::Scalar>;
22    fn reduce_xor(&self) -> Option<Self::Scalar>;
23
24    fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
25    fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
26    fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
27}
28
29macro_rules! impl_bitwise_kernel {
30    ($(($T:ty, $to_bits:expr, $from_bits:expr)),+ $(,)?) => {
31        $(
32        impl BitwiseKernel for PrimitiveArray<$T> {
33            type Scalar = $T;
34
35            #[inline(never)]
36            fn count_ones(&self) -> PrimitiveArray<u32> {
37                PrimitiveArray::new(
38                    ArrowDataType::UInt32,
39                    self.values_iter()
40                        .map(|&v| $to_bits(v).count_ones())
41                        .collect_trusted::<Vec<_>>()
42                        .into(),
43                    self.validity().cloned(),
44                )
45            }
46
47            #[inline(never)]
48            fn count_zeros(&self) -> PrimitiveArray<u32> {
49                PrimitiveArray::new(
50                    ArrowDataType::UInt32,
51                    self.values_iter()
52                        .map(|&v| $to_bits(v).count_zeros())
53                        .collect_trusted::<Vec<_>>()
54                        .into(),
55                    self.validity().cloned(),
56                )
57            }
58
59            #[inline(never)]
60            fn leading_ones(&self) -> PrimitiveArray<u32> {
61                PrimitiveArray::new(
62                    ArrowDataType::UInt32,
63                    self.values_iter()
64                        .map(|&v| $to_bits(v).leading_ones())
65                        .collect_trusted::<Vec<_>>()
66                        .into(),
67                    self.validity().cloned(),
68                )
69            }
70
71            #[inline(never)]
72            fn leading_zeros(&self) -> PrimitiveArray<u32> {
73                PrimitiveArray::new(
74                    ArrowDataType::UInt32,
75                    self.values_iter()
76                        .map(|&v| $to_bits(v).leading_zeros())
77                        .collect_trusted::<Vec<_>>()
78                        .into(),
79                    self.validity().cloned(),
80                )
81            }
82
83            #[inline(never)]
84            fn trailing_ones(&self) -> PrimitiveArray<u32> {
85                PrimitiveArray::new(
86                    ArrowDataType::UInt32,
87                    self.values_iter()
88                        .map(|&v| $to_bits(v).trailing_ones())
89                        .collect_trusted::<Vec<_>>()
90                        .into(),
91                    self.validity().cloned(),
92                )
93            }
94
95            #[inline(never)]
96            fn trailing_zeros(&self) -> PrimitiveArray<u32> {
97                PrimitiveArray::new(
98                    ArrowDataType::UInt32,
99                    self.values().iter()
100                        .map(|&v| $to_bits(v).trailing_zeros())
101                        .collect_trusted::<Vec<_>>()
102                        .into(),
103                    self.validity().cloned(),
104                )
105            }
106
107            #[inline(never)]
108            fn reduce_and(&self) -> Option<Self::Scalar> {
109                if !self.has_nulls() {
110                    self.values_iter().copied().map($to_bits).reduce(|a, b| a & b).map($from_bits)
111                } else {
112                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a & b).map($from_bits)
113                }
114            }
115
116            #[inline(never)]
117            fn reduce_or(&self) -> Option<Self::Scalar> {
118                if !self.has_nulls() {
119                    self.values_iter().copied().map($to_bits).reduce(|a, b| a | b).map($from_bits)
120                } else {
121                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a | b).map($from_bits)
122                }
123            }
124
125            #[inline(never)]
126            fn reduce_xor(&self) -> Option<Self::Scalar> {
127                if !self.has_nulls() {
128                    self.values_iter().copied().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
129                } else {
130                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
131                }
132            }
133
134            fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
135                $from_bits($to_bits(lhs) & $to_bits(rhs))
136            }
137            fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
138                $from_bits($to_bits(lhs) | $to_bits(rhs))
139            }
140            fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
141                $from_bits($to_bits(lhs) ^ $to_bits(rhs))
142            }
143        }
144        )+
145    };
146}
147
148impl_bitwise_kernel! {
149    (i8, identity, identity),
150    (i16, identity, identity),
151    (i32, identity, identity),
152    (i64, identity, identity),
153    (u8, identity, identity),
154    (u16, identity, identity),
155    (u32, identity, identity),
156    (u64, identity, identity),
157    (f32, f32::to_bits, f32::from_bits),
158    (f64, f64::to_bits, f64::from_bits),
159}
160
161#[cfg(feature = "dtype-u128")]
162impl_bitwise_kernel! {
163    (u128, identity, identity),
164}
165
166#[cfg(feature = "dtype-i128")]
167impl_bitwise_kernel! {
168    (i128, identity, identity),
169}
170
171impl BitwiseKernel for BooleanArray {
172    type Scalar = bool;
173
174    #[inline(never)]
175    fn count_ones(&self) -> PrimitiveArray<u32> {
176        PrimitiveArray::new(
177            ArrowDataType::UInt32,
178            self.values_iter()
179                .map(u32::from)
180                .collect_trusted::<Vec<_>>()
181                .into(),
182            self.validity().cloned(),
183        )
184    }
185
186    #[inline(never)]
187    fn count_zeros(&self) -> PrimitiveArray<u32> {
188        PrimitiveArray::new(
189            ArrowDataType::UInt32,
190            self.values_iter()
191                .map(|v| u32::from(!v))
192                .collect_trusted::<Vec<_>>()
193                .into(),
194            self.validity().cloned(),
195        )
196    }
197
198    #[inline(always)]
199    fn leading_ones(&self) -> PrimitiveArray<u32> {
200        self.count_ones()
201    }
202
203    #[inline(always)]
204    fn leading_zeros(&self) -> PrimitiveArray<u32> {
205        self.count_zeros()
206    }
207
208    #[inline(always)]
209    fn trailing_ones(&self) -> PrimitiveArray<u32> {
210        self.count_ones()
211    }
212
213    #[inline(always)]
214    fn trailing_zeros(&self) -> PrimitiveArray<u32> {
215        self.count_zeros()
216    }
217
218    fn reduce_and(&self) -> Option<Self::Scalar> {
219        if self.len() == self.null_count() {
220            None
221        } else if !self.has_nulls() {
222            Some(self.values().unset_bits() == 0)
223        } else {
224            let false_found = binary_fold(
225                self.values(),
226                self.validity().unwrap(),
227                |lhs, rhs| (!lhs & rhs) != 0,
228                false,
229                |a, b| a || b,
230            );
231            Some(!false_found)
232        }
233    }
234
235    fn reduce_or(&self) -> Option<Self::Scalar> {
236        if self.len() == self.null_count() {
237            None
238        } else if !self.has_nulls() {
239            Some(self.values().set_bits() > 0)
240        } else {
241            Some(intersects_with(self.values(), self.validity().unwrap()))
242        }
243    }
244
245    fn reduce_xor(&self) -> Option<Self::Scalar> {
246        if self.len() == self.null_count() {
247            None
248        } else if !self.has_nulls() {
249            Some(self.values().set_bits() % 2 == 1)
250        } else {
251            let nonnull_parity = binary_fold(
252                self.values(),
253                self.validity().unwrap(),
254                |lhs, rhs| lhs & rhs,
255                0,
256                |a, b| a ^ b,
257            );
258            Some(nonnull_parity.count_ones() % 2 == 1)
259        }
260    }
261
262    fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
263        lhs & rhs
264    }
265    fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
266        lhs | rhs
267    }
268    fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
269        lhs ^ rhs
270    }
271}