use roaring::RoaringBitmap;
use super::{Bitmask, MaskError};
pub fn encode(bits: &[bool]) -> Result<Vec<u8>, MaskError> {
if bits.len() > u32::MAX as usize {
return Err(MaskError::Malformed(format!(
"roaring mask length {} exceeds u32::MAX addressable positions",
bits.len()
)));
}
let mut bm = RoaringBitmap::new();
for (i, &b) in bits.iter().enumerate() {
if b {
bm.insert(i as u32);
}
}
bm.optimize();
let mut out = Vec::with_capacity(bm.serialized_size());
bm.serialize_into(&mut out)
.map_err(|e| MaskError::Roaring(format!("serialize: {e}")))?;
Ok(out)
}
pub fn decode(bytes: &[u8], n_elements: usize) -> Result<Bitmask, MaskError> {
if n_elements > u32::MAX as usize {
return Err(MaskError::Malformed(format!(
"roaring mask length {n_elements} exceeds u32::MAX addressable positions"
)));
}
let bm = RoaringBitmap::deserialize_from(bytes)
.map_err(|e| MaskError::Roaring(format!("deserialize: {e}")))?;
let mut out: Vec<bool> = Vec::new();
super::try_reserve_mask(&mut out, n_elements)?;
out.resize(n_elements, false);
for key in bm.iter() {
let k = key as usize;
if k >= n_elements {
return Err(MaskError::Malformed(format!(
"roaring mask has key {k} but only {n_elements} elements declared"
)));
}
out[k] = true;
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_empty() {
let enc = encode(&[]).unwrap();
let dec = decode(&enc, 0).unwrap();
assert!(dec.is_empty());
}
#[test]
fn single_set_bit() {
for n in [1, 2, 8, 64, 1000] {
for idx in [0, n / 2, n - 1] {
let mut bits = vec![false; n];
bits[idx] = true;
let enc = encode(&bits).unwrap();
let dec = decode(&enc, n).unwrap();
assert_eq!(dec, bits, "n={n} idx={idx}");
}
}
}
#[test]
fn all_zeros() {
let bits = vec![false; 1024];
let enc = encode(&bits).unwrap();
let dec = decode(&enc, 1024).unwrap();
assert_eq!(dec, bits);
}
#[test]
fn all_ones() {
let bits = vec![true; 1024];
let enc = encode(&bits).unwrap();
let dec = decode(&enc, 1024).unwrap();
assert_eq!(dec, bits);
}
#[test]
fn clustered_mask_compresses_well() {
let mut bits = Vec::new();
bits.extend(std::iter::repeat_n(false, 1000));
bits.extend(std::iter::repeat_n(true, 500));
bits.extend(std::iter::repeat_n(false, 5000));
bits.extend(std::iter::repeat_n(true, 500));
bits.extend(std::iter::repeat_n(false, 500));
let enc = encode(&bits).unwrap();
assert!(
enc.len() < 100,
"roaring should compress clustered masks well, got {} bytes",
enc.len()
);
let dec = decode(&enc, bits.len()).unwrap();
assert_eq!(dec, bits);
}
#[test]
fn random_sparse_mask() {
let mut rng: u64 = 0xABCDEF0123456789;
let bits: Vec<bool> = (0..10_000)
.map(|_| {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
rng.is_multiple_of(100)
})
.collect();
let enc = encode(&bits).unwrap();
let dec = decode(&enc, bits.len()).unwrap();
assert_eq!(dec, bits);
}
#[test]
fn random_dense_mask() {
let mut rng: u64 = 0xFEDCBA9876543210;
let bits: Vec<bool> = (0..10_000)
.map(|_| {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
(rng & 1) == 1
})
.collect();
let enc = encode(&bits).unwrap();
let dec = decode(&enc, bits.len()).unwrap();
assert_eq!(dec, bits);
}
#[test]
fn decode_key_out_of_range_rejected() {
let bits = (0..11).map(|i| i == 10).collect::<Vec<_>>();
let enc = encode(&bits).unwrap();
let err = decode(&enc, 5).unwrap_err();
assert!(matches!(err, MaskError::Malformed(_)));
}
#[test]
fn decode_garbage_input_rejected() {
let garbage = b"this is not a roaring bitmap";
let err = decode(garbage, 100).unwrap_err();
assert!(matches!(err, MaskError::Roaring(_)));
}
}