bitbuffer/
writebuffer.rs

1use crate::Endianness;
2use std::cmp::min;
3use std::marker::PhantomData;
4use std::ops::{Index, IndexMut, Range};
5
6enum WriteData<'a> {
7    Vec(&'a mut Vec<u8>),
8    Slice { data: &'a mut [u8], length: usize },
9}
10
11impl WriteData<'_> {
12    fn pop(&mut self) -> Option<u8> {
13        match self {
14            WriteData::Vec(vec) => vec.pop(),
15            WriteData::Slice { data, length } if *length > 0 => {
16                *length -= 1;
17                Some(data[*length])
18            }
19            _ => None,
20        }
21    }
22
23    fn extend_from_slice(&mut self, other: &[u8]) {
24        match self {
25            WriteData::Vec(vec) => vec.extend_from_slice(other),
26            WriteData::Slice { data, length } => {
27                let end = *length + other.len();
28                let target = &mut data[*length..end];
29                target.copy_from_slice(other);
30                *length += other.len();
31            }
32        }
33    }
34
35    fn push(&mut self, byte: u8) {
36        match self {
37            WriteData::Vec(vec) => vec.push(byte),
38            WriteData::Slice { data, length } => {
39                data[*length] = byte;
40                *length += 1;
41            }
42        }
43    }
44
45    fn last_mut(&mut self) -> Option<&mut u8> {
46        match self {
47            WriteData::Vec(vec) => vec.last_mut(),
48            WriteData::Slice { data, length } if *length > 0 => Some(&mut data[*length - 1]),
49            _ => None,
50        }
51    }
52}
53
54impl Index<usize> for WriteData<'_> {
55    type Output = u8;
56
57    fn index(&self, index: usize) -> &Self::Output {
58        match self {
59            WriteData::Vec(vec) => &vec[index],
60            WriteData::Slice { data, .. } => &data[index],
61        }
62    }
63}
64
65impl IndexMut<usize> for WriteData<'_> {
66    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
67        match self {
68            WriteData::Vec(vec) => &mut vec[index],
69            WriteData::Slice { data, .. } => &mut data[index],
70        }
71    }
72}
73
74impl Index<Range<usize>> for WriteData<'_> {
75    type Output = [u8];
76
77    fn index(&self, index: Range<usize>) -> &Self::Output {
78        match self {
79            WriteData::Vec(vec) => &vec[index],
80            WriteData::Slice { data, .. } => &data[index],
81        }
82    }
83}
84
85impl IndexMut<Range<usize>> for WriteData<'_> {
86    fn index_mut(&mut self, index: Range<usize>) -> &mut Self::Output {
87        match self {
88            WriteData::Vec(vec) => &mut vec[index],
89            WriteData::Slice { data, .. } => &mut data[index],
90        }
91    }
92}
93
94pub struct WriteBuffer<'a, E: Endianness> {
95    bit_len: usize,
96    bytes: WriteData<'a>,
97    endianness: PhantomData<E>,
98}
99
100impl<'a, E: Endianness> WriteBuffer<'a, E> {
101    pub fn new(bytes: &'a mut Vec<u8>, _endianness: E) -> Self {
102        WriteBuffer {
103            bit_len: 0,
104            bytes: WriteData::Vec(bytes),
105            endianness: PhantomData,
106        }
107    }
108    pub fn for_slice(bytes: &'a mut [u8], _endianness: E) -> Self {
109        WriteBuffer {
110            bit_len: 0,
111            bytes: WriteData::Slice {
112                data: bytes,
113                length: 0,
114            },
115            endianness: PhantomData,
116        }
117    }
118
119    /// The number of written bits in the buffer
120    pub fn bit_len(&self) -> usize {
121        self.bit_len
122    }
123
124    pub fn push_non_fit_bits<I>(&mut self, bits: I, count: usize)
125    where
126        I: ExactSizeIterator,
127        I: DoubleEndedIterator<Item = (usize, u8)>,
128    {
129        let mut remaining = count;
130        for (chunk, chunk_size) in bits {
131            if remaining > 0 {
132                let bits = min(remaining, chunk_size as usize);
133                self.push_bits(chunk, bits);
134                remaining -= bits
135            }
136        }
137    }
138
139    /// Push up to an usize worth of bits
140    pub fn push_bits(&mut self, bits: usize, count: usize) {
141        if count == 0 {
142            return;
143        }
144
145        // ensure there are no stray bits
146        let bits = bits & (usize::MAX >> (usize::BITS as usize - count));
147
148        let bit_offset = self.bit_len & 7;
149
150        debug_assert!(count <= usize::BITS as usize - bit_offset);
151
152        let last_written_byte = if bit_offset > 0 {
153            self.bytes.pop().unwrap_or(0)
154        } else {
155            0
156        };
157        let merged_byte_count = (count + bit_offset + 7) / 8;
158
159        if E::is_le() {
160            let merged = last_written_byte as usize | (bits << bit_offset);
161            self.bytes
162                .extend_from_slice(&merged.to_le_bytes()[0..merged_byte_count]);
163        } else {
164            let merged = ((last_written_byte as usize) << (usize::BITS as usize - 8))
165                | (bits << (usize::BITS as usize - bit_offset - count));
166            self.bytes
167                .extend_from_slice(&merged.to_be_bytes()[0..merged_byte_count]);
168        }
169        self.bit_len += count;
170    }
171
172    pub fn set_at(&mut self, pos: usize, bits: u64, count: usize) {
173        debug_assert!(count < 64 - 8);
174
175        let bit_offset = pos & 7;
176        let byte_pos = pos / 8;
177        let byte_count = (count + bit_offset + 7) / 8;
178
179        let mut old = [0; 8];
180        old[0..byte_count].copy_from_slice(&self.bytes[byte_pos..byte_pos + byte_count]);
181
182        let old = u64::from_le_bytes(old);
183        let merged = old | (bits << bit_offset);
184        let merged = merged.to_le_bytes();
185        self.bytes[byte_pos..byte_pos + byte_count].copy_from_slice(&merged[0..byte_count]);
186    }
187
188    pub fn extends_from_slice(&mut self, slice: &[u8]) {
189        debug_assert_eq!(0, self.bit_len & 7);
190        self.bytes.extend_from_slice(slice);
191        self.bit_len += slice.len() * 8
192    }
193
194    pub fn push_bool(&mut self, val: bool) {
195        let val = val as u8;
196        let bit_offset = self.bit_len() % 8;
197        let shift = if E::is_le() {
198            bit_offset
199        } else {
200            7 - bit_offset
201        };
202        if bit_offset == 0 {
203            self.bytes.push(val << shift);
204        } else {
205            *self.bytes.last_mut().unwrap() |= val << shift;
206        }
207        self.bit_len += 1;
208    }
209}