Skip to main content

hopper_core/collections/
bit_set.rs

1//! Compact bit array for flags and bitmask operations.
2//!
3//! Wire layout: raw bytes, each holding 8 bits. Bit 0 of byte 0 is index 0.
4//! No header -- capacity is derived from the byte slice length.
5
6use hopper_runtime::error::ProgramError;
7
8/// Compact bit array overlaid on a byte slice.
9///
10/// - O(1) get/set/clear/toggle per bit
11/// - No overhead -- 1 bit per flag
12/// - Used for feature flags, user permission masks, state bitfields
13pub struct BitSet<'a> {
14    data: &'a mut [u8],
15}
16
17impl<'a> BitSet<'a> {
18    /// Overlay a BitSet on a mutable byte slice.
19    #[inline(always)]
20    pub fn from_bytes(data: &'a mut [u8]) -> Self {
21        Self { data }
22    }
23
24    /// Number of bits available.
25    #[inline(always)]
26    pub fn capacity(&self) -> usize {
27        self.data.len() * 8
28    }
29
30    /// Get a bit by index.
31    #[inline(always)]
32    pub fn get(&self, index: usize) -> Result<bool, ProgramError> {
33        let byte_idx = index / 8;
34        let bit_idx = index % 8;
35        if byte_idx >= self.data.len() {
36            return Err(ProgramError::InvalidArgument);
37        }
38        Ok((self.data[byte_idx] >> bit_idx) & 1 == 1)
39    }
40
41    /// Set a bit to 1.
42    #[inline(always)]
43    pub fn set(&mut self, index: usize) -> Result<(), ProgramError> {
44        let byte_idx = index / 8;
45        let bit_idx = index % 8;
46        if byte_idx >= self.data.len() {
47            return Err(ProgramError::InvalidArgument);
48        }
49        self.data[byte_idx] |= 1 << bit_idx;
50        Ok(())
51    }
52
53    /// Clear a bit to 0.
54    #[inline(always)]
55    pub fn clear(&mut self, index: usize) -> Result<(), ProgramError> {
56        let byte_idx = index / 8;
57        let bit_idx = index % 8;
58        if byte_idx >= self.data.len() {
59            return Err(ProgramError::InvalidArgument);
60        }
61        self.data[byte_idx] &= !(1 << bit_idx);
62        Ok(())
63    }
64
65    /// Toggle a bit.
66    #[inline(always)]
67    pub fn toggle(&mut self, index: usize) -> Result<(), ProgramError> {
68        let byte_idx = index / 8;
69        let bit_idx = index % 8;
70        if byte_idx >= self.data.len() {
71            return Err(ProgramError::InvalidArgument);
72        }
73        self.data[byte_idx] ^= 1 << bit_idx;
74        Ok(())
75    }
76
77    /// Count the number of set bits (popcount).
78    #[inline]
79    pub fn count_ones(&self) -> usize {
80        let mut count = 0usize;
81        for &byte in self.data.iter() {
82            count += byte.count_ones() as usize;
83        }
84        count
85    }
86
87    /// Count the number of clear bits.
88    #[inline]
89    pub fn count_zeros(&self) -> usize {
90        self.capacity() - self.count_ones()
91    }
92
93    /// Check if ALL bits in a mask are set (starting at byte offset).
94    #[inline]
95    pub fn check_flags(&self, byte_offset: usize, required: u8) -> Result<(), ProgramError> {
96        if byte_offset >= self.data.len() {
97            return Err(ProgramError::InvalidArgument);
98        }
99        if self.data[byte_offset] & required != required {
100            return Err(ProgramError::InvalidAccountData);
101        }
102        Ok(())
103    }
104
105    /// Check if ANY bit in a mask is set.
106    #[inline]
107    pub fn check_any_flag(&self, byte_offset: usize, any_of: u8) -> Result<(), ProgramError> {
108        if byte_offset >= self.data.len() {
109            return Err(ProgramError::InvalidArgument);
110        }
111        if self.data[byte_offset] & any_of == 0 {
112            return Err(ProgramError::InvalidAccountData);
113        }
114        Ok(())
115    }
116
117    /// Compute the byte size needed for a BitSet with the given number of bits.
118    #[inline(always)]
119    pub const fn required_bytes(num_bits: usize) -> usize {
120        num_bits.div_ceil(8)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn set_get_clear() {
130        let mut buf = [0u8; 4]; // 32 bits
131        let mut bs = BitSet::from_bytes(&mut buf);
132
133        assert!(!bs.get(0).unwrap());
134        bs.set(0).unwrap();
135        assert!(bs.get(0).unwrap());
136        bs.clear(0).unwrap();
137        assert!(!bs.get(0).unwrap());
138    }
139
140    #[test]
141    fn toggle() {
142        let mut buf = [0u8; 1];
143        let mut bs = BitSet::from_bytes(&mut buf);
144
145        bs.toggle(3).unwrap();
146        assert!(bs.get(3).unwrap());
147        bs.toggle(3).unwrap();
148        assert!(!bs.get(3).unwrap());
149    }
150
151    #[test]
152    fn count_ones() {
153        let mut buf = [0b1010_0101u8, 0b1111_0000];
154        let bs = BitSet::from_bytes(&mut buf);
155        assert_eq!(bs.count_ones(), 4 + 4);
156    }
157
158    #[test]
159    fn out_of_bounds() {
160        let mut buf = [0u8; 1]; // 8 bits
161        let mut bs = BitSet::from_bytes(&mut buf);
162        assert!(bs.get(8).is_err());
163        assert!(bs.set(8).is_err());
164    }
165}