use super::{Bitmask, MaskError};
pub fn pack(bits: &[bool]) -> Result<Vec<u8>, MaskError> {
let n_bytes = bits.len().div_ceil(8);
let mut out: Vec<u8> = Vec::new();
out.try_reserve_exact(n_bytes)
.map_err(|e| MaskError::AllocationFailed {
bytes: n_bytes,
reason: e.to_string(),
})?;
out.resize(n_bytes, 0);
for (i, &b) in bits.iter().enumerate() {
if b {
out[i / 8] |= 1 << (7 - (i % 8));
}
}
Ok(out)
}
pub fn unpack(bytes: &[u8], n_elements: usize) -> Result<Bitmask, MaskError> {
let expected = n_elements.div_ceil(8);
if bytes.len() != expected {
return Err(MaskError::LengthMismatch {
expected,
actual: bytes.len(),
});
}
let mut out: Vec<bool> = Vec::new();
super::try_reserve_mask(&mut out, n_elements)?;
for i in 0..n_elements {
let bit = (bytes[i / 8] >> (7 - (i % 8))) & 1;
out.push(bit == 1);
}
Ok(out)
}
pub fn popcount(bits: &[bool]) -> usize {
bits.iter().filter(|b| **b).count()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_empty() {
assert!(pack(&[]).unwrap().is_empty());
}
#[test]
fn pack_single_bit_set() {
assert_eq!(pack(&[true]).unwrap(), vec![0b1000_0000]);
assert_eq!(pack(&[false]).unwrap(), vec![0b0000_0000]);
}
#[test]
fn pack_full_byte_msb_first() {
let bits = [true, true, true, true, false, false, false, false];
assert_eq!(pack(&bits).unwrap(), vec![0b1111_0000]);
}
#[test]
fn pack_partial_final_byte_zero_filled() {
let bits = [
true, false, true, false, true, false, true, false, true, false,
];
assert_eq!(pack(&bits).unwrap(), vec![0b1010_1010, 0b1000_0000]);
}
#[test]
fn pack_multi_byte() {
let bits: Vec<bool> = (0..16).map(|i| i % 2 == 0).collect();
assert_eq!(pack(&bits).unwrap(), vec![0xAA, 0xAA]);
}
#[test]
fn unpack_roundtrip() {
for n in [1usize, 7, 8, 9, 15, 16, 17, 64, 65, 256, 1000] {
let bits: Vec<bool> = (0..n).map(|i| i % 3 == 0 || i % 7 == 1).collect();
let packed = pack(&bits).unwrap();
assert_eq!(packed.len(), n.div_ceil(8));
let unpacked = unpack(&packed, n).unwrap();
assert_eq!(unpacked, bits, "roundtrip failed for n={n}");
}
}
#[test]
fn unpack_length_mismatch() {
let packed = vec![0xFF; 3]; let err = unpack(&packed, 32).unwrap_err();
assert!(matches!(
err,
MaskError::LengthMismatch {
expected: 4,
actual: 3
}
));
}
#[test]
fn unpack_lenient_on_trailing_bits() {
let packed = vec![0xFF, 0xFC];
let unpacked = unpack(&packed, 10).unwrap();
assert_eq!(unpacked, vec![true; 10]);
}
#[test]
fn popcount_various() {
assert_eq!(popcount(&[]), 0);
assert_eq!(popcount(&[false; 10]), 0);
assert_eq!(popcount(&[true; 10]), 10);
assert_eq!(popcount(&[true, false, true, true, false, true]), 4);
}
}