use super::{Bitmask, MaskError};
pub fn encode(bits: &[bool]) -> Vec<u8> {
if bits.is_empty() {
return Vec::new();
}
let mut out = Vec::with_capacity(1 + bits.len() / 4);
out.push(if bits[0] { 1 } else { 0 });
let mut current = bits[0];
let mut run: u64 = 1;
for &b in &bits[1..] {
if b == current {
run += 1;
} else {
write_uleb128(&mut out, run);
current = b;
run = 1;
}
}
write_uleb128(&mut out, run);
out
}
pub fn decode(bytes: &[u8], n_elements: usize) -> Result<Bitmask, MaskError> {
if n_elements == 0 {
return if bytes.is_empty() {
Ok(Vec::new())
} else {
Err(MaskError::Rle(format!(
"zero-element mask must have empty payload; got {} bytes",
bytes.len()
)))
};
}
if bytes.is_empty() {
return Err(MaskError::Rle(format!(
"empty RLE payload but {n_elements} elements declared"
)));
}
let start_bit = match bytes[0] {
0 => false,
1 => true,
other => {
return Err(MaskError::Rle(format!(
"invalid start_bit byte {other:#04x} (must be 0x00 or 0x01)"
)));
}
};
let mut out: Vec<bool> = Vec::new();
super::try_reserve_mask(&mut out, n_elements)?;
let mut current = start_bit;
let mut cursor = 1;
while cursor < bytes.len() {
let (run, consumed) = read_uleb128(&bytes[cursor..])?;
cursor += consumed;
if run == 0 {
return Err(MaskError::Rle(
"zero-length run is not permitted by the RLE format".to_string(),
));
}
let run_usize = usize::try_from(run).map_err(|_| {
MaskError::Rle(format!(
"run count {run} exceeds usize — malformed or truncated"
))
})?;
let remaining = n_elements - out.len();
if run_usize > remaining {
return Err(MaskError::Rle(format!(
"run overruns element count: decoded {}..{} but only {n_elements} declared",
out.len(),
out.len().saturating_add(run_usize)
)));
}
for _ in 0..run_usize {
out.push(current);
}
current = !current;
}
if out.len() != n_elements {
return Err(MaskError::Rle(format!(
"RLE decoded {} elements but {n_elements} declared",
out.len()
)));
}
Ok(out)
}
fn write_uleb128(out: &mut Vec<u8>, mut value: u64) {
loop {
let byte = (value & 0x7F) as u8;
value >>= 7;
if value == 0 {
out.push(byte);
return;
}
out.push(byte | 0x80);
}
}
fn read_uleb128(bytes: &[u8]) -> Result<(u64, usize), MaskError> {
let mut value: u64 = 0;
let mut shift = 0;
for (i, &b) in bytes.iter().enumerate() {
if i >= 10 {
return Err(MaskError::Rle(
"ULEB128 integer overflows u64 — malformed payload".to_string(),
));
}
let chunk = (b & 0x7F) as u64;
value |= chunk
.checked_shl(shift as u32)
.ok_or_else(|| MaskError::Rle(format!("ULEB128 shift overflow at byte {i}")))?;
if (b & 0x80) == 0 {
return Ok((value, i + 1));
}
shift += 7;
}
Err(MaskError::Rle(
"truncated ULEB128: missing terminator byte".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_empty() {
assert_eq!(encode(&[]), Vec::<u8>::new());
}
#[test]
fn decode_empty() {
assert_eq!(decode(&[], 0).unwrap(), Vec::<bool>::new());
}
#[test]
fn decode_empty_mismatched() {
let err = decode(&[], 10).unwrap_err();
assert!(matches!(err, MaskError::Rle(_)));
}
#[test]
fn single_bit_zero() {
let enc = encode(&[false]);
assert_eq!(enc, vec![0x00, 0x01]);
assert_eq!(decode(&enc, 1).unwrap(), vec![false]);
}
#[test]
fn single_bit_one() {
let enc = encode(&[true]);
assert_eq!(enc, vec![0x01, 0x01]);
assert_eq!(decode(&enc, 1).unwrap(), vec![true]);
}
#[test]
fn all_zeros_encodes_as_single_run() {
let bits = vec![false; 1000];
let enc = encode(&bits);
assert_eq!(enc, vec![0x00, 0xE8, 0x07]);
assert_eq!(decode(&enc, 1000).unwrap(), bits);
}
#[test]
fn all_ones_encodes_as_single_run() {
let bits = vec![true; 64];
let enc = encode(&bits);
assert_eq!(enc, vec![0x01, 0x40]);
assert_eq!(decode(&enc, 64).unwrap(), bits);
}
#[test]
fn alternating_bits_is_worst_case() {
let bits: Vec<bool> = (0..8).map(|i| i % 2 == 0).collect();
let enc = encode(&bits);
assert_eq!(enc.len(), 1 + 8);
assert_eq!(decode(&enc, 8).unwrap(), bits);
}
#[test]
fn clustered_mask_round_trips() {
let mut bits = Vec::new();
bits.extend(std::iter::repeat_n(false, 100));
bits.extend(std::iter::repeat_n(true, 50));
bits.extend(std::iter::repeat_n(false, 200));
bits.extend(std::iter::repeat_n(true, 50));
let enc = encode(&bits);
assert!(
enc.len() < 10,
"clustered mask should be tiny: got {} bytes",
enc.len()
);
assert_eq!(decode(&enc, bits.len()).unwrap(), bits);
}
#[test]
fn large_mask_roundtrip() {
let mut bits = Vec::new();
let mut on = false;
for run_len in [50, 100, 5, 1000, 200, 7, 42, 98765] {
bits.extend(std::iter::repeat_n(on, run_len));
on = !on;
}
let n = bits.len();
let enc = encode(&bits);
assert_eq!(decode(&enc, n).unwrap(), bits);
}
#[test]
fn decode_invalid_start_bit() {
let err = decode(&[0x42, 0x01], 1).unwrap_err();
assert!(matches!(err, MaskError::Rle(_)));
}
#[test]
fn decode_truncated_varint() {
let err = decode(&[0x00, 0x80], 1).unwrap_err();
assert!(matches!(err, MaskError::Rle(_)));
}
#[test]
fn decode_run_overruns_declared_count() {
let err = decode(&[0x00, 0x0A], 3).unwrap_err();
assert!(matches!(err, MaskError::Rle(_)));
}
#[test]
fn decode_run_undershoots_declared_count() {
let enc = vec![0x00, 0x02, 0x01];
let err = decode(&enc, 5).unwrap_err();
assert!(matches!(err, MaskError::Rle(_)));
}
#[test]
fn uleb128_multi_byte_boundary() {
for val in [
0x80_u64,
0x3FFF,
0x4000,
0xFFFF,
0x100_000,
u32::MAX as u64,
u64::MAX,
] {
let mut buf = Vec::new();
write_uleb128(&mut buf, val);
let (got, consumed) = read_uleb128(&buf).unwrap();
assert_eq!(got, val);
assert_eq!(consumed, buf.len());
}
}
#[test]
fn sweep_random_masks_round_trip() {
let mut rng = rand_state(0xDEADBEEF_CAFEBABE);
for size in [0, 1, 2, 7, 8, 9, 63, 64, 100, 1_000, 10_000] {
let bits: Vec<bool> = (0..size).map(|_| rng_next(&mut rng) & 1 == 1).collect();
let enc = encode(&bits);
let dec = decode(&enc, size).unwrap();
assert_eq!(dec, bits, "size={size}");
}
}
type State = u64;
fn rand_state(seed: u64) -> State {
seed.max(1)
}
fn rng_next(s: &mut State) -> u64 {
*s ^= *s << 13;
*s ^= *s >> 7;
*s ^= *s << 17;
*s
}
}