use super::BitMap;
use bytes::{Buf, BufMut};
use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum Error {
#[error("pruned_chunks * CHUNK_SIZE_BITS overflows u64")]
PrunedChunksOverflow,
}
#[derive(Clone, Debug)]
pub struct Prunable<const N: usize> {
bitmap: BitMap<N>,
pruned_chunks: usize,
}
impl<const N: usize> Prunable<N> {
pub const CHUNK_SIZE_BITS: u64 = BitMap::<N>::CHUNK_SIZE_BITS;
pub const fn new() -> Self {
Self {
bitmap: BitMap::new(),
pruned_chunks: 0,
}
}
pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
let pruned_chunks_u64 = pruned_chunks as u64;
pruned_chunks_u64
.checked_mul(Self::CHUNK_SIZE_BITS)
.ok_or(Error::PrunedChunksOverflow)?;
Ok(Self {
bitmap: BitMap::new(),
pruned_chunks,
})
}
#[inline]
pub const fn len(&self) -> u64 {
let pruned_bits = (self.pruned_chunks as u64)
.checked_mul(Self::CHUNK_SIZE_BITS)
.expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64");
pruned_bits
.checked_add(self.bitmap.len())
.expect("invariant violated: pruned_bits + bitmap.len() overflows u64")
}
#[inline]
pub const fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub const fn is_chunk_aligned(&self) -> bool {
self.len().is_multiple_of(Self::CHUNK_SIZE_BITS)
}
#[inline]
pub fn chunks_len(&self) -> usize {
self.pruned_chunks + self.bitmap.chunks_len()
}
#[inline]
pub const fn pruned_chunks(&self) -> usize {
self.pruned_chunks
}
#[inline]
pub fn complete_chunks(&self) -> usize {
self.pruned_chunks
+ self
.bitmap
.chunks_len()
.saturating_sub(if self.is_chunk_aligned() { 0 } else { 1 })
}
#[inline]
pub const fn pruned_bits(&self) -> u64 {
(self.pruned_chunks as u64)
.checked_mul(Self::CHUNK_SIZE_BITS)
.expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64")
}
#[inline]
pub fn get_bit(&self, bit: u64) -> bool {
let chunk_num = Self::to_chunk_index(bit);
assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
self.bitmap.get(bit - self.pruned_bits())
}
#[inline]
pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
let chunk_num = Self::to_chunk_index(bit);
assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
self.bitmap.get_chunk_containing(bit - self.pruned_bits())
}
#[inline]
pub const fn get_bit_from_chunk(chunk: &[u8; N], bit: u64) -> bool {
BitMap::<N>::get_bit_from_chunk(chunk, bit)
}
#[inline]
pub fn last_chunk(&self) -> (&[u8; N], u64) {
self.bitmap.last_chunk()
}
pub fn set_bit(&mut self, bit: u64, value: bool) {
let chunk_num = Self::to_chunk_index(bit);
assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
self.bitmap.set(bit - self.pruned_bits(), value);
}
pub fn push(&mut self, bit: bool) {
self.bitmap.push(bit);
}
pub fn pop(&mut self) -> bool {
self.bitmap.pop()
}
pub fn push_byte(&mut self, byte: u8) {
self.bitmap.push_byte(byte);
}
pub fn push_chunk(&mut self, chunk: &[u8; N]) {
self.bitmap.push_chunk(chunk);
}
pub fn pop_chunk(&mut self) -> [u8; N] {
self.bitmap.pop_chunk()
}
pub fn prune_to_bit(&mut self, bit: u64) {
assert!(
bit <= self.len(),
"bit {} out of bounds (len: {})",
bit,
self.len()
);
let chunk = Self::to_chunk_index(bit);
if chunk < self.pruned_chunks {
return;
}
let chunks_to_prune = chunk - self.pruned_chunks;
self.bitmap.prune_chunks(chunks_to_prune);
self.pruned_chunks = chunk;
}
#[inline]
pub const fn chunk_byte_bitmask(bit: u64) -> u8 {
BitMap::<N>::chunk_byte_bitmask(bit)
}
#[inline]
pub const fn chunk_byte_offset(bit: u64) -> usize {
BitMap::<N>::chunk_byte_offset(bit)
}
#[inline]
pub fn to_chunk_index(bit: u64) -> usize {
BitMap::<N>::to_chunk_index(bit)
}
#[inline]
pub fn get_chunk(&self, chunk: usize) -> &[u8; N] {
assert!(
chunk >= self.pruned_chunks,
"chunk {chunk} is pruned (pruned_chunks: {})",
self.pruned_chunks
);
self.bitmap.get_chunk(chunk - self.pruned_chunks)
}
pub(super) fn set_chunk_by_index(&mut self, chunk_index: usize, chunk_data: &[u8; N]) {
assert!(
chunk_index >= self.pruned_chunks,
"cannot set pruned chunk {chunk_index} (pruned_chunks: {})",
self.pruned_chunks
);
let bitmap_chunk_idx = chunk_index - self.pruned_chunks;
self.bitmap.set_chunk_by_index(bitmap_chunk_idx, chunk_data);
}
pub(super) fn unprune_chunks(&mut self, chunks: &[[u8; N]]) {
assert!(
chunks.len() <= self.pruned_chunks,
"cannot unprune {} chunks (only {} pruned)",
chunks.len(),
self.pruned_chunks
);
for chunk in chunks.iter() {
self.bitmap.prepend_chunk(chunk);
}
self.pruned_chunks -= chunks.len();
}
}
impl<const N: usize> Default for Prunable<N> {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> Write for Prunable<N> {
fn write(&self, buf: &mut impl BufMut) {
(self.pruned_chunks as u64).write(buf);
self.bitmap.write(buf);
}
}
impl<const N: usize> Read for Prunable<N> {
type Cfg = u64;
fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, CodecError> {
let pruned_chunks_u64 = u64::read(buf)?;
let pruned_bits =
pruned_chunks_u64
.checked_mul(Self::CHUNK_SIZE_BITS)
.ok_or(CodecError::Invalid(
"Prunable",
"pruned_chunks would overflow when computing pruned_bits",
))?;
let pruned_chunks = usize::try_from(pruned_chunks_u64)
.map_err(|_| CodecError::Invalid("Prunable", "pruned_chunks doesn't fit in usize"))?;
let bitmap = BitMap::<N>::read_cfg(buf, max_len)?;
pruned_bits
.checked_add(bitmap.len())
.ok_or(CodecError::Invalid(
"Prunable",
"total bitmap length (pruned + unpruned) would overflow u64",
))?;
Ok(Self {
bitmap,
pruned_chunks,
})
}
}
impl<const N: usize> EncodeSize for Prunable<N> {
fn encode_size(&self) -> usize {
(self.pruned_chunks as u64).encode_size() + self.bitmap.encode_size()
}
}
#[cfg(feature = "arbitrary")]
impl<const N: usize> arbitrary::Arbitrary<'_> for Prunable<N> {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let mut bitmap = Self {
bitmap: BitMap::<N>::arbitrary(u)?,
pruned_chunks: 0,
};
let prune_to = u.int_in_range(0..=bitmap.len())?;
bitmap.prune_to_bit(prune_to);
Ok(bitmap)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hex;
use bytes::BytesMut;
use commonware_codec::Encode;
#[test]
fn test_new() {
let prunable: Prunable<32> = Prunable::new();
assert_eq!(prunable.len(), 0);
assert_eq!(prunable.pruned_bits(), 0);
assert_eq!(prunable.pruned_chunks(), 0);
assert!(prunable.is_empty());
assert_eq!(prunable.chunks_len(), 0); }
#[test]
fn test_new_with_pruned_chunks() {
let prunable: Prunable<2> = Prunable::new_with_pruned_chunks(1).unwrap();
assert_eq!(prunable.len(), 16);
assert_eq!(prunable.pruned_bits(), 16);
assert_eq!(prunable.pruned_chunks(), 1);
assert_eq!(prunable.chunks_len(), 1);
}
#[test]
fn test_new_with_pruned_chunks_overflow() {
let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) as usize + 1;
let result = Prunable::<4>::new_with_pruned_chunks(overflowing_pruned_chunks);
assert!(matches!(result, Err(Error::PrunedChunksOverflow)));
}
#[test]
fn test_push_and_get_bits() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push(true);
prunable.push(false);
prunable.push(true);
assert_eq!(prunable.len(), 3);
assert!(!prunable.is_empty());
assert!(prunable.get_bit(0));
assert!(!prunable.get_bit(1));
assert!(prunable.get_bit(2));
}
#[test]
fn test_push_byte() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_byte(0xFF);
assert_eq!(prunable.len(), 8);
for i in 0..8 {
assert!(prunable.get_bit(i as u64));
}
prunable.push_byte(0x00);
assert_eq!(prunable.len(), 16);
for i in 8..16 {
assert!(!prunable.get_bit(i as u64));
}
}
#[test]
fn test_push_chunk() {
let mut prunable: Prunable<4> = Prunable::new();
let chunk = hex!("0xAABBCCDD");
prunable.push_chunk(&chunk);
assert_eq!(prunable.len(), 32);
let retrieved_chunk = prunable.get_chunk_containing(0);
assert_eq!(retrieved_chunk, &chunk);
}
#[test]
fn test_set_bit() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push(false);
prunable.push(false);
prunable.push(false);
assert!(!prunable.get_bit(1));
prunable.set_bit(1, true);
assert!(prunable.get_bit(1));
prunable.set_bit(1, false);
assert!(!prunable.get_bit(1));
}
#[test]
fn test_pruning_basic() {
let mut prunable: Prunable<4> = Prunable::new();
let chunk1 = hex!("0x01020304");
let chunk2 = hex!("0x05060708");
let chunk3 = hex!("0x090A0B0C");
prunable.push_chunk(&chunk1);
prunable.push_chunk(&chunk2);
prunable.push_chunk(&chunk3);
assert_eq!(prunable.len(), 96); assert_eq!(prunable.pruned_chunks(), 0);
prunable.prune_to_bit(32);
assert_eq!(prunable.pruned_chunks(), 1);
assert_eq!(prunable.pruned_bits(), 32);
assert_eq!(prunable.len(), 96);
assert_eq!(prunable.get_chunk_containing(32), &chunk2);
assert_eq!(prunable.get_chunk_containing(64), &chunk3);
prunable.prune_to_bit(64);
assert_eq!(prunable.pruned_chunks(), 2);
assert_eq!(prunable.pruned_bits(), 64);
assert_eq!(prunable.len(), 96);
assert_eq!(prunable.get_chunk_containing(64), &chunk3);
}
#[test]
#[should_panic(expected = "bit pruned")]
fn test_get_pruned_bit_panics() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.prune_to_bit(32);
prunable.get_bit(0);
}
#[test]
#[should_panic(expected = "bit pruned")]
fn test_get_pruned_chunk_panics() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.prune_to_bit(32);
prunable.get_chunk_containing(0);
}
#[test]
#[should_panic(expected = "bit pruned")]
fn test_set_pruned_bit_panics() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.prune_to_bit(32);
prunable.set_bit(0, true);
}
#[test]
#[should_panic(expected = "bit 25 out of bounds (len: 24)")]
fn test_prune_to_bit_out_of_bounds() {
let mut prunable: Prunable<1> = Prunable::new();
prunable.push_byte(1);
prunable.push_byte(2);
prunable.push_byte(3);
prunable.prune_to_bit(25);
}
#[test]
fn test_pruning_with_partial_chunk() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[0xFF; 4]);
prunable.push_chunk(&[0xAA; 4]);
prunable.push(true);
prunable.push(false);
prunable.push(true);
assert_eq!(prunable.len(), 67);
prunable.prune_to_bit(32);
assert_eq!(prunable.pruned_chunks(), 1);
assert_eq!(prunable.len(), 67);
assert!(prunable.get_bit(64));
assert!(!prunable.get_bit(65));
assert!(prunable.get_bit(66));
}
#[test]
fn test_prune_idempotent() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.prune_to_bit(32);
assert_eq!(prunable.pruned_chunks(), 1);
prunable.prune_to_bit(32);
assert_eq!(prunable.pruned_chunks(), 1);
prunable.prune_to_bit(16);
assert_eq!(prunable.pruned_chunks(), 1);
}
#[test]
fn test_push_after_pruning() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.prune_to_bit(32);
assert_eq!(prunable.len(), 64);
assert_eq!(prunable.pruned_chunks(), 1);
prunable.push_chunk(&[9, 10, 11, 12]);
assert_eq!(prunable.len(), 96);
assert_eq!(prunable.get_chunk_containing(64), &[9, 10, 11, 12]);
}
#[test]
fn test_chunk_calculations() {
assert_eq!(Prunable::<4>::to_chunk_index(0), 0);
assert_eq!(Prunable::<4>::to_chunk_index(31), 0);
assert_eq!(Prunable::<4>::to_chunk_index(32), 1);
assert_eq!(Prunable::<4>::to_chunk_index(63), 1);
assert_eq!(Prunable::<4>::to_chunk_index(64), 2);
assert_eq!(Prunable::<4>::chunk_byte_offset(0), 0);
assert_eq!(Prunable::<4>::chunk_byte_offset(8), 1);
assert_eq!(Prunable::<4>::chunk_byte_offset(16), 2);
assert_eq!(Prunable::<4>::chunk_byte_offset(24), 3);
assert_eq!(Prunable::<4>::chunk_byte_offset(32), 0);
assert_eq!(Prunable::<4>::chunk_byte_bitmask(0), 0b00000001);
assert_eq!(Prunable::<4>::chunk_byte_bitmask(1), 0b00000010);
assert_eq!(Prunable::<4>::chunk_byte_bitmask(7), 0b10000000);
assert_eq!(Prunable::<4>::chunk_byte_bitmask(8), 0b00000001); }
#[test]
fn test_last_chunk_with_pruning() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.push(true);
prunable.push(false);
let (_, next_bit) = prunable.last_chunk();
assert_eq!(next_bit, 2);
let chunk_data = *prunable.last_chunk().0;
prunable.prune_to_bit(32);
let (chunk2, next_bit2) = prunable.last_chunk();
assert_eq!(next_bit2, 2);
assert_eq!(&chunk_data, chunk2);
}
#[test]
fn test_different_chunk_sizes() {
let mut p8: Prunable<8> = Prunable::new();
let mut p16: Prunable<16> = Prunable::new();
let mut p32: Prunable<32> = Prunable::new();
for i in 0..10 {
p8.push(i % 2 == 0);
p16.push(i % 2 == 0);
p32.push(i % 2 == 0);
}
assert_eq!(p8.len(), 10);
assert_eq!(p16.len(), 10);
assert_eq!(p32.len(), 10);
for i in 0..10 {
let expected = i % 2 == 0;
if expected {
assert!(p8.get_bit(i));
assert!(p16.get_bit(i));
assert!(p32.get_bit(i));
} else {
assert!(!p8.get_bit(i));
assert!(!p16.get_bit(i));
assert!(!p32.get_bit(i));
}
}
}
#[test]
fn test_get_bit_from_chunk() {
let chunk: [u8; 4] = [0b10101010, 0b11001100, 0b11110000, 0b00001111];
assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 0));
assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 1));
assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 2));
assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 3));
assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 8));
assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 9));
assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 10));
assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 11));
}
#[test]
fn test_get_chunk() {
let mut prunable: Prunable<4> = Prunable::new();
let chunk1 = hex!("0x11223344");
let chunk2 = hex!("0x55667788");
let chunk3 = hex!("0x99AABBCC");
prunable.push_chunk(&chunk1);
prunable.push_chunk(&chunk2);
prunable.push_chunk(&chunk3);
assert_eq!(prunable.get_chunk(0), &chunk1);
assert_eq!(prunable.get_chunk(1), &chunk2);
assert_eq!(prunable.get_chunk(2), &chunk3);
prunable.prune_to_bit(32);
assert_eq!(prunable.get_chunk(1), &chunk2);
assert_eq!(prunable.get_chunk(2), &chunk3);
}
#[test]
fn test_pop() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push(true);
prunable.push(false);
prunable.push(true);
assert_eq!(prunable.len(), 3);
assert!(prunable.pop());
assert_eq!(prunable.len(), 2);
assert!(!prunable.pop());
assert_eq!(prunable.len(), 1);
assert!(prunable.pop());
assert_eq!(prunable.len(), 0);
assert!(prunable.is_empty());
for i in 0..100 {
prunable.push(i % 3 == 0);
}
assert_eq!(prunable.len(), 100);
for i in (0..100).rev() {
let expected = i % 3 == 0;
assert_eq!(prunable.pop(), expected);
assert_eq!(prunable.len(), i);
}
assert!(prunable.is_empty());
}
#[test]
fn test_pop_chunk() {
let mut prunable: Prunable<4> = Prunable::new();
const CHUNK_SIZE: u64 = Prunable::<4>::CHUNK_SIZE_BITS;
let chunk1 = hex!("0xAABBCCDD");
prunable.push_chunk(&chunk1);
assert_eq!(prunable.len(), CHUNK_SIZE);
let popped = prunable.pop_chunk();
assert_eq!(popped, chunk1);
assert_eq!(prunable.len(), 0);
assert!(prunable.is_empty());
let chunk2 = hex!("0x11223344");
let chunk3 = hex!("0x55667788");
let chunk4 = hex!("0x99AABBCC");
prunable.push_chunk(&chunk2);
prunable.push_chunk(&chunk3);
prunable.push_chunk(&chunk4);
assert_eq!(prunable.len(), CHUNK_SIZE * 3);
assert_eq!(prunable.pop_chunk(), chunk4);
assert_eq!(prunable.len(), CHUNK_SIZE * 2);
assert_eq!(prunable.pop_chunk(), chunk3);
assert_eq!(prunable.len(), CHUNK_SIZE);
assert_eq!(prunable.pop_chunk(), chunk2);
assert_eq!(prunable.len(), 0);
prunable = Prunable::new();
let first_chunk = hex!("0xAABBCCDD");
let second_chunk = hex!("0x11223344");
prunable.push_chunk(&first_chunk);
prunable.push_chunk(&second_chunk);
assert_eq!(prunable.pop_chunk(), second_chunk);
assert_eq!(prunable.len(), CHUNK_SIZE);
for i in 0..CHUNK_SIZE {
let byte_idx = (i / 8) as usize;
let bit_idx = i % 8;
let expected = (first_chunk[byte_idx] >> bit_idx) & 1 == 1;
assert_eq!(prunable.get_bit(i), expected);
}
assert_eq!(prunable.pop_chunk(), first_chunk);
assert_eq!(prunable.len(), 0);
}
#[test]
#[should_panic(expected = "cannot pop chunk when not chunk aligned")]
fn test_pop_chunk_not_aligned() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[0xFF; 4]);
prunable.push(true);
prunable.pop_chunk();
}
#[test]
#[should_panic(expected = "cannot pop chunk: bitmap has fewer than CHUNK_SIZE_BITS bits")]
fn test_pop_chunk_insufficient_bits() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push(true);
prunable.push(false);
prunable.pop_chunk();
}
#[test]
fn test_write_read_empty() {
let original: Prunable<4> = Prunable::new();
let encoded = original.encode();
let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
assert_eq!(decoded.len(), original.len());
assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
assert!(decoded.is_empty());
}
#[test]
fn test_write_read_non_empty() {
let mut original: Prunable<4> = Prunable::new();
original.push_chunk(&hex!("0xAABBCCDD"));
original.push_chunk(&hex!("0x11223344"));
original.push(true);
original.push(false);
original.push(true);
let encoded = original.encode();
let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
assert_eq!(decoded.len(), original.len());
assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
assert_eq!(decoded.len(), 67);
for i in 0..original.len() {
assert_eq!(decoded.get_bit(i), original.get_bit(i));
}
}
#[test]
fn test_write_read_with_pruning() {
let mut original: Prunable<4> = Prunable::new();
original.push_chunk(&hex!("0x01020304"));
original.push_chunk(&hex!("0x05060708"));
original.push_chunk(&hex!("0x090A0B0C"));
original.prune_to_bit(32);
assert_eq!(original.pruned_chunks(), 1);
assert_eq!(original.len(), 96);
let encoded = original.encode();
let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
assert_eq!(decoded.len(), original.len());
assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
assert_eq!(decoded.pruned_chunks(), 1);
assert_eq!(decoded.len(), 96);
assert_eq!(decoded.get_chunk_containing(32), &hex!("0x05060708"));
assert_eq!(decoded.get_chunk_containing(64), &hex!("0x090A0B0C"));
}
#[test]
fn test_write_read_with_pruning_2() {
let mut original: Prunable<4> = Prunable::new();
for i in 0..5 {
let chunk = [
(i * 4) as u8,
(i * 4 + 1) as u8,
(i * 4 + 2) as u8,
(i * 4 + 3) as u8,
];
original.push_chunk(&chunk);
}
original.prune_to_bit(96); assert_eq!(original.pruned_chunks(), 3);
assert_eq!(original.len(), 160);
let encoded = original.encode();
let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
assert_eq!(decoded.len(), original.len());
assert_eq!(decoded.pruned_chunks(), 3);
for i in 96..original.len() {
assert_eq!(decoded.get_bit(i), original.get_bit(i));
}
}
#[test]
fn test_encode_size_matches() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.push(true);
let size = prunable.encode_size();
let encoded = prunable.encode();
assert_eq!(size, encoded.len());
}
#[test]
fn test_encode_size_with_pruning() {
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.push_chunk(&[9, 10, 11, 12]);
prunable.prune_to_bit(32);
let size = prunable.encode_size();
let encoded = prunable.encode();
assert_eq!(size, encoded.len());
}
#[test]
fn test_read_max_len_validation() {
let mut original: Prunable<4> = Prunable::new();
for _ in 0..10 {
original.push(true);
}
let encoded = original.encode();
assert!(Prunable::<4>::read_cfg(&mut encoded.as_ref(), &100).is_ok());
let result = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &5);
assert!(result.is_err());
}
#[test]
fn test_codec_roundtrip_different_chunk_sizes() {
let mut p8: Prunable<8> = Prunable::new();
let mut p16: Prunable<16> = Prunable::new();
let mut p32: Prunable<32> = Prunable::new();
for i in 0..100 {
let bit = i % 3 == 0;
p8.push(bit);
p16.push(bit);
p32.push(bit);
}
let encoded8 = p8.encode();
let decoded8 = Prunable::<8>::read_cfg(&mut encoded8.as_ref(), &u64::MAX).unwrap();
assert_eq!(decoded8.len(), p8.len());
let encoded16 = p16.encode();
let decoded16 = Prunable::<16>::read_cfg(&mut encoded16.as_ref(), &u64::MAX).unwrap();
assert_eq!(decoded16.len(), p16.len());
let encoded32 = p32.encode();
let decoded32 = Prunable::<32>::read_cfg(&mut encoded32.as_ref(), &u64::MAX).unwrap();
assert_eq!(decoded32.len(), p32.len());
}
#[test]
fn test_read_pruned_chunks_overflow() {
let mut buf = BytesMut::new();
let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) + 1;
overflowing_pruned_chunks.write(&mut buf);
0u64.write(&mut buf);
let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
match result {
Err(CodecError::Invalid(type_name, msg)) => {
assert_eq!(type_name, "Prunable");
assert_eq!(
msg,
"pruned_chunks would overflow when computing pruned_bits"
);
}
Ok(_) => panic!("Expected error but got Ok"),
Err(e) => panic!("Expected Invalid error for pruned_bits overflow, got: {e:?}"),
}
}
#[test]
fn test_read_total_length_overflow() {
let mut buf = BytesMut::new();
let max_safe_pruned_chunks = u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS;
let pruned_bits = max_safe_pruned_chunks * Prunable::<4>::CHUNK_SIZE_BITS;
let remaining_space = u64::MAX - pruned_bits;
let bitmap_len = remaining_space + 1;
max_safe_pruned_chunks.write(&mut buf);
bitmap_len.write(&mut buf);
let num_chunks = bitmap_len.div_ceil(Prunable::<4>::CHUNK_SIZE_BITS);
for _ in 0..(num_chunks * 4) {
0u8.write(&mut buf);
}
let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
match result {
Err(CodecError::Invalid(type_name, msg)) => {
assert_eq!(type_name, "Prunable");
assert_eq!(
msg,
"total bitmap length (pruned + unpruned) would overflow u64"
);
}
Ok(_) => panic!("Expected error but got Ok"),
Err(e) => panic!("Expected Invalid error for total length overflow, got: {e:?}"),
}
}
#[test]
fn test_is_chunk_aligned() {
let prunable: Prunable<4> = Prunable::new();
assert!(prunable.is_chunk_aligned());
let mut prunable: Prunable<4> = Prunable::new();
for i in 1..=32 {
prunable.push(i % 2 == 0);
if i == 32 {
assert!(prunable.is_chunk_aligned()); } else {
assert!(!prunable.is_chunk_aligned()); }
}
for i in 33..=64 {
prunable.push(i % 2 == 0);
if i == 64 {
assert!(prunable.is_chunk_aligned()); } else {
assert!(!prunable.is_chunk_aligned()); }
}
let mut prunable: Prunable<4> = Prunable::new();
assert!(prunable.is_chunk_aligned());
prunable.push_chunk(&[1, 2, 3, 4]);
assert!(prunable.is_chunk_aligned()); prunable.push_chunk(&[5, 6, 7, 8]);
assert!(prunable.is_chunk_aligned()); prunable.push(true);
assert!(!prunable.is_chunk_aligned());
let mut prunable: Prunable<4> = Prunable::new();
prunable.push_chunk(&[1, 2, 3, 4]);
prunable.push_chunk(&[5, 6, 7, 8]);
prunable.push_chunk(&[9, 10, 11, 12]);
assert!(prunable.is_chunk_aligned());
prunable.prune_to_bit(32);
assert!(prunable.is_chunk_aligned());
assert_eq!(prunable.len(), 96);
prunable.push(true);
prunable.push(false);
assert!(!prunable.is_chunk_aligned());
prunable.prune_to_bit(64);
assert!(!prunable.is_chunk_aligned());
let prunable: Prunable<4> = Prunable::new_with_pruned_chunks(2).unwrap();
assert!(prunable.is_chunk_aligned());
let mut prunable: Prunable<4> = Prunable::new_with_pruned_chunks(1).unwrap();
assert!(prunable.is_chunk_aligned()); prunable.push(true);
assert!(!prunable.is_chunk_aligned());
let mut prunable: Prunable<4> = Prunable::new();
for _ in 0..4 {
prunable.push_byte(0xFF);
}
assert!(prunable.is_chunk_aligned());
prunable.pop();
assert!(!prunable.is_chunk_aligned());
for _ in 0..31 {
prunable.pop();
}
assert!(prunable.is_chunk_aligned()); }
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<Prunable<16>>,
}
}
}