1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
use core::convert::TryInto;
/// Zstandard encodes some types of data in a way that the data must be read
/// back to front to decode it properly. `BitReaderReversed` provides a
/// convenient interface to do that.
pub struct BitReaderReversed<'s> {
/// The index of the last read byte in the source.
index: usize,
/// How many bits have been consumed from `bit_container`.
bits_consumed: u8,
/// How many bits have been consumed past the end of the input. Will be zero until all the input
/// has been read.
extra_bits: usize,
/// The source data to read from.
source: &'s [u8],
/// The reader doesn't read directly from the source, it reads bits from here, and the container
/// is "refilled" as it's emptied.
bit_container: u64,
}
impl<'s> BitReaderReversed<'s> {
/// How many bits are left to read by the reader.
pub fn bits_remaining(&self) -> isize {
self.index as isize * 8 + (64 - self.bits_consumed as isize) - self.extra_bits as isize
}
pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
BitReaderReversed {
index: source.len(),
bits_consumed: 64,
source,
bit_container: 0,
extra_bits: 0,
}
}
/// We refill the container in full bytes, shifting the still unread portion to the left, and filling the lower bits with new data
#[cold]
fn refill(&mut self) {
let bytes_consumed = self.bits_consumed as usize / 8;
if bytes_consumed == 0 {
return;
}
if self.index >= bytes_consumed {
// We can safely move the window contained in `bit_container` down by `bytes_consumed`
// If the reader wasn't byte aligned, the byte that was partially read is now in the highest order bits in the `bit_container`
self.index -= bytes_consumed;
// Some bits of the `bits_container` might have been consumed already because we read the window byte aligned
self.bits_consumed &= 7;
self.bit_container =
u64::from_le_bytes((&self.source[self.index..][..8]).try_into().unwrap());
} else if self.index > 0 {
// Read the last portion of source into the `bit_container`
if self.source.len() >= 8 {
self.bit_container = u64::from_le_bytes((&self.source[..8]).try_into().unwrap());
} else {
let mut value = [0; 8];
value[..self.source.len()].copy_from_slice(self.source);
self.bit_container = u64::from_le_bytes(value);
}
self.bits_consumed -= 8 * self.index as u8;
self.index = 0;
self.bit_container <<= self.bits_consumed;
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
} else if self.bits_consumed < 64 {
// Shift out already used bits and fill up with zeroes
self.bit_container <<= self.bits_consumed;
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
} else {
// All useful bits have already been read and more than 64 bits have been consumed, all we now do is return zeroes
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
self.bit_container = 0;
}
// Assert that at least `56 = 64 - 8` bits are available to read.
debug_assert!(self.bits_consumed < 8);
}
/// Read `n` number of bits from the source. Will read at most 56 bits.
/// If there are no more bits to be read from the source zero bits will be returned instead.
#[inline(always)]
pub fn get_bits(&mut self, n: u8) -> u64 {
if self.bits_consumed + n > 64 {
self.refill();
}
let value = self.peek_bits(n);
self.consume(n);
value
}
/// Get the next `n` bits from the source without consuming them.
/// Caller is responsible for making sure that `n` many bits have been refilled.
#[inline(always)]
pub fn peek_bits(&mut self, n: u8) -> u64 {
if n == 0 {
return 0;
}
let mask = (1u64 << n) - 1u64;
let shift_by = 64 - self.bits_consumed - n;
(self.bit_container >> shift_by) & mask
}
/// Get the next `n1` `n2` and `n3` bits from the source without consuming them.
/// Caller is responsible for making sure that `sum` many bits have been refilled.
#[inline(always)]
pub fn peek_bits_triple(&mut self, sum: u8, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
if sum == 0 {
return (0, 0, 0);
}
// all_three contains bits like this: |XXXX..XXX111122223333|
// Where XXX are already consumed bytes, 1/2/3 are bits of the respective value
// Lower bits are to the right
let all_three = self.bit_container >> (64 - self.bits_consumed - sum);
let mask1 = (1u64 << n1) - 1u64;
let shift_by1 = n3 + n2;
let val1 = (all_three >> shift_by1) & mask1;
let mask2 = (1u64 << n2) - 1u64;
let shift_by2 = n3;
let val2 = (all_three >> shift_by2) & mask2;
let mask3 = (1u64 << n3) - 1u64;
let val3 = all_three & mask3;
(val1, val2, val3)
}
/// Consume `n` bits from the source.
#[inline(always)]
pub fn consume(&mut self, n: u8) {
self.bits_consumed += n;
debug_assert!(self.bits_consumed <= 64);
}
/// Same as calling get_bits three times but slightly more performant
#[inline(always)]
pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
let sum = n1 + n2 + n3;
if sum <= 56 {
self.refill();
let triple = self.peek_bits_triple(sum, n1, n2, n3);
self.consume(sum);
return triple;
}
(self.get_bits(n1), self.get_bits(n2), self.get_bits(n3))
}
}
#[cfg(test)]
mod test {
#[test]
fn it_works() {
let data = [0b10101010, 0b01010101];
let mut br = super::BitReaderReversed::new(&data);
assert_eq!(br.get_bits(1), 0);
assert_eq!(br.get_bits(1), 1);
assert_eq!(br.get_bits(1), 0);
assert_eq!(br.get_bits(4), 0b1010);
assert_eq!(br.get_bits(4), 0b1101);
assert_eq!(br.get_bits(4), 0b0101);
// Last 0 from source, three zeroes filled in
assert_eq!(br.get_bits(4), 0b0000);
// All zeroes filled in
assert_eq!(br.get_bits(4), 0b0000);
assert_eq!(br.bits_remaining(), -7);
}
}