1use crate::Endianness;
2use std::cmp::min;
3use std::marker::PhantomData;
4use std::ops::{Index, IndexMut, Range};
5
6enum WriteData<'a> {
7 Vec(&'a mut Vec<u8>),
8 Slice { data: &'a mut [u8], length: usize },
9}
10
11impl WriteData<'_> {
12 fn pop(&mut self) -> Option<u8> {
13 match self {
14 WriteData::Vec(vec) => vec.pop(),
15 WriteData::Slice { data, length } if *length > 0 => {
16 *length -= 1;
17 Some(data[*length])
18 }
19 _ => None,
20 }
21 }
22
23 fn extend_from_slice(&mut self, other: &[u8]) {
24 match self {
25 WriteData::Vec(vec) => vec.extend_from_slice(other),
26 WriteData::Slice { data, length } => {
27 let end = *length + other.len();
28 let target = &mut data[*length..end];
29 target.copy_from_slice(other);
30 *length += other.len();
31 }
32 }
33 }
34
35 fn push(&mut self, byte: u8) {
36 match self {
37 WriteData::Vec(vec) => vec.push(byte),
38 WriteData::Slice { data, length } => {
39 data[*length] = byte;
40 *length += 1;
41 }
42 }
43 }
44
45 fn last_mut(&mut self) -> Option<&mut u8> {
46 match self {
47 WriteData::Vec(vec) => vec.last_mut(),
48 WriteData::Slice { data, length } if *length > 0 => Some(&mut data[*length - 1]),
49 _ => None,
50 }
51 }
52}
53
54impl Index<usize> for WriteData<'_> {
55 type Output = u8;
56
57 fn index(&self, index: usize) -> &Self::Output {
58 match self {
59 WriteData::Vec(vec) => &vec[index],
60 WriteData::Slice { data, .. } => &data[index],
61 }
62 }
63}
64
65impl IndexMut<usize> for WriteData<'_> {
66 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
67 match self {
68 WriteData::Vec(vec) => &mut vec[index],
69 WriteData::Slice { data, .. } => &mut data[index],
70 }
71 }
72}
73
74impl Index<Range<usize>> for WriteData<'_> {
75 type Output = [u8];
76
77 fn index(&self, index: Range<usize>) -> &Self::Output {
78 match self {
79 WriteData::Vec(vec) => &vec[index],
80 WriteData::Slice { data, .. } => &data[index],
81 }
82 }
83}
84
85impl IndexMut<Range<usize>> for WriteData<'_> {
86 fn index_mut(&mut self, index: Range<usize>) -> &mut Self::Output {
87 match self {
88 WriteData::Vec(vec) => &mut vec[index],
89 WriteData::Slice { data, .. } => &mut data[index],
90 }
91 }
92}
93
94pub struct WriteBuffer<'a, E: Endianness> {
95 bit_len: usize,
96 bytes: WriteData<'a>,
97 endianness: PhantomData<E>,
98}
99
100impl<'a, E: Endianness> WriteBuffer<'a, E> {
101 pub fn new(bytes: &'a mut Vec<u8>, _endianness: E) -> Self {
102 WriteBuffer {
103 bit_len: 0,
104 bytes: WriteData::Vec(bytes),
105 endianness: PhantomData,
106 }
107 }
108 pub fn for_slice(bytes: &'a mut [u8], _endianness: E) -> Self {
109 WriteBuffer {
110 bit_len: 0,
111 bytes: WriteData::Slice {
112 data: bytes,
113 length: 0,
114 },
115 endianness: PhantomData,
116 }
117 }
118
119 pub fn bit_len(&self) -> usize {
121 self.bit_len
122 }
123
124 pub fn push_non_fit_bits<I>(&mut self, bits: I, count: usize)
125 where
126 I: ExactSizeIterator,
127 I: DoubleEndedIterator<Item = (usize, u8)>,
128 {
129 let mut remaining = count;
130 for (chunk, chunk_size) in bits {
131 if remaining > 0 {
132 let bits = min(remaining, chunk_size as usize);
133 self.push_bits(chunk, bits);
134 remaining -= bits
135 }
136 }
137 }
138
139 pub fn push_bits(&mut self, bits: usize, count: usize) {
141 if count == 0 {
142 return;
143 }
144
145 let bits = bits & (usize::MAX >> (usize::BITS as usize - count));
147
148 let bit_offset = self.bit_len & 7;
149
150 debug_assert!(count <= usize::BITS as usize - bit_offset);
151
152 let last_written_byte = if bit_offset > 0 {
153 self.bytes.pop().unwrap_or(0)
154 } else {
155 0
156 };
157 let merged_byte_count = (count + bit_offset + 7) / 8;
158
159 if E::is_le() {
160 let merged = last_written_byte as usize | (bits << bit_offset);
161 self.bytes
162 .extend_from_slice(&merged.to_le_bytes()[0..merged_byte_count]);
163 } else {
164 let merged = ((last_written_byte as usize) << (usize::BITS as usize - 8))
165 | (bits << (usize::BITS as usize - bit_offset - count));
166 self.bytes
167 .extend_from_slice(&merged.to_be_bytes()[0..merged_byte_count]);
168 }
169 self.bit_len += count;
170 }
171
172 pub fn set_at(&mut self, pos: usize, bits: u64, count: usize) {
173 debug_assert!(count < 64 - 8);
174
175 let bit_offset = pos & 7;
176 let byte_pos = pos / 8;
177 let byte_count = (count + bit_offset + 7) / 8;
178
179 let mut old = [0; 8];
180 old[0..byte_count].copy_from_slice(&self.bytes[byte_pos..byte_pos + byte_count]);
181
182 let old = u64::from_le_bytes(old);
183 let merged = old | (bits << bit_offset);
184 let merged = merged.to_le_bytes();
185 self.bytes[byte_pos..byte_pos + byte_count].copy_from_slice(&merged[0..byte_count]);
186 }
187
188 pub fn extends_from_slice(&mut self, slice: &[u8]) {
189 debug_assert_eq!(0, self.bit_len & 7);
190 self.bytes.extend_from_slice(slice);
191 self.bit_len += slice.len() * 8
192 }
193
194 pub fn push_bool(&mut self, val: bool) {
195 let val = val as u8;
196 let bit_offset = self.bit_len() % 8;
197 let shift = if E::is_le() {
198 bit_offset
199 } else {
200 7 - bit_offset
201 };
202 if bit_offset == 0 {
203 self.bytes.push(val << shift);
204 } else {
205 *self.bytes.last_mut().unwrap() |= val << shift;
206 }
207 self.bit_len += 1;
208 }
209}