1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
//! Types and functions to work with individual bits.
use std::io::{self, Read, Write};
/// Buffered bit-level writer that wraps an [`io::Write`] byte stream.
pub struct WriteBits<W> {
writer: W,
buffer: u128,
buffer_len: u8,
}
impl<W> WriteBits<W> {
/// Creates a new `WriteBits` wrapping the given writer.
pub fn new(write: W) -> Self {
WriteBits {
writer: write,
buffer: 0,
buffer_len: 0,
}
}
}
impl<W> WriteBits<W>
where
W: Write,
{
/// Write one bit to the stream.
/// Returns `Ok(())` if the bit was written successfully.
/// Returns `Err` if the writer returned an error on write or if the writer is exhausted.
pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
loop {
if self.buffer_len == 128 {
match self.flush() {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write bit",
));
}
Ok(_) => continue,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
let bit = if bit { 1 } else { 0 };
self.buffer |= (bit as u128) << self.buffer_len;
self.buffer_len += 1;
return Ok(());
}
}
/// Write bits from the slice.
/// `bit_offset` specifies the bit offset in the buffer.
/// `bit_len` specifies number of bits to write.
///
/// Returns number of bits written.
/// It would be the `bit_len`.
/// Unless writer is exhausted.
pub fn write_bits(
&mut self,
buffer: &[u8],
bit_offset: usize,
bit_len: usize,
) -> io::Result<usize> {
let buffer_free = 128 - self.buffer_len;
if usize::from(buffer_free) < bit_len {
// If input doesn't fit in the buffer, consider flushing it right now.
if modulo_8(usize::from(self.buffer_len)) + bit_len <= 128 {
// If after flushing we can fit whole input - flush now.
self.flush()?;
} else if usize::from(buffer_free) < (8 - modulo_8(bit_offset)) {
// If we can't fit even bits from first input byte, flush now.
self.flush()?;
}
}
let mut buffer = buffer;
let mut bit_offset = bit_offset;
let mut bit_len = bit_len;
let mut total_bits_written = 0;
loop {
let (new_buffer, new_bit_offset, new_bit_len) =
self.copy_from_buffer(buffer, bit_offset, bit_len);
total_bits_written += bit_len - new_bit_len;
buffer = new_buffer;
bit_offset = new_bit_offset;
bit_len = new_bit_len;
if bit_len == 0 {
// All bits written.
return Ok(total_bits_written);
}
debug_assert!(self.buffer_len >= 120);
if self.flush()? == 0 {
// Flush stalled, can't write more just yet.
return Ok(total_bits_written);
}
}
}
/// Writes exactly `bit_len` bits from `buffer` starting at `bit_offset`.
///
/// Unlike [`write_bits`](Self::write_bits), this retries on partial writes
/// and returns an error if the writer cannot accept all bits.
pub fn write_all_bits(
&mut self,
mut buffer: &[u8],
mut bit_offset: usize,
mut bit_len: usize,
) -> io::Result<()> {
loop {
match self.write_bits(buffer, bit_offset, bit_len) {
Ok(written) if bit_len == written => {
return Ok(());
}
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write all bits",
));
}
Ok(written) => {
buffer = &buffer[(bit_offset + written) / 8..];
bit_offset = modulo_8(bit_offset + written);
bit_len -= written;
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
}
fn flush(&mut self) -> io::Result<usize> {
if self.buffer_len >= 8 {
let write_bytes = self.buffer_len / 8;
let bytes_written = loop {
let r = self
.writer
.write(&self.buffer.to_le_bytes()[..write_bytes as usize]);
match r {
Ok(0) => return Ok(0),
Ok(n) => break n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
};
self.buffer_len -= (bytes_written * 8) as u8;
if self.buffer_len > 0 {
self.buffer >>= bytes_written * 8;
} else {
self.buffer = 0;
}
Ok(bytes_written)
} else {
Ok(0)
}
}
/// Flushes any remaining buffered bits to the underlying writer, padding
/// the last byte with zeros if necessary.
pub fn finish(&mut self) -> io::Result<()> {
if self.buffer_len > 0 {
let write_bytes = self.buffer_len.div_ceil(8);
self.writer
.write_all(&self.buffer.to_le_bytes()[..write_bytes as usize])?;
self.buffer_len = 0;
}
Ok(())
}
fn copy_from_buffer<'a>(
&mut self,
mut buffer: &'a [u8],
mut bit_offset: usize,
mut bit_len: usize,
) -> (&'a [u8], usize, usize) {
let mut buffer_free = 128 - self.buffer_len;
if buffer_free == 0 {
return (buffer, bit_offset, bit_len);
}
// Skip whole bytes in the input buffer.
// This ensures that bit_offset is less than 8.
if bit_offset >= 8 {
let byte_offset = bit_offset / 8;
buffer = &buffer[byte_offset..];
bit_offset = modulo_8(bit_offset);
}
if buffer_free > 0 && bit_len > 0 && bit_offset > 0 {
let copy_len = (buffer_free as usize).min(8 - bit_offset).min(bit_len);
let mut copy_bits = buffer[0];
copy_bits >>= bit_offset;
copy_bits &= (1 << copy_len) - 1;
self.buffer |= (copy_bits as u128) << self.buffer_len;
self.buffer_len += copy_len as u8;
buffer_free -= copy_len as u8;
if bit_offset + copy_len >= 8 {
debug_assert_eq!(bit_offset + copy_len, 8);
bit_offset = 0;
bit_len -= copy_len;
buffer = &buffer[1..];
} else {
bit_offset += copy_len;
bit_len -= copy_len;
return (buffer, bit_offset, bit_len);
}
}
if buffer_free > 0 && bit_len > 0 {
debug_assert_eq!(bit_offset, 0);
let copy_len: usize = if bit_len > buffer_free as usize {
// If we can't fit whole input, copy only whole bytes.
// This is required to ensure that `io::Write::write` is able to return exact number of bytes written.
round_down_8(buffer_free as usize)
} else {
bit_len
};
if copy_len > 0 {
let mut copy_bytes = [0; 16];
let copy_bytes_len = copy_len.div_ceil(8);
copy_bytes[..copy_bytes_len].copy_from_slice(&buffer[..copy_bytes_len]);
let mut copy_bits = u128::from_le_bytes(copy_bytes);
if copy_len < 128 {
copy_bits &= (1u128 << copy_len) - 1;
}
self.buffer |= copy_bits << self.buffer_len;
self.buffer_len += copy_len as u8;
bit_offset += copy_len;
bit_len -= copy_len;
}
}
(buffer, bit_offset, bit_len)
}
}
impl<W> Write for WriteBits<W>
where
W: Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_bits(buf, 0, buf.len() * 8).map(|n| {
debug_assert_eq!(modulo_8(n), 0);
n / 8
})
}
fn flush(&mut self) -> io::Result<()> {
self.flush()?;
self.writer.flush()
}
}
/// Buffered bit-level reader that wraps an [`io::Read`] byte stream.
pub struct ReadBits<R> {
reader: R,
buffer: u128,
buffer_len: u8,
}
impl<R> ReadBits<R> {
/// Creates a new `ReadBits` wrapping the given reader.
pub fn new(reader: R) -> Self {
ReadBits {
reader,
buffer: 0,
buffer_len: 0,
}
}
}
impl<R> ReadBits<R>
where
R: Read,
{
/// Read one bit from the stream.
///
/// Returns `Ok(bit)` if the bit was read successfully.
/// Returns `Err` if the reader returned an error on read or if the reader is exhausted.
pub fn read_bit(&mut self) -> io::Result<bool> {
loop {
if self.buffer_len == 0 {
match self.fill_buffer(1) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to read bit",
));
}
Ok(_) => continue,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
let value = (self.buffer & 1) != 0;
self.buffer >>= 1;
self.buffer_len -= 1;
return Ok(value);
}
}
/// Read bits into the buffer.
/// `bit_offset` specifies the bit offset in the buffer to read into.
/// All bits before `bit_offset` will be preserved.
/// `bit_len` specifies number of bits to read.
///
/// Returns number of bits read.
/// It would be the `bit_len`.
/// Unless reader is exhausted.
///
/// # Panics
///
/// The function will panic if `buffer` doesn't fit bits in range `bit_offset..bit_offset+bit_len`.
/// This means that `buffer.len()` must be equal or greater than `(bit_offset + bit_len + 7) / 8`.
///
/// Function can also panic if internal reader panics on read.
pub fn read_bits(
&mut self,
buffer: &mut [u8],
bit_offset: usize,
bit_len: usize,
) -> io::Result<usize> {
assert!(buffer.len() >= (bit_offset + bit_len).div_ceil(8));
if bit_len == 0 {
return Ok(0);
}
let mut total_bits_read = 0;
let mut buffer = buffer;
let mut bit_offset = bit_offset;
let mut bit_len = bit_len;
loop {
let (new_buffer, new_bit_offset, new_bit_len) =
self.copy_from_buffer(buffer, bit_offset, bit_len);
total_bits_read += bit_len - new_bit_len;
buffer = new_buffer;
bit_offset = new_bit_offset;
bit_len = new_bit_len;
if bit_len == 0 {
return Ok(total_bits_read);
}
if self.fill_buffer(bit_len)? == 0 {
return Ok(total_bits_read);
}
}
}
/// Reads exactly `bit_len` bits into `buffer` starting at `bit_offset`.
///
/// Unlike [`read_bits`](Self::read_bits), this retries on partial reads
/// and returns an error if the reader is exhausted before all bits are read.
pub fn read_all_bits(
&mut self,
mut buffer: &mut [u8],
mut bit_offset: usize,
mut bit_len: usize,
) -> io::Result<()> {
loop {
match self.read_bits(buffer, bit_offset, bit_len) {
Ok(read) if bit_len == read => {
return Ok(());
}
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to read all bits",
));
}
Ok(read) => {
buffer = &mut buffer[(bit_offset + read) / 8..];
bit_offset = modulo_8(bit_offset + read);
bit_len -= read;
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
}
// Fill internal buffer with bytes from the reader.
// Attempts to grow internal buffer to at least `bit_len` bits.
// Must never be called if internal buffer already has at least `bit_len` bits.
// May fill less if reader is exhausted or if buffer capacity is reached.
// Always fills with whole bytes from internal reader.
fn fill_buffer(&mut self, bit_len: usize) -> io::Result<usize> {
debug_assert!(bit_len > usize::from(self.buffer_len));
// Figure out how many bytes is needed to fill the buffer to `bit_len` bits.
let desired_byte_len = (bit_len - usize::from(self.buffer_len)).div_ceil(8);
// Figure out how many bytes we can actually read without overflowing the buffer.
let max_byte_len = (128 - self.buffer_len) / 8;
// Use minimum of those two to avoid overflow and read not more than needed.
let byte_len = desired_byte_len.min(usize::from(max_byte_len));
let mut buffer = [0; 16];
let bytes_read = loop {
let r = self.reader.read(&mut buffer[..byte_len]);
match r {
Ok(0) => return Ok(0),
Ok(n) => break n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
};
self.buffer |= u128::from_le_bytes(buffer) << self.buffer_len;
self.buffer_len += bytes_read as u8 * 8;
Ok(bytes_read)
}
fn copy_from_buffer<'a>(
&mut self,
mut buffer: &'a mut [u8],
mut bit_offset: usize,
mut bit_len: usize,
) -> (&'a mut [u8], usize, usize) {
if bit_offset >= 8 {
let byte_offset = bit_offset / 8;
buffer = &mut buffer[byte_offset..];
bit_offset = modulo_8(bit_offset);
}
if self.buffer_len > 0 && bit_len > 0 && bit_offset > 0 {
let mut copy_bits = self.buffer.to_le_bytes()[0];
let copy_len = (self.buffer_len as usize).min(8 - bit_offset).min(bit_len);
copy_bits &= (1 << copy_len) - 1;
copy_bits <<= bit_offset;
buffer[0] |= copy_bits;
self.buffer >>= copy_len;
self.buffer_len -= copy_len as u8;
if bit_offset + copy_len >= 8 {
debug_assert_eq!(bit_offset + copy_len, 8);
bit_offset = 0;
bit_len -= copy_len;
buffer = &mut buffer[1..];
} else {
bit_offset += copy_len;
bit_len -= copy_len;
return (buffer, bit_offset, bit_len);
}
}
if self.buffer_len > 0 && bit_len > 0 {
debug_assert_eq!(bit_offset, 0);
let copy_len = if bit_len > self.buffer_len as usize {
// If we can't fill whole buffer, copy only whole bytes.
// This is required to ensure that `io::Read::read` is able to return exact number of bytes read.
round_down_8(self.buffer_len as usize)
} else {
bit_len
};
if copy_len > 0 {
let mut copy_bits = self.buffer;
if copy_len < 128 {
copy_bits &= (1u128 << copy_len) - 1;
}
let copy_bytes = copy_bits.to_le_bytes();
let copy_bytes_len = copy_len.div_ceil(8);
buffer[..copy_bytes_len].copy_from_slice(©_bytes[..copy_bytes_len]);
self.buffer >>= copy_len;
self.buffer_len -= copy_len as u8;
bit_offset += copy_len;
bit_len -= copy_len;
}
}
(buffer, bit_offset, bit_len)
}
}
impl<R> Read for ReadBits<R>
where
R: Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.read_bits(buf, 0, buf.len() * 8).map(|n| {
debug_assert_eq!(modulo_8(n), 0);
n / 8
})
}
}
/// Creates a [`WriteBits`] wrapper, passes it to `f`, then flushes and
/// returns the result.
pub fn write_bits_scope<W, O>(
write: W,
f: impl FnOnce(&mut WriteBits<W>) -> io::Result<O>,
) -> io::Result<O>
where
W: io::Write,
{
let mut write = WriteBits::new(write);
let result = f(&mut write);
write.finish()?;
result
}
/// Creates a [`ReadBits`] wrapper and passes it to `f`.
pub fn read_bits_scope<R, O>(
read: R,
f: impl FnOnce(&mut ReadBits<R>) -> io::Result<O>,
) -> io::Result<O>
where
R: io::Read,
{
let mut read = ReadBits::new(read);
f(&mut read)
}
#[test]
fn test_writer() {
let writes = [
(&[1u8, 2, 3, 4][..], 27),
(&[5], 3),
(&[6, 7], 16),
(&[8, 9, 10], 22),
(&[11, 12, 13, 14], 30),
];
let mut buffer = Vec::new();
let mut write = WriteBits::new(&mut buffer);
for (data, bit_len) in writes.iter() {
write.write_bits(data, 0, *bit_len).unwrap();
}
write.finish().unwrap();
assert_eq!(
buffer[..],
[0x01, 0x02, 0x03, 0xAC, 0xC1, 0x01, 0x42, 0x82, 0xB2, 0xC0, 0xD0, 0xE0, 0]
);
}
#[test]
fn test_reader() {
let reads = [
(&[1u8, 2, 3, 4][..], 27),
(&[5], 3),
(&[6, 7], 16),
(&[8, 9, 10], 22),
(&[11, 12, 13, 14], 30),
];
let buffer = vec![
0x01, 0x02, 0x03, 0xAC, 0xC1, 0x01, 0x42, 0x82, 0xB2, 0xC0, 0xD0, 0xE0, 0,
];
let mut read = ReadBits::new(&buffer[..]);
for &(data, bit_len) in reads.iter() {
let mut buffer = [0; 4];
read.read_bits(&mut buffer, 0, bit_len).unwrap();
assert_eq!(buffer[..data.len()], data[..]);
}
}
fn round_down_8(n: usize) -> usize {
n & !7
}
fn modulo_8(n: usize) -> usize {
n & 7
}
#[test]
fn test_test() {
let writes = [(0, &[0][..], &[0]), (1, &[1], &[0]), (2, &[0], &[0x1F])];
let mut buffer = Vec::new();
let mut write = WriteBits::new(&mut buffer);
for (bit_len, index, value) in writes {
write.write_bits(index, 0, bit_len).unwrap();
write.write_bits(value, 0, 8).unwrap();
}
write.finish().unwrap();
drop(write);
let mut read = ReadBits::new(&buffer[..]);
for (bit_len, index, value) in writes {
let mut buffer = [0u8];
read.read_bits(&mut buffer, 0, bit_len).unwrap();
assert_eq!(buffer[0], index[0]);
read.read_bits(&mut buffer, 0, 8).unwrap();
assert_eq!(buffer[0], value[0]);
}
}