use bytes::Bytes;
use crate::codec::BitVector;
use crate::codec::succinct::{SuccinctBitVector, WaveletTree};
const MAGIC: &[u8; 4] = b"WTRE";
const VERSION: u8 = 1;
const HEADER_SIZE: usize = 40;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PackedWaveletError {
TruncatedHeader,
BadMagic,
UnsupportedVersion(u8),
Truncated {
region: &'static str,
},
SizeOverflow,
BitCountMismatch {
level: usize,
expected: u64,
actual: u64,
},
InvariantViolation(crate::codec::succinct::WaveletInvariantError),
}
impl std::fmt::Display for PackedWaveletError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TruncatedHeader => write!(f, "packed wavelet header truncated"),
Self::BadMagic => write!(f, "packed wavelet bad magic (expected 'WTRE')"),
Self::UnsupportedVersion(v) => write!(f, "packed wavelet unsupported version {v}"),
Self::Truncated { region } => write!(f, "packed wavelet truncated in {region}"),
Self::SizeOverflow => write!(f, "packed wavelet size field overflows usize"),
Self::BitCountMismatch {
level,
expected,
actual,
} => write!(
f,
"packed wavelet bit count mismatch at level {level}: expected {expected}, got {actual}"
),
Self::InvariantViolation(e) => write!(f, "packed wavelet invariant violation: {e}"),
}
}
}
impl std::error::Error for PackedWaveletError {}
#[must_use]
pub fn serialize_wavelet_tree(tree: &WaveletTree) -> Vec<u8> {
let symbols = tree.symbols_slice();
let height = tree.height();
let levels = tree.levels_slice();
let sigma = tree.sigma();
let len = tree.len() as u64;
let symbols_bytes = symbols.len() * 8;
let level_bytes: usize = levels
.iter()
.map(|sbv| 16 + sbv.inner().data_bytes().len())
.sum();
let total = HEADER_SIZE + symbols_bytes + level_bytes;
let mut buf = Vec::with_capacity(total);
buf.extend_from_slice(MAGIC); buf.push(VERSION); buf.extend_from_slice(&[0u8; 3]); buf.extend_from_slice(&u32::try_from(height).unwrap_or(u32::MAX).to_le_bytes()); buf.extend_from_slice(&[0u8; 4]); buf.extend_from_slice(&sigma.to_le_bytes()); buf.extend_from_slice(&len.to_le_bytes()); buf.extend_from_slice(&(symbols.len() as u64).to_le_bytes());
for &sym in symbols {
buf.extend_from_slice(&sym.to_le_bytes());
}
for sbv in levels {
let bv = sbv.inner();
let bit_count = bv.len() as u64;
let word_data = bv.data_bytes();
let word_count = (word_data.len() / 8) as u64;
buf.extend_from_slice(&bit_count.to_le_bytes());
buf.extend_from_slice(&word_count.to_le_bytes());
buf.extend_from_slice(word_data);
}
buf
}
pub fn deserialize_wavelet_tree(data: Bytes) -> Result<WaveletTree, PackedWaveletError> {
if data.len() < HEADER_SIZE {
return Err(PackedWaveletError::TruncatedHeader);
}
if &data[0..4] != MAGIC {
return Err(PackedWaveletError::BadMagic);
}
let version = data[4];
if version != VERSION {
return Err(PackedWaveletError::UnsupportedVersion(version));
}
let height_raw = u32::from_le_bytes(data[8..12].try_into().expect("4-byte slice"));
let sigma = u64::from_le_bytes(data[16..24].try_into().expect("8-byte slice"));
let len_raw = u64::from_le_bytes(data[24..32].try_into().expect("8-byte slice"));
let symbol_count_raw = u64::from_le_bytes(data[32..40].try_into().expect("8-byte slice"));
let height = usize::try_from(height_raw).map_err(|_| PackedWaveletError::SizeOverflow)?;
let len_usize = usize::try_from(len_raw).map_err(|_| PackedWaveletError::SizeOverflow)?;
let symbol_count =
usize::try_from(symbol_count_raw).map_err(|_| PackedWaveletError::SizeOverflow)?;
let mut cursor = HEADER_SIZE;
let symbols_bytes = symbol_count
.checked_mul(8)
.ok_or(PackedWaveletError::SizeOverflow)?;
let symbols_end = cursor
.checked_add(symbols_bytes)
.ok_or(PackedWaveletError::SizeOverflow)?;
if symbols_end > data.len() {
return Err(PackedWaveletError::Truncated { region: "symbols" });
}
let mut symbols: Vec<u64> = Vec::with_capacity(symbol_count);
for i in 0..symbol_count {
let off = cursor + i * 8;
let chunk: [u8; 8] = data[off..off + 8].try_into().expect("8-byte slice");
symbols.push(u64::from_le_bytes(chunk));
}
cursor = symbols_end;
let mut levels: Vec<SuccinctBitVector> = Vec::with_capacity(height);
for level_idx in 0..height {
let level_header_end = cursor
.checked_add(16)
.ok_or(PackedWaveletError::SizeOverflow)?;
if level_header_end > data.len() {
return Err(PackedWaveletError::Truncated {
region: "level header",
});
}
let bit_count =
u64::from_le_bytes(data[cursor..cursor + 8].try_into().expect("8-byte slice"));
let word_count = u64::from_le_bytes(
data[cursor + 8..cursor + 16]
.try_into()
.expect("8-byte slice"),
);
cursor = level_header_end;
if bit_count != len_raw {
return Err(PackedWaveletError::BitCountMismatch {
level: level_idx,
expected: len_raw,
actual: bit_count,
});
}
let level_bytes = usize::try_from(
word_count
.checked_mul(8)
.ok_or(PackedWaveletError::SizeOverflow)?,
)
.map_err(|_| PackedWaveletError::SizeOverflow)?;
let level_data_end = cursor
.checked_add(level_bytes)
.ok_or(PackedWaveletError::SizeOverflow)?;
if level_data_end > data.len() {
return Err(PackedWaveletError::Truncated {
region: "level data",
});
}
let level_slice = data.slice(cursor..level_data_end);
cursor = level_data_end;
let bv = BitVector::from_mmap(level_slice, len_usize).map_err(|_| {
PackedWaveletError::Truncated {
region: "level bits",
}
})?;
levels.push(SuccinctBitVector::from_bitvec(bv));
}
WaveletTree::from_packed_parts(levels, height, sigma, len_usize, symbols)
.map_err(PackedWaveletError::InvariantViolation)
}
#[cfg(test)]
mod tests {
use super::*;
fn build_tree(seq: &[u64]) -> WaveletTree {
WaveletTree::new(seq)
}
fn assert_trees_equal(orig: &WaveletTree, restored: &WaveletTree) {
assert_eq!(orig.len(), restored.len());
assert_eq!(orig.sigma(), restored.sigma());
for i in 0..orig.len() {
assert_eq!(
orig.access(i),
restored.access(i),
"access mismatch at position {i}"
);
}
}
#[test]
fn alix_packed_wavelet_roundtrip_small() {
let seq = vec![1u64, 3, 2, 1, 2, 3, 1, 2];
let tree = build_tree(&seq);
let bytes = serialize_wavelet_tree(&tree);
let restored = deserialize_wavelet_tree(Bytes::from(bytes)).expect("deserialize");
assert_trees_equal(&tree, &restored);
}
#[test]
fn gus_packed_wavelet_roundtrip_large() {
let seq: Vec<u64> = (0..1024u64).map(|i| (i * 7) % 16).collect();
let tree = build_tree(&seq);
let bytes = serialize_wavelet_tree(&tree);
let restored = deserialize_wavelet_tree(Bytes::from(bytes)).expect("deserialize");
assert_trees_equal(&tree, &restored);
}
#[test]
fn vincent_packed_wavelet_empty() {
let tree = WaveletTree::new(&[]);
let bytes = serialize_wavelet_tree(&tree);
let restored = deserialize_wavelet_tree(Bytes::from(bytes)).expect("deserialize");
assert_eq!(restored.len(), 0);
assert!(restored.is_empty());
}
#[test]
fn jules_packed_wavelet_single_symbol() {
let seq = vec![42u64; 16];
let tree = build_tree(&seq);
let bytes = serialize_wavelet_tree(&tree);
let restored = deserialize_wavelet_tree(Bytes::from(bytes)).expect("deserialize");
assert_trees_equal(&tree, &restored);
}
#[test]
fn mia_packed_wavelet_bad_magic_rejected() {
let bad = Bytes::from(vec![0u8; HEADER_SIZE]);
assert_eq!(
deserialize_wavelet_tree(bad).unwrap_err(),
PackedWaveletError::BadMagic
);
}
#[test]
fn shosanna_packed_wavelet_truncated_header_rejected() {
let short = Bytes::from(vec![b'W', b'T', b'R', b'E']);
assert_eq!(
deserialize_wavelet_tree(short).unwrap_err(),
PackedWaveletError::TruncatedHeader
);
}
#[test]
fn beatrix_packed_wavelet_unsupported_version_rejected() {
let mut buf = vec![0u8; HEADER_SIZE];
buf[..4].copy_from_slice(MAGIC);
buf[4] = 99;
assert_eq!(
deserialize_wavelet_tree(Bytes::from(buf)).unwrap_err(),
PackedWaveletError::UnsupportedVersion(99)
);
}
#[test]
fn hans_packed_wavelet_size_smaller_than_bincode() {
let seq: Vec<u64> = (0..2048u64).map(|i| (i * 11) % 64).collect();
let tree = build_tree(&seq);
let v2_bytes = serialize_wavelet_tree(&tree);
let v1_bytes = bincode::serde::encode_to_vec(&tree, bincode::config::standard())
.expect("bincode encode");
eprintln!(
"v1 bincode: {} bytes, v2 packed: {} bytes",
v1_bytes.len(),
v2_bytes.len()
);
assert!(
v2_bytes.len() < v1_bytes.len(),
"v2 packed must be smaller than v1 bincode (v1={}, v2={})",
v1_bytes.len(),
v2_bytes.len()
);
}
#[test]
fn django_packed_wavelet_zero_copy_per_level() {
let seq = vec![1u64, 2, 3, 4, 5, 6, 7, 8];
let tree = build_tree(&seq);
let bytes = serialize_wavelet_tree(&tree);
let source = Bytes::from(bytes);
let source_ptr = source.as_ptr();
let source_len = source.len();
let restored = deserialize_wavelet_tree(source).expect("deserialize");
for (idx, level) in restored.levels_slice().iter().enumerate() {
let inner_ptr = level.inner().data_bytes().as_ptr();
let offset = inner_ptr as usize - source_ptr as usize;
assert!(
offset < source_len,
"level {idx}: inner BitVector should be inside source allocation; offset={offset}"
);
}
}
}