polars_arrow/bitmap/utils/chunk_iterator/
mod.rs

1mod chunks_exact;
2mod merge;
3
4pub use chunks_exact::BitChunksExact;
5pub(crate) use merge::merge_reversed;
6
7use crate::trusted_len::TrustedLen;
8pub use crate::types::BitChunk;
9use crate::types::BitChunkIter;
10
11/// Trait representing an exact iterator over bytes in [`BitChunk`].
12pub trait BitChunkIterExact<B: BitChunk>: TrustedLen<Item = B> {
13    /// The remainder of the iterator.
14    fn remainder(&self) -> B;
15
16    /// The number of items in the remainder
17    fn remainder_len(&self) -> usize;
18
19    /// An iterator over individual items of the remainder
20    #[inline]
21    fn remainder_iter(&self) -> BitChunkIter<B> {
22        BitChunkIter::new(self.remainder(), self.remainder_len())
23    }
24}
25
26/// This struct is used to efficiently iterate over bit masks by loading bytes on
27/// the stack with alignments of `uX`. This allows efficient iteration over bitmaps.
28#[derive(Debug)]
29pub struct BitChunks<'a, T: BitChunk> {
30    chunk_iterator: std::slice::ChunksExact<'a, u8>,
31    current: T,
32    remainder_bytes: &'a [u8],
33    last_chunk: T,
34    remaining: usize,
35    /// offset inside a byte
36    bit_offset: usize,
37    len: usize,
38    phantom: std::marker::PhantomData<T>,
39}
40
41/// writes `bytes` into `dst`.
42#[inline]
43fn copy_with_merge<T: BitChunk>(dst: &mut T::Bytes, bytes: &[u8], bit_offset: usize) {
44    bytes
45        .windows(2)
46        .chain(std::iter::once([bytes[bytes.len() - 1], 0].as_ref()))
47        .take(size_of::<T>())
48        .enumerate()
49        .for_each(|(i, w)| {
50            let val = merge_reversed(w[0], w[1], bit_offset);
51            dst[i] = val;
52        });
53}
54
55impl<'a, T: BitChunk> BitChunks<'a, T> {
56    /// Creates a [`BitChunks`].
57    pub fn new(slice: &'a [u8], offset: usize, len: usize) -> Self {
58        assert!(offset + len <= slice.len() * 8);
59
60        let slice = &slice[offset / 8..];
61        let bit_offset = offset % 8;
62        let size_of = size_of::<T>();
63
64        let bytes_len = len / 8;
65        let bytes_upper_len = (len + bit_offset).div_ceil(8);
66        let mut chunks = slice[..bytes_len].chunks_exact(size_of);
67
68        let remainder = &slice[bytes_len - chunks.remainder().len()..bytes_upper_len];
69
70        let remainder_bytes = if chunks.len() == 0 { slice } else { remainder };
71
72        let last_chunk = remainder_bytes
73            .first()
74            .map(|first| {
75                let mut last = T::zero().to_ne_bytes();
76                last[0] = *first;
77                T::from_ne_bytes(last)
78            })
79            .unwrap_or_else(T::zero);
80
81        let remaining = chunks.size_hint().0;
82
83        let current = chunks
84            .next()
85            .map(|x| match x.try_into() {
86                Ok(a) => T::from_ne_bytes(a),
87                Err(_) => unreachable!(),
88            })
89            .unwrap_or_else(T::zero);
90
91        Self {
92            chunk_iterator: chunks,
93            len,
94            current,
95            remaining,
96            remainder_bytes,
97            last_chunk,
98            bit_offset,
99            phantom: std::marker::PhantomData,
100        }
101    }
102
103    #[inline]
104    fn load_next(&mut self) {
105        self.current = match self.chunk_iterator.next().unwrap().try_into() {
106            Ok(a) => T::from_ne_bytes(a),
107            Err(_) => unreachable!(),
108        };
109    }
110
111    /// Returns the remainder [`BitChunk`].
112    pub fn remainder(&self) -> T {
113        // remaining bytes may not fit in `size_of::<T>()`. We complement
114        // them to fit by allocating T and writing to it byte by byte
115        let mut remainder = T::zero().to_ne_bytes();
116
117        let remainder = match (self.remainder_bytes.is_empty(), self.bit_offset == 0) {
118            (true, _) => remainder,
119            (false, true) => {
120                // all remaining bytes
121                self.remainder_bytes
122                    .iter()
123                    .take(size_of::<T>())
124                    .enumerate()
125                    .for_each(|(i, val)| remainder[i] = *val);
126
127                remainder
128            },
129            (false, false) => {
130                // all remaining bytes
131                copy_with_merge::<T>(&mut remainder, self.remainder_bytes, self.bit_offset);
132                remainder
133            },
134        };
135        let mut remainder = T::from_ne_bytes(remainder);
136        let mask = (T::one() << self.remainder_len()) - T::one();
137        remainder &= mask;
138        remainder
139    }
140
141    /// Returns the remainder bits in [`BitChunks::remainder`].
142    pub fn remainder_len(&self) -> usize {
143        self.len - (size_of::<T>() * ((self.len / 8) / size_of::<T>()) * 8)
144    }
145}
146
147impl<T: BitChunk> Iterator for BitChunks<'_, T> {
148    type Item = T;
149
150    #[inline]
151    fn next(&mut self) -> Option<T> {
152        if self.remaining == 0 {
153            return None;
154        }
155
156        let current = self.current;
157        let combined = if self.bit_offset == 0 {
158            // fast case where there is no offset. In this case, there is bit-alignment
159            // at byte boundary and thus the bytes correspond exactly.
160            if self.remaining >= 2 {
161                self.load_next();
162            }
163            current
164        } else {
165            let next = if self.remaining >= 2 {
166                // case where `next` is complete and thus we can take it all
167                self.load_next();
168                self.current
169            } else {
170                // case where the `next` is incomplete and thus we take the remaining
171                self.last_chunk
172            };
173            merge_reversed(current, next, self.bit_offset)
174        };
175
176        self.remaining -= 1;
177        Some(combined)
178    }
179
180    #[inline]
181    fn size_hint(&self) -> (usize, Option<usize>) {
182        // it contains always one more than the chunk_iterator, which is the last
183        // one where the remainder is merged into current.
184        (self.remaining, Some(self.remaining))
185    }
186}
187
188impl<T: BitChunk> BitChunkIterExact<T> for BitChunks<'_, T> {
189    #[inline]
190    fn remainder(&self) -> T {
191        self.remainder()
192    }
193
194    #[inline]
195    fn remainder_len(&self) -> usize {
196        self.remainder_len()
197    }
198}
199
200impl<T: BitChunk> ExactSizeIterator for BitChunks<'_, T> {
201    #[inline]
202    fn len(&self) -> usize {
203        self.chunk_iterator.len()
204    }
205}
206
207unsafe impl<T: BitChunk> TrustedLen for BitChunks<'_, T> {}