Skip to main content

dsi_bitstream/utils/
count.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Sebastiano Vigna
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
5 */
6
7use crate::{
8    prelude::{
9        DeltaRead, DeltaWrite, GammaRead, GammaWrite, OmegaRead, OmegaWrite, PiRead, PiWrite,
10        ZetaRead, ZetaWrite, len_delta, len_gamma, len_omega, len_pi, len_zeta,
11    },
12    traits::*,
13};
14#[cfg(feature = "mem_dbg")]
15use mem_dbg::{MemDbg, MemSize};
16
17/// A wrapper around a [`BitWrite`] that keeps track of the number
18/// of bits written and optionally prints on standard error the
19/// operations performed on the stream.
20#[derive(Debug, Clone)]
21#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
22pub struct CountBitWriter<E: Endianness, BW: BitWrite<E>, const PRINT: bool = false> {
23    bit_write: BW,
24    /// The number of bits written so far on the underlying [`BitWrite`].
25    pub bits_written: usize,
26    _marker: core::marker::PhantomData<E>,
27}
28
29impl<E: Endianness, BW: BitWrite<E>, const PRINT: bool> CountBitWriter<E, BW, PRINT> {
30    #[must_use]
31    pub const fn new(bit_write: BW) -> Self {
32        Self {
33            bit_write,
34            bits_written: 0,
35            _marker: core::marker::PhantomData,
36        }
37    }
38
39    /// Consumes this writer and returns the underlying [`BitWrite`].
40    #[must_use]
41    pub fn into_inner(self) -> BW {
42        self.bit_write
43    }
44}
45
46impl<E: Endianness, BW: BitWrite<E>, const PRINT: bool> BitWrite<E>
47    for CountBitWriter<E, BW, PRINT>
48{
49    type Error = <BW as BitWrite<E>>::Error;
50
51    fn write_bits(&mut self, value: u64, num_bits: usize) -> Result<usize, Self::Error> {
52        self.bit_write.write_bits(value, num_bits).inspect(|x| {
53            self.bits_written += *x;
54            if PRINT {
55                #[cfg(feature = "std")]
56                eprintln!(
57                    "write_bits({:#016x}, {}) = {} (total = {})",
58                    value, num_bits, x, self.bits_written
59                );
60            }
61        })
62    }
63
64    fn write_unary(&mut self, n: u64) -> Result<usize, Self::Error> {
65        self.bit_write.write_unary(n).inspect(|x| {
66            self.bits_written += *x;
67            if PRINT {
68                #[cfg(feature = "std")]
69                eprintln!("write_unary({}) = {} (total = {})", n, x, self.bits_written);
70            }
71        })
72    }
73
74    fn flush(&mut self) -> Result<usize, Self::Error> {
75        self.bit_write.flush().inspect(|x| {
76            self.bits_written += *x;
77            if PRINT {
78                #[cfg(feature = "std")]
79                eprintln!("flush() = {} (total = {})", x, self.bits_written);
80            }
81        })
82    }
83}
84
85impl<E: Endianness, BW: BitWrite<E> + GammaWrite<E>, const PRINT: bool> GammaWrite<E>
86    for CountBitWriter<E, BW, PRINT>
87{
88    fn write_gamma(&mut self, n: u64) -> Result<usize, BW::Error> {
89        self.bit_write.write_gamma(n).inspect(|x| {
90            self.bits_written += *x;
91            if PRINT {
92                #[cfg(feature = "std")]
93                eprintln!("write_gamma({}) = {} (total = {})", n, x, self.bits_written);
94            }
95        })
96    }
97}
98
99impl<E: Endianness, BW: BitWrite<E> + DeltaWrite<E>, const PRINT: bool> DeltaWrite<E>
100    for CountBitWriter<E, BW, PRINT>
101{
102    fn write_delta(&mut self, n: u64) -> Result<usize, BW::Error> {
103        self.bit_write.write_delta(n).inspect(|x| {
104            self.bits_written += *x;
105            if PRINT {
106                #[cfg(feature = "std")]
107                eprintln!("write_delta({}) = {} (total = {})", n, x, self.bits_written);
108            }
109        })
110    }
111}
112
113impl<E: Endianness, BW: BitWrite<E> + ZetaWrite<E>, const PRINT: bool> ZetaWrite<E>
114    for CountBitWriter<E, BW, PRINT>
115{
116    fn write_zeta(&mut self, n: u64, k: usize) -> Result<usize, BW::Error> {
117        self.bit_write.write_zeta(n, k).inspect(|x| {
118            self.bits_written += *x;
119            if PRINT {
120                #[cfg(feature = "std")]
121                eprintln!(
122                    "write_zeta({}, {}) = {} (total = {})",
123                    n, k, x, self.bits_written
124                );
125            }
126        })
127    }
128
129    fn write_zeta3(&mut self, n: u64) -> Result<usize, BW::Error> {
130        self.bit_write.write_zeta3(n).inspect(|x| {
131            self.bits_written += *x;
132            if PRINT {
133                #[cfg(feature = "std")]
134                eprintln!("write_zeta3({}) = {} (total = {})", n, x, self.bits_written);
135            }
136        })
137    }
138}
139
140impl<E: Endianness, BW: BitWrite<E> + OmegaWrite<E>, const PRINT: bool> OmegaWrite<E>
141    for CountBitWriter<E, BW, PRINT>
142{
143    fn write_omega(&mut self, n: u64) -> Result<usize, BW::Error> {
144        self.bit_write.write_omega(n).inspect(|x| {
145            self.bits_written += *x;
146            if PRINT {
147                #[cfg(feature = "std")]
148                eprintln!("write_omega({}) = {} (total = {})", n, x, self.bits_written);
149            }
150        })
151    }
152}
153
154impl<E: Endianness, BW: BitWrite<E> + PiWrite<E>, const PRINT: bool> PiWrite<E>
155    for CountBitWriter<E, BW, PRINT>
156{
157    fn write_pi(&mut self, n: u64, k: usize) -> Result<usize, BW::Error> {
158        self.bit_write.write_pi(n, k).inspect(|x| {
159            self.bits_written += *x;
160            if PRINT {
161                #[cfg(feature = "std")]
162                eprintln!(
163                    "write_pi({}, {}) = {} (total = {})",
164                    n, k, x, self.bits_written
165                );
166            }
167        })
168    }
169
170    fn write_pi2(&mut self, n: u64) -> Result<usize, BW::Error> {
171        self.bit_write.write_pi2(n).inspect(|x| {
172            self.bits_written += *x;
173            if PRINT {
174                #[cfg(feature = "std")]
175                eprintln!("write_pi2({}) = {} (total = {})", n, x, self.bits_written);
176            }
177        })
178    }
179}
180
181impl<E: Endianness, BW: BitWrite<E> + BitSeek, const PRINT: bool> BitSeek
182    for CountBitWriter<E, BW, PRINT>
183{
184    type Error = <BW as BitSeek>::Error;
185
186    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
187        self.bit_write.bit_pos()
188    }
189
190    fn set_bit_pos(&mut self, bit_pos: u64) -> Result<(), Self::Error> {
191        self.bit_write.set_bit_pos(bit_pos)
192    }
193}
194
195/// A wrapper around a [`BitRead`] that keeps track of the number
196/// of bits read and optionally prints on standard error the
197/// operations performed on the stream.
198#[derive(Debug, Clone)]
199#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
200pub struct CountBitReader<E: Endianness, BR: BitRead<E>, const PRINT: bool = false> {
201    bit_read: BR,
202    /// The number of bits read (or skipped) so far from the underlying [`BitRead`].
203    pub bits_read: usize,
204    _marker: core::marker::PhantomData<E>,
205}
206
207impl<E: Endianness, BR: BitRead<E>, const PRINT: bool> CountBitReader<E, BR, PRINT> {
208    #[must_use]
209    pub const fn new(bit_read: BR) -> Self {
210        Self {
211            bit_read,
212            bits_read: 0,
213            _marker: core::marker::PhantomData,
214        }
215    }
216
217    /// Consumes this reader and returns the underlying [`BitRead`].
218    #[must_use]
219    pub fn into_inner(self) -> BR {
220        self.bit_read
221    }
222}
223
224impl<E: Endianness, BR: BitRead<E>, const PRINT: bool> BitRead<E> for CountBitReader<E, BR, PRINT> {
225    type Error = <BR as BitRead<E>>::Error;
226    type PeekWord = BR::PeekWord;
227    const PEEK_BITS: usize = BR::PEEK_BITS;
228
229    fn read_bits(&mut self, num_bits: usize) -> Result<u64, Self::Error> {
230        self.bit_read.read_bits(num_bits).inspect(|x| {
231            let _ = x;
232            self.bits_read += num_bits;
233            if PRINT {
234                #[cfg(feature = "std")]
235                eprintln!(
236                    "read_bits({}) = {:#016x} (total = {})",
237                    num_bits, x, self.bits_read
238                );
239            }
240        })
241    }
242
243    fn read_unary(&mut self) -> Result<u64, Self::Error> {
244        self.bit_read.read_unary().inspect(|x| {
245            self.bits_read += *x as usize + 1;
246            if PRINT {
247                #[cfg(feature = "std")]
248                eprintln!("read_unary() = {} (total = {})", x, self.bits_read);
249            }
250        })
251    }
252
253    fn peek_bits(&mut self, n_bits: usize) -> Result<Self::PeekWord, Self::Error> {
254        self.bit_read.peek_bits(n_bits)
255    }
256
257    fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
258        self.bits_read += n_bits;
259        if PRINT {
260            #[cfg(feature = "std")]
261            eprintln!("skip_bits({}) (total = {})", n_bits, self.bits_read);
262        }
263        self.bit_read.skip_bits(n_bits)
264    }
265
266    fn skip_bits_after_peek(&mut self, n: usize) {
267        self.bit_read.skip_bits_after_peek(n)
268    }
269}
270
271impl<E: Endianness, BR: BitRead<E> + GammaRead<E>, const PRINT: bool> GammaRead<E>
272    for CountBitReader<E, BR, PRINT>
273{
274    fn read_gamma(&mut self) -> Result<u64, BR::Error> {
275        self.bit_read.read_gamma().inspect(|x| {
276            self.bits_read += len_gamma(*x);
277            if PRINT {
278                #[cfg(feature = "std")]
279                eprintln!("read_gamma() = {} (total = {})", x, self.bits_read);
280            }
281        })
282    }
283}
284
285impl<E: Endianness, BR: BitRead<E> + DeltaRead<E>, const PRINT: bool> DeltaRead<E>
286    for CountBitReader<E, BR, PRINT>
287{
288    fn read_delta(&mut self) -> Result<u64, BR::Error> {
289        self.bit_read.read_delta().inspect(|x| {
290            self.bits_read += len_delta(*x);
291            if PRINT {
292                #[cfg(feature = "std")]
293                eprintln!("read_delta() = {} (total = {})", x, self.bits_read);
294            }
295        })
296    }
297}
298
299impl<E: Endianness, BR: BitRead<E> + ZetaRead<E>, const PRINT: bool> ZetaRead<E>
300    for CountBitReader<E, BR, PRINT>
301{
302    fn read_zeta(&mut self, k: usize) -> Result<u64, BR::Error> {
303        self.bit_read.read_zeta(k).inspect(|x| {
304            self.bits_read += len_zeta(*x, k);
305            if PRINT {
306                #[cfg(feature = "std")]
307                eprintln!("read_zeta({}) = {} (total = {})", k, x, self.bits_read);
308            }
309        })
310    }
311
312    fn read_zeta3(&mut self) -> Result<u64, BR::Error> {
313        self.bit_read.read_zeta3().inspect(|x| {
314            self.bits_read += len_zeta(*x, 3);
315            if PRINT {
316                #[cfg(feature = "std")]
317                eprintln!("read_zeta3() = {} (total = {})", x, self.bits_read);
318            }
319        })
320    }
321}
322
323impl<E: Endianness, BR: BitRead<E> + OmegaRead<E>, const PRINT: bool> OmegaRead<E>
324    for CountBitReader<E, BR, PRINT>
325{
326    fn read_omega(&mut self) -> Result<u64, BR::Error> {
327        self.bit_read.read_omega().inspect(|x| {
328            self.bits_read += len_omega(*x);
329            if PRINT {
330                #[cfg(feature = "std")]
331                eprintln!("read_omega() = {} (total = {})", x, self.bits_read);
332            }
333        })
334    }
335}
336
337impl<E: Endianness, BR: BitRead<E> + PiRead<E>, const PRINT: bool> PiRead<E>
338    for CountBitReader<E, BR, PRINT>
339{
340    fn read_pi(&mut self, k: usize) -> Result<u64, BR::Error> {
341        self.bit_read.read_pi(k).inspect(|x| {
342            self.bits_read += len_pi(*x, k);
343            if PRINT {
344                #[cfg(feature = "std")]
345                eprintln!("read_pi({}) = {} (total = {})", k, x, self.bits_read);
346            }
347        })
348    }
349
350    fn read_pi2(&mut self) -> Result<u64, BR::Error> {
351        self.bit_read.read_pi2().inspect(|x| {
352            self.bits_read += len_pi(*x, 2);
353            if PRINT {
354                #[cfg(feature = "std")]
355                eprintln!("read_pi2() = {} (total = {})", x, self.bits_read);
356            }
357        })
358    }
359}
360
361impl<E: Endianness, BR: BitRead<E> + BitSeek, const PRINT: bool> BitSeek
362    for CountBitReader<E, BR, PRINT>
363{
364    type Error = <BR as BitSeek>::Error;
365
366    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
367        self.bit_read.bit_pos()
368    }
369
370    fn set_bit_pos(&mut self, bit_pos: u64) -> Result<(), Self::Error> {
371        self.bit_read.set_bit_pos(bit_pos)
372    }
373}
374
375#[cfg(test)]
376#[cfg(feature = "std")]
377mod tests {
378    use super::*;
379    use crate::prelude::*;
380
381    #[test]
382    fn test_count() -> Result<(), Box<dyn core::error::Error + Send + Sync + 'static>> {
383        let mut buffer = <Vec<u64>>::new();
384        let bit_write = <BufBitWriter<LE, _>>::new(MemWordWriterVec::new(&mut buffer));
385        let mut count_bit_write = CountBitWriter::<_, _, true>::new(bit_write);
386
387        count_bit_write.write_unary(5)?;
388        assert_eq!(count_bit_write.bits_written, 6);
389        count_bit_write.write_unary(100)?;
390        assert_eq!(count_bit_write.bits_written, 107);
391        count_bit_write.write_bits(1, 20)?;
392        assert_eq!(count_bit_write.bits_written, 127);
393        count_bit_write.write_bits(1, 33)?;
394        assert_eq!(count_bit_write.bits_written, 160);
395        count_bit_write.write_gamma(2)?;
396        assert_eq!(count_bit_write.bits_written, 163);
397        count_bit_write.write_delta(1)?;
398        assert_eq!(count_bit_write.bits_written, 167);
399        count_bit_write.write_zeta(0, 4)?;
400        assert_eq!(count_bit_write.bits_written, 171);
401        count_bit_write.write_zeta3(0)?;
402        assert_eq!(count_bit_write.bits_written, 174);
403        count_bit_write.write_omega(3)?;
404        assert_eq!(count_bit_write.bits_written, 174 + len_omega(3));
405        let after_omega = count_bit_write.bits_written;
406        count_bit_write.write_pi(5, 3)?;
407        assert_eq!(count_bit_write.bits_written, after_omega + len_pi(5, 3));
408        let after_pi = count_bit_write.bits_written;
409        count_bit_write.write_pi2(7)?;
410        assert_eq!(count_bit_write.bits_written, after_pi + len_pi(7, 2));
411        let after_pi2 = count_bit_write.bits_written;
412        count_bit_write.flush()?;
413        drop(count_bit_write);
414
415        let bit_read = <BufBitReader<LE, _>>::new(MemWordReader::<u64, _>::new(&buffer));
416        let mut count_bit_read = CountBitReader::<_, _, true>::new(bit_read);
417
418        assert_eq!(count_bit_read.peek_bits(5)?, 0);
419        assert_eq!(count_bit_read.read_unary()?, 5);
420        assert_eq!(count_bit_read.bits_read, 6);
421        assert_eq!(count_bit_read.read_unary()?, 100);
422        assert_eq!(count_bit_read.bits_read, 107);
423        assert_eq!(count_bit_read.read_bits(20)?, 1);
424        assert_eq!(count_bit_read.bits_read, 127);
425        count_bit_read.skip_bits(33)?;
426        assert_eq!(count_bit_read.bits_read, 160);
427        assert_eq!(count_bit_read.read_gamma()?, 2);
428        assert_eq!(count_bit_read.bits_read, 163);
429        assert_eq!(count_bit_read.read_delta()?, 1);
430        assert_eq!(count_bit_read.bits_read, 167);
431        assert_eq!(count_bit_read.read_zeta(4)?, 0);
432        assert_eq!(count_bit_read.bits_read, 171);
433        assert_eq!(count_bit_read.read_zeta3()?, 0);
434        assert_eq!(count_bit_read.bits_read, 174);
435        assert_eq!(count_bit_read.read_omega()?, 3);
436        assert_eq!(count_bit_read.bits_read, after_omega);
437        assert_eq!(count_bit_read.read_pi(3)?, 5);
438        assert_eq!(count_bit_read.bits_read, after_pi);
439        assert_eq!(count_bit_read.read_pi2()?, 7);
440        assert_eq!(count_bit_read.bits_read, after_pi2);
441
442        Ok(())
443    }
444}