1use crate::{
8 prelude::{
9 DeltaRead, DeltaWrite, GammaRead, GammaWrite, OmegaRead, OmegaWrite, PiRead, PiWrite,
10 ZetaRead, ZetaWrite, len_delta, len_gamma, len_omega, len_pi, len_zeta,
11 },
12 traits::*,
13};
14#[cfg(feature = "mem_dbg")]
15use mem_dbg::{MemDbg, MemSize};
16
17#[derive(Debug, Clone)]
21#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
22pub struct CountBitWriter<E: Endianness, BW: BitWrite<E>, const PRINT: bool = false> {
23 bit_write: BW,
24 pub bits_written: usize,
26 _marker: core::marker::PhantomData<E>,
27}
28
29impl<E: Endianness, BW: BitWrite<E>, const PRINT: bool> CountBitWriter<E, BW, PRINT> {
30 #[must_use]
31 pub const fn new(bit_write: BW) -> Self {
32 Self {
33 bit_write,
34 bits_written: 0,
35 _marker: core::marker::PhantomData,
36 }
37 }
38
39 #[must_use]
41 pub fn into_inner(self) -> BW {
42 self.bit_write
43 }
44}
45
46impl<E: Endianness, BW: BitWrite<E>, const PRINT: bool> BitWrite<E>
47 for CountBitWriter<E, BW, PRINT>
48{
49 type Error = <BW as BitWrite<E>>::Error;
50
51 fn write_bits(&mut self, value: u64, num_bits: usize) -> Result<usize, Self::Error> {
52 self.bit_write.write_bits(value, num_bits).inspect(|x| {
53 self.bits_written += *x;
54 if PRINT {
55 #[cfg(feature = "std")]
56 eprintln!(
57 "write_bits({:#016x}, {}) = {} (total = {})",
58 value, num_bits, x, self.bits_written
59 );
60 }
61 })
62 }
63
64 fn write_unary(&mut self, n: u64) -> Result<usize, Self::Error> {
65 self.bit_write.write_unary(n).inspect(|x| {
66 self.bits_written += *x;
67 if PRINT {
68 #[cfg(feature = "std")]
69 eprintln!("write_unary({}) = {} (total = {})", n, x, self.bits_written);
70 }
71 })
72 }
73
74 fn flush(&mut self) -> Result<usize, Self::Error> {
75 self.bit_write.flush().inspect(|x| {
76 self.bits_written += *x;
77 if PRINT {
78 #[cfg(feature = "std")]
79 eprintln!("flush() = {} (total = {})", x, self.bits_written);
80 }
81 })
82 }
83}
84
85impl<E: Endianness, BW: BitWrite<E> + GammaWrite<E>, const PRINT: bool> GammaWrite<E>
86 for CountBitWriter<E, BW, PRINT>
87{
88 fn write_gamma(&mut self, n: u64) -> Result<usize, BW::Error> {
89 self.bit_write.write_gamma(n).inspect(|x| {
90 self.bits_written += *x;
91 if PRINT {
92 #[cfg(feature = "std")]
93 eprintln!("write_gamma({}) = {} (total = {})", n, x, self.bits_written);
94 }
95 })
96 }
97}
98
99impl<E: Endianness, BW: BitWrite<E> + DeltaWrite<E>, const PRINT: bool> DeltaWrite<E>
100 for CountBitWriter<E, BW, PRINT>
101{
102 fn write_delta(&mut self, n: u64) -> Result<usize, BW::Error> {
103 self.bit_write.write_delta(n).inspect(|x| {
104 self.bits_written += *x;
105 if PRINT {
106 #[cfg(feature = "std")]
107 eprintln!("write_delta({}) = {} (total = {})", n, x, self.bits_written);
108 }
109 })
110 }
111}
112
113impl<E: Endianness, BW: BitWrite<E> + ZetaWrite<E>, const PRINT: bool> ZetaWrite<E>
114 for CountBitWriter<E, BW, PRINT>
115{
116 fn write_zeta(&mut self, n: u64, k: usize) -> Result<usize, BW::Error> {
117 self.bit_write.write_zeta(n, k).inspect(|x| {
118 self.bits_written += *x;
119 if PRINT {
120 #[cfg(feature = "std")]
121 eprintln!(
122 "write_zeta({}, {}) = {} (total = {})",
123 n, k, x, self.bits_written
124 );
125 }
126 })
127 }
128
129 fn write_zeta3(&mut self, n: u64) -> Result<usize, BW::Error> {
130 self.bit_write.write_zeta3(n).inspect(|x| {
131 self.bits_written += *x;
132 if PRINT {
133 #[cfg(feature = "std")]
134 eprintln!("write_zeta3({}) = {} (total = {})", n, x, self.bits_written);
135 }
136 })
137 }
138}
139
140impl<E: Endianness, BW: BitWrite<E> + OmegaWrite<E>, const PRINT: bool> OmegaWrite<E>
141 for CountBitWriter<E, BW, PRINT>
142{
143 fn write_omega(&mut self, n: u64) -> Result<usize, BW::Error> {
144 self.bit_write.write_omega(n).inspect(|x| {
145 self.bits_written += *x;
146 if PRINT {
147 #[cfg(feature = "std")]
148 eprintln!("write_omega({}) = {} (total = {})", n, x, self.bits_written);
149 }
150 })
151 }
152}
153
154impl<E: Endianness, BW: BitWrite<E> + PiWrite<E>, const PRINT: bool> PiWrite<E>
155 for CountBitWriter<E, BW, PRINT>
156{
157 fn write_pi(&mut self, n: u64, k: usize) -> Result<usize, BW::Error> {
158 self.bit_write.write_pi(n, k).inspect(|x| {
159 self.bits_written += *x;
160 if PRINT {
161 #[cfg(feature = "std")]
162 eprintln!(
163 "write_pi({}, {}) = {} (total = {})",
164 n, k, x, self.bits_written
165 );
166 }
167 })
168 }
169
170 fn write_pi2(&mut self, n: u64) -> Result<usize, BW::Error> {
171 self.bit_write.write_pi2(n).inspect(|x| {
172 self.bits_written += *x;
173 if PRINT {
174 #[cfg(feature = "std")]
175 eprintln!("write_pi2({}) = {} (total = {})", n, x, self.bits_written);
176 }
177 })
178 }
179}
180
181impl<E: Endianness, BW: BitWrite<E> + BitSeek, const PRINT: bool> BitSeek
182 for CountBitWriter<E, BW, PRINT>
183{
184 type Error = <BW as BitSeek>::Error;
185
186 fn bit_pos(&mut self) -> Result<u64, Self::Error> {
187 self.bit_write.bit_pos()
188 }
189
190 fn set_bit_pos(&mut self, bit_pos: u64) -> Result<(), Self::Error> {
191 self.bit_write.set_bit_pos(bit_pos)
192 }
193}
194
195#[derive(Debug, Clone)]
199#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
200pub struct CountBitReader<E: Endianness, BR: BitRead<E>, const PRINT: bool = false> {
201 bit_read: BR,
202 pub bits_read: usize,
204 _marker: core::marker::PhantomData<E>,
205}
206
207impl<E: Endianness, BR: BitRead<E>, const PRINT: bool> CountBitReader<E, BR, PRINT> {
208 #[must_use]
209 pub const fn new(bit_read: BR) -> Self {
210 Self {
211 bit_read,
212 bits_read: 0,
213 _marker: core::marker::PhantomData,
214 }
215 }
216
217 #[must_use]
219 pub fn into_inner(self) -> BR {
220 self.bit_read
221 }
222}
223
224impl<E: Endianness, BR: BitRead<E>, const PRINT: bool> BitRead<E> for CountBitReader<E, BR, PRINT> {
225 type Error = <BR as BitRead<E>>::Error;
226 type PeekWord = BR::PeekWord;
227 const PEEK_BITS: usize = BR::PEEK_BITS;
228
229 fn read_bits(&mut self, num_bits: usize) -> Result<u64, Self::Error> {
230 self.bit_read.read_bits(num_bits).inspect(|x| {
231 let _ = x;
232 self.bits_read += num_bits;
233 if PRINT {
234 #[cfg(feature = "std")]
235 eprintln!(
236 "read_bits({}) = {:#016x} (total = {})",
237 num_bits, x, self.bits_read
238 );
239 }
240 })
241 }
242
243 fn read_unary(&mut self) -> Result<u64, Self::Error> {
244 self.bit_read.read_unary().inspect(|x| {
245 self.bits_read += *x as usize + 1;
246 if PRINT {
247 #[cfg(feature = "std")]
248 eprintln!("read_unary() = {} (total = {})", x, self.bits_read);
249 }
250 })
251 }
252
253 fn peek_bits(&mut self, n_bits: usize) -> Result<Self::PeekWord, Self::Error> {
254 self.bit_read.peek_bits(n_bits)
255 }
256
257 fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
258 self.bits_read += n_bits;
259 if PRINT {
260 #[cfg(feature = "std")]
261 eprintln!("skip_bits({}) (total = {})", n_bits, self.bits_read);
262 }
263 self.bit_read.skip_bits(n_bits)
264 }
265
266 fn skip_bits_after_peek(&mut self, n: usize) {
267 self.bit_read.skip_bits_after_peek(n)
268 }
269}
270
271impl<E: Endianness, BR: BitRead<E> + GammaRead<E>, const PRINT: bool> GammaRead<E>
272 for CountBitReader<E, BR, PRINT>
273{
274 fn read_gamma(&mut self) -> Result<u64, BR::Error> {
275 self.bit_read.read_gamma().inspect(|x| {
276 self.bits_read += len_gamma(*x);
277 if PRINT {
278 #[cfg(feature = "std")]
279 eprintln!("read_gamma() = {} (total = {})", x, self.bits_read);
280 }
281 })
282 }
283}
284
285impl<E: Endianness, BR: BitRead<E> + DeltaRead<E>, const PRINT: bool> DeltaRead<E>
286 for CountBitReader<E, BR, PRINT>
287{
288 fn read_delta(&mut self) -> Result<u64, BR::Error> {
289 self.bit_read.read_delta().inspect(|x| {
290 self.bits_read += len_delta(*x);
291 if PRINT {
292 #[cfg(feature = "std")]
293 eprintln!("read_delta() = {} (total = {})", x, self.bits_read);
294 }
295 })
296 }
297}
298
299impl<E: Endianness, BR: BitRead<E> + ZetaRead<E>, const PRINT: bool> ZetaRead<E>
300 for CountBitReader<E, BR, PRINT>
301{
302 fn read_zeta(&mut self, k: usize) -> Result<u64, BR::Error> {
303 self.bit_read.read_zeta(k).inspect(|x| {
304 self.bits_read += len_zeta(*x, k);
305 if PRINT {
306 #[cfg(feature = "std")]
307 eprintln!("read_zeta({}) = {} (total = {})", k, x, self.bits_read);
308 }
309 })
310 }
311
312 fn read_zeta3(&mut self) -> Result<u64, BR::Error> {
313 self.bit_read.read_zeta3().inspect(|x| {
314 self.bits_read += len_zeta(*x, 3);
315 if PRINT {
316 #[cfg(feature = "std")]
317 eprintln!("read_zeta3() = {} (total = {})", x, self.bits_read);
318 }
319 })
320 }
321}
322
323impl<E: Endianness, BR: BitRead<E> + OmegaRead<E>, const PRINT: bool> OmegaRead<E>
324 for CountBitReader<E, BR, PRINT>
325{
326 fn read_omega(&mut self) -> Result<u64, BR::Error> {
327 self.bit_read.read_omega().inspect(|x| {
328 self.bits_read += len_omega(*x);
329 if PRINT {
330 #[cfg(feature = "std")]
331 eprintln!("read_omega() = {} (total = {})", x, self.bits_read);
332 }
333 })
334 }
335}
336
337impl<E: Endianness, BR: BitRead<E> + PiRead<E>, const PRINT: bool> PiRead<E>
338 for CountBitReader<E, BR, PRINT>
339{
340 fn read_pi(&mut self, k: usize) -> Result<u64, BR::Error> {
341 self.bit_read.read_pi(k).inspect(|x| {
342 self.bits_read += len_pi(*x, k);
343 if PRINT {
344 #[cfg(feature = "std")]
345 eprintln!("read_pi({}) = {} (total = {})", k, x, self.bits_read);
346 }
347 })
348 }
349
350 fn read_pi2(&mut self) -> Result<u64, BR::Error> {
351 self.bit_read.read_pi2().inspect(|x| {
352 self.bits_read += len_pi(*x, 2);
353 if PRINT {
354 #[cfg(feature = "std")]
355 eprintln!("read_pi2() = {} (total = {})", x, self.bits_read);
356 }
357 })
358 }
359}
360
361impl<E: Endianness, BR: BitRead<E> + BitSeek, const PRINT: bool> BitSeek
362 for CountBitReader<E, BR, PRINT>
363{
364 type Error = <BR as BitSeek>::Error;
365
366 fn bit_pos(&mut self) -> Result<u64, Self::Error> {
367 self.bit_read.bit_pos()
368 }
369
370 fn set_bit_pos(&mut self, bit_pos: u64) -> Result<(), Self::Error> {
371 self.bit_read.set_bit_pos(bit_pos)
372 }
373}
374
375#[cfg(test)]
376#[cfg(feature = "std")]
377mod tests {
378 use super::*;
379 use crate::prelude::*;
380
381 #[test]
382 fn test_count() -> Result<(), Box<dyn core::error::Error + Send + Sync + 'static>> {
383 let mut buffer = <Vec<u64>>::new();
384 let bit_write = <BufBitWriter<LE, _>>::new(MemWordWriterVec::new(&mut buffer));
385 let mut count_bit_write = CountBitWriter::<_, _, true>::new(bit_write);
386
387 count_bit_write.write_unary(5)?;
388 assert_eq!(count_bit_write.bits_written, 6);
389 count_bit_write.write_unary(100)?;
390 assert_eq!(count_bit_write.bits_written, 107);
391 count_bit_write.write_bits(1, 20)?;
392 assert_eq!(count_bit_write.bits_written, 127);
393 count_bit_write.write_bits(1, 33)?;
394 assert_eq!(count_bit_write.bits_written, 160);
395 count_bit_write.write_gamma(2)?;
396 assert_eq!(count_bit_write.bits_written, 163);
397 count_bit_write.write_delta(1)?;
398 assert_eq!(count_bit_write.bits_written, 167);
399 count_bit_write.write_zeta(0, 4)?;
400 assert_eq!(count_bit_write.bits_written, 171);
401 count_bit_write.write_zeta3(0)?;
402 assert_eq!(count_bit_write.bits_written, 174);
403 count_bit_write.write_omega(3)?;
404 assert_eq!(count_bit_write.bits_written, 174 + len_omega(3));
405 let after_omega = count_bit_write.bits_written;
406 count_bit_write.write_pi(5, 3)?;
407 assert_eq!(count_bit_write.bits_written, after_omega + len_pi(5, 3));
408 let after_pi = count_bit_write.bits_written;
409 count_bit_write.write_pi2(7)?;
410 assert_eq!(count_bit_write.bits_written, after_pi + len_pi(7, 2));
411 let after_pi2 = count_bit_write.bits_written;
412 count_bit_write.flush()?;
413 drop(count_bit_write);
414
415 let bit_read = <BufBitReader<LE, _>>::new(MemWordReader::<u64, _>::new(&buffer));
416 let mut count_bit_read = CountBitReader::<_, _, true>::new(bit_read);
417
418 assert_eq!(count_bit_read.peek_bits(5)?, 0);
419 assert_eq!(count_bit_read.read_unary()?, 5);
420 assert_eq!(count_bit_read.bits_read, 6);
421 assert_eq!(count_bit_read.read_unary()?, 100);
422 assert_eq!(count_bit_read.bits_read, 107);
423 assert_eq!(count_bit_read.read_bits(20)?, 1);
424 assert_eq!(count_bit_read.bits_read, 127);
425 count_bit_read.skip_bits(33)?;
426 assert_eq!(count_bit_read.bits_read, 160);
427 assert_eq!(count_bit_read.read_gamma()?, 2);
428 assert_eq!(count_bit_read.bits_read, 163);
429 assert_eq!(count_bit_read.read_delta()?, 1);
430 assert_eq!(count_bit_read.bits_read, 167);
431 assert_eq!(count_bit_read.read_zeta(4)?, 0);
432 assert_eq!(count_bit_read.bits_read, 171);
433 assert_eq!(count_bit_read.read_zeta3()?, 0);
434 assert_eq!(count_bit_read.bits_read, 174);
435 assert_eq!(count_bit_read.read_omega()?, 3);
436 assert_eq!(count_bit_read.bits_read, after_omega);
437 assert_eq!(count_bit_read.read_pi(3)?, 5);
438 assert_eq!(count_bit_read.bits_read, after_pi);
439 assert_eq!(count_bit_read.read_pi2()?, 7);
440 assert_eq!(count_bit_read.bits_read, after_pi2);
441
442 Ok(())
443 }
444}