Skip to main content

polars_compute/bitwise/
mod.rs

1use std::convert::identity;
2
3use arrow::array::{Array, BooleanArray, PrimitiveArray};
4use arrow::bitmap::{binary_fold, intersects_with};
5use arrow::datatypes::ArrowDataType;
6use arrow::legacy::utils::CustomIterTools;
7use polars_utils::float16::pf16;
8
9pub trait BitwiseKernel {
10    type Scalar;
11
12    fn count_ones(&self) -> PrimitiveArray<u32>;
13    fn count_zeros(&self) -> PrimitiveArray<u32>;
14
15    fn leading_ones(&self) -> PrimitiveArray<u32>;
16    fn leading_zeros(&self) -> PrimitiveArray<u32>;
17
18    fn trailing_ones(&self) -> PrimitiveArray<u32>;
19    fn trailing_zeros(&self) -> PrimitiveArray<u32>;
20
21    fn reduce_and(&self) -> Option<Self::Scalar>;
22    fn reduce_or(&self) -> Option<Self::Scalar>;
23    fn reduce_xor(&self) -> Option<Self::Scalar>;
24
25    fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
26    fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
27    fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
28}
29
30macro_rules! impl_bitwise_kernel {
31    ($(($T:ty, $to_bits:expr, $from_bits:expr)),+ $(,)?) => {
32        $(
33        impl BitwiseKernel for PrimitiveArray<$T> {
34            type Scalar = $T;
35
36            #[inline(never)]
37            fn count_ones(&self) -> PrimitiveArray<u32> {
38                PrimitiveArray::new(
39                    ArrowDataType::UInt32,
40                    self.values_iter()
41                        .map(|&v| $to_bits(v).count_ones())
42                        .collect_trusted::<Vec<_>>()
43                        .into(),
44                    self.validity().cloned(),
45                )
46            }
47
48            #[inline(never)]
49            fn count_zeros(&self) -> PrimitiveArray<u32> {
50                PrimitiveArray::new(
51                    ArrowDataType::UInt32,
52                    self.values_iter()
53                        .map(|&v| $to_bits(v).count_zeros())
54                        .collect_trusted::<Vec<_>>()
55                        .into(),
56                    self.validity().cloned(),
57                )
58            }
59
60            #[inline(never)]
61            fn leading_ones(&self) -> PrimitiveArray<u32> {
62                PrimitiveArray::new(
63                    ArrowDataType::UInt32,
64                    self.values_iter()
65                        .map(|&v| $to_bits(v).leading_ones())
66                        .collect_trusted::<Vec<_>>()
67                        .into(),
68                    self.validity().cloned(),
69                )
70            }
71
72            #[inline(never)]
73            fn leading_zeros(&self) -> PrimitiveArray<u32> {
74                PrimitiveArray::new(
75                    ArrowDataType::UInt32,
76                    self.values_iter()
77                        .map(|&v| $to_bits(v).leading_zeros())
78                        .collect_trusted::<Vec<_>>()
79                        .into(),
80                    self.validity().cloned(),
81                )
82            }
83
84            #[inline(never)]
85            fn trailing_ones(&self) -> PrimitiveArray<u32> {
86                PrimitiveArray::new(
87                    ArrowDataType::UInt32,
88                    self.values_iter()
89                        .map(|&v| $to_bits(v).trailing_ones())
90                        .collect_trusted::<Vec<_>>()
91                        .into(),
92                    self.validity().cloned(),
93                )
94            }
95
96            #[inline(never)]
97            fn trailing_zeros(&self) -> PrimitiveArray<u32> {
98                PrimitiveArray::new(
99                    ArrowDataType::UInt32,
100                    self.values().iter()
101                        .map(|&v| $to_bits(v).trailing_zeros())
102                        .collect_trusted::<Vec<_>>()
103                        .into(),
104                    self.validity().cloned(),
105                )
106            }
107
108            #[inline(never)]
109            fn reduce_and(&self) -> Option<Self::Scalar> {
110                if !self.has_nulls() {
111                    self.values_iter().copied().map($to_bits).reduce(|a, b| a & b).map($from_bits)
112                } else {
113                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a & b).map($from_bits)
114                }
115            }
116
117            #[inline(never)]
118            fn reduce_or(&self) -> Option<Self::Scalar> {
119                if !self.has_nulls() {
120                    self.values_iter().copied().map($to_bits).reduce(|a, b| a | b).map($from_bits)
121                } else {
122                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a | b).map($from_bits)
123                }
124            }
125
126            #[inline(never)]
127            fn reduce_xor(&self) -> Option<Self::Scalar> {
128                if !self.has_nulls() {
129                    self.values_iter().copied().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
130                } else {
131                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
132                }
133            }
134
135            fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
136                $from_bits($to_bits(lhs) & $to_bits(rhs))
137            }
138            fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
139                $from_bits($to_bits(lhs) | $to_bits(rhs))
140            }
141            fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
142                $from_bits($to_bits(lhs) ^ $to_bits(rhs))
143            }
144        }
145        )+
146    };
147}
148
149impl_bitwise_kernel! {
150    (i8, identity, identity),
151    (i16, identity, identity),
152    (i32, identity, identity),
153    (i64, identity, identity),
154    (u8, identity, identity),
155    (u16, identity, identity),
156    (u32, identity, identity),
157    (u64, identity, identity),
158    (pf16, pf16::to_bits, pf16::from_bits),
159    (f32, f32::to_bits, f32::from_bits),
160    (f64, f64::to_bits, f64::from_bits),
161}
162
163#[cfg(feature = "dtype-u128")]
164impl_bitwise_kernel! {
165    (u128, identity, identity),
166}
167
168#[cfg(feature = "dtype-i128")]
169impl_bitwise_kernel! {
170    (i128, identity, identity),
171}
172
173impl BitwiseKernel for BooleanArray {
174    type Scalar = bool;
175
176    #[inline(never)]
177    fn count_ones(&self) -> PrimitiveArray<u32> {
178        PrimitiveArray::new(
179            ArrowDataType::UInt32,
180            self.values_iter()
181                .map(u32::from)
182                .collect_trusted::<Vec<_>>()
183                .into(),
184            self.validity().cloned(),
185        )
186    }
187
188    #[inline(never)]
189    fn count_zeros(&self) -> PrimitiveArray<u32> {
190        PrimitiveArray::new(
191            ArrowDataType::UInt32,
192            self.values_iter()
193                .map(|v| u32::from(!v))
194                .collect_trusted::<Vec<_>>()
195                .into(),
196            self.validity().cloned(),
197        )
198    }
199
200    #[inline(always)]
201    fn leading_ones(&self) -> PrimitiveArray<u32> {
202        self.count_ones()
203    }
204
205    #[inline(always)]
206    fn leading_zeros(&self) -> PrimitiveArray<u32> {
207        self.count_zeros()
208    }
209
210    #[inline(always)]
211    fn trailing_ones(&self) -> PrimitiveArray<u32> {
212        self.count_ones()
213    }
214
215    #[inline(always)]
216    fn trailing_zeros(&self) -> PrimitiveArray<u32> {
217        self.count_zeros()
218    }
219
220    fn reduce_and(&self) -> Option<Self::Scalar> {
221        if self.len() == self.null_count() {
222            None
223        } else if !self.has_nulls() {
224            Some(self.values().unset_bits() == 0)
225        } else {
226            let false_found = binary_fold(
227                self.values(),
228                self.validity().unwrap(),
229                |lhs, rhs| (!lhs & rhs) != 0,
230                false,
231                |a, b| a || b,
232            );
233            Some(!false_found)
234        }
235    }
236
237    fn reduce_or(&self) -> Option<Self::Scalar> {
238        if self.len() == self.null_count() {
239            None
240        } else if !self.has_nulls() {
241            Some(self.values().set_bits() > 0)
242        } else {
243            Some(intersects_with(self.values(), self.validity().unwrap()))
244        }
245    }
246
247    fn reduce_xor(&self) -> Option<Self::Scalar> {
248        if self.len() == self.null_count() {
249            None
250        } else if !self.has_nulls() {
251            Some(self.values().set_bits() % 2 == 1)
252        } else {
253            let nonnull_parity = binary_fold(
254                self.values(),
255                self.validity().unwrap(),
256                |lhs, rhs| lhs & rhs,
257                0,
258                |a, b| a ^ b,
259            );
260            Some(nonnull_parity.count_ones() % 2 == 1)
261        }
262    }
263
264    fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
265        lhs & rhs
266    }
267    fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
268        lhs | rhs
269    }
270    fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
271        lhs ^ rhs
272    }
273}