use hopper_runtime::error::ProgramError;
pub struct BitSet<'a> {
data: &'a mut [u8],
}
impl<'a> BitSet<'a> {
#[inline(always)]
pub fn from_bytes(data: &'a mut [u8]) -> Self {
Self { data }
}
#[inline(always)]
pub fn capacity(&self) -> usize {
self.data.len() * 8
}
#[inline(always)]
pub fn get(&self, index: usize) -> Result<bool, ProgramError> {
let byte_idx = index / 8;
let bit_idx = index % 8;
if byte_idx >= self.data.len() {
return Err(ProgramError::InvalidArgument);
}
Ok((self.data[byte_idx] >> bit_idx) & 1 == 1)
}
#[inline(always)]
pub fn set(&mut self, index: usize) -> Result<(), ProgramError> {
let byte_idx = index / 8;
let bit_idx = index % 8;
if byte_idx >= self.data.len() {
return Err(ProgramError::InvalidArgument);
}
self.data[byte_idx] |= 1 << bit_idx;
Ok(())
}
#[inline(always)]
pub fn clear(&mut self, index: usize) -> Result<(), ProgramError> {
let byte_idx = index / 8;
let bit_idx = index % 8;
if byte_idx >= self.data.len() {
return Err(ProgramError::InvalidArgument);
}
self.data[byte_idx] &= !(1 << bit_idx);
Ok(())
}
#[inline(always)]
pub fn toggle(&mut self, index: usize) -> Result<(), ProgramError> {
let byte_idx = index / 8;
let bit_idx = index % 8;
if byte_idx >= self.data.len() {
return Err(ProgramError::InvalidArgument);
}
self.data[byte_idx] ^= 1 << bit_idx;
Ok(())
}
#[inline]
pub fn count_ones(&self) -> usize {
let mut count = 0usize;
for &byte in self.data.iter() {
count += byte.count_ones() as usize;
}
count
}
#[inline]
pub fn count_zeros(&self) -> usize {
self.capacity() - self.count_ones()
}
#[inline]
pub fn check_flags(&self, byte_offset: usize, required: u8) -> Result<(), ProgramError> {
if byte_offset >= self.data.len() {
return Err(ProgramError::InvalidArgument);
}
if self.data[byte_offset] & required != required {
return Err(ProgramError::InvalidAccountData);
}
Ok(())
}
#[inline]
pub fn check_any_flag(&self, byte_offset: usize, any_of: u8) -> Result<(), ProgramError> {
if byte_offset >= self.data.len() {
return Err(ProgramError::InvalidArgument);
}
if self.data[byte_offset] & any_of == 0 {
return Err(ProgramError::InvalidAccountData);
}
Ok(())
}
#[inline(always)]
pub const fn required_bytes(num_bits: usize) -> usize {
num_bits.div_ceil(8)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn set_get_clear() {
let mut buf = [0u8; 4]; let mut bs = BitSet::from_bytes(&mut buf);
assert!(!bs.get(0).unwrap());
bs.set(0).unwrap();
assert!(bs.get(0).unwrap());
bs.clear(0).unwrap();
assert!(!bs.get(0).unwrap());
}
#[test]
fn toggle() {
let mut buf = [0u8; 1];
let mut bs = BitSet::from_bytes(&mut buf);
bs.toggle(3).unwrap();
assert!(bs.get(3).unwrap());
bs.toggle(3).unwrap();
assert!(!bs.get(3).unwrap());
}
#[test]
fn count_ones() {
let mut buf = [0b1010_0101u8, 0b1111_0000];
let bs = BitSet::from_bytes(&mut buf);
assert_eq!(bs.count_ones(), 4 + 4);
}
#[test]
fn out_of_bounds() {
let mut buf = [0u8; 1]; let mut bs = BitSet::from_bytes(&mut buf);
assert!(bs.get(8).is_err());
assert!(bs.set(8).is_err());
}
}