use std::io::{self, Read, Write};
const MAGIC: &[u8; 4] = b"SQRY";
const FORMAT_VERSION: u32 = 1;
pub const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
pub const DEFAULT_MAX_UNCOMPRESSED_SIZE: u64 = 500 * 1024 * 1024;
const MIN_MAX_UNCOMPRESSED_SIZE: u64 = 1024 * 1024;
const MAX_MAX_UNCOMPRESSED_SIZE: u64 = 2 * 1024 * 1024 * 1024;
#[must_use]
pub fn max_uncompressed_size() -> u64 {
let size = std::env::var("SQRY_MAX_INDEX_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_MAX_UNCOMPRESSED_SIZE);
size.clamp(MIN_MAX_UNCOMPRESSED_SIZE, MAX_MAX_UNCOMPRESSED_SIZE)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CompressionFormat {
None = 0,
Zstd = 1,
}
impl CompressionFormat {
fn from_u8(value: u8) -> Result<Self, CompressionError> {
match value {
0 => Ok(Self::None),
1 => Ok(Self::Zstd),
_ => Err(CompressionError::UnsupportedCompression(value)),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CompressionError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("Unsupported compression format: {0}")]
UnsupportedCompression(u8),
#[error("Invalid magic bytes, expected SQRY")]
InvalidMagic,
#[error("Index version {index_version} is too new for sqry {sqry_version}, please upgrade")]
IndexVersionTooNew {
index_version: u32,
sqry_version: &'static str,
},
#[error("Invalid index version: {0}")]
InvalidIndexVersion(u32),
#[error("Invalid header size: expected at least 21 bytes, got {0}")]
InvalidHeaderSize(usize),
#[error("Decompressed size mismatch: expected {expected}, got {actual}")]
SizeMismatch {
expected: u64,
actual: u64,
},
#[error("Decompression bomb detected: uncompressed size {size} exceeds maximum {max}")]
DecompressionBomb {
size: u64,
max: u64,
},
}
#[derive(Debug, Clone)]
pub struct CompressedIndex {
version: u32,
compression: CompressionFormat,
level: i32,
uncompressed_size: u64,
data: Vec<u8>,
}
impl CompressedIndex {
pub fn compress(data: &[u8], level: i32) -> Result<Self, CompressionError> {
let mut encoder = zstd::Encoder::new(Vec::new(), level)?;
encoder.write_all(data)?;
let compressed = encoder.finish()?;
Ok(Self {
version: FORMAT_VERSION,
compression: CompressionFormat::Zstd,
level,
uncompressed_size: data.len() as u64,
data: compressed,
})
}
#[must_use]
pub fn uncompressed(data: &[u8]) -> Self {
Self {
version: FORMAT_VERSION,
compression: CompressionFormat::None,
level: 0,
uncompressed_size: data.len() as u64,
data: data.to_vec(),
}
}
pub fn decompress(&self) -> Result<Vec<u8>, CompressionError> {
let max_size = max_uncompressed_size();
if self.uncompressed_size > max_size {
return Err(CompressionError::DecompressionBomb {
size: self.uncompressed_size,
max: max_size,
});
}
match self.compression {
CompressionFormat::None => {
if self.data.len() as u64 > max_size {
return Err(CompressionError::DecompressionBomb {
size: self.data.len() as u64,
max: max_size,
});
}
Ok(self.data.clone())
}
CompressionFormat::Zstd => {
let decoder = zstd::Decoder::new(&self.data[..])?;
let mut limited = decoder.take(max_size + 1);
let mut decompressed = Vec::new();
limited.read_to_end(&mut decompressed)?;
let actual_size = decompressed.len() as u64;
if actual_size != self.uncompressed_size {
return Err(CompressionError::SizeMismatch {
expected: self.uncompressed_size,
actual: actual_size,
});
}
if actual_size > max_size {
return Err(CompressionError::DecompressionBomb {
size: actual_size,
max: max_size,
});
}
Ok(decompressed)
}
}
}
#[must_use]
pub fn serialize(&self) -> Vec<u8> {
let mut buffer = Vec::with_capacity(21 + self.data.len());
buffer.extend_from_slice(MAGIC);
buffer.extend_from_slice(&self.version.to_le_bytes());
buffer.push(self.compression as u8);
buffer.extend_from_slice(&self.level.to_le_bytes());
buffer.extend_from_slice(&self.uncompressed_size.to_le_bytes());
buffer.extend_from_slice(&self.data);
buffer
}
pub fn deserialize(data: &[u8]) -> Result<Self, CompressionError> {
if data.len() < 21 {
return Err(CompressionError::InvalidHeaderSize(data.len()));
}
if &data[0..4] != MAGIC {
return Err(CompressionError::InvalidMagic);
}
let version = u32::from_le_bytes(
data[4..8]
.try_into()
.map_err(|_| CompressionError::InvalidHeaderSize(data.len()))?,
);
match version {
0 => return Err(CompressionError::InvalidIndexVersion(0)),
FORMAT_VERSION => {
}
v if v > FORMAT_VERSION => {
return Err(CompressionError::IndexVersionTooNew {
index_version: v,
sqry_version: env!("CARGO_PKG_VERSION"),
});
}
_ => {
return Err(CompressionError::InvalidIndexVersion(version));
}
}
let compression = CompressionFormat::from_u8(data[8])?;
let level = i32::from_le_bytes(
data[9..13]
.try_into()
.map_err(|_| CompressionError::InvalidHeaderSize(data.len()))?,
);
let uncompressed_size = u64::from_le_bytes(
data[13..21]
.try_into()
.map_err(|_| CompressionError::InvalidHeaderSize(data.len()))?,
);
let index_data = data[21..].to_vec();
Ok(Self {
version,
compression,
level,
uncompressed_size,
data: index_data,
})
}
#[must_use]
pub fn compression(&self) -> CompressionFormat {
self.compression
}
#[must_use]
pub fn uncompressed_size(&self) -> u64 {
self.uncompressed_size
}
#[must_use]
pub fn compressed_size(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn compression_ratio(&self) -> f64 {
if self.data.is_empty() {
return 1.0;
}
Self::to_f64_lossy_u64(self.uncompressed_size) / Self::to_f64_lossy_usize(self.data.len())
}
#[inline]
#[allow(clippy::cast_precision_loss)] fn to_f64_lossy_u64(value: u64) -> f64 {
value as f64
}
#[inline]
#[allow(clippy::cast_precision_loss)] fn to_f64_lossy_usize(value: usize) -> f64 {
value as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compress_decompress_roundtrip() {
let original = b"test data for compression";
let compressed = CompressedIndex::compress(original, DEFAULT_COMPRESSION_LEVEL).unwrap();
let decompressed = compressed.decompress().unwrap();
assert_eq!(original, &decompressed[..]);
}
#[test]
fn test_serialize_deserialize_roundtrip() {
let original = b"test data for serialization";
let compressed = CompressedIndex::compress(original, 3).unwrap();
let serialized = compressed.serialize();
let deserialized = CompressedIndex::deserialize(&serialized).unwrap();
let decompressed = deserialized.decompress().unwrap();
assert_eq!(original, &decompressed[..]);
}
#[test]
fn test_compression_reduces_size() {
let original = vec![b'a'; 10000];
let compressed = CompressedIndex::compress(&original, 3).unwrap();
assert!(
compressed.compressed_size() < original.len(),
"Compressed size {} should be less than original size {}",
compressed.compressed_size(),
original.len()
);
}
#[test]
fn test_compression_ratio() {
let original = vec![b'x'; 1000];
let compressed = CompressedIndex::compress(&original, 3).unwrap();
let ratio = compressed.compression_ratio();
assert!(
ratio > 1.0,
"Compression ratio should be > 1.0 for compressible data"
);
}
#[test]
fn test_uncompressed_roundtrip() {
let original = b"uncompressed test data";
let uncompressed = CompressedIndex::uncompressed(original);
let decompressed = uncompressed.decompress().unwrap();
assert_eq!(original, &decompressed[..]);
assert_eq!(uncompressed.compression(), CompressionFormat::None);
}
#[test]
fn test_magic_bytes_in_header() {
let original = b"test";
let compressed = CompressedIndex::compress(original, 3).unwrap();
let serialized = compressed.serialize();
assert_eq!(&serialized[0..4], b"SQRY");
}
#[test]
fn test_invalid_magic_bytes() {
let mut invalid_data = vec![0u8; 21];
invalid_data[0..4].copy_from_slice(b"XXXX"); let result = CompressedIndex::deserialize(&invalid_data);
assert!(matches!(result, Err(CompressionError::InvalidMagic)));
}
#[test]
fn test_header_too_small() {
let too_small = b"SQRY123"; let result = CompressedIndex::deserialize(too_small);
assert!(matches!(
result,
Err(CompressionError::InvalidHeaderSize(7))
));
}
#[test]
fn test_unsupported_compression_format() {
let mut data = vec![0u8; 21];
data[0..4].copy_from_slice(b"SQRY");
data[4..8].copy_from_slice(&1u32.to_le_bytes()); data[8] = 99;
let result = CompressedIndex::deserialize(&data);
assert!(matches!(
result,
Err(CompressionError::UnsupportedCompression(99))
));
}
#[test]
fn test_future_version_error() {
let mut data = vec![0u8; 21];
data[0..4].copy_from_slice(b"SQRY");
data[4..8].copy_from_slice(&999u32.to_le_bytes());
let result = CompressedIndex::deserialize(&data);
assert!(matches!(
result,
Err(CompressionError::IndexVersionTooNew { .. })
));
}
#[test]
fn test_zero_version_error() {
let mut data = vec![0u8; 21];
data[0..4].copy_from_slice(b"SQRY");
data[4..8].copy_from_slice(&0u32.to_le_bytes());
let result = CompressedIndex::deserialize(&data);
assert!(matches!(
result,
Err(CompressionError::InvalidIndexVersion(0))
));
}
#[test]
fn test_compression_metadata() {
let original = vec![b'y'; 5000];
let compressed = CompressedIndex::compress(&original, 5).unwrap();
assert_eq!(compressed.uncompressed_size(), 5000);
assert_eq!(compressed.compression(), CompressionFormat::Zstd);
assert!(compressed.compressed_size() < 5000);
}
#[test]
fn test_empty_data_compression() {
let original = b"";
let compressed = CompressedIndex::compress(original, 3).unwrap();
let decompressed = compressed.decompress().unwrap();
assert_eq!(original, &decompressed[..]);
assert_eq!(compressed.uncompressed_size(), 0);
}
#[test]
fn test_large_data_compression() {
let original = vec![b'z'; 1_000_000];
let compressed = CompressedIndex::compress(&original, 3).unwrap();
let decompressed = compressed.decompress().unwrap();
assert_eq!(original, decompressed);
assert!(
compressed.compressed_size() < 100_000,
"Expected < 100KB compressed, got {}",
compressed.compressed_size()
);
}
#[test]
fn test_decompression_bomb_protection_blocks_oversized() {
let original = vec![b'a'; 1_000_000]; let compressed = CompressedIndex::compress(&original, 3).unwrap();
let mut serialized = compressed.serialize();
let fake_size = 600u64 * 1024 * 1024; serialized[13..21].copy_from_slice(&fake_size.to_le_bytes());
let corrupted = CompressedIndex::deserialize(&serialized).unwrap();
let result = corrupted.decompress();
assert!(
matches!(result, Err(CompressionError::DecompressionBomb { .. })),
"Should reject oversized decompression claim"
);
}
#[test]
fn test_decompression_bomb_protection_allows_at_limit() {
let original = vec![b'b'; 100_000]; let mut compressed = CompressedIndex::compress(&original, 3).unwrap();
let exact_limit = 500u64 * 1024 * 1024;
compressed.uncompressed_size = exact_limit;
let serialized = compressed.serialize();
let deserialized = CompressedIndex::deserialize(&serialized).unwrap();
let result = deserialized.decompress();
assert!(
!matches!(result, Err(CompressionError::DecompressionBomb { .. })),
"Should not reject data exactly at limit as decompression bomb"
);
}
#[test]
fn test_decompression_bomb_protection_blocks_one_over_limit() {
let original = vec![b'c'; 100_000]; let compressed = CompressedIndex::compress(&original, 3).unwrap();
let mut serialized = compressed.serialize();
let over_limit = (500u64 * 1024 * 1024) + 1; serialized[13..21].copy_from_slice(&over_limit.to_le_bytes());
let corrupted = CompressedIndex::deserialize(&serialized).unwrap();
let result = corrupted.decompress();
assert!(
matches!(result, Err(CompressionError::DecompressionBomb { .. })),
"Should reject data exceeding limit by even 1 byte"
);
}
#[test]
fn test_decompression_enforces_streaming_limit() {
let original = vec![b'd'; 200_000]; let compressed = CompressedIndex::compress(&original, 3).unwrap();
let result = compressed.decompress();
assert!(result.is_ok(), "Decompression within limit should succeed");
}
#[test]
fn test_max_uncompressed_size_clamping_enforces_minimum() {
const MIN_MAX_UNCOMPRESSED_SIZE: u64 = 1024 * 1024; const MAX_MAX_UNCOMPRESSED_SIZE: u64 = 2 * 1024 * 1024 * 1024;
assert_eq!(MIN_MAX_UNCOMPRESSED_SIZE, 1_048_576, "MIN should be 1MB");
assert_eq!(
MAX_MAX_UNCOMPRESSED_SIZE, 2_147_483_648,
"MAX should be 2GB"
);
let default_size = max_uncompressed_size();
assert!(
default_size >= MIN_MAX_UNCOMPRESSED_SIZE,
"Default {default_size} should be >= MIN {MIN_MAX_UNCOMPRESSED_SIZE}"
);
assert!(
default_size <= MAX_MAX_UNCOMPRESSED_SIZE,
"Default {default_size} should be <= MAX {MAX_MAX_UNCOMPRESSED_SIZE}"
);
}
#[test]
fn test_max_uncompressed_size_default_is_500mb() {
let default = max_uncompressed_size();
assert!(
default >= 500 * 1024 * 1024 || std::env::var("SQRY_MAX_INDEX_SIZE").is_ok(),
"Default should be 500MB or env var should be set"
);
}
#[test]
fn test_decompression_bomb_error_includes_sizes() {
let original = vec![b'e'; 100_000];
let compressed = CompressedIndex::compress(&original, 3).unwrap();
let mut serialized = compressed.serialize();
let oversized = 600u64 * 1024 * 1024; serialized[13..21].copy_from_slice(&oversized.to_le_bytes());
let corrupted = CompressedIndex::deserialize(&serialized).unwrap();
match corrupted.decompress() {
Err(CompressionError::DecompressionBomb { size, max }) => {
assert_eq!(size, oversized, "Error should report actual claimed size");
assert!(max > 0, "Error should report max limit");
assert!(size > max, "Error should show size exceeds max");
}
other => panic!("Expected DecompressionBomb error, got {other:?}"),
}
}
#[test]
fn test_compression_format_from_u8() {
assert!(matches!(
CompressionFormat::from_u8(0),
Ok(CompressionFormat::None)
));
assert!(matches!(
CompressionFormat::from_u8(1),
Ok(CompressionFormat::Zstd)
));
assert!(matches!(
CompressionFormat::from_u8(99),
Err(CompressionError::UnsupportedCompression(99))
));
}
}