1use std::ops::{Deref, DerefMut};
2
3use crate::prelude::*;
4
5use bytes::{Bytes, BytesMut};
6
7use super::util::bytes_needed;
8
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct BitsMut {
11 pub(crate) inner: BytesMut,
12 pub(crate) bit_start: usize,
14 pub(crate) bit_len: usize,
16 pub(crate) capacity: usize,
18}
19
20impl BitsMut {
21 pub fn new() -> Self {
24 BitsMut::with_capacity(0)
25 }
26
27 pub fn from_bytes_mut(bytes_mut: BytesMut) -> Self {
28 let capacity = bytes_mut.capacity() * 8;
29 let bit_len = bytes_mut.len() * 8;
30 Self {
31 inner: bytes_mut,
32 bit_start: 0,
33 bit_len,
34 capacity,
35 }
36 }
37
38 pub fn with_capacity(capacity: usize) -> Self {
44 let byte_capacity = bytes_needed(capacity);
45 Self {
46 inner: BytesMut::with_capacity(byte_capacity),
47 bit_start: 0,
48 bit_len: 0,
49 capacity,
50 }
51 }
52
53 pub fn with_capacity_bytes(capacity: usize) -> Self {
59 Self::with_capacity(capacity * 8)
60 }
61
62 pub fn zeroed_bits(len: usize) -> Self {
67 let num_bytes = bytes_needed(len);
68 Self {
69 inner: BytesMut::zeroed(num_bytes),
70 bit_start: 0,
71 bit_len: len,
72 capacity: len,
73 }
74 }
75
76 pub fn zeroed_bytes(len: usize) -> Self {
81 Self::zeroed_bits(len * 8)
82 }
83
84 pub fn freeze(self) -> Bits {
89 Bits {
90 inner: self.inner.freeze(),
91 bit_start: self.bit_start,
92 bit_len: self.bit_len,
93 }
94 }
95
96 pub fn extend_from_bit_slice(&mut self, slice: &BitSlice) {
100 let count = slice.len();
101 self.reserve_bits(count);
102
103 let dest = self.spare_capacity_mut();
104 assert!(dest.len() >= count);
105 dest[..count].copy_from_bitslice(slice);
106
107 self.advance_mut_bits(count);
108 }
109
110 pub fn spare_capacity_mut(&mut self) -> &mut BitSlice {
118 let bit_start = self.bit_start + self.bit_len;
124
125 let spare_uninit = self.inner.spare_capacity_mut();
127
128 let (ptr, len) = if bit_start % 8 == 0 {
132 (spare_uninit.as_mut_ptr() as *mut u8, spare_uninit.len())
133 } else {
134 let ptr = unsafe { spare_uninit.as_mut_ptr().offset(-1) as *mut u8 };
135 (ptr, spare_uninit.len() + 1)
137 };
138
139 let spare_bytes: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
140
141 &mut BitSlice::from_slice_mut(spare_bytes)[bit_start % 8..]
143 }
144
145 pub fn set_len_bits(&mut self, len: usize) {
150 self.bit_len = len;
151 unsafe { self.inner.set_len(bytes_needed(len)) };
152 }
153
154 pub fn reserve_bits(&mut self, additional: usize) {
157 let len = self.len_bits();
158 let remainder = self.capacity - len;
159
160 if additional <= remainder {
161 return;
162 }
163 let bytes_needed = bytes_needed(additional);
164 self.inner.reserve(bytes_needed);
165 self.capacity = self.inner.capacity() * 8;
166 }
167
168 pub fn reserve_bytes(&mut self, additional: usize) {
171 self.reserve_bits(additional * 8);
172 }
173
174 pub fn split_to_bits(&mut self, at: usize) -> Self {
179 assert!(
180 at <= self.bit_len,
181 "split_to out of bounds: {:?} must be <= {:?}",
182 at,
183 self.bit_len
184 );
185
186 let mut other = self.clone();
187 self.advance_unchecked_bits(at);
188 other.capacity = at;
189 other.bit_len = at;
190 other
191 }
192
193 pub fn split_to_bytes(&mut self, at: usize) -> Self {
199 self.split_to_bits(at * 8)
200 }
201
202 pub fn split(&mut self) -> Self {
207 self.split_to_bits(self.bit_len)
208 }
209
210 pub fn split_off_bits(&mut self, at: usize) -> Self {
215 assert!(
216 at <= self.capacity,
217 "split_off out of bounds: {:?} must be <= {:?}",
218 at,
219 self.bit_len
220 );
221
222 let mut other = self.clone();
223 other.advance_unchecked_bits(at);
225 self.capacity = at;
226 self.bit_len = std::cmp::min(self.bit_len, at);
227
228 other
229 }
230
231 pub fn split_off_bytes(&mut self, at: usize) -> Self {
237 self.split_off_bits(at * 8)
238 }
239
240 pub fn len_bits(&self) -> usize {
242 self.bit_len
243 }
244
245 pub fn len_bytes(&self) -> usize {
249 self.bit_len / 8
250 }
251
252 pub fn is_empty(&self) -> bool {
254 self.bit_len == 0
255 }
256
257 fn advance_unchecked_bits(&mut self, count: usize) {
259 if count == 0 {
260 return;
261 }
262
263 self.bit_start += count;
264 self.bit_len = self.bit_len.saturating_sub(count);
265 self.capacity -= count;
266 }
267}
268
269impl Default for BitsMut {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275impl From<BitVec> for BitsMut {
276 fn from(bv: BitVec) -> Self {
277 let bit_len = bv.len();
283 let aligned: BitVec = bv.iter().by_vals().collect();
284 let bytes = aligned.into_vec();
285
286 Self {
287 inner: BytesMut::from(&bytes[..]),
288 bit_start: 0,
289 bit_len,
290 capacity: bytes.len() * 8,
291 }
292 }
293}
294
295impl From<&BitSlice> for BitsMut {
296 fn from(slice: &BitSlice) -> Self {
297 BitsMut::from(slice.to_bitvec())
298 }
299}
300
301impl From<Vec<u8>> for BitsMut {
302 fn from(vec: Vec<u8>) -> Self {
303 let bit_len = vec.len() * 8;
304 let inner = BytesMut::from(Bytes::from(vec));
307 let byte_capacity = inner.capacity();
308 Self {
309 inner,
310 bit_start: 0,
311 bit_len,
312 capacity: byte_capacity * 8,
313 }
314 }
315}
316
317impl Deref for BitsMut {
318 type Target = BitSlice;
319
320 fn deref(&self) -> &Self::Target {
321 &BitSlice::from_slice(&self.inner)[self.bit_start..self.bit_start + self.bit_len]
322 }
323}
324
325impl DerefMut for BitsMut {
326 fn deref_mut(&mut self) -> &mut Self::Target {
327 &mut BitSlice::from_slice_mut(&mut self.inner)
328 [self.bit_start..self.bit_start + self.bit_len]
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_split_to() {
338 let mut bits = BitsMut::from(bits![1, 1, 1, 1, 0, 0, 0, 0]);
339
340 let mut head = bits.split_to_bits(4);
341 head.set(0, false);
342 head.set(1, false);
343 assert_eq!(head[..], bits![0, 0, 1, 1]);
344
345 bits.set(0, true);
346 bits.set(1, true);
347 assert_eq!(bits[..], bits![1, 1, 0, 0]);
348 }
349
350 #[test]
351 fn test_split_to_bytes() {
352 #[rustfmt::skip]
353 let mut bits = BitsMut::from(vec![
354 0b1111_1111,
355 0b0000_0000,
356 0b1010_1010,
357 0b0101_0101
358 ]);
359
360 let mut head = bits.split_to_bytes(1);
361 assert_eq!(head.len_bits(), 8);
363 assert_eq!(bits.len_bits(), 24);
364 head.set(0, false);
365 head.set(1, false);
366 head.set(2, false);
367 head.set(3, false);
368
369 bits.set(0, true);
370 bits.set(1, true);
371 bits.set(2, true);
372 bits.set(3, true);
373 assert_eq!(head[..], bits![0, 0, 0, 0, 1, 1, 1, 1]);
374 assert_eq!(
375 bits[..],
376 bits![1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1]
377 );
378 let mut unaligned_split = bits.split_to_bits(12);
381 let mut unaligned_byte_split = unaligned_split.split_to_bytes(1);
383 assert_eq!(unaligned_byte_split.len_bits(), 8);
385 assert_eq!(unaligned_split.len_bits(), 4);
386
387 unaligned_byte_split.set(0, false);
388 unaligned_byte_split.set(1, false);
389 assert_eq!(unaligned_byte_split[..], bits![0, 0, 1, 1, 0, 0, 0, 0]);
390
391 unaligned_split.set(0, false);
392 unaligned_split.set(1, true);
393 assert_eq!(unaligned_split[..], bits![0, 1, 1, 0]);
394 }
395
396 #[test]
397 fn test_split_off() {
398 let mut bits = BitsMut::zeroed_bits(32);
399
400 let mut tail = bits.split_off_bits(12);
401 assert_eq!(bits.len_bits(), 12);
402 assert_eq!(tail.len_bits(), 20);
403 bits.set(0, true);
404 bits.set(1, true);
405 bits.set(2, true);
406 bits.set(3, true);
407 assert_eq!(bits[..], bits![1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]);
408
409 tail.set(0, true);
410 tail.set(1, true);
411 tail.set(2, true);
412 tail.set(3, true);
413 assert_eq!(
414 tail[..],
415 bits![1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
416 );
417 }
418
419 #[test]
420 fn test_spare_capacity_mut() {
421 let mut bits_mut = BitsMut::with_capacity(24);
422 let spare = bits_mut.spare_capacity_mut();
423 spare.set(0, true);
424 bits_mut.set_len_bits(1);
425
426 let spare = bits_mut.spare_capacity_mut();
427 spare.set(0, false);
428 spare.set(1, false);
429 spare.set(2, true);
430 bits_mut.set_len_bits(4);
431
432 assert_eq!(&bits_mut[..], bits![1, 0, 0, 1]);
433 }
434
435 #[test]
436 fn test_extend_from_slice() {
437 let mut bits_mut = BitsMut::new();
438 let data = bits![0, 1, 1, 0, 1, 1, 0];
439
440 bits_mut.extend_from_bit_slice(data);
441 assert_eq!(&bits_mut[..], data);
442 }
443}