dsi_bitstream/impls/
bit_reader.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Tommaso Fontana
3 * SPDX-FileCopyrightText: 2023 Inria
4 * SPDX-FileCopyrightText: 2023 Sebastiano Vigna
5 *
6 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
7 */
8
9use core::convert::Infallible;
10#[cfg(feature = "mem_dbg")]
11use mem_dbg::{MemDbg, MemSize};
12use std::error::Error;
13
14use crate::codes::params::{DefaultReadParams, ReadParams};
15use crate::traits::*;
16
17/// An implementation of [`BitRead`] for a [`WordRead`] with word `u64` and of
18/// [`BitSeek`] for a [`WordSeek`].
19///
20/// This implementation accesses randomly the underlying [`WordRead`] without
21/// any buffering. It is usually slower than
22/// [`BufBitReader`](crate::impls::BufBitReader).
23///
24/// The peek word is `u32`. The value returned by
25/// [`peek_bits`](crate::traits::BitRead::peek_bits) contains at least 32 bits
26/// (extended with zeros beyond end of stream), that is, a full peek word.
27///
28/// The additional type parameter `RP` is used to select the parameters for the
29/// instantanous codes, but the casual user should be happy with the default
30/// value. See [`ReadParams`] for more details.
31///
32/// For additional flexibility, this structures implements [`std::io::Read`].
33/// Note that because of coherence rules it is not possible to implement
34/// [`std::io::Read`] for a generic [`BitRead`].
35
36#[derive(Debug, Clone)]
37#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
38pub struct BitReader<E: Endianness, WR, RP: ReadParams = DefaultReadParams> {
39    /// The stream which we will read words from.
40    data: WR,
41    /// The index of the current bit.
42    bit_index: u64,
43    _marker: core::marker::PhantomData<(E, RP)>,
44}
45
46impl<E: Endianness, WR, RP: ReadParams> BitReader<E, WR, RP> {
47    pub fn new(data: WR) -> Self {
48        check_tables(32);
49        Self {
50            data,
51            bit_index: 0,
52            _marker: core::marker::PhantomData,
53        }
54    }
55}
56
57impl<
58        E: Error + Send + Sync + 'static,
59        WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
60        RP: ReadParams,
61    > BitRead<BE> for BitReader<BE, WR, RP>
62{
63    type Error = <WR as WordRead>::Error;
64    type PeekWord = u32;
65
66    #[inline]
67    fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
68        self.bit_index += n_bits as u64;
69        Ok(())
70    }
71
72    #[inline]
73    fn read_bits(&mut self, n_bits: usize) -> Result<u64, Self::Error> {
74        if n_bits == 0 {
75            return Ok(0);
76        }
77
78        assert!(n_bits <= 64);
79
80        self.data.set_word_pos(self.bit_index / 64)?;
81        let in_word_offset = (self.bit_index % 64) as usize;
82
83        let res = if (in_word_offset + n_bits) <= 64 {
84            // single word access
85            let word = self.data.read_word()?.to_be();
86            (word << in_word_offset) >> (64 - n_bits)
87        } else {
88            // double word access
89            let high_word = self.data.read_word()?.to_be();
90            let low_word = self.data.read_word()?.to_be();
91            let shamt1 = 64 - n_bits;
92            let shamt2 = 128 - in_word_offset - n_bits;
93            ((high_word << in_word_offset) >> shamt1) | (low_word >> shamt2)
94        };
95        self.bit_index += n_bits as u64;
96        Ok(res)
97    }
98
99    #[inline]
100    fn peek_bits(&mut self, n_bits: usize) -> Result<u32, Self::Error> {
101        if n_bits == 0 {
102            return Ok(0);
103        }
104
105        assert!(n_bits <= 32);
106
107        self.data.set_word_pos(self.bit_index / 64)?;
108        let in_word_offset = (self.bit_index % 64) as usize;
109
110        let res = if (in_word_offset + n_bits) <= 64 {
111            // single word access
112            let word = self.data.read_word()?.to_be();
113            (word << in_word_offset) >> (64 - n_bits)
114        } else {
115            // double word access
116            let high_word = self.data.read_word()?.to_be();
117            let low_word = self.data.read_word()?.to_be();
118            let shamt1 = 64 - n_bits;
119            let shamt2 = 128 - in_word_offset - n_bits;
120            ((high_word << in_word_offset) >> shamt1) | (low_word >> shamt2)
121        };
122        Ok(res as u32)
123    }
124
125    #[inline]
126    fn read_unary(&mut self) -> Result<u64, Self::Error> {
127        self.data.set_word_pos(self.bit_index / 64)?;
128        let in_word_offset = self.bit_index % 64;
129        let mut bits_in_word = 64 - in_word_offset;
130        let mut total = 0;
131
132        let mut word = self.data.read_word()?.to_be();
133        word <<= in_word_offset;
134        loop {
135            let zeros = word.leading_zeros() as u64;
136            // the unary code fits in the word
137            if zeros < bits_in_word {
138                self.bit_index += total + zeros + 1;
139                return Ok(total + zeros);
140            }
141            total += bits_in_word;
142            bits_in_word = 64;
143            word = self.data.read_word()?.to_be();
144        }
145    }
146
147    #[inline(always)]
148    fn skip_bits_after_peek(&mut self, n: usize) {
149        self.bit_index += n as u64;
150    }
151}
152
153impl<WR: WordSeek, RP: ReadParams> BitSeek for BitReader<LE, WR, RP> {
154    type Error = Infallible;
155
156    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
157        Ok(self.bit_index)
158    }
159
160    fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
161        self.bit_index = bit_index;
162        Ok(())
163    }
164}
165
166impl<WR: WordSeek, RP: ReadParams> BitSeek for BitReader<BE, WR, RP> {
167    type Error = Infallible;
168
169    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
170        Ok(self.bit_index)
171    }
172
173    fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
174        self.bit_index = bit_index;
175        Ok(())
176    }
177}
178
179impl<
180        E: Error + Send + Sync + 'static,
181        WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
182        RP: ReadParams,
183    > BitRead<LE> for BitReader<LE, WR, RP>
184{
185    type Error = <WR as WordRead>::Error;
186    type PeekWord = u32;
187
188    #[inline]
189    fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
190        self.bit_index += n_bits as u64;
191        Ok(())
192    }
193
194    #[inline]
195    fn read_bits(&mut self, n_bits: usize) -> Result<u64, Self::Error> {
196        #[cfg(feature = "checks")]
197        assert!(n_bits <= 64);
198
199        if n_bits == 0 {
200            return Ok(0);
201        }
202
203        self.data.set_word_pos(self.bit_index / 64)?;
204        let in_word_offset = (self.bit_index % 64) as usize;
205
206        let res = if (in_word_offset + n_bits) <= 64 {
207            // single word access
208            let word = self.data.read_word()?.to_le();
209            let shamt = 64 - n_bits;
210            (word << (shamt - in_word_offset)) >> shamt
211        } else {
212            // double word access
213            let low_word = self.data.read_word()?.to_le();
214            let high_word = self.data.read_word()?.to_le();
215            let shamt1 = 128 - in_word_offset - n_bits;
216            let shamt2 = 64 - n_bits;
217            ((high_word << shamt1) >> shamt2) | (low_word >> in_word_offset)
218        };
219        self.bit_index += n_bits as u64;
220        Ok(res)
221    }
222
223    #[inline]
224    fn peek_bits(&mut self, n_bits: usize) -> Result<u32, Self::Error> {
225        if n_bits == 0 {
226            return Ok(0);
227        }
228
229        assert!(n_bits <= 32);
230
231        self.data.set_word_pos(self.bit_index / 64)?;
232        let in_word_offset = (self.bit_index % 64) as usize;
233
234        let res = if (in_word_offset + n_bits) <= 64 {
235            // single word access
236            let word = self.data.read_word()?.to_le();
237            let shamt = 64 - n_bits;
238            (word << (shamt - in_word_offset)) >> shamt
239        } else {
240            // double word access
241            let low_word = self.data.read_word()?.to_le();
242            let high_word = self.data.read_word()?.to_le();
243            let shamt1 = 128 - in_word_offset - n_bits;
244            let shamt2 = 64 - n_bits;
245            ((high_word << shamt1) >> shamt2) | (low_word >> in_word_offset)
246        };
247        Ok(res as u32)
248    }
249
250    #[inline]
251    fn read_unary(&mut self) -> Result<u64, Self::Error> {
252        self.data.set_word_pos(self.bit_index / 64)?;
253        let in_word_offset = self.bit_index % 64;
254        let mut bits_in_word = 64 - in_word_offset;
255        let mut total = 0;
256
257        let mut word = self.data.read_word()?.to_le();
258        word >>= in_word_offset;
259        loop {
260            let zeros = word.trailing_zeros() as u64;
261            // the unary code fits in the word
262            if zeros < bits_in_word {
263                self.bit_index += total + zeros + 1;
264                return Ok(total + zeros);
265            }
266            total += bits_in_word;
267            bits_in_word = 64;
268            word = self.data.read_word()?.to_le();
269        }
270    }
271
272    #[inline(always)]
273    fn skip_bits_after_peek(&mut self, n: usize) {
274        self.bit_index += n as u64;
275    }
276}
277
278impl<
279        E: Error + Send + Sync + 'static,
280        WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
281        RP: ReadParams,
282    > std::io::Read for BitReader<LE, WR, RP>
283{
284    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
285        let mut iter = buf.chunks_exact_mut(8);
286
287        for chunk in &mut iter {
288            let word = self
289                .read_bits(64)
290                .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
291            chunk.copy_from_slice(&word.to_le_bytes());
292        }
293
294        let rem = iter.into_remainder();
295        if !rem.is_empty() {
296            let word = self
297                .read_bits(rem.len() * 8)
298                .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
299            rem.copy_from_slice(&word.to_le_bytes()[..rem.len()]);
300        }
301
302        Ok(buf.len())
303    }
304}
305
306impl<
307        E: Error + Send + Sync + 'static,
308        WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
309        RP: ReadParams,
310    > std::io::Read for BitReader<BE, WR, RP>
311{
312    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
313        let mut iter = buf.chunks_exact_mut(8);
314
315        for chunk in &mut iter {
316            let word = self
317                .read_bits(64)
318                .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
319            chunk.copy_from_slice(&word.to_be_bytes());
320        }
321
322        let rem = iter.into_remainder();
323        if !rem.is_empty() {
324            let word = self
325                .read_bits(rem.len() * 8)
326                .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
327            rem.copy_from_slice(&word.to_be_bytes()[8 - rem.len()..]);
328        }
329
330        Ok(buf.len())
331    }
332}