nuclear_router/bitset/
fixed_bitset.rs

1#![allow(unsafe_code)]
2
3use super::table::TABLE;
4use std::{mem, slice};
5
6pub unsafe trait BitStorage: Sized {
7    fn bit_size() -> usize {
8        mem::size_of::<Self>() * 8
9    }
10
11    fn as_bytes(&self) -> &[u8] {
12        unsafe { slice::from_raw_parts(self as *const _ as *const u8, mem::size_of::<Self>()) }
13    }
14    fn as_bytes_mut(&mut self) -> &mut [u8] {
15        unsafe { slice::from_raw_parts_mut(self as *mut _ as *mut u8, mem::size_of::<Self>()) }
16    }
17}
18
19unsafe impl BitStorage for [u128; 4] {}
20unsafe impl BitStorage for [u128; 2] {}
21unsafe impl BitStorage for u128 {}
22unsafe impl BitStorage for u64 {}
23unsafe impl BitStorage for u32 {}
24unsafe impl BitStorage for u16 {}
25unsafe impl BitStorage for u8 {}
26
27#[derive(Debug, Clone)]
28pub struct FixedBitSet<S: BitStorage> {
29    buf: S,
30}
31
32impl<S: BitStorage> FixedBitSet<S> {
33    pub fn zero() -> Self {
34        Self {
35            buf: unsafe { mem::zeroed() },
36        }
37    }
38
39    pub fn one() -> Self {
40        let mut set = Self {
41            buf: unsafe { mem::MaybeUninit::uninit().assume_init() },
42        };
43        set.buf
44            .as_bytes_mut()
45            .iter_mut()
46            .for_each(|x| *x = u8::max_value());
47        set
48    }
49}
50
51impl<S: BitStorage> FixedBitSet<S> {
52    pub fn intersect_with(&mut self, other: &Self) {
53        self.buf
54            .as_bytes_mut()
55            .iter_mut()
56            .zip(other.buf.as_bytes().iter())
57            .for_each(|(lhs, rhs)| *lhs &= rhs)
58    }
59
60    pub fn union_with(&mut self, other: &Self) {
61        self.buf
62            .as_bytes_mut()
63            .iter_mut()
64            .zip(other.buf.as_bytes().iter())
65            .for_each(|(lhs, rhs)| *lhs |= rhs)
66    }
67
68    pub fn set(&mut self, index: usize, bit: bool) {
69        let idx = index / 8;
70        let offset: u8 = (index % 8) as _;
71        let mask = (bit as u8) << offset;
72        let bytes = self.buf.as_bytes_mut();
73        let pos: &mut u8 = match bytes.get_mut(idx) {
74            Some(pos) => pos,
75            None => panic!(
76                "bitset index out of bound: index = {}, bound = {}",
77                index,
78                S::bit_size()
79            ),
80        };
81        *pos |= mask
82    }
83
84    pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
85        self.buf
86            .as_bytes()
87            .iter()
88            .enumerate()
89            .flat_map(|(i, &x)| TABLE[x as usize].iter().map(move |&j| i * 8 + j))
90    }
91}