dsi_bitstream/utils/
count.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Sebastiano Vigna
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
5 */
6
7use crate::{
8    prelude::{
9        len_delta, len_gamma, len_zeta, DeltaRead, DeltaWrite, GammaRead, GammaWrite, ZetaRead,
10        ZetaWrite,
11    },
12    traits::*,
13};
14#[cfg(feature = "mem_dbg")]
15use mem_dbg::{MemDbg, MemSize};
16
17/// A wrapper around a [`BitWrite`] that keeps track of the number of
18/// bits written and optionally prints on standard error the operations performed on the stream.
19#[derive(Debug, Clone)]
20#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
21pub struct CountBitWriter<E: Endianness, BW: BitWrite<E>, const PRINT: bool = false> {
22    bit_write: BW,
23    /// The number of bits written so far on the underlying [`BitWrite`].
24    pub bits_written: usize,
25    _marker: core::marker::PhantomData<E>,
26}
27
28impl<E: Endianness, BW: BitWrite<E>, const PRINT: bool> CountBitWriter<E, BW, PRINT> {
29    pub fn new(bit_write: BW) -> Self {
30        Self {
31            bit_write,
32            bits_written: 0,
33            _marker: core::marker::PhantomData,
34        }
35    }
36
37    pub fn into_inner(self) -> BW {
38        self.bit_write
39    }
40}
41
42impl<E: Endianness, BW: BitWrite<E>, const PRINT: bool> BitWrite<E>
43    for CountBitWriter<E, BW, PRINT>
44{
45    type Error = <BW as BitWrite<E>>::Error;
46
47    fn write_bits(&mut self, value: u64, n_bits: usize) -> Result<usize, Self::Error> {
48        self.bit_write.write_bits(value, n_bits).inspect(|x| {
49            self.bits_written += *x;
50            if PRINT {
51                eprintln!(
52                    "write_bits({:#016x}, {}) = {} (total = {})",
53                    value, n_bits, x, self.bits_written
54                );
55            }
56        })
57    }
58
59    fn write_unary(&mut self, value: u64) -> Result<usize, Self::Error> {
60        self.bit_write.write_unary(value).inspect(|x| {
61            self.bits_written += *x;
62            if PRINT {
63                eprintln!(
64                    "write_unary({}) = {} (total = {})",
65                    value, x, self.bits_written
66                );
67            }
68        })
69    }
70
71    fn flush(&mut self) -> Result<usize, Self::Error> {
72        self.bit_write.flush().inspect(|x| {
73            self.bits_written += *x;
74            if PRINT {
75                eprintln!("flush() = {} (total = {})", x, self.bits_written);
76            }
77        })
78    }
79}
80
81impl<E: Endianness, BW: BitWrite<E> + GammaWrite<E>, const PRINT: bool> GammaWrite<E>
82    for CountBitWriter<E, BW, PRINT>
83{
84    fn write_gamma(&mut self, value: u64) -> Result<usize, BW::Error> {
85        self.bit_write.write_gamma(value).inspect(|x| {
86            self.bits_written += *x;
87            if PRINT {
88                eprintln!(
89                    "write_gamma({}) = {} (total = {})",
90                    value, x, self.bits_written
91                );
92            }
93        })
94    }
95}
96
97impl<E: Endianness, BW: BitWrite<E> + DeltaWrite<E>, const PRINT: bool> DeltaWrite<E>
98    for CountBitWriter<E, BW, PRINT>
99{
100    fn write_delta(&mut self, value: u64) -> Result<usize, BW::Error> {
101        self.bit_write.write_delta(value).inspect(|x| {
102            self.bits_written += *x;
103            if PRINT {
104                eprintln!(
105                    "write_delta({}) = {} (total = {})",
106                    value, x, self.bits_written
107                );
108            }
109        })
110    }
111}
112
113impl<E: Endianness, BW: BitWrite<E> + ZetaWrite<E>, const PRINT: bool> ZetaWrite<E>
114    for CountBitWriter<E, BW, PRINT>
115{
116    fn write_zeta(&mut self, value: u64, k: usize) -> Result<usize, BW::Error> {
117        self.bit_write.write_zeta(value, k).inspect(|x| {
118            self.bits_written += *x;
119            if PRINT {
120                eprintln!(
121                    "write_zeta({}, {}) = {} (total = {})",
122                    value, x, k, self.bits_written
123                );
124            }
125        })
126    }
127
128    fn write_zeta3(&mut self, value: u64) -> Result<usize, BW::Error> {
129        self.bit_write.write_zeta3(value).inspect(|x| {
130            self.bits_written += *x;
131            if PRINT {
132                eprintln!(
133                    "write_zeta({}) = {} (total = {})",
134                    value, x, self.bits_written
135                );
136            }
137        })
138    }
139}
140
141impl<E: Endianness, BR: BitWrite<E> + BitSeek, const PRINT: bool> BitSeek
142    for CountBitWriter<E, BR, PRINT>
143{
144    type Error = <BR as BitSeek>::Error;
145
146    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
147        self.bit_write.bit_pos()
148    }
149
150    fn set_bit_pos(&mut self, bit_pos: u64) -> Result<(), Self::Error> {
151        self.bit_write.set_bit_pos(bit_pos)
152    }
153}
154
155/// A wrapper around a [`BitRead`] that keeps track of the number of
156/// bits read and optionally prints on standard error the operations performed on the stream.
157#[derive(Debug, Clone)]
158#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
159pub struct CountBitReader<E: Endianness, BR: BitRead<E>, const PRINT: bool = false> {
160    bit_read: BR,
161    /// The number of bits read (or skipped) so far from the underlying [`BitRead`].
162    pub bits_read: usize,
163    _marker: core::marker::PhantomData<E>,
164}
165
166impl<E: Endianness, BR: BitRead<E>, const PRINT: bool> CountBitReader<E, BR, PRINT> {
167    pub fn new(bit_read: BR) -> Self {
168        Self {
169            bit_read,
170            bits_read: 0,
171            _marker: core::marker::PhantomData,
172        }
173    }
174
175    pub fn into_inner(self) -> BR {
176        self.bit_read
177    }
178}
179
180impl<E: Endianness, BR: BitRead<E>, const PRINT: bool> BitRead<E> for CountBitReader<E, BR, PRINT> {
181    type Error = <BR as BitRead<E>>::Error;
182    type PeekWord = BR::PeekWord;
183
184    fn read_bits(&mut self, n_bits: usize) -> Result<u64, Self::Error> {
185        self.bit_read.read_bits(n_bits).inspect(|x| {
186            self.bits_read += n_bits;
187            if PRINT {
188                eprintln!(
189                    "read_bits({}) = {:#016x} (total = {})",
190                    n_bits, x, self.bits_read
191                );
192            }
193        })
194    }
195
196    fn read_unary(&mut self) -> Result<u64, Self::Error> {
197        self.bit_read.read_unary().inspect(|x| {
198            self.bits_read += *x as usize + 1;
199            if PRINT {
200                eprintln!("read_unary() = {} (total = {})", x, self.bits_read);
201            }
202        })
203    }
204
205    fn peek_bits(&mut self, n_bits: usize) -> Result<Self::PeekWord, Self::Error> {
206        self.bit_read.peek_bits(n_bits)
207    }
208
209    fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
210        self.bits_read += n_bits;
211        if PRINT {
212            eprintln!("skip_bits({}) (total = {})", n_bits, self.bits_read);
213        }
214        self.bit_read.skip_bits(n_bits)
215    }
216
217    fn skip_bits_after_peek(&mut self, n: usize) {
218        self.bit_read.skip_bits_after_peek(n)
219    }
220}
221
222impl<E: Endianness, BR: BitRead<E> + GammaRead<E>, const PRINT: bool> GammaRead<E>
223    for CountBitReader<E, BR, PRINT>
224{
225    fn read_gamma(&mut self) -> Result<u64, BR::Error> {
226        self.bit_read.read_gamma().inspect(|x| {
227            self.bits_read += len_gamma(*x);
228            if PRINT {
229                eprintln!("read_gamma() = {} (total = {})", x, self.bits_read);
230            }
231        })
232    }
233}
234
235impl<E: Endianness, BR: BitRead<E> + DeltaRead<E>, const PRINT: bool> DeltaRead<E>
236    for CountBitReader<E, BR, PRINT>
237{
238    fn read_delta(&mut self) -> Result<u64, BR::Error> {
239        self.bit_read.read_delta().inspect(|x| {
240            self.bits_read += len_delta(*x);
241            if PRINT {
242                eprintln!("read_delta() = {} (total = {})", x, self.bits_read);
243            }
244        })
245    }
246}
247
248impl<E: Endianness, BR: BitRead<E> + ZetaRead<E>, const PRINT: bool> ZetaRead<E>
249    for CountBitReader<E, BR, PRINT>
250{
251    fn read_zeta(&mut self, k: usize) -> Result<u64, BR::Error> {
252        self.bit_read.read_zeta(k).inspect(|x| {
253            self.bits_read += len_zeta(*x, k);
254            if PRINT {
255                eprintln!("read_zeta({}) = {} (total = {})", k, x, self.bits_read);
256            }
257        })
258    }
259
260    fn read_zeta3(&mut self) -> Result<u64, BR::Error> {
261        self.bit_read.read_zeta3().inspect(|x| {
262            self.bits_read += len_zeta(*x, 3);
263            if PRINT {
264                eprintln!("read_zeta3() = {} (total = {})", x, self.bits_read);
265            }
266        })
267    }
268}
269
270impl<E: Endianness, BR: BitRead<E> + BitSeek, const PRINT: bool> BitSeek
271    for CountBitReader<E, BR, PRINT>
272{
273    type Error = <BR as BitSeek>::Error;
274
275    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
276        self.bit_read.bit_pos()
277    }
278
279    fn set_bit_pos(&mut self, bit_pos: u64) -> Result<(), Self::Error> {
280        self.bit_read.set_bit_pos(bit_pos)
281    }
282}
283
284#[cfg(test)]
285#[test]
286fn test_count() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
287    use crate::prelude::*;
288    let mut buffer = <Vec<u64>>::new();
289    let bit_write = <BufBitWriter<LE, _>>::new(MemWordWriterVec::new(&mut buffer));
290    let mut count_bit_write = CountBitWriter::<_, _, true>::new(bit_write);
291
292    count_bit_write.write_unary(5)?;
293    assert_eq!(count_bit_write.bits_written, 6);
294    count_bit_write.write_unary(100)?;
295    assert_eq!(count_bit_write.bits_written, 107);
296    count_bit_write.write_bits(1, 20)?;
297    assert_eq!(count_bit_write.bits_written, 127);
298    count_bit_write.write_bits(1, 33)?;
299    assert_eq!(count_bit_write.bits_written, 160);
300    count_bit_write.write_gamma(2)?;
301    assert_eq!(count_bit_write.bits_written, 163);
302    count_bit_write.write_delta(1)?;
303    assert_eq!(count_bit_write.bits_written, 167);
304    count_bit_write.write_zeta(0, 4)?;
305    assert_eq!(count_bit_write.bits_written, 171);
306    count_bit_write.write_zeta3(0)?;
307    assert_eq!(count_bit_write.bits_written, 174);
308    count_bit_write.flush()?;
309    drop(count_bit_write);
310
311    let bit_read = <BufBitReader<LE, _>>::new(MemWordReader::<u64, _>::new(&buffer));
312    let mut count_bit_read = CountBitReader::<_, _, true>::new(bit_read);
313
314    assert_eq!(count_bit_read.peek_bits(5)?, 0);
315    assert_eq!(count_bit_read.read_unary()?, 5);
316    assert_eq!(count_bit_read.bits_read, 6);
317    assert_eq!(count_bit_read.read_unary()?, 100);
318    assert_eq!(count_bit_read.bits_read, 107);
319    assert_eq!(count_bit_read.read_bits(20)?, 1);
320    assert_eq!(count_bit_read.bits_read, 127);
321    count_bit_read.skip_bits(33)?;
322    assert_eq!(count_bit_read.bits_read, 160);
323    assert_eq!(count_bit_read.read_gamma()?, 2);
324    assert_eq!(count_bit_read.bits_read, 163);
325    assert_eq!(count_bit_read.read_delta()?, 1);
326    assert_eq!(count_bit_read.bits_read, 167);
327    assert_eq!(count_bit_read.read_zeta(4)?, 0);
328    assert_eq!(count_bit_read.bits_read, 171);
329    assert_eq!(count_bit_read.read_zeta3()?, 0);
330    assert_eq!(count_bit_read.bits_read, 174);
331
332    Ok(())
333}