use crate::seed::SeedRng;
pub struct PaddingStrategy {
rng: SeedRng,
max_padding: usize,
}
impl PaddingStrategy {
pub fn new(rng: SeedRng, max_padding: usize) -> Self {
Self { rng, max_padding }
}
pub fn pad(&mut self, payload: &[u8]) -> Vec<u8> {
let pad_len = if self.max_padding == 0 {
0
} else {
self.rng.range(0, self.max_padding as u64) as usize
};
let mut result = Vec::with_capacity(2 + payload.len() + pad_len);
result.extend_from_slice(&(payload.len() as u16).to_be_bytes());
result.extend_from_slice(payload);
let fill_start = result.len();
result.resize(fill_start + pad_len, 0);
let mut off = fill_start;
while off < result.len() {
let w = self.rng.next_u64().to_be_bytes();
for b in w {
if off >= result.len() {
break;
}
result[off] = b;
off += 1;
}
}
result
}
pub fn unpad(data: &[u8]) -> Option<Vec<u8>> {
if data.len() < 2 {
return None;
}
let orig_len = u16::from_be_bytes([data[0], data[1]]) as usize;
if data.len() < 2 + orig_len {
return None;
}
Some(data[2..2 + orig_len].to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pad_unpad_roundtrip() {
let mut p = PaddingStrategy::new(SeedRng::new([0xABu8; 32]), 64);
let plain = b"secret";
let padded = p.pad(plain);
assert!(padded.len() >= 2 + plain.len());
assert_eq!(PaddingStrategy::unpad(&padded).unwrap(), plain);
}
#[test]
fn unpad_rejects_truncated() {
assert!(PaddingStrategy::unpad(&[0, 2, 1]).is_none());
}
#[test]
fn empty_payload_roundtrip() {
let mut p = PaddingStrategy::new(SeedRng::new([0x01u8; 32]), 32);
let padded = p.pad(b"");
assert!(padded.len() >= 2);
assert_eq!(PaddingStrategy::unpad(&padded).unwrap(), b"");
}
#[test]
fn zero_max_padding_adds_no_extra_bytes() {
let mut p = PaddingStrategy::new(SeedRng::new([0x02u8; 32]), 0);
let plain = b"hello";
let padded = p.pad(plain);
assert_eq!(padded.len(), 2 + plain.len());
assert_eq!(PaddingStrategy::unpad(&padded).unwrap(), plain);
}
#[test]
fn deterministic_padding_from_same_seed() {
let seed = [0x55u8; 32];
let mut p1 = PaddingStrategy::new(SeedRng::new(seed), 128);
let mut p2 = PaddingStrategy::new(SeedRng::new(seed), 128);
let plain = b"determinism-check";
assert_eq!(p1.pad(plain), p2.pad(plain));
}
#[test]
fn different_seed_different_padding() {
let mut p1 = PaddingStrategy::new(SeedRng::new([0x11u8; 32]), 128);
let mut p2 = PaddingStrategy::new(SeedRng::new([0x22u8; 32]), 128);
let plain = b"diff-check";
assert_ne!(p1.pad(plain), p2.pad(plain));
}
#[test]
fn unpad_rejects_too_short() {
assert!(PaddingStrategy::unpad(&[0]).is_none());
assert!(PaddingStrategy::unpad(&[]).is_none());
}
#[test]
fn unpad_rejects_corrupted_length() {
let bad = [0, 100, 1, 2, 3];
assert!(PaddingStrategy::unpad(&bad).is_none());
}
#[test]
fn large_payload_roundtrip() {
let mut p = PaddingStrategy::new(SeedRng::new([0xFFu8; 32]), 256);
let plain: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
let padded = p.pad(&plain);
assert!(padded.len() >= 2 + plain.len());
assert_eq!(PaddingStrategy::unpad(&padded).unwrap(), plain);
}
#[test]
fn multiple_pads_vary_in_length() {
let mut p = PaddingStrategy::new(SeedRng::new([0x33u8; 32]), 128);
let plain = b"test";
let lengths: Vec<usize> = (0..20).map(|_| p.pad(plain).len()).collect();
let unique: std::collections::HashSet<_> = lengths.iter().collect();
assert!(
unique.len() > 1,
"expected varying pad lengths, got {lengths:?}"
);
}
}