use std::fmt::Debug;
use std::io;
use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
AeadCore, Aes256Gcm, Nonce,
};
pub trait PageCrypto: Send + Sync + Debug + 'static {
fn encrypt(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>>;
fn decrypt(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>>;
fn encryption_start_offset(&self) -> u64 {
0
}
fn overhead(&self) -> usize {
0
}
}
pub struct Aes256GcmPageCrypto {
cipher: Aes256Gcm,
skip_below_offset: u64,
}
impl Clone for Aes256GcmPageCrypto {
fn clone(&self) -> Self {
Self {
cipher: self.cipher.clone(),
skip_below_offset: self.skip_below_offset,
}
}
}
impl Debug for Aes256GcmPageCrypto {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Aes256GcmPageCrypto")
.field("skip_below_offset", &self.skip_below_offset)
.finish_non_exhaustive()
}
}
impl Aes256GcmPageCrypto {
const NONCE_SIZE: usize = 12;
const TAG_SIZE: usize = 16;
pub const OVERHEAD: usize = Self::NONCE_SIZE + Self::TAG_SIZE;
const MAGIC_COMPRESSED: [u8; 2] = [b'E', b'C']; const MAGIC_RAW: [u8; 2] = [b'E', b'R']; const INTERNAL_HEADER_SIZE: usize = 10;
pub fn new(key: &[u8; 32], skip_header: bool) -> Self {
Self {
cipher: Aes256Gcm::new(key.into()),
skip_below_offset: if skip_header { 4096 } else { 0 },
}
}
pub fn with_page_size(key: &[u8; 32], page_size: u64) -> Self {
Self {
cipher: Aes256Gcm::new(key.into()),
skip_below_offset: page_size,
}
}
pub fn with_skip_below_offset(mut self, offset: u64) -> Self {
self.skip_below_offset = offset;
self
}
fn generate_nonce() -> [u8; Self::NONCE_SIZE] {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let mut result = [0u8; Self::NONCE_SIZE];
result.copy_from_slice(&nonce);
result
}
}
impl PageCrypto for Aes256GcmPageCrypto {
fn encrypt(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>> {
if data.len() != page_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Input must be exactly page_size ({} bytes), got {}", page_size, data.len()),
));
}
if page_size <= Self::OVERHEAD {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Page size must be > {} bytes, got {}", Self::OVERHEAD, page_size),
));
}
if offset < self.skip_below_offset {
return Ok(data.to_vec());
}
let usable = page_size - Self::OVERHEAD;
let nonce = Self::generate_nonce();
let compressed = zstd::encode_all(data.as_ref(), 1)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Internal compression failed: {e}")))?;
let plaintext: Vec<u8>;
let max_compressed_size = usable - Self::INTERNAL_HEADER_SIZE;
if compressed.len() <= max_compressed_size {
plaintext = {
let mut p = Vec::with_capacity(usable);
p.extend_from_slice(&Self::MAGIC_COMPRESSED);
p.extend_from_slice(&(compressed.len() as u32).to_le_bytes());
p.extend_from_slice(&(data.len() as u32).to_le_bytes());
p.extend_from_slice(&compressed);
p.resize(usable, 0); p
};
} else {
plaintext = {
let mut p = Vec::with_capacity(usable);
p.extend_from_slice(&Self::MAGIC_RAW);
p.extend_from_slice(&data[..usable - 2]); p
};
}
debug_assert_eq!(plaintext.len(), usable);
let ciphertext_with_tag = self
.cipher
.encrypt(Nonce::from_slice(&nonce), plaintext.as_slice())
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Encryption failed: {e}")))?;
debug_assert_eq!(ciphertext_with_tag.len(), usable + Self::TAG_SIZE);
let mut output = Vec::with_capacity(page_size);
output.extend_from_slice(&nonce);
output.extend_from_slice(&ciphertext_with_tag);
debug_assert_eq!(output.len(), page_size);
Ok(output)
}
fn decrypt(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>> {
if data.len() != page_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Input must be exactly page_size ({} bytes), got {}", page_size, data.len()),
));
}
if page_size <= Self::OVERHEAD {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Page size must be > {} bytes, got {}", Self::OVERHEAD, page_size),
));
}
if offset < self.skip_below_offset {
return Ok(data.to_vec());
}
let nonce = &data[..Self::NONCE_SIZE];
let ciphertext_with_tag = &data[Self::NONCE_SIZE..];
if nonce.iter().all(|&b| b == 0)
&& ciphertext_with_tag.len() >= 8
&& ciphertext_with_tag[..8].iter().all(|&b| b == 0)
{
return Ok(data.to_vec());
}
let plaintext = self
.cipher
.decrypt(Nonce::from_slice(nonce), ciphertext_with_tag)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Decryption failed: {e}")))?;
let magic = &plaintext[0..2];
if magic == Self::MAGIC_COMPRESSED {
let compressed_len = u32::from_le_bytes([plaintext[2], plaintext[3], plaintext[4], plaintext[5]]) as usize;
let orig_len = u32::from_le_bytes([plaintext[6], plaintext[7], plaintext[8], plaintext[9]]) as usize;
let usable = page_size - Self::OVERHEAD;
let max_compressed_size = usable - Self::INTERNAL_HEADER_SIZE;
if compressed_len > 0 && compressed_len <= max_compressed_size && orig_len == page_size {
let compressed = &plaintext[Self::INTERNAL_HEADER_SIZE..Self::INTERNAL_HEADER_SIZE + compressed_len];
let decompressed = zstd::decode_all(compressed)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Internal decompression failed: {e}")))?;
if decompressed.len() == orig_len {
return Ok(decompressed);
}
}
} else if magic == Self::MAGIC_RAW {
let usable = page_size - Self::OVERHEAD;
let mut output = Vec::with_capacity(page_size);
output.extend_from_slice(&plaintext[2..usable]); output.resize(page_size, 0); return Ok(output);
}
let mut output = plaintext;
output.resize(page_size, 0);
Ok(output)
}
fn encryption_start_offset(&self) -> u64 {
self.skip_below_offset
}
fn overhead(&self) -> usize {
Self::OVERHEAD
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoOpPageCrypto;
impl PageCrypto for NoOpPageCrypto {
fn encrypt(&self, _offset: u64, data: &[u8], _page_size: usize) -> io::Result<Vec<u8>> {
Ok(data.to_vec())
}
fn decrypt(&self, _offset: u64, data: &[u8], _page_size: usize) -> io::Result<Vec<u8>> {
Ok(data.to_vec())
}
}
pub trait PageCompression: Send + Sync + Debug + 'static {
fn compress(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>>;
fn decompress(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>>;
fn compression_start_offset(&self) -> u64 {
0
}
}
#[derive(Debug, Clone)]
pub struct ZstdPageCompression {
level: i32,
skip_below_offset: u64,
}
impl ZstdPageCompression {
const MAGIC_COMPRESSED: [u8; 2] = [b'Z', b'S'];
const MAGIC_UNCOMPRESSED: [u8; 2] = [b'U', b'C'];
const HEADER_SIZE: usize = 10;
pub fn new(skip_header: bool) -> Self {
Self {
level: 3, skip_below_offset: if skip_header { 4096 } else { 0 },
}
}
pub fn with_page_size(page_size: u64) -> Self {
Self {
level: 3,
skip_below_offset: page_size,
}
}
pub fn with_level(mut self, level: i32) -> Self {
self.level = level.clamp(1, 22);
self
}
pub fn with_skip_below_offset(mut self, offset: u64) -> Self {
self.skip_below_offset = offset;
self
}
}
impl PageCompression for ZstdPageCompression {
fn compress(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>> {
if data.len() != page_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Input must be exactly page_size ({} bytes), got {}", page_size, data.len()),
));
}
if offset < self.skip_below_offset {
return Ok(data.to_vec());
}
let compressed = zstd::encode_all(data.as_ref(), self.level)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Compression failed: {e}")))?;
let max_compressed_size = page_size - Self::HEADER_SIZE;
if compressed.len() <= max_compressed_size && compressed.len() < data.len() {
let mut output = Vec::with_capacity(page_size);
output.extend_from_slice(&Self::MAGIC_COMPRESSED);
output.extend_from_slice(&(compressed.len() as u32).to_le_bytes());
output.extend_from_slice(&(data.len() as u32).to_le_bytes());
output.extend_from_slice(&compressed);
output.resize(page_size, 0); Ok(output)
} else {
Ok(data.to_vec())
}
}
fn decompress(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>> {
if data.len() != page_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Input must be exactly page_size ({} bytes), got {}", page_size, data.len()),
));
}
if offset < self.skip_below_offset {
return Ok(data.to_vec());
}
let magic = &data[0..2];
if magic == Self::MAGIC_COMPRESSED {
let compressed_len = u32::from_le_bytes([data[2], data[3], data[4], data[5]]) as usize;
let orig_len = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
let max_compressed_size = page_size - Self::HEADER_SIZE;
if compressed_len > 0 && compressed_len <= max_compressed_size && orig_len == page_size {
let compressed = &data[Self::HEADER_SIZE..Self::HEADER_SIZE + compressed_len];
match zstd::decode_all(compressed) {
Ok(decompressed) if decompressed.len() == orig_len => {
let mut output = decompressed;
output.resize(page_size, 0);
return Ok(output);
}
_ => {
}
}
}
Ok(data.to_vec())
} else if magic == Self::MAGIC_UNCOMPRESSED {
let orig_len = u32::from_le_bytes([data[2], data[3], data[4], data[5]]) as usize;
if orig_len == page_size {
let stored_data = &data[6..]; let mut output = Vec::with_capacity(page_size);
output.extend_from_slice(&stored_data[..stored_data.len().min(orig_len)]);
output.resize(page_size, 0);
Ok(output)
} else {
Ok(data.to_vec())
}
} else {
Ok(data.to_vec())
}
}
fn compression_start_offset(&self) -> u64 {
self.skip_below_offset
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoOpPageCompression;
impl PageCompression for NoOpPageCompression {
fn compress(&self, _offset: u64, data: &[u8], _page_size: usize) -> io::Result<Vec<u8>> {
Ok(data.to_vec())
}
fn decompress(&self, _offset: u64, data: &[u8], _page_size: usize) -> io::Result<Vec<u8>> {
Ok(data.to_vec())
}
}
use std::sync::Arc;
pub struct ZstdDictPageCompression {
dict: Arc<[u8]>,
level: i32,
skip_below_offset: u64,
}
impl Clone for ZstdDictPageCompression {
fn clone(&self) -> Self {
Self {
dict: Arc::clone(&self.dict),
level: self.level,
skip_below_offset: self.skip_below_offset,
}
}
}
impl std::fmt::Debug for ZstdDictPageCompression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZstdDictPageCompression")
.field("dict_size", &self.dict.len())
.field("level", &self.level)
.field("skip_below_offset", &self.skip_below_offset)
.finish()
}
}
impl ZstdDictPageCompression {
const MAGIC_DICT_COMPRESSED: [u8; 2] = [b'Z', b'D'];
const HEADER_SIZE: usize = 10;
pub fn new(dict: &[u8], skip_header: bool) -> Self {
Self {
dict: Arc::from(dict),
level: 3,
skip_below_offset: if skip_header { 4096 } else { 0 },
}
}
pub fn with_page_size(dict: &[u8], page_size: u64) -> Self {
Self {
dict: Arc::from(dict),
level: 3,
skip_below_offset: page_size,
}
}
pub fn from_arc(dict: Arc<[u8]>, skip_header: bool) -> Self {
Self {
dict,
level: 3,
skip_below_offset: if skip_header { 4096 } else { 0 },
}
}
pub fn with_level(mut self, level: i32) -> Self {
self.level = level.clamp(1, 22);
self
}
pub fn with_skip_below_offset(mut self, offset: u64) -> Self {
self.skip_below_offset = offset;
self
}
pub fn dictionary(&self) -> &[u8] {
&self.dict
}
}
impl PageCompression for ZstdDictPageCompression {
fn compress(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>> {
if data.len() != page_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Input must be exactly page_size ({} bytes), got {}", page_size, data.len()),
));
}
if offset < self.skip_below_offset {
return Ok(data.to_vec());
}
let mut compressor = zstd::bulk::Compressor::with_dictionary(self.level, &self.dict)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Failed to create compressor: {e}")))?;
let compressed = compressor
.compress(data)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Compression failed: {e}")))?;
let max_compressed_size = page_size - Self::HEADER_SIZE;
if compressed.len() <= max_compressed_size && compressed.len() < data.len() {
let mut output = Vec::with_capacity(page_size);
output.extend_from_slice(&Self::MAGIC_DICT_COMPRESSED);
output.extend_from_slice(&(compressed.len() as u32).to_le_bytes());
output.extend_from_slice(&(data.len() as u32).to_le_bytes());
output.extend_from_slice(&compressed);
output.resize(page_size, 0);
Ok(output)
} else {
Ok(data.to_vec())
}
}
fn decompress(&self, offset: u64, data: &[u8], page_size: usize) -> io::Result<Vec<u8>> {
if data.len() != page_size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Input must be exactly page_size ({} bytes), got {}", page_size, data.len()),
));
}
if offset < self.skip_below_offset {
return Ok(data.to_vec());
}
let magic = &data[0..2];
if magic == Self::MAGIC_DICT_COMPRESSED {
let compressed_len = u32::from_le_bytes([data[2], data[3], data[4], data[5]]) as usize;
let orig_len = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
let max_compressed_size = page_size - Self::HEADER_SIZE;
if compressed_len > 0 && compressed_len <= max_compressed_size && orig_len == page_size {
let compressed = &data[Self::HEADER_SIZE..Self::HEADER_SIZE + compressed_len];
let mut decompressor = zstd::bulk::Decompressor::with_dictionary(&self.dict)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Failed to create decompressor: {e}")))?;
match decompressor.decompress(compressed, orig_len) {
Ok(decompressed) if decompressed.len() == orig_len => {
let mut output = decompressed;
output.resize(page_size, 0);
return Ok(output);
}
_ => {
}
}
}
Ok(data.to_vec())
} else if magic == ZstdPageCompression::MAGIC_COMPRESSED {
let compressed_len = u32::from_le_bytes([data[2], data[3], data[4], data[5]]) as usize;
let orig_len = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
let max_compressed_size = page_size - ZstdPageCompression::HEADER_SIZE;
if compressed_len > 0 && compressed_len <= max_compressed_size && orig_len == page_size {
let compressed = &data[ZstdPageCompression::HEADER_SIZE..ZstdPageCompression::HEADER_SIZE + compressed_len];
match zstd::decode_all(compressed) {
Ok(decompressed) if decompressed.len() == orig_len => {
let mut output = decompressed;
output.resize(page_size, 0);
return Ok(output);
}
_ => {
}
}
}
Ok(data.to_vec())
} else if magic == ZstdPageCompression::MAGIC_UNCOMPRESSED {
let orig_len = u32::from_le_bytes([data[2], data[3], data[4], data[5]]) as usize;
if orig_len == page_size {
let stored_data = &data[6..];
let mut output = Vec::with_capacity(page_size);
output.extend_from_slice(&stored_data[..stored_data.len().min(orig_len)]);
output.resize(page_size, 0);
Ok(output)
} else {
Ok(data.to_vec())
}
} else {
Ok(data.to_vec())
}
}
fn compression_start_offset(&self) -> u64 {
self.skip_below_offset
}
}
pub struct DictionaryTrainer;
impl DictionaryTrainer {
pub const DEFAULT_DICT_SIZE: usize = 64 * 1024;
pub const MIN_RECOMMENDED_SAMPLES: usize = 100;
pub const MIN_REQUIRED_SAMPLES: usize = 10;
pub fn train(samples: &[Vec<u8>], dict_size: usize) -> io::Result<Vec<u8>> {
if samples.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Cannot train dictionary from empty samples",
));
}
if samples.len() < Self::MIN_REQUIRED_SAMPLES {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"At least {} samples required for dictionary training, got {}. \
Training with too few samples produces poor dictionaries.",
Self::MIN_REQUIRED_SAMPLES,
samples.len()
),
));
}
if samples.len() < Self::MIN_RECOMMENDED_SAMPLES {
#[cfg(feature = "logging")]
log::warn!(
"Training dictionary with {} samples (recommended: {}). \
Results may be suboptimal.",
samples.len(),
Self::MIN_RECOMMENDED_SAMPLES
);
}
zstd::dict::from_samples(samples, dict_size)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Dictionary training failed: {e}")))
}
pub fn train_from_continuous(data: &[u8], sample_sizes: &[usize], dict_size: usize) -> io::Result<Vec<u8>> {
if data.is_empty() || sample_sizes.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Cannot train dictionary from empty data",
));
}
if sample_sizes.len() < Self::MIN_REQUIRED_SAMPLES {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"At least {} samples required for dictionary training, got {}. \
Training with too few samples produces poor dictionaries.",
Self::MIN_REQUIRED_SAMPLES,
sample_sizes.len()
),
));
}
if sample_sizes.len() < Self::MIN_RECOMMENDED_SAMPLES {
#[cfg(feature = "logging")]
log::warn!(
"Training dictionary with {} samples (recommended: {}). \
Results may be suboptimal.",
sample_sizes.len(),
Self::MIN_RECOMMENDED_SAMPLES
);
}
zstd::dict::from_continuous(data, sample_sizes, dict_size)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Dictionary training failed: {e}")))
}
pub fn load_from_file(path: impl AsRef<std::path::Path>) -> io::Result<Vec<u8>> {
std::fs::read(path)
}
pub fn save_to_file(dict: &[u8], path: impl AsRef<std::path::Path>) -> io::Result<()> {
std::fs::write(path, dict)
}
pub fn estimate_improvement(samples: &[Vec<u8>], dict: &[u8], level: i32) -> io::Result<(f64, f64)> {
if samples.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Cannot estimate with empty samples",
));
}
let mut total_original = 0usize;
let mut total_without_dict = 0usize;
let mut total_with_dict = 0usize;
let mut compressor_with_dict = zstd::bulk::Compressor::with_dictionary(level, dict)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Failed to create compressor: {e}")))?;
for sample in samples {
total_original += sample.len();
let compressed = zstd::encode_all(sample.as_slice(), level)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Compression failed: {e}")))?;
total_without_dict += compressed.len();
let compressed_dict = compressor_with_dict
.compress(sample)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Compression failed: {e}")))?;
total_with_dict += compressed_dict.len();
}
let ratio_without = total_without_dict as f64 / total_original as f64;
let ratio_with = total_with_dict as f64 / total_original as f64;
Ok((ratio_without, ratio_with))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = [0x42u8; 32];
let crypto = Aes256GcmPageCrypto::new(&key, false);
let page_size = 4096;
let usable = page_size - Aes256GcmPageCrypto::OVERHEAD;
let mut original = vec![0u8; page_size];
for i in 0..usable.min(256) {
original[i] = (i % 256) as u8;
}
let encrypted = crypto.encrypt(4096, &original, page_size).unwrap();
assert_eq!(encrypted.len(), page_size);
assert_ne!(&encrypted[..usable], &original[..usable]);
let decrypted = crypto.decrypt(4096, &encrypted, page_size).unwrap();
assert_eq!(&decrypted[..usable], &original[..usable]);
}
#[test]
fn test_skip_header_page() {
let key = [0x42u8; 32];
let crypto = Aes256GcmPageCrypto::new(&key, true).with_skip_below_offset(4096);
let page_size = 4096;
let original = vec![0x42u8; page_size];
let header_result = crypto.encrypt(0, &original, page_size).unwrap();
assert_eq!(header_result, original);
let data_result = crypto.encrypt(4096, &original, page_size).unwrap();
assert_ne!(data_result, original);
}
#[test]
fn test_random_data_roundtrip() {
let key = [0x42u8; 32];
let crypto = Aes256GcmPageCrypto::new(&key, false);
let page_size = 4096;
let usable = page_size - Aes256GcmPageCrypto::OVERHEAD;
let mut original = vec![0u8; page_size];
for i in 0..page_size {
original[i] = ((i * 17 + 31) % 256) as u8;
}
let encrypted = crypto.encrypt(4096, &original, page_size).unwrap();
assert_eq!(encrypted.len(), page_size);
let decrypted = crypto.decrypt(4096, &encrypted, page_size).unwrap();
assert_eq!(&decrypted[..usable], &original[..usable]);
}
#[test]
fn test_overhead_constant() {
assert_eq!(Aes256GcmPageCrypto::OVERHEAD, 28);
let key = [0x42u8; 32];
let crypto = Aes256GcmPageCrypto::new(&key, false);
assert_eq!(crypto.overhead(), 28);
}
#[test]
fn test_compress_decompress_roundtrip() {
let compression = ZstdPageCompression::new(false);
let page_size = 4096;
let original: Vec<u8> = (0..page_size).map(|i| (i % 64) as u8).collect();
let compressed = compression.compress(4096, &original, page_size).unwrap();
assert_eq!(compressed.len(), page_size);
assert_eq!(&compressed[0..2], &ZstdPageCompression::MAGIC_COMPRESSED);
let decompressed = compression.decompress(4096, &compressed, page_size).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn test_compression_skip_header() {
let compression = ZstdPageCompression::new(true).with_skip_below_offset(4096);
let page_size = 4096;
let original: Vec<u8> = (0..page_size).map(|i| (i % 64) as u8).collect();
let header_result = compression.compress(0, &original, page_size).unwrap();
assert_eq!(header_result, original);
let data_result = compression.compress(4096, &original, page_size).unwrap();
assert_ne!(data_result, original);
assert_eq!(&data_result[0..2], &ZstdPageCompression::MAGIC_COMPRESSED);
}
#[test]
fn test_varied_data_roundtrip() {
let compression = ZstdPageCompression::new(false);
let page_size = 4096;
let original: Vec<u8> = (0..page_size).map(|i| ((i * 17 + 31) % 256) as u8).collect();
let compressed = compression.compress(4096, &original, page_size).unwrap();
assert_eq!(compressed.len(), page_size);
let decompressed = compression.decompress(4096, &compressed, page_size).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn test_raw_page_migration() {
let compression = ZstdPageCompression::new(false);
let page_size = 4096;
let original = vec![0x42u8; page_size];
let decompressed = compression.decompress(4096, &original, page_size).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn test_dict_compression_roundtrip() {
let page_size = 4096;
let samples: Vec<Vec<u8>> = (0..100)
.map(|i| {
(0..page_size)
.map(|j| ((i + j) % 64) as u8)
.collect()
})
.collect();
let dict = DictionaryTrainer::train(&samples, 8192).unwrap();
assert!(!dict.is_empty());
let compression = ZstdDictPageCompression::new(&dict, false);
let original: Vec<u8> = (0..page_size).map(|i| (i % 64) as u8).collect();
let compressed = compression.compress(4096, &original, page_size).unwrap();
assert_eq!(compressed.len(), page_size);
assert_eq!(&compressed[0..2], &ZstdDictPageCompression::MAGIC_DICT_COMPRESSED);
let decompressed = compression.decompress(4096, &compressed, page_size).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn test_dict_compression_skip_header() {
let page_size = 4096;
let samples: Vec<Vec<u8>> = (0..50)
.map(|i| vec![(i % 256) as u8; page_size])
.collect();
let dict = DictionaryTrainer::train(&samples, 4096).unwrap();
let compression = ZstdDictPageCompression::new(&dict, true);
let original: Vec<u8> = (0..page_size).map(|i| (i % 64) as u8).collect();
let header_result = compression.compress(0, &original, page_size).unwrap();
assert_eq!(header_result, original);
let data_result = compression.compress(4096, &original, page_size).unwrap();
assert_ne!(data_result, original);
assert_eq!(&data_result[0..2], &ZstdDictPageCompression::MAGIC_DICT_COMPRESSED);
}
#[test]
fn test_dict_trainer_estimate_improvement() {
let page_size = 4096;
let samples: Vec<Vec<u8>> = (0..50)
.map(|i| {
(0..page_size)
.map(|j| ((i * 3 + j) % 64) as u8)
.collect()
})
.collect();
let dict = DictionaryTrainer::train(&samples, 8192).unwrap();
let (ratio_without, ratio_with) = DictionaryTrainer::estimate_improvement(&samples, &dict, 3).unwrap();
assert!(ratio_without > 0.0 && ratio_without < 1.0);
assert!(ratio_with > 0.0 && ratio_with < 1.0);
assert!(ratio_with <= ratio_without * 1.1, "Dictionary should not significantly worsen compression: {} vs {}",
ratio_with, ratio_without);
}
#[test]
fn test_dict_backwards_compatible_with_regular() {
let page_size = 4096;
let samples: Vec<Vec<u8>> = (0..50)
.map(|i| vec![(i % 256) as u8; page_size])
.collect();
let dict = DictionaryTrainer::train(&samples, 4096).unwrap();
let dict_compression = ZstdDictPageCompression::new(&dict, false);
let regular_compression = ZstdPageCompression::new(false);
let original: Vec<u8> = (0..page_size).map(|i| (i % 64) as u8).collect();
let regular_compressed = regular_compression.compress(4096, &original, page_size).unwrap();
let decompressed = dict_compression.decompress(4096, ®ular_compressed, page_size).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn test_random_nonce_produces_different_ciphertext() {
let key = [0x42u8; 32];
let crypto = Aes256GcmPageCrypto::new(&key, false);
let page_size = 4096;
let original = vec![0x42u8; page_size];
let encrypted1 = crypto.encrypt(4096, &original, page_size).unwrap();
let encrypted2 = crypto.encrypt(4096, &original, page_size).unwrap();
assert_ne!(&encrypted1[..12], &encrypted2[..12],
"Random nonces should be different for each encryption");
assert_ne!(encrypted1, encrypted2,
"Same plaintext encrypted twice should produce different ciphertexts");
let decrypted1 = crypto.decrypt(4096, &encrypted1, page_size).unwrap();
let decrypted2 = crypto.decrypt(4096, &encrypted2, page_size).unwrap();
let usable = page_size - Aes256GcmPageCrypto::OVERHEAD;
assert_eq!(&decrypted1[..usable], &original[..usable]);
assert_eq!(&decrypted2[..usable], &original[..usable]);
}
#[test]
fn test_page_update_security() {
let key = [0x42u8; 32];
let crypto = Aes256GcmPageCrypto::new(&key, false);
let page_size = 4096;
let offset = 4096u64;
let data_v1 = vec![0x11u8; page_size];
let encrypted_v1 = crypto.encrypt(offset, &data_v1, page_size).unwrap();
let data_v2 = vec![0x22u8; page_size];
let encrypted_v2 = crypto.encrypt(offset, &data_v2, page_size).unwrap();
assert_ne!(&encrypted_v1[..12], &encrypted_v2[..12],
"Page updates must use different nonces");
let decrypted_v1 = crypto.decrypt(offset, &encrypted_v1, page_size).unwrap();
let decrypted_v2 = crypto.decrypt(offset, &encrypted_v2, page_size).unwrap();
let usable = page_size - Aes256GcmPageCrypto::OVERHEAD;
assert_eq!(&decrypted_v1[..usable], &data_v1[..usable]);
assert_eq!(&decrypted_v2[..usable], &data_v2[..usable]);
}
#[test]
fn test_compressed_data_ending_in_zeros() {
let compression = ZstdPageCompression::new(false);
let page_size = 4096;
let mut original = vec![0u8; page_size];
for i in 0..page_size {
original[i] = (i % 4) as u8;
}
let compressed = compression.compress(4096, &original, page_size).unwrap();
assert_eq!(compressed.len(), page_size);
let decompressed = compression.decompress(4096, &compressed, page_size).unwrap();
assert_eq!(decompressed, original,
"Decompression must correctly handle compressed data that may end in zeros");
}
#[test]
fn test_incompressible_data_with_nonzero_trailing_bytes() {
let compression = ZstdPageCompression::new(false);
let page_size = 4096;
let mut original: Vec<u8> = (0..page_size)
.map(|i| ((i * 17 + 31) ^ (i * 13 + 7)) as u8)
.collect();
original[page_size - 1] = 0xDE;
original[page_size - 2] = 0xAD;
original[page_size - 3] = 0xBE;
original[page_size - 4] = 0xEF;
let compressed = compression.compress(4096, &original, page_size).unwrap();
assert_eq!(compressed.len(), page_size);
let decompressed = compression.decompress(4096, &compressed, page_size).unwrap();
assert_eq!(decompressed, original,
"Incompressible data must preserve all bytes including trailing non-zeros");
}
#[test]
fn test_data_starting_with_magic_bytes() {
let compression = ZstdPageCompression::new(false);
let page_size = 4096;
let mut original = vec![0x42u8; page_size];
original[0] = b'Z';
original[1] = b'S';
original[2] = 0xFF; original[3] = 0xFF;
original[4] = 0xFF;
original[5] = 0xFF;
let compressed = compression.compress(4096, &original, page_size).unwrap();
let decompressed = compression.decompress(4096, &compressed, page_size).unwrap();
assert_eq!(decompressed, original,
"Raw data starting with ZS magic must roundtrip correctly");
original[0] = b'U';
original[1] = b'C';
original[2] = 0x00;
original[3] = 0x00;
original[4] = 0x00;
original[5] = 0x00;
let compressed = compression.compress(4096, &original, page_size).unwrap();
let decompressed = compression.decompress(4096, &compressed, page_size).unwrap();
assert_eq!(decompressed, original,
"Raw data starting with UC magic must roundtrip correctly");
original[0] = b'Z';
original[1] = b'D';
original[2] = 0xFF;
original[3] = 0xFF;
original[4] = 0xFF;
original[5] = 0xFF;
let compressed = compression.compress(4096, &original, page_size).unwrap();
let decompressed = compression.decompress(4096, &compressed, page_size).unwrap();
assert_eq!(decompressed, original,
"Raw data starting with ZD magic must roundtrip correctly");
}
#[test]
fn test_dict_training_too_few_samples_error() {
let samples: Vec<Vec<u8>> = (0..5) .map(|i| vec![(i % 256) as u8; 1000])
.collect();
let result = DictionaryTrainer::train(&samples, 4096);
assert!(result.is_err(), "Training with < 10 samples should fail");
let err = result.unwrap_err();
assert!(err.to_string().contains("At least 10 samples required"),
"Error message should mention minimum samples requirement");
}
#[test]
fn test_dict_training_minimum_samples_works() {
let samples: Vec<Vec<u8>> = (0..10) .map(|i| {
(0..1000).map(|j| ((i * 3 + j) % 256) as u8).collect()
})
.collect();
let result = DictionaryTrainer::train(&samples, 4096);
assert!(result.is_ok(), "Training with exactly 10 samples should work");
}
#[test]
fn test_crypto_clone() {
let key = [0x42u8; 32];
let crypto = Aes256GcmPageCrypto::new(&key, true);
let crypto_clone = crypto.clone();
let page_size = 4096;
let original = vec![0x42u8; page_size];
let encrypted1 = crypto.encrypt(4096, &original, page_size).unwrap();
let encrypted2 = crypto_clone.encrypt(4096, &original, page_size).unwrap();
let usable = page_size - Aes256GcmPageCrypto::OVERHEAD;
let decrypted1 = crypto_clone.decrypt(4096, &encrypted1, page_size).unwrap();
let decrypted2 = crypto.decrypt(4096, &encrypted2, page_size).unwrap();
assert_eq!(&decrypted1[..usable], &original[..usable]);
assert_eq!(&decrypted2[..usable], &original[..usable]);
}
#[test]
fn test_compression_clone() {
let compression = ZstdPageCompression::new(true).with_level(5);
let compression_clone = compression.clone();
let page_size = 4096;
let original: Vec<u8> = (0..page_size).map(|i| (i % 64) as u8).collect();
let compressed1 = compression.compress(4096, &original, page_size).unwrap();
let compressed2 = compression_clone.compress(4096, &original, page_size).unwrap();
let decompressed1 = compression_clone.decompress(4096, &compressed1, page_size).unwrap();
let decompressed2 = compression.decompress(4096, &compressed2, page_size).unwrap();
assert_eq!(decompressed1, original);
assert_eq!(decompressed2, original);
}
#[test]
fn test_with_page_size_constructors() {
let key = [0x42u8; 32];
let crypto = Aes256GcmPageCrypto::with_page_size(&key, 8192);
assert_eq!(crypto.encryption_start_offset(), 8192);
let compression = ZstdPageCompression::with_page_size(8192);
assert_eq!(compression.compression_start_offset(), 8192);
}
}