nuclear_router/bitset/
fixed_bitset.rs1#![allow(unsafe_code)]
2
3use super::table::TABLE;
4use std::{mem, slice};
5
6pub unsafe trait BitStorage: Sized {
7 fn bit_size() -> usize {
8 mem::size_of::<Self>() * 8
9 }
10
11 fn as_bytes(&self) -> &[u8] {
12 unsafe { slice::from_raw_parts(self as *const _ as *const u8, mem::size_of::<Self>()) }
13 }
14 fn as_bytes_mut(&mut self) -> &mut [u8] {
15 unsafe { slice::from_raw_parts_mut(self as *mut _ as *mut u8, mem::size_of::<Self>()) }
16 }
17}
18
19unsafe impl BitStorage for [u128; 4] {}
20unsafe impl BitStorage for [u128; 2] {}
21unsafe impl BitStorage for u128 {}
22unsafe impl BitStorage for u64 {}
23unsafe impl BitStorage for u32 {}
24unsafe impl BitStorage for u16 {}
25unsafe impl BitStorage for u8 {}
26
27#[derive(Debug, Clone)]
28pub struct FixedBitSet<S: BitStorage> {
29 buf: S,
30}
31
32impl<S: BitStorage> FixedBitSet<S> {
33 pub fn zero() -> Self {
34 Self {
35 buf: unsafe { mem::zeroed() },
36 }
37 }
38
39 pub fn one() -> Self {
40 let mut set = Self {
41 buf: unsafe { mem::MaybeUninit::uninit().assume_init() },
42 };
43 set.buf
44 .as_bytes_mut()
45 .iter_mut()
46 .for_each(|x| *x = u8::max_value());
47 set
48 }
49}
50
51impl<S: BitStorage> FixedBitSet<S> {
52 pub fn intersect_with(&mut self, other: &Self) {
53 self.buf
54 .as_bytes_mut()
55 .iter_mut()
56 .zip(other.buf.as_bytes().iter())
57 .for_each(|(lhs, rhs)| *lhs &= rhs)
58 }
59
60 pub fn union_with(&mut self, other: &Self) {
61 self.buf
62 .as_bytes_mut()
63 .iter_mut()
64 .zip(other.buf.as_bytes().iter())
65 .for_each(|(lhs, rhs)| *lhs |= rhs)
66 }
67
68 pub fn set(&mut self, index: usize, bit: bool) {
69 let idx = index / 8;
70 let offset: u8 = (index % 8) as _;
71 let mask = (bit as u8) << offset;
72 let bytes = self.buf.as_bytes_mut();
73 let pos: &mut u8 = match bytes.get_mut(idx) {
74 Some(pos) => pos,
75 None => panic!(
76 "bitset index out of bound: index = {}, bound = {}",
77 index,
78 S::bit_size()
79 ),
80 };
81 *pos |= mask
82 }
83
84 pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
85 self.buf
86 .as_bytes()
87 .iter()
88 .enumerate()
89 .flat_map(|(i, &x)| TABLE[x as usize].iter().map(move |&j| i * 8 + j))
90 }
91}