polars_core/frame/group_by/aggregations/
boolean.rs

1use arrow::bitmap::bitmask::BitMask;
2
3use super::*;
4use crate::chunked_array::cast::CastOptions;
5
6pub fn _agg_helper_idx_bool<F>(groups: &GroupsIdx, f: F) -> Series
7where
8    F: Fn((IdxSize, &IdxVec)) -> Option<bool> + Send + Sync,
9{
10    let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect());
11    ca.into_series()
12}
13
14pub fn _agg_helper_slice_bool<F>(groups: &[[IdxSize; 2]], f: F) -> Series
15where
16    F: Fn([IdxSize; 2]) -> Option<bool> + Send + Sync,
17{
18    let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect());
19    ca.into_series()
20}
21
22#[cfg(feature = "bitwise")]
23impl BooleanChunked {
24    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> BooleanChunked {
25        self.agg_all(groups, true)
26    }
27
28    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> BooleanChunked {
29        self.agg_any(groups, true)
30    }
31
32    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> BooleanChunked {
33        self.bool_agg(
34            groups,
35            true,
36            |values, idxs| {
37                idxs.iter()
38                    .map(|i| {
39                        <IdxSize as From<bool>>::from(unsafe {
40                            values.get_bit_unchecked(*i as usize)
41                        })
42                    })
43                    .sum::<IdxSize>()
44                    % 2
45                    == 1
46            },
47            |values, validity, idxs| {
48                idxs.iter()
49                    .map(|i| {
50                        <IdxSize as From<bool>>::from(unsafe {
51                            validity.get_bit_unchecked(*i as usize)
52                                & values.get_bit_unchecked(*i as usize)
53                        })
54                    })
55                    .sum::<IdxSize>()
56                    % 2
57                    == 0
58            },
59            |_, _, _| unreachable!(),
60            |values, start, length| {
61                unsafe { values.sliced_unchecked(start as usize, length as usize) }.set_bits() % 2
62                    == 1
63            },
64            |values, validity, start, length| {
65                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
66                let validity =
67                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
68                values.num_intersections_with(validity) % 2 == 1
69            },
70            |_, _, _, _| unreachable!(),
71        )
72    }
73}
74
75impl BooleanChunked {
76    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
77        // faster paths
78        match (self.is_sorted_flag(), self.null_count()) {
79            (IsSorted::Ascending, 0) => {
80                return self.clone().into_series().agg_first(groups);
81            },
82            (IsSorted::Descending, 0) => {
83                return self.clone().into_series().agg_last(groups);
84            },
85            _ => {},
86        }
87        let ca_self = self.rechunk();
88        let arr = ca_self.downcast_iter().next().unwrap();
89        let no_nulls = arr.null_count() == 0;
90        match groups {
91            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
92                debug_assert!(idx.len() <= self.len());
93                if idx.is_empty() {
94                    None
95                } else if idx.len() == 1 {
96                    arr.get(first as usize)
97                } else if no_nulls {
98                    take_min_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
99                } else {
100                    take_min_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
101                }
102            }),
103            GroupsType::Slice {
104                groups: groups_slice,
105                ..
106            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
107                debug_assert!(len <= self.len() as IdxSize);
108                match len {
109                    0 => None,
110                    1 => self.get(first as usize),
111                    _ => {
112                        let arr_group = _slice_from_offsets(self, first, len);
113                        arr_group.min()
114                    },
115                }
116            }),
117        }
118    }
119    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
120        // faster paths
121        match (self.is_sorted_flag(), self.null_count()) {
122            (IsSorted::Ascending, 0) => {
123                return self.clone().into_series().agg_last(groups);
124            },
125            (IsSorted::Descending, 0) => {
126                return self.clone().into_series().agg_first(groups);
127            },
128            _ => {},
129        }
130
131        let ca_self = self.rechunk();
132        let arr = ca_self.downcast_iter().next().unwrap();
133        let no_nulls = arr.null_count() == 0;
134        match groups {
135            GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
136                debug_assert!(idx.len() <= self.len());
137                if idx.is_empty() {
138                    None
139                } else if idx.len() == 1 {
140                    self.get(first as usize)
141                } else if no_nulls {
142                    take_max_bool_iter_unchecked_no_nulls(arr, idx2usize(idx))
143                } else {
144                    take_max_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize)
145                }
146            }),
147            GroupsType::Slice {
148                groups: groups_slice,
149                ..
150            } => _agg_helper_slice_bool(groups_slice, |[first, len]| {
151                debug_assert!(len <= self.len() as IdxSize);
152                match len {
153                    0 => None,
154                    1 => self.get(first as usize),
155                    _ => {
156                        let arr_group = _slice_from_offsets(self, first, len);
157                        arr_group.max()
158                    },
159                }
160            }),
161        }
162    }
163
164    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
165        self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)
166            .unwrap()
167            .agg_sum(groups)
168    }
169
170    /// # Safety
171    ///
172    /// Groups should be in correct.
173    #[expect(clippy::too_many_arguments)]
174    unsafe fn bool_agg(
175        &self,
176        groups: &GroupsType,
177        ignore_nulls: bool,
178
179        idx_no_valid: impl Fn(BitMask, &[IdxSize]) -> bool + Send + Sync,
180        idx_validity: impl Fn(BitMask, BitMask, &[IdxSize]) -> bool + Send + Sync,
181        idx_kleene: impl Fn(BitMask, BitMask, &[IdxSize]) -> Option<bool> + Send + Sync,
182
183        slice_no_valid: impl Fn(BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
184        slice_validity: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> bool + Send + Sync,
185        slice_kleene: impl Fn(BitMask, BitMask, IdxSize, IdxSize) -> Option<bool> + Send + Sync,
186    ) -> BooleanChunked {
187        let name = self.name().clone();
188        let values = self.rechunk();
189        let values = values.downcast_as_array();
190
191        let ca: BooleanChunked = POOL.install(|| {
192            let validity = values
193                .validity()
194                .filter(|v| v.unset_bits() > 0)
195                .map(BitMask::from_bitmap);
196            let values = BitMask::from_bitmap(values.values());
197
198            if !ignore_nulls && let Some(validity) = validity {
199                match groups {
200                    GroupsType::Idx(idx) => idx
201                        .into_par_iter()
202                        .map(|(_, idx)| idx_kleene(values, validity, idx))
203                        .collect(),
204                    GroupsType::Slice {
205                        groups,
206                        overlapping: _,
207                    } => groups
208                        .into_par_iter()
209                        .map(|[start, length]| slice_kleene(values, validity, *start, *length))
210                        .collect(),
211                }
212            } else {
213                match groups {
214                    GroupsType::Idx(idx) => match validity {
215                        None => idx
216                            .into_par_iter()
217                            .map(|(_, idx)| idx_no_valid(values, idx))
218                            .collect(),
219                        Some(validity) => idx
220                            .into_par_iter()
221                            .map(|(_, idx)| idx_validity(values, validity, idx))
222                            .collect(),
223                    },
224                    GroupsType::Slice {
225                        groups,
226                        overlapping: _,
227                    } => match validity {
228                        None => groups
229                            .into_par_iter()
230                            .map(|[start, length]| slice_no_valid(values, *start, *length))
231                            .collect(),
232                        Some(validity) => groups
233                            .into_par_iter()
234                            .map(|[start, length]| {
235                                slice_validity(values, validity, *start, *length)
236                            })
237                            .collect(),
238                    },
239                }
240            }
241        });
242        ca.with_name(name)
243    }
244
245    /// # Safety
246    ///
247    /// Groups should be in correct.
248    pub unsafe fn agg_any(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
249        self.bool_agg(
250            groups,
251            ignore_nulls,
252            |values, idxs| {
253                idxs.iter()
254                    .any(|i| unsafe { values.get_bit_unchecked(*i as usize) })
255            },
256            |values, validity, idxs| {
257                idxs.iter().any(|i| unsafe {
258                    validity.get_bit_unchecked(*i as usize) & values.get_bit_unchecked(*i as usize)
259                })
260            },
261            |values, validity, idxs| {
262                let mut saw_null = false;
263                for i in idxs.iter() {
264                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
265                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
266
267                    if is_valid & is_true {
268                        return Some(true);
269                    }
270                    saw_null |= !is_valid;
271                }
272                (!saw_null).then_some(false)
273            },
274            |values, start, length| {
275                unsafe { values.sliced_unchecked(start as usize, length as usize) }.leading_zeros()
276                    < length as usize
277            },
278            |values, validity, start, length| {
279                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
280                let validity =
281                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
282                values.intersects_with(validity)
283            },
284            |values, validity, start, length| {
285                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
286                let validity =
287                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
288
289                if values.intersects_with(validity) {
290                    Some(true)
291                } else if validity.unset_bits() == 0 {
292                    Some(false)
293                } else {
294                    None
295                }
296            },
297        )
298    }
299
300    /// # Safety
301    ///
302    /// Groups should be in correct.
303    pub unsafe fn agg_all(&self, groups: &GroupsType, ignore_nulls: bool) -> BooleanChunked {
304        self.bool_agg(
305            groups,
306            ignore_nulls,
307            |values, idxs| {
308                idxs.iter()
309                    .all(|i| unsafe { values.get_bit_unchecked(*i as usize) })
310            },
311            |values, validity, idxs| {
312                idxs.iter().all(|i| unsafe {
313                    !validity.get_bit_unchecked(*i as usize) | values.get_bit_unchecked(*i as usize)
314                })
315            },
316            |values, validity, idxs| {
317                let mut saw_null = false;
318                for i in idxs.iter() {
319                    let is_valid = unsafe { validity.get_bit_unchecked(*i as usize) };
320                    let is_true = unsafe { values.get_bit_unchecked(*i as usize) };
321
322                    if is_valid & !is_true {
323                        return Some(false);
324                    }
325                    saw_null |= !is_valid;
326                }
327                (!saw_null).then_some(true)
328            },
329            |values, start, length| {
330                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
331                values.unset_bits() == 0
332            },
333            |values, validity, start, length| {
334                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
335                let validity =
336                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
337                values.num_intersections_with(validity) == validity.set_bits()
338            },
339            |values, validity, start, length| {
340                let values = unsafe { values.sliced_unchecked(start as usize, length as usize) };
341                let validity =
342                    unsafe { validity.sliced_unchecked(start as usize, length as usize) };
343
344                let num_non_nulls = validity.set_bits();
345
346                if values.num_intersections_with(validity) < num_non_nulls {
347                    Some(false)
348                } else if num_non_nulls < values.len() {
349                    None
350                } else {
351                    Some(true)
352                }
353            },
354        )
355    }
356}