1use crate::iter::BlockIter;
2use crate::range_compat::*;
3use crate::storage::{Address, BlockType};
4use crate::traits::{get_masked_block, BitSliceable, Bits, BitsMut};
5
6use core::marker::PhantomData;
7use core::{cmp, fmt, hash, ptr};
8
9#[derive(Copy, Clone, Debug)]
13struct SliceSpan {
14 offset: u8,
15 len: u64,
16 aligned_blocks: usize,
17}
18#[derive(Copy, Clone, Debug)]
24enum BlockAddress {
25 FullBlockAt(usize),
26 SomeBitsAt(Address, usize),
27}
28
29impl SliceSpan {
30 fn new<Block: BlockType>(offset: u8, bit_len: u64) -> Self {
31 SliceSpan {
32 offset,
33 len: bit_len,
34 aligned_blocks: if offset == 0 {
35 Block::ceil_div_nbits(bit_len)
36 } else {
37 0
38 },
39 }
40 }
41
42 fn from_block_len<Block: BlockType>(block_len: usize) -> Self {
43 Self::new::<Block>(0, Block::mul_nbits(block_len))
44 }
45
46 fn block_len<Block: BlockType>(&self) -> usize {
47 Block::ceil_div_nbits(self.len)
48 }
49
50 fn find_block<Block: BlockType>(&self, position: usize) -> Option<BlockAddress> {
51 if position < self.aligned_blocks {
52 Some(BlockAddress::FullBlockAt(position))
53 } else if position < self.block_len::<Block>() {
54 let start = Block::mul_nbits(position) + u64::from(self.offset);
55 let address = Address::new::<Block>(start);
56 let count = Block::block_bits(self.len, position);
57 Some(BlockAddress::SomeBitsAt(address, count))
58 } else {
59 None
60 }
61 }
62
63 fn find_bits<Block: BlockType>(&self, position: u64, count: usize) -> Option<BlockAddress> {
64 if position + (count as u64) <= self.len {
65 let address = Address::new::<Block>(position + u64::from(self.offset));
66 if count == Block::nbits() && address.bit_offset == 0 {
67 Some(BlockAddress::FullBlockAt(address.block_index))
68 } else {
69 Some(BlockAddress::SomeBitsAt(address, count))
70 }
71 } else {
72 None
73 }
74 }
75
76 fn find_bit<Block: BlockType>(&self, position: u64) -> Option<Address> {
77 if position < self.len {
78 Some(Address::new::<Block>(position + self.offset as u64))
79 } else {
80 None
81 }
82 }
83}
84
85impl BlockAddress {
86 unsafe fn read<Block: BlockType>(self, bits: *const Block) -> Block {
87 match self {
88 BlockAddress::FullBlockAt(position) => ptr::read(bits.add(position)),
89
90 BlockAddress::SomeBitsAt(address, count) => {
91 let offset = address.bit_offset;
92 let ptr1 = bits.add(address.block_index);
93 let block1 = ptr::read(ptr1);
94
95 let shift1 = offset;
99 let shift2 = Block::nbits() - shift1;
100
101 let bits_size1 = cmp::min(shift2, count);
103 let chunk1 = block1.get_bits(shift1, bits_size1);
104
105 let bits_size2 = count - bits_size1;
107 if bits_size2 == 0 {
108 return chunk1;
109 }
110
111 let block2 = ptr::read(ptr1.offset(1));
114 let chunk2 = block2.get_bits(0, bits_size2);
115 chunk1 | (chunk2 << shift2)
116 }
117 }
118 }
119
120 unsafe fn write<Block: BlockType>(self, bits: *mut Block, value: Block) {
121 match self {
122 BlockAddress::FullBlockAt(position) => ptr::write(bits.add(position), value),
123
124 BlockAddress::SomeBitsAt(address, count) => {
125 let offset = address.bit_offset;
126 let ptr1 = bits.add(address.block_index);
127
128 let shift1 = offset;
133 let shift2 = Block::nbits() - shift1;
134
135 let bits_size1 = cmp::min(count, shift2);
139 let old_block1 = ptr::read(ptr1);
140 let new_block1 = old_block1.with_bits(shift1, bits_size1, value);
141 ptr::write(ptr1, new_block1);
142
143 let bits_size2 = count - bits_size1;
145 if bits_size2 == 0 {
146 return;
147 }
148 let ptr2 = ptr1.offset(1);
149 let old_block2 = ptr::read(ptr2);
150 let new_block2 = old_block2.with_bits(0, bits_size2, value >> shift2);
151 ptr::write(ptr2, new_block2);
152 }
153 }
154 }
155}
156
157#[derive(Copy, Clone)]
174pub struct BitSlice<'a, Block> {
175 bits: *const Block,
176 span: SliceSpan,
177 marker: PhantomData<&'a ()>,
178}
179
180pub struct BitSliceMut<'a, Block> {
200 bits: *mut Block,
201 span: SliceSpan,
202 marker: PhantomData<&'a mut ()>,
203}
204
205impl<'a, Block: BlockType> BitSlice<'a, Block> {
206 pub fn from_slice(blocks: &'a [Block]) -> Self {
224 BitSlice {
225 bits: blocks.as_ptr(),
226 span: SliceSpan::from_block_len::<Block>(blocks.len()),
227 marker: PhantomData,
228 }
229 }
230
231 pub unsafe fn from_raw_parts(bits: *const Block, offset: u64, len: u64) -> Self {
243 let address = Address::new::<Block>(offset);
244 BitSlice {
245 bits: bits.add(address.block_index),
246 span: SliceSpan::new::<Block>(address.bit_offset as u8, len),
247 marker: PhantomData,
248 }
249 }
250
251 pub fn len(&self) -> u64 {
265 self.span.len
266 }
267
268 pub fn is_empty(&self) -> bool {
283 self.len() == 0
284 }
285}
286
287impl<'a, Block: BlockType> BitSliceMut<'a, Block> {
288 pub fn from_slice(blocks: &mut [Block]) -> Self {
293 BitSliceMut {
294 bits: blocks.as_mut_ptr(),
295 span: SliceSpan::from_block_len::<Block>(blocks.len()),
296 marker: PhantomData,
297 }
298 }
299
300 pub unsafe fn from_raw_parts(bits: *mut Block, offset: u64, len: u64) -> Self {
312 let address = Address::new::<Block>(offset);
313 BitSliceMut {
314 bits: bits.add(address.block_index),
315 span: SliceSpan::new::<Block>(address.bit_offset as u8, len),
316 marker: PhantomData,
317 }
318 }
319
320 pub fn len(&self) -> u64 {
322 self.span.len
323 }
324
325 pub fn is_empty(&self) -> bool {
327 self.len() == 0
328 }
329
330 pub fn as_bit_slice(&self) -> BitSlice<'a, Block> {
332 BitSlice {
333 bits: self.bits,
334 span: self.span,
335 marker: PhantomData,
336 }
337 }
338}
339
340impl<'a, 'b, Block: BlockType> From<&'b BitSliceMut<'a, Block>> for BitSlice<'a, Block> {
341 fn from(slice: &'b BitSliceMut<'a, Block>) -> Self {
342 slice.as_bit_slice()
343 }
344}
345
346impl<'a, Block: BlockType> From<&'a [Block]> for BitSlice<'a, Block> {
347 fn from(slice: &'a [Block]) -> Self {
348 BitSlice::from_slice(slice)
349 }
350}
351
352impl<'a, Block: BlockType> From<&'a mut [Block]> for BitSliceMut<'a, Block> {
353 fn from(slice: &'a mut [Block]) -> Self {
354 BitSliceMut::from_slice(slice)
355 }
356}
357
358unsafe fn get_raw_bit<Block: BlockType>(bits: *const Block, address: Address) -> bool {
362 let ptr = bits.add(address.block_index);
363 let block = ptr::read(ptr);
364 block.get_bit(address.bit_offset)
365}
366
367unsafe fn set_raw_bit<Block: BlockType>(bits: *mut Block, address: Address, value: bool) {
371 let ptr = bits.add(address.block_index);
372 let old_block = ptr::read(ptr);
373 let new_block = old_block.with_bit(address.bit_offset, value);
374 ptr::write(ptr, new_block);
375}
376
377impl<Block: BlockType> Bits for BitSlice<'_, Block> {
378 type Block = Block;
379
380 fn bit_len(&self) -> u64 {
381 self.len()
382 }
383
384 fn get_bit(&self, position: u64) -> bool {
385 let address = self
386 .span
387 .find_bit::<Block>(position)
388 .expect("BitSlice::get_bit: out of bounds");
389 unsafe { get_raw_bit(self.bits, address) }
390 }
391
392 fn get_block(&self, position: usize) -> Block {
393 get_masked_block(self, position)
394 }
395
396 fn get_raw_block(&self, position: usize) -> Block {
397 let block_addr = self
398 .span
399 .find_block::<Block>(position)
400 .expect("BitSlice::get_block: out of bounds");
401 unsafe { block_addr.read(self.bits) }
402 }
403
404 fn get_bits(&self, start: u64, count: usize) -> Self::Block {
405 let block_addr = self
406 .span
407 .find_bits::<Block>(start, count)
408 .expect("BitSlice::get_bits: out of bounds");
409 unsafe { block_addr.read(self.bits) }
410 }
411}
412
413impl<Block: BlockType> Bits for BitSliceMut<'_, Block> {
414 type Block = Block;
415
416 fn bit_len(&self) -> u64 {
417 self.len()
418 }
419
420 fn get_bit(&self, position: u64) -> bool {
421 let address = self
422 .span
423 .find_bit::<Block>(position)
424 .expect("BitSliceMut::get_bit: out of bounds");
425 unsafe { get_raw_bit(self.bits, address) }
426 }
427
428 fn get_block(&self, position: usize) -> Block {
429 let block_addr = self
430 .span
431 .find_block::<Block>(position)
432 .expect("BitSliceMut::get_block: out of bounds");
433 unsafe { block_addr.read(self.bits) }
434 }
435
436 fn get_bits(&self, start: u64, count: usize) -> Self::Block {
437 let block_addr = self
438 .span
439 .find_bits::<Block>(start, count)
440 .expect("BitSliceMut::get_bits: out of bounds");
441 unsafe { block_addr.read(self.bits) }
442 }
443}
444
445impl<Block: BlockType> BitsMut for BitSliceMut<'_, Block> {
446 fn set_bit(&mut self, position: u64, value: bool) {
447 let address = self
448 .span
449 .find_bit::<Block>(position)
450 .expect("BitSliceMut::set_bit: out of bounds");
451 unsafe {
452 set_raw_bit(self.bits, address, value);
453 }
454 }
455
456 fn set_block(&mut self, position: usize, value: Block) {
457 let block_addr = self
458 .span
459 .find_block::<Block>(position)
460 .expect("BitSliceMut::set_block: out of bounds");
461 unsafe {
462 block_addr.write(self.bits, value);
463 }
464 }
465
466 fn set_bits(&mut self, start: u64, count: usize, value: Self::Block) {
467 let block_addr = self
468 .span
469 .find_bits::<Block>(start, count)
470 .expect("BitSliceMut::set_bits: out of bounds");
471 unsafe {
472 block_addr.write(self.bits, value);
473 }
474 }
475}
476
477impl_index_from_bits! {
478 impl['a, Block: BlockType] Index<u64> for BitSlice<'a, Block>;
479 impl['a, Block: BlockType] Index<u64> for BitSliceMut<'a, Block>;
480}
481
482impl<Block: BlockType> BitSliceable<Range<u64>> for BitSlice<'_, Block> {
483 type Slice = Self;
484
485 fn bit_slice(self, range: Range<u64>) -> Self {
486 assert!(range.start <= range.end, "BitSlice::slice: bad range");
487 assert!(range.end <= self.span.len, "BitSlice::slice: out of bounds");
488
489 unsafe {
490 BitSlice::from_raw_parts(
491 self.bits,
492 range.start + u64::from(self.span.offset),
493 range.end - range.start,
494 )
495 }
496 }
497}
498
499impl<Block: BlockType> BitSliceable<Range<u64>> for BitSliceMut<'_, Block> {
500 type Slice = Self;
501
502 fn bit_slice(self, range: Range<u64>) -> Self {
503 assert!(range.start <= range.end, "BitSliceMut::slice: bad range");
504 assert!(
505 range.end <= self.span.len,
506 "BitSliceMut::slice: out of bounds"
507 );
508
509 unsafe {
510 BitSliceMut::from_raw_parts(
511 self.bits,
512 range.start + u64::from(self.span.offset),
513 range.end - range.start,
514 )
515 }
516 }
517}
518
519impl<Block: BlockType> BitSliceable<RangeInclusive<u64>> for BitSlice<'_, Block> {
520 type Slice = Self;
521
522 fn bit_slice(self, range: RangeInclusive<u64>) -> Self {
523 let (start, end) = get_inclusive_bounds(range).expect("BitSlice::slice: bad range");
524 assert!(end < self.span.len, "BitSlice::slice: out of bounds");
525
526 unsafe {
527 BitSlice::from_raw_parts(
528 self.bits,
529 start + u64::from(self.span.offset),
530 end - start + 1,
531 )
532 }
533 }
534}
535
536impl<Block: BlockType> BitSliceable<RangeInclusive<u64>> for BitSliceMut<'_, Block> {
537 type Slice = Self;
538
539 fn bit_slice(self, range: RangeInclusive<u64>) -> Self {
540 let (start, end) = get_inclusive_bounds(range).expect("BitSliceMut::slice: bad range");
541 assert!(end < self.span.len, "BitSliceMut::slice: out of bounds");
542
543 unsafe {
544 BitSliceMut::from_raw_parts(
545 self.bits,
546 start + u64::from(self.span.offset),
547 end - start + 1,
548 )
549 }
550 }
551}
552
553impl<Block: BlockType> BitSliceable<RangeFrom<u64>> for BitSlice<'_, Block> {
554 type Slice = Self;
555
556 fn bit_slice(self, range: RangeFrom<u64>) -> Self {
557 let len = self.span.len;
558 self.bit_slice(range.start..len)
559 }
560}
561
562impl<Block: BlockType> BitSliceable<RangeFrom<u64>> for BitSliceMut<'_, Block> {
563 type Slice = Self;
564
565 fn bit_slice(self, range: RangeFrom<u64>) -> Self {
566 let len = self.span.len;
567 self.bit_slice(range.start..len)
568 }
569}
570
571impl<Block: BlockType> BitSliceable<RangeTo<u64>> for BitSlice<'_, Block> {
572 type Slice = Self;
573
574 fn bit_slice(self, range: RangeTo<u64>) -> Self {
575 self.bit_slice(0..range.end)
576 }
577}
578
579impl<Block: BlockType> BitSliceable<RangeTo<u64>> for BitSliceMut<'_, Block> {
580 type Slice = Self;
581
582 fn bit_slice(self, range: RangeTo<u64>) -> Self {
583 self.bit_slice(0..range.end)
584 }
585}
586
587impl<Block: BlockType> BitSliceable<RangeToInclusive<u64>> for BitSlice<'_, Block> {
588 type Slice = Self;
589
590 fn bit_slice(self, range: RangeToInclusive<u64>) -> Self {
591 self.bit_slice(0..range.end + 1)
592 }
593}
594
595impl<Block: BlockType> BitSliceable<RangeToInclusive<u64>> for BitSliceMut<'_, Block> {
596 type Slice = Self;
597
598 fn bit_slice(self, range: RangeToInclusive<u64>) -> Self {
599 self.bit_slice(0..range.end + 1)
600 }
601}
602
603impl<Block: BlockType> BitSliceable<RangeFull> for BitSlice<'_, Block> {
604 type Slice = Self;
605
606 fn bit_slice(self, _: RangeFull) -> Self {
607 self
608 }
609}
610
611impl<Block: BlockType> BitSliceable<RangeFull> for BitSliceMut<'_, Block> {
612 type Slice = Self;
613
614 fn bit_slice(self, _: RangeFull) -> Self {
615 self
616 }
617}
618
619impl<'a, Block, R> BitSliceable<R> for &'a [Block]
620where
621 Block: BlockType,
622 BitSlice<'a, Block>: BitSliceable<R, Block = Block, Slice = BitSlice<'a, Block>>,
623{
624 type Slice = BitSlice<'a, Block>;
625
626 fn bit_slice(self, range: R) -> Self::Slice {
627 BitSlice::from_slice(self).bit_slice(range)
628 }
629}
630
631impl<'a, Block, R> BitSliceable<R> for &'a mut [Block]
632where
633 Block: BlockType,
634 BitSliceMut<'a, Block>: BitSliceable<R, Block = Block, Slice = BitSliceMut<'a, Block>>,
635{
636 type Slice = BitSliceMut<'a, Block>;
637
638 fn bit_slice(self, range: R) -> Self::Slice {
639 BitSliceMut::from_slice(self).bit_slice(range)
640 }
641}
642
643impl<Other: Bits> PartialEq<Other> for BitSlice<'_, Other::Block> {
644 fn eq(&self, other: &Other) -> bool {
645 BlockIter::new(self) == BlockIter::new(other)
646 }
647}
648
649impl<Block: BlockType> Eq for BitSlice<'_, Block> {}
650
651impl<Block: BlockType> PartialOrd for BitSlice<'_, Block> {
652 fn partial_cmp(&self, other: &BitSlice<Block>) -> Option<cmp::Ordering> {
653 Some(self.cmp(other))
654 }
655}
656
657impl<Block: BlockType> Ord for BitSlice<'_, Block> {
658 fn cmp(&self, other: &Self) -> cmp::Ordering {
659 let iter1 = BlockIter::new(*self);
660 let iter2 = BlockIter::new(*other);
661 (iter1).cmp(iter2)
662 }
663}
664
665impl<Other: Bits> PartialEq<Other> for BitSliceMut<'_, Other::Block> {
666 fn eq(&self, other: &Other) -> bool {
667 BlockIter::new(self) == BlockIter::new(other)
668 }
669}
670
671impl<Block: BlockType> Eq for BitSliceMut<'_, Block> {}
672
673impl<Block: BlockType> PartialOrd for BitSliceMut<'_, Block> {
674 fn partial_cmp(&self, other: &BitSliceMut<Block>) -> Option<cmp::Ordering> {
675 Some(self.cmp(other))
676 }
677}
678
679impl<Block: BlockType> Ord for BitSliceMut<'_, Block> {
680 fn cmp(&self, other: &Self) -> cmp::Ordering {
681 self.as_bit_slice().cmp(&other.as_bit_slice())
682 }
683}
684
685impl<Block: BlockType + hash::Hash> hash::Hash for BitSlice<'_, Block> {
686 fn hash<H: hash::Hasher>(&self, state: &mut H) {
687 state.write_u64(self.bit_len());
688 for block in BlockIter::new(self) {
689 block.hash(state);
690 }
691 }
692}
693
694impl<Block: BlockType + hash::Hash> hash::Hash for BitSliceMut<'_, Block> {
695 fn hash<H: hash::Hasher>(&self, state: &mut H) {
696 self.as_bit_slice().hash(state);
697 }
698}
699
700impl<Block: BlockType> fmt::Debug for BitSlice<'_, Block> {
701 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
702 write!(f, "bit_vec![")?;
703 if !self.is_empty() {
704 write!(f, "{}", self.get_bit(0))?;
705 }
706 for i in 1..self.span.len {
707 write!(f, ", {}", self.get_bit(i))?;
708 }
709 write!(f, "]")
710 }
711}
712
713impl<Block: BlockType> fmt::Debug for BitSliceMut<'_, Block> {
714 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
715 self.as_bit_slice().fmt(f)
716 }
717}
718
719#[cfg(test)]
720mod test {
721 use super::*;
722 use crate::BitVec;
723 use alloc::format;
724
725 #[test]
726 fn bit_slice_from_slice() {
727 let mut bytes = [0b00001111u8];
728 {
729 let mut bs = BitSliceMut::from_slice(&mut bytes);
730 assert_eq!(bs.get_block(0), 0b00001111);
731 bs.set_bit(1, false);
732 assert_eq!(bs.get_block(0), 0b00001101);
733 }
734
735 assert_eq!(bytes[0], 0b00001101);
736 }
737
738 #[test]
739 fn bit_slice_index() {
740 let mut bytes = [0b00001111u8];
741 {
742 let bs = BitSlice::from_slice(&bytes);
743 assert!(bs[3]);
744 assert!(!bs[4]);
745 }
746 {
747 let bs = BitSliceMut::from_slice(&mut bytes);
748 assert!(bs[3]);
749 assert!(!bs[4]);
750 }
751 }
752
753 #[test]
754 fn bit_slice_update_across_blocks() {
755 let mut bv: BitVec<u8> = bit_vec![ true; 20 ];
756 bv.set_bit(3, false);
757 bv.set_bit(7, false);
758
759 {
760 let mut slice: BitSliceMut<u8> = (&mut bv).bit_slice(4..12);
761 slice.set_bit(1, false);
762 slice.set_bit(5, false);
763 }
764
765 assert!(bv[0]);
766 assert!(bv[1]);
767 assert!(bv[2]);
768 assert!(!bv[3]);
769 assert!(bv[4]);
770 assert!(!bv[5]);
771 assert!(bv[6]);
772 assert!(!bv[7]);
773 assert!(bv[8]);
774 assert!(!bv[9]);
775 }
776
777 #[test]
778 fn debug_for_bit_slice() {
779 let slice = [0b00110101u8];
780 let bs = BitSlice::from_slice(&slice);
781 let exp = "bit_vec![true, false, true, false, true, true, false, false]";
782 let act = format!("{:?}", bs);
783 assert_eq!(act, exp);
784 }
785
786 #[test]
787 fn range_to_inclusive() {
788 use BitSliceable;
789
790 let base = [0b00110101u8];
791 let slice = base.bit_slice(::core::ops::RangeToInclusive { end: 4 });
792 assert_eq!(slice.len(), 5);
793 }
794}