polars_arrow/bitmap/utils/chunk_iterator/
mod.rs

1mod chunks_exact;
2mod merge;
3
4pub use chunks_exact::BitChunksExact;
5pub(crate) use merge::merge_reversed;
6
7use crate::trusted_len::TrustedLen;
8pub use crate::types::BitChunk;
9use crate::types::BitChunkIter;
10
11/// Trait representing an exact iterator over bytes in [`BitChunk`].
12pub trait BitChunkIterExact<B: BitChunk>: TrustedLen<Item = B> {
13    /// The remainder of the iterator.
14    fn remainder(&self) -> B;
15
16    /// The number of items in the remainder
17    fn remainder_len(&self) -> usize;
18
19    /// An iterator over individual items of the remainder
20    #[inline]
21    fn remainder_iter(&self) -> BitChunkIter<B> {
22        BitChunkIter::new(self.remainder(), self.remainder_len())
23    }
24}
25
26/// This struct is used to efficiently iterate over bit masks by loading bytes on
27/// the stack with alignments of `uX`. This allows efficient iteration over bitmaps.
28#[derive(Debug)]
29pub struct BitChunks<'a, T: BitChunk> {
30    chunk_iterator: std::slice::ChunksExact<'a, u8>,
31    current: T,
32    remainder_bytes: &'a [u8],
33    last_chunk: T,
34    remaining: usize,
35    /// offset inside a byte
36    bit_offset: usize,
37    len: usize,
38    phantom: std::marker::PhantomData<T>,
39}
40
41/// writes `bytes` into `dst`.
42#[inline]
43fn copy_with_merge<T: BitChunk>(dst: &mut T::Bytes, bytes: &[u8], bit_offset: usize) {
44    bytes
45        .windows(2)
46        .chain(std::iter::once([bytes[bytes.len() - 1], 0].as_ref()))
47        .take(size_of::<T>())
48        .enumerate()
49        .for_each(|(i, w)| {
50            let val = merge_reversed(w[0], w[1], bit_offset);
51            dst[i] = val;
52        });
53}
54
55impl<'a, T: BitChunk> BitChunks<'a, T> {
56    /// Creates a [`BitChunks`].
57    pub fn new(slice: &'a [u8], offset: usize, len: usize) -> Self {
58        assert!(offset + len <= slice.len() * 8);
59
60        let slice = &slice[offset / 8..];
61        let bit_offset = offset % 8;
62        let size_of = size_of::<T>();
63
64        let bytes_len = len / 8;
65        let bytes_upper_len = (len + bit_offset + 7) / 8;
66        let mut chunks = slice[..bytes_len].chunks_exact(size_of);
67
68        let remainder = &slice[bytes_len - chunks.remainder().len()..bytes_upper_len];
69
70        let remainder_bytes = if chunks.len() == 0 { slice } else { remainder };
71
72        let last_chunk = remainder_bytes
73            .first()
74            .map(|first| {
75                let mut last = T::zero().to_ne_bytes();
76                last[0] = *first;
77                T::from_ne_bytes(last)
78            })
79            .unwrap_or_else(T::zero);
80
81        let remaining = chunks.size_hint().0;
82
83        let current = chunks
84            .next()
85            .map(|x| match x.try_into() {
86                Ok(a) => T::from_ne_bytes(a),
87                Err(_) => unreachable!(),
88            })
89            .unwrap_or_else(T::zero);
90
91        Self {
92            chunk_iterator: chunks,
93            len,
94            current,
95            remaining,
96            remainder_bytes,
97            last_chunk,
98            bit_offset,
99            phantom: std::marker::PhantomData,
100        }
101    }
102
103    #[inline]
104    fn load_next(&mut self) {
105        self.current = match self.chunk_iterator.next().unwrap().try_into() {
106            Ok(a) => T::from_ne_bytes(a),
107            Err(_) => unreachable!(),
108        };
109    }
110
111    /// Returns the remainder [`BitChunk`].
112    pub fn remainder(&self) -> T {
113        // remaining bytes may not fit in `size_of::<T>()`. We complement
114        // them to fit by allocating T and writing to it byte by byte
115        let mut remainder = T::zero().to_ne_bytes();
116
117        let remainder = match (self.remainder_bytes.is_empty(), self.bit_offset == 0) {
118            (true, _) => remainder,
119            (false, true) => {
120                // all remaining bytes
121                self.remainder_bytes
122                    .iter()
123                    .take(size_of::<T>())
124                    .enumerate()
125                    .for_each(|(i, val)| remainder[i] = *val);
126
127                remainder
128            },
129            (false, false) => {
130                // all remaining bytes
131                copy_with_merge::<T>(&mut remainder, self.remainder_bytes, self.bit_offset);
132                remainder
133            },
134        };
135        T::from_ne_bytes(remainder)
136    }
137
138    /// Returns the remainder bits in [`BitChunks::remainder`].
139    pub fn remainder_len(&self) -> usize {
140        self.len - (size_of::<T>() * ((self.len / 8) / size_of::<T>()) * 8)
141    }
142}
143
144impl<T: BitChunk> Iterator for BitChunks<'_, T> {
145    type Item = T;
146
147    #[inline]
148    fn next(&mut self) -> Option<T> {
149        if self.remaining == 0 {
150            return None;
151        }
152
153        let current = self.current;
154        let combined = if self.bit_offset == 0 {
155            // fast case where there is no offset. In this case, there is bit-alignment
156            // at byte boundary and thus the bytes correspond exactly.
157            if self.remaining >= 2 {
158                self.load_next();
159            }
160            current
161        } else {
162            let next = if self.remaining >= 2 {
163                // case where `next` is complete and thus we can take it all
164                self.load_next();
165                self.current
166            } else {
167                // case where the `next` is incomplete and thus we take the remaining
168                self.last_chunk
169            };
170            merge_reversed(current, next, self.bit_offset)
171        };
172
173        self.remaining -= 1;
174        Some(combined)
175    }
176
177    #[inline]
178    fn size_hint(&self) -> (usize, Option<usize>) {
179        // it contains always one more than the chunk_iterator, which is the last
180        // one where the remainder is merged into current.
181        (self.remaining, Some(self.remaining))
182    }
183}
184
185impl<T: BitChunk> BitChunkIterExact<T> for BitChunks<'_, T> {
186    #[inline]
187    fn remainder(&self) -> T {
188        self.remainder()
189    }
190
191    #[inline]
192    fn remainder_len(&self) -> usize {
193        self.remainder_len()
194    }
195}
196
197impl<T: BitChunk> ExactSizeIterator for BitChunks<'_, T> {
198    #[inline]
199    fn len(&self) -> usize {
200        self.chunk_iterator.len()
201    }
202}
203
204unsafe impl<T: BitChunk> TrustedLen for BitChunks<'_, T> {}