dsi_bitstream/utils/
count.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Sebastiano Vigna
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
5 */
6
7use crate::{
8    prelude::{
9        len_delta, len_gamma, len_zeta, DeltaRead, DeltaWrite, GammaRead, GammaWrite, ZetaRead,
10        ZetaWrite,
11    },
12    traits::*,
13};
14#[cfg(feature = "mem_dbg")]
15use mem_dbg::{MemDbg, MemSize};
16
17/// A wrapper around a [`BitWrite`] that keeps track of the number of
18/// bits written and optionally prints on standard error the operations performed on the stream.
19#[derive(Debug, Clone)]
20#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
21pub struct CountBitWriter<E: Endianness, BW: BitWrite<E>, const PRINT: bool = false> {
22    bit_write: BW,
23    /// The number of bits written so far on the underlying [`BitWrite`].
24    pub bits_written: usize,
25    _marker: core::marker::PhantomData<E>,
26}
27
28impl<E: Endianness, BW: BitWrite<E>, const PRINT: bool> CountBitWriter<E, BW, PRINT> {
29    pub fn new(bit_write: BW) -> Self {
30        Self {
31            bit_write,
32            bits_written: 0,
33            _marker: core::marker::PhantomData,
34        }
35    }
36
37    pub fn into_inner(self) -> BW {
38        self.bit_write
39    }
40}
41
42impl<E: Endianness, BW: BitWrite<E>, const PRINT: bool> BitWrite<E>
43    for CountBitWriter<E, BW, PRINT>
44{
45    type Error = <BW as BitWrite<E>>::Error;
46
47    fn write_bits(&mut self, value: u64, n_bits: usize) -> Result<usize, Self::Error> {
48        self.bit_write.write_bits(value, n_bits).inspect(|x| {
49            self.bits_written += *x;
50            if PRINT {
51                #[cfg(feature = "std")]
52                eprintln!(
53                    "write_bits({:#016x}, {}) = {} (total = {})",
54                    value, n_bits, x, self.bits_written
55                );
56            }
57        })
58    }
59
60    fn write_unary(&mut self, value: u64) -> Result<usize, Self::Error> {
61        self.bit_write.write_unary(value).inspect(|x| {
62            self.bits_written += *x;
63            if PRINT {
64                #[cfg(feature = "std")]
65                eprintln!(
66                    "write_unary({}) = {} (total = {})",
67                    value, x, self.bits_written
68                );
69            }
70        })
71    }
72
73    fn flush(&mut self) -> Result<usize, Self::Error> {
74        self.bit_write.flush().inspect(|x| {
75            self.bits_written += *x;
76            if PRINT {
77                #[cfg(feature = "std")]
78                eprintln!("flush() = {} (total = {})", x, self.bits_written);
79            }
80        })
81    }
82}
83
84impl<E: Endianness, BW: BitWrite<E> + GammaWrite<E>, const PRINT: bool> GammaWrite<E>
85    for CountBitWriter<E, BW, PRINT>
86{
87    fn write_gamma(&mut self, value: u64) -> Result<usize, BW::Error> {
88        self.bit_write.write_gamma(value).inspect(|x| {
89            self.bits_written += *x;
90            if PRINT {
91                #[cfg(feature = "std")]
92                eprintln!(
93                    "write_gamma({}) = {} (total = {})",
94                    value, x, self.bits_written
95                );
96            }
97        })
98    }
99}
100
101impl<E: Endianness, BW: BitWrite<E> + DeltaWrite<E>, const PRINT: bool> DeltaWrite<E>
102    for CountBitWriter<E, BW, PRINT>
103{
104    fn write_delta(&mut self, value: u64) -> Result<usize, BW::Error> {
105        self.bit_write.write_delta(value).inspect(|x| {
106            self.bits_written += *x;
107            if PRINT {
108                #[cfg(feature = "std")]
109                eprintln!(
110                    "write_delta({}) = {} (total = {})",
111                    value, x, self.bits_written
112                );
113            }
114        })
115    }
116}
117
118impl<E: Endianness, BW: BitWrite<E> + ZetaWrite<E>, const PRINT: bool> ZetaWrite<E>
119    for CountBitWriter<E, BW, PRINT>
120{
121    fn write_zeta(&mut self, value: u64, k: usize) -> Result<usize, BW::Error> {
122        self.bit_write.write_zeta(value, k).inspect(|x| {
123            self.bits_written += *x;
124            if PRINT {
125                #[cfg(feature = "std")]
126                eprintln!(
127                    "write_zeta({}, {}) = {} (total = {})",
128                    value, x, k, self.bits_written
129                );
130            }
131        })
132    }
133
134    fn write_zeta3(&mut self, value: u64) -> Result<usize, BW::Error> {
135        self.bit_write.write_zeta3(value).inspect(|x| {
136            self.bits_written += *x;
137            if PRINT {
138                #[cfg(feature = "std")]
139                eprintln!(
140                    "write_zeta({}) = {} (total = {})",
141                    value, x, self.bits_written
142                );
143            }
144        })
145    }
146}
147
148impl<E: Endianness, BR: BitWrite<E> + BitSeek, const PRINT: bool> BitSeek
149    for CountBitWriter<E, BR, PRINT>
150{
151    type Error = <BR as BitSeek>::Error;
152
153    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
154        self.bit_write.bit_pos()
155    }
156
157    fn set_bit_pos(&mut self, bit_pos: u64) -> Result<(), Self::Error> {
158        self.bit_write.set_bit_pos(bit_pos)
159    }
160}
161
162/// A wrapper around a [`BitRead`] that keeps track of the number of
163/// bits read and optionally prints on standard error the operations performed on the stream.
164#[derive(Debug, Clone)]
165#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
166pub struct CountBitReader<E: Endianness, BR: BitRead<E>, const PRINT: bool = false> {
167    bit_read: BR,
168    /// The number of bits read (or skipped) so far from the underlying [`BitRead`].
169    pub bits_read: usize,
170    _marker: core::marker::PhantomData<E>,
171}
172
173impl<E: Endianness, BR: BitRead<E>, const PRINT: bool> CountBitReader<E, BR, PRINT> {
174    pub fn new(bit_read: BR) -> Self {
175        Self {
176            bit_read,
177            bits_read: 0,
178            _marker: core::marker::PhantomData,
179        }
180    }
181
182    pub fn into_inner(self) -> BR {
183        self.bit_read
184    }
185}
186
187impl<E: Endianness, BR: BitRead<E>, const PRINT: bool> BitRead<E> for CountBitReader<E, BR, PRINT> {
188    type Error = <BR as BitRead<E>>::Error;
189    type PeekWord = BR::PeekWord;
190
191    fn read_bits(&mut self, n_bits: usize) -> Result<u64, Self::Error> {
192        self.bit_read.read_bits(n_bits).inspect(|x| {
193            let _ = x;
194            self.bits_read += n_bits;
195            if PRINT {
196                #[cfg(feature = "std")]
197                eprintln!(
198                    "read_bits({}) = {:#016x} (total = {})",
199                    n_bits, x, self.bits_read
200                );
201            }
202        })
203    }
204
205    fn read_unary(&mut self) -> Result<u64, Self::Error> {
206        self.bit_read.read_unary().inspect(|x| {
207            self.bits_read += *x as usize + 1;
208            if PRINT {
209                #[cfg(feature = "std")]
210                eprintln!("read_unary() = {} (total = {})", x, self.bits_read);
211            }
212        })
213    }
214
215    fn peek_bits(&mut self, n_bits: usize) -> Result<Self::PeekWord, Self::Error> {
216        self.bit_read.peek_bits(n_bits)
217    }
218
219    fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
220        self.bits_read += n_bits;
221        if PRINT {
222            #[cfg(feature = "std")]
223            eprintln!("skip_bits({}) (total = {})", n_bits, self.bits_read);
224        }
225        self.bit_read.skip_bits(n_bits)
226    }
227
228    fn skip_bits_after_peek(&mut self, n: usize) {
229        self.bit_read.skip_bits_after_peek(n)
230    }
231}
232
233impl<E: Endianness, BR: BitRead<E> + GammaRead<E>, const PRINT: bool> GammaRead<E>
234    for CountBitReader<E, BR, PRINT>
235{
236    fn read_gamma(&mut self) -> Result<u64, BR::Error> {
237        self.bit_read.read_gamma().inspect(|x| {
238            self.bits_read += len_gamma(*x);
239            if PRINT {
240                #[cfg(feature = "std")]
241                eprintln!("read_gamma() = {} (total = {})", x, self.bits_read);
242            }
243        })
244    }
245}
246
247impl<E: Endianness, BR: BitRead<E> + DeltaRead<E>, const PRINT: bool> DeltaRead<E>
248    for CountBitReader<E, BR, PRINT>
249{
250    fn read_delta(&mut self) -> Result<u64, BR::Error> {
251        self.bit_read.read_delta().inspect(|x| {
252            self.bits_read += len_delta(*x);
253            if PRINT {
254                #[cfg(feature = "std")]
255                eprintln!("read_delta() = {} (total = {})", x, self.bits_read);
256            }
257        })
258    }
259}
260
261impl<E: Endianness, BR: BitRead<E> + ZetaRead<E>, const PRINT: bool> ZetaRead<E>
262    for CountBitReader<E, BR, PRINT>
263{
264    fn read_zeta(&mut self, k: usize) -> Result<u64, BR::Error> {
265        self.bit_read.read_zeta(k).inspect(|x| {
266            self.bits_read += len_zeta(*x, k);
267            if PRINT {
268                #[cfg(feature = "std")]
269                eprintln!("read_zeta({}) = {} (total = {})", k, x, self.bits_read);
270            }
271        })
272    }
273
274    fn read_zeta3(&mut self) -> Result<u64, BR::Error> {
275        self.bit_read.read_zeta3().inspect(|x| {
276            self.bits_read += len_zeta(*x, 3);
277            if PRINT {
278                #[cfg(feature = "std")]
279                eprintln!("read_zeta3() = {} (total = {})", x, self.bits_read);
280            }
281        })
282    }
283}
284
285impl<E: Endianness, BR: BitRead<E> + BitSeek, const PRINT: bool> BitSeek
286    for CountBitReader<E, BR, PRINT>
287{
288    type Error = <BR as BitSeek>::Error;
289
290    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
291        self.bit_read.bit_pos()
292    }
293
294    fn set_bit_pos(&mut self, bit_pos: u64) -> Result<(), Self::Error> {
295        self.bit_read.set_bit_pos(bit_pos)
296    }
297}
298
299#[cfg(test)]
300#[cfg(feature = "std")]
301#[test]
302fn test_count() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
303    use crate::prelude::*;
304    let mut buffer = <Vec<u64>>::new();
305    let bit_write = <BufBitWriter<LE, _>>::new(MemWordWriterVec::new(&mut buffer));
306    let mut count_bit_write = CountBitWriter::<_, _, true>::new(bit_write);
307
308    count_bit_write.write_unary(5)?;
309    assert_eq!(count_bit_write.bits_written, 6);
310    count_bit_write.write_unary(100)?;
311    assert_eq!(count_bit_write.bits_written, 107);
312    count_bit_write.write_bits(1, 20)?;
313    assert_eq!(count_bit_write.bits_written, 127);
314    count_bit_write.write_bits(1, 33)?;
315    assert_eq!(count_bit_write.bits_written, 160);
316    count_bit_write.write_gamma(2)?;
317    assert_eq!(count_bit_write.bits_written, 163);
318    count_bit_write.write_delta(1)?;
319    assert_eq!(count_bit_write.bits_written, 167);
320    count_bit_write.write_zeta(0, 4)?;
321    assert_eq!(count_bit_write.bits_written, 171);
322    count_bit_write.write_zeta3(0)?;
323    assert_eq!(count_bit_write.bits_written, 174);
324    count_bit_write.flush()?;
325    drop(count_bit_write);
326
327    let bit_read = <BufBitReader<LE, _>>::new(MemWordReader::<u64, _>::new(&buffer));
328    let mut count_bit_read = CountBitReader::<_, _, true>::new(bit_read);
329
330    assert_eq!(count_bit_read.peek_bits(5)?, 0);
331    assert_eq!(count_bit_read.read_unary()?, 5);
332    assert_eq!(count_bit_read.bits_read, 6);
333    assert_eq!(count_bit_read.read_unary()?, 100);
334    assert_eq!(count_bit_read.bits_read, 107);
335    assert_eq!(count_bit_read.read_bits(20)?, 1);
336    assert_eq!(count_bit_read.bits_read, 127);
337    count_bit_read.skip_bits(33)?;
338    assert_eq!(count_bit_read.bits_read, 160);
339    assert_eq!(count_bit_read.read_gamma()?, 2);
340    assert_eq!(count_bit_read.bits_read, 163);
341    assert_eq!(count_bit_read.read_delta()?, 1);
342    assert_eq!(count_bit_read.bits_read, 167);
343    assert_eq!(count_bit_read.read_zeta(4)?, 0);
344    assert_eq!(count_bit_read.bits_read, 171);
345    assert_eq!(count_bit_read.read_zeta3()?, 0);
346    assert_eq!(count_bit_read.bits_read, 174);
347
348    Ok(())
349}