hopper_core/collections/
bit_set.rs1use hopper_runtime::error::ProgramError;
7
8pub struct BitSet<'a> {
14 data: &'a mut [u8],
15}
16
17impl<'a> BitSet<'a> {
18 #[inline(always)]
20 pub fn from_bytes(data: &'a mut [u8]) -> Self {
21 Self { data }
22 }
23
24 #[inline(always)]
26 pub fn capacity(&self) -> usize {
27 self.data.len() * 8
28 }
29
30 #[inline(always)]
32 pub fn get(&self, index: usize) -> Result<bool, ProgramError> {
33 let byte_idx = index / 8;
34 let bit_idx = index % 8;
35 if byte_idx >= self.data.len() {
36 return Err(ProgramError::InvalidArgument);
37 }
38 Ok((self.data[byte_idx] >> bit_idx) & 1 == 1)
39 }
40
41 #[inline(always)]
43 pub fn set(&mut self, index: usize) -> Result<(), ProgramError> {
44 let byte_idx = index / 8;
45 let bit_idx = index % 8;
46 if byte_idx >= self.data.len() {
47 return Err(ProgramError::InvalidArgument);
48 }
49 self.data[byte_idx] |= 1 << bit_idx;
50 Ok(())
51 }
52
53 #[inline(always)]
55 pub fn clear(&mut self, index: usize) -> Result<(), ProgramError> {
56 let byte_idx = index / 8;
57 let bit_idx = index % 8;
58 if byte_idx >= self.data.len() {
59 return Err(ProgramError::InvalidArgument);
60 }
61 self.data[byte_idx] &= !(1 << bit_idx);
62 Ok(())
63 }
64
65 #[inline(always)]
67 pub fn toggle(&mut self, index: usize) -> Result<(), ProgramError> {
68 let byte_idx = index / 8;
69 let bit_idx = index % 8;
70 if byte_idx >= self.data.len() {
71 return Err(ProgramError::InvalidArgument);
72 }
73 self.data[byte_idx] ^= 1 << bit_idx;
74 Ok(())
75 }
76
77 #[inline]
79 pub fn count_ones(&self) -> usize {
80 let mut count = 0usize;
81 for &byte in self.data.iter() {
82 count += byte.count_ones() as usize;
83 }
84 count
85 }
86
87 #[inline]
89 pub fn count_zeros(&self) -> usize {
90 self.capacity() - self.count_ones()
91 }
92
93 #[inline]
95 pub fn check_flags(&self, byte_offset: usize, required: u8) -> Result<(), ProgramError> {
96 if byte_offset >= self.data.len() {
97 return Err(ProgramError::InvalidArgument);
98 }
99 if self.data[byte_offset] & required != required {
100 return Err(ProgramError::InvalidAccountData);
101 }
102 Ok(())
103 }
104
105 #[inline]
107 pub fn check_any_flag(&self, byte_offset: usize, any_of: u8) -> Result<(), ProgramError> {
108 if byte_offset >= self.data.len() {
109 return Err(ProgramError::InvalidArgument);
110 }
111 if self.data[byte_offset] & any_of == 0 {
112 return Err(ProgramError::InvalidAccountData);
113 }
114 Ok(())
115 }
116
117 #[inline(always)]
119 pub const fn required_bytes(num_bits: usize) -> usize {
120 num_bits.div_ceil(8)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn set_get_clear() {
130 let mut buf = [0u8; 4]; let mut bs = BitSet::from_bytes(&mut buf);
132
133 assert!(!bs.get(0).unwrap());
134 bs.set(0).unwrap();
135 assert!(bs.get(0).unwrap());
136 bs.clear(0).unwrap();
137 assert!(!bs.get(0).unwrap());
138 }
139
140 #[test]
141 fn toggle() {
142 let mut buf = [0u8; 1];
143 let mut bs = BitSet::from_bytes(&mut buf);
144
145 bs.toggle(3).unwrap();
146 assert!(bs.get(3).unwrap());
147 bs.toggle(3).unwrap();
148 assert!(!bs.get(3).unwrap());
149 }
150
151 #[test]
152 fn count_ones() {
153 let mut buf = [0b1010_0101u8, 0b1111_0000];
154 let bs = BitSet::from_bytes(&mut buf);
155 assert_eq!(bs.count_ones(), 4 + 4);
156 }
157
158 #[test]
159 fn out_of_bounds() {
160 let mut buf = [0u8; 1]; let mut bs = BitSet::from_bytes(&mut buf);
162 assert!(bs.get(8).is_err());
163 assert!(bs.set(8).is_err());
164 }
165}