dsi_bitstream/impls/
bit_reader.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Tommaso Fontana
3 * SPDX-FileCopyrightText: 2023 Inria
4 * SPDX-FileCopyrightText: 2023 Sebastiano Vigna
5 *
6 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
7 */
8
9use core::convert::Infallible;
10use core::error::Error;
11#[cfg(feature = "mem_dbg")]
12use mem_dbg::{MemDbg, MemSize};
13
14use crate::codes::params::{DefaultReadParams, ReadParams};
15use crate::traits::*;
16
17/// An implementation of [`BitRead`] for a [`WordRead`] with word `u64` and of
18/// [`BitSeek`] for a [`WordSeek`].
19///
20/// This implementation accesses randomly the underlying [`WordRead`] without
21/// any buffering. It is usually slower than
22/// [`BufBitReader`](crate::impls::BufBitReader).
23///
24/// The peek word is `u32`. The value returned by
25/// [`peek_bits`](crate::traits::BitRead::peek_bits) contains at least 32 bits
26/// (extended with zeros beyond end of stream), that is, a full peek word.
27///
28/// The additional type parameter `RP` is used to select the parameters for the
29/// instantanous codes, but the casual user should be happy with the default
30/// value. See [`ReadParams`] for more details.
31///
32/// For additional flexibility, this structures implements [`std::io::Read`].
33/// Note that because of coherence rules it is not possible to implement
34/// [`std::io::Read`] for a generic [`BitRead`].
35
36#[derive(Debug, Clone)]
37#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
38pub struct BitReader<E: Endianness, WR, RP: ReadParams = DefaultReadParams> {
39    /// The stream which we will read words from.
40    data: WR,
41    /// The index of the current bit.
42    bit_index: u64,
43    _marker: core::marker::PhantomData<(E, RP)>,
44}
45
46impl<E: Endianness, WR, RP: ReadParams> BitReader<E, WR, RP> {
47    pub fn new(data: WR) -> Self {
48        #[cfg(feature = "std")]
49        check_tables(32);
50        Self {
51            data,
52            bit_index: 0,
53            _marker: core::marker::PhantomData,
54        }
55    }
56}
57
58impl<
59    E: Error + Send + Sync + 'static,
60    WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
61    RP: ReadParams,
62> BitRead<BE> for BitReader<BE, WR, RP>
63{
64    type Error = <WR as WordRead>::Error;
65    type PeekWord = u32;
66
67    #[inline]
68    fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
69        self.bit_index += n_bits as u64;
70        Ok(())
71    }
72
73    #[inline]
74    fn read_bits(&mut self, n_bits: usize) -> Result<u64, Self::Error> {
75        if n_bits == 0 {
76            return Ok(0);
77        }
78
79        assert!(n_bits <= 64);
80
81        self.data.set_word_pos(self.bit_index / 64)?;
82        let in_word_offset = (self.bit_index % 64) as usize;
83
84        let res = if (in_word_offset + n_bits) <= 64 {
85            // single word access
86            let word = self.data.read_word()?.to_be();
87            (word << in_word_offset) >> (64 - n_bits)
88        } else {
89            // double word access
90            let high_word = self.data.read_word()?.to_be();
91            let low_word = self.data.read_word()?.to_be();
92            let shamt1 = 64 - n_bits;
93            let shamt2 = 128 - in_word_offset - n_bits;
94            ((high_word << in_word_offset) >> shamt1) | (low_word >> shamt2)
95        };
96        self.bit_index += n_bits as u64;
97        Ok(res)
98    }
99
100    #[inline]
101    fn peek_bits(&mut self, n_bits: usize) -> Result<u32, Self::Error> {
102        if n_bits == 0 {
103            return Ok(0);
104        }
105
106        assert!(n_bits <= 32);
107
108        self.data.set_word_pos(self.bit_index / 64)?;
109        let in_word_offset = (self.bit_index % 64) as usize;
110
111        let res = if (in_word_offset + n_bits) <= 64 {
112            // single word access
113            let word = self.data.read_word()?.to_be();
114            (word << in_word_offset) >> (64 - n_bits)
115        } else {
116            // double word access
117            let high_word = self.data.read_word()?.to_be();
118            let low_word = self.data.read_word()?.to_be();
119            let shamt1 = 64 - n_bits;
120            let shamt2 = 128 - in_word_offset - n_bits;
121            ((high_word << in_word_offset) >> shamt1) | (low_word >> shamt2)
122        };
123        Ok(res as u32)
124    }
125
126    #[inline]
127    fn read_unary(&mut self) -> Result<u64, Self::Error> {
128        self.data.set_word_pos(self.bit_index / 64)?;
129        let in_word_offset = self.bit_index % 64;
130        let mut bits_in_word = 64 - in_word_offset;
131        let mut total = 0;
132
133        let mut word = self.data.read_word()?.to_be();
134        word <<= in_word_offset;
135        loop {
136            let zeros = word.leading_zeros() as u64;
137            // the unary code fits in the word
138            if zeros < bits_in_word {
139                self.bit_index += total + zeros + 1;
140                return Ok(total + zeros);
141            }
142            total += bits_in_word;
143            bits_in_word = 64;
144            word = self.data.read_word()?.to_be();
145        }
146    }
147
148    #[inline(always)]
149    fn skip_bits_after_peek(&mut self, n: usize) {
150        self.bit_index += n as u64;
151    }
152}
153
154impl<WR: WordSeek, RP: ReadParams> BitSeek for BitReader<LE, WR, RP> {
155    type Error = Infallible;
156
157    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
158        Ok(self.bit_index)
159    }
160
161    fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
162        self.bit_index = bit_index;
163        Ok(())
164    }
165}
166
167impl<WR: WordSeek, RP: ReadParams> BitSeek for BitReader<BE, WR, RP> {
168    type Error = Infallible;
169
170    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
171        Ok(self.bit_index)
172    }
173
174    fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
175        self.bit_index = bit_index;
176        Ok(())
177    }
178}
179
180impl<
181    E: Error + Send + Sync + 'static,
182    WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
183    RP: ReadParams,
184> BitRead<LE> for BitReader<LE, WR, RP>
185{
186    type Error = <WR as WordRead>::Error;
187    type PeekWord = u32;
188
189    #[inline]
190    fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
191        self.bit_index += n_bits as u64;
192        Ok(())
193    }
194
195    #[inline]
196    fn read_bits(&mut self, n_bits: usize) -> Result<u64, Self::Error> {
197        #[cfg(feature = "checks")]
198        assert!(n_bits <= 64);
199
200        if n_bits == 0 {
201            return Ok(0);
202        }
203
204        self.data.set_word_pos(self.bit_index / 64)?;
205        let in_word_offset = (self.bit_index % 64) as usize;
206
207        let res = if (in_word_offset + n_bits) <= 64 {
208            // single word access
209            let word = self.data.read_word()?.to_le();
210            let shamt = 64 - n_bits;
211            (word << (shamt - in_word_offset)) >> shamt
212        } else {
213            // double word access
214            let low_word = self.data.read_word()?.to_le();
215            let high_word = self.data.read_word()?.to_le();
216            let shamt1 = 128 - in_word_offset - n_bits;
217            let shamt2 = 64 - n_bits;
218            ((high_word << shamt1) >> shamt2) | (low_word >> in_word_offset)
219        };
220        self.bit_index += n_bits as u64;
221        Ok(res)
222    }
223
224    #[inline]
225    fn peek_bits(&mut self, n_bits: usize) -> Result<u32, Self::Error> {
226        if n_bits == 0 {
227            return Ok(0);
228        }
229
230        assert!(n_bits <= 32);
231
232        self.data.set_word_pos(self.bit_index / 64)?;
233        let in_word_offset = (self.bit_index % 64) as usize;
234
235        let res = if (in_word_offset + n_bits) <= 64 {
236            // single word access
237            let word = self.data.read_word()?.to_le();
238            let shamt = 64 - n_bits;
239            (word << (shamt - in_word_offset)) >> shamt
240        } else {
241            // double word access
242            let low_word = self.data.read_word()?.to_le();
243            let high_word = self.data.read_word()?.to_le();
244            let shamt1 = 128 - in_word_offset - n_bits;
245            let shamt2 = 64 - n_bits;
246            ((high_word << shamt1) >> shamt2) | (low_word >> in_word_offset)
247        };
248        Ok(res as u32)
249    }
250
251    #[inline]
252    fn read_unary(&mut self) -> Result<u64, Self::Error> {
253        self.data.set_word_pos(self.bit_index / 64)?;
254        let in_word_offset = self.bit_index % 64;
255        let mut bits_in_word = 64 - in_word_offset;
256        let mut total = 0;
257
258        let mut word = self.data.read_word()?.to_le();
259        word >>= in_word_offset;
260        loop {
261            let zeros = word.trailing_zeros() as u64;
262            // the unary code fits in the word
263            if zeros < bits_in_word {
264                self.bit_index += total + zeros + 1;
265                return Ok(total + zeros);
266            }
267            total += bits_in_word;
268            bits_in_word = 64;
269            word = self.data.read_word()?.to_le();
270        }
271    }
272
273    #[inline(always)]
274    fn skip_bits_after_peek(&mut self, n: usize) {
275        self.bit_index += n as u64;
276    }
277}
278
279#[cfg(feature = "std")]
280impl<
281    E: Error + Send + Sync + 'static,
282    WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
283    RP: ReadParams,
284> std::io::Read for BitReader<LE, WR, RP>
285{
286    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
287        let mut iter = buf.chunks_exact_mut(8);
288
289        for chunk in &mut iter {
290            let word = self
291                .read_bits(64)
292                .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
293            chunk.copy_from_slice(&word.to_le_bytes());
294        }
295
296        let rem = iter.into_remainder();
297        if !rem.is_empty() {
298            let word = self
299                .read_bits(rem.len() * 8)
300                .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
301            rem.copy_from_slice(&word.to_le_bytes()[..rem.len()]);
302        }
303
304        Ok(buf.len())
305    }
306}
307
308#[cfg(feature = "std")]
309impl<
310    E: Error + Send + Sync + 'static,
311    WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
312    RP: ReadParams,
313> std::io::Read for BitReader<BE, WR, RP>
314{
315    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
316        let mut iter = buf.chunks_exact_mut(8);
317
318        for chunk in &mut iter {
319            let word = self
320                .read_bits(64)
321                .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
322            chunk.copy_from_slice(&word.to_be_bytes());
323        }
324
325        let rem = iter.into_remainder();
326        if !rem.is_empty() {
327            let word = self
328                .read_bits(rem.len() * 8)
329                .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
330            rem.copy_from_slice(&word.to_be_bytes()[8 - rem.len()..]);
331        }
332
333        Ok(buf.len())
334    }
335}