use super::CompressionProvider;
use std::io::Read;
const DICT_MAGIC: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC];
fn bounded_read(reader: &mut impl Read, capacity: usize) -> crate::Result<Vec<u8>> {
let mut output = vec![0u8; capacity];
let mut filled = 0;
loop {
let dest = output
.get_mut(filled..)
.ok_or(crate::Error::DecompressedSizeTooLarge {
declared: filled as u64,
limit: capacity as u64,
})?;
match reader.read(dest) {
Ok(0) => break,
Ok(n) => filled += n,
Err(e) => return Err(crate::Error::from(e)),
}
}
let mut probe = [0u8; 1];
match reader.read(&mut probe) {
Ok(0) => {}
Ok(_) => {
return Err(crate::Error::DecompressedSizeTooLarge {
declared: (filled + 1) as u64,
limit: capacity as u64,
});
}
Err(e) => return Err(crate::Error::from(e)),
}
output.truncate(filled);
Ok(output)
}
fn bounded_read_into(reader: &mut impl Read, dest: &mut [u8]) -> crate::Result<usize> {
let mut filled = 0;
while filled < dest.len() {
let slot = dest.get_mut(filled..).unwrap_or(&mut []);
match reader.read(slot) {
Ok(0) => break,
Ok(n) => filled += n,
Err(e) => return Err(crate::Error::from(e)),
}
}
if filled == dest.len() {
let mut probe = [0u8; 1];
match reader.read(&mut probe) {
Ok(0) => {}
Ok(_) => {
return Err(crate::Error::DecompressedSizeTooLarge {
declared: (filled as u64).saturating_add(1),
limit: dest.len() as u64,
});
}
Err(e) => return Err(crate::Error::from(e)),
}
}
Ok(filled)
}
fn decode_raw_content_bounded(
decoder: &mut structured_zstd::decoding::FrameDecoder,
cursor: &mut std::io::Cursor<&[u8]>,
capacity: usize,
) -> crate::Result<Vec<u8>> {
use structured_zstd::decoding::BlockDecodingStrategy;
let mut output: Vec<u8> = Vec::new();
loop {
let remaining = capacity.saturating_sub(output.len());
if !decoder.is_finished() {
decoder
.decode_blocks(
&mut *cursor,
BlockDecodingStrategy::UptoBytes(remaining.max(1)),
)
.map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
}
let can = decoder.can_collect();
if can > 0 {
let new_len =
output
.len()
.checked_add(can)
.ok_or(crate::Error::DecompressedSizeTooLarge {
declared: u64::MAX,
limit: capacity as u64,
})?;
if new_len > capacity {
return Err(crate::Error::DecompressedSizeTooLarge {
declared: new_len as u64,
limit: capacity as u64,
});
}
let prev_len = output.len();
output.resize(new_len, 0u8);
let dest = output
.get_mut(prev_len..)
.unwrap_or_else(|| unreachable!("output resized to new_len above"));
decoder.read_exact(dest).map_err(crate::Error::from)?;
}
if decoder.is_finished() && decoder.can_collect() == 0 {
break;
}
}
Ok(output)
}
fn do_decompress_with_dict(
decoder: &mut structured_zstd::decoding::FrameDecoder,
data: &[u8],
raw_content_id: u32,
capacity: usize,
is_raw_content: bool,
) -> crate::Result<Vec<u8>> {
if is_raw_content {
let mut cursor = std::io::Cursor::new(data);
decoder.expect_dict_id(Some(raw_content_id));
decoder.init(&mut cursor).map_err(|e| {
if matches!(
e,
structured_zstd::decoding::errors::FrameDecoderError::UnexpectedDictId { .. }
) {
crate::Error::Decompress(crate::CompressionType::ZstdDict {
level: 0,
dict_id: raw_content_id,
})
} else {
crate::Error::Io(crate::io::Error::other(e.to_string()))
}
})?;
let declared_size = decoder.content_size();
if declared_size > 0 && declared_size > capacity as u64 {
return Err(crate::Error::DecompressedSizeTooLarge {
declared: declared_size,
limit: capacity as u64,
});
}
decoder
.force_dict(raw_content_id)
.map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
decode_raw_content_bounded(decoder, &mut cursor, capacity)
} else {
let mut output = Vec::with_capacity(capacity);
decoder.decode_all_to_vec(data, &mut output).map_err(|e| {
if matches!(
e,
structured_zstd::decoding::errors::FrameDecoderError::TargetTooSmall
) {
crate::Error::DecompressedSizeTooLarge {
declared: capacity as u64 + 1,
limit: capacity as u64,
}
} else {
crate::Error::Io(crate::io::Error::other(e.to_string()))
}
})?;
Ok(output)
}
}
fn inner_block_layout(
info: Option<&structured_zstd::encoding::frame_emit_info::FrameEmitInfo>,
) -> Vec<u32> {
let Some(info) = info else { return Vec::new() };
let n = info.blocks.len();
if n < 2 {
return Vec::new();
}
let mut ends = Vec::with_capacity(n);
for i in 0..n {
let Some(range) = info.decompressed_byte_range(i) else {
return Vec::new();
};
let Ok(end) = u32::try_from(range.end) else {
return Vec::new();
};
ends.push(end);
}
ends
}
fn with_tls_compressor<R>(
level: i32,
f: impl FnOnce(&mut structured_zstd::encoding::FrameCompressor) -> R,
) -> R {
use structured_zstd::encoding::{CompressionLevel, FrameCompressor};
thread_local! {
static TLS_COMPRESSOR: std::cell::RefCell<Option<(i32, FrameCompressor)>> =
const { std::cell::RefCell::new(None) };
}
TLS_COMPRESSOR.with(|cell| {
let mut state = cell.borrow_mut();
if !matches!(&*state, Some((l, _)) if *l == level) {
*state = Some((
level,
FrameCompressor::new(CompressionLevel::from_level(level)),
));
}
let Some((_, compressor)) = state.as_mut() else {
unreachable!("TLS_COMPRESSOR initialised above");
};
f(compressor)
})
}
pub struct ZstdProvider;
impl CompressionProvider for ZstdProvider {
fn compress(data: &[u8], level: i32) -> crate::Result<Vec<u8>> {
Ok(with_tls_compressor(level, |compressor| {
compressor.compress_independent_frame(data)
}))
}
fn compress_with_layout(data: &[u8], level: i32) -> crate::Result<(Vec<u8>, Vec<u32>)> {
Ok(with_tls_compressor(level, |compressor| {
let frame = compressor.compress_independent_frame(data);
let layout = inner_block_layout(compressor.last_frame_emit_info());
(frame, layout)
}))
}
fn decompress(data: &[u8], capacity: usize) -> crate::Result<Vec<u8>> {
let mut decoder = structured_zstd::decoding::StreamingDecoder::new(data)
.map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
bounded_read(&mut decoder, capacity)
}
fn compress_with_dict(data: &[u8], level: i32, dict_raw: &[u8]) -> crate::Result<Vec<u8>> {
use structured_zstd::decoding::Dictionary;
use structured_zstd::encoding::{
CompressionLevel, EncoderDictionary, FrameCompressor, MatchGeneratorDriver,
};
type CachedCompressor =
FrameCompressor<std::io::Cursor<Vec<u8>>, Vec<u8>, MatchGeneratorDriver>;
thread_local! {
static TLS_COMPRESSOR: std::cell::RefCell<Option<(u64, i32, CachedCompressor)>> =
const { std::cell::RefCell::new(None) };
}
let dict_key = xxhash_rust::xxh3::xxh3_64(dict_raw);
TLS_COMPRESSOR.with(|cell| {
let mut state = cell.borrow_mut();
if !matches!(&*state, Some((k, l, _)) if *k == dict_key && *l == level) {
let mut compressor = FrameCompressor::new(CompressionLevel::from_level(level));
if dict_raw.starts_with(&DICT_MAGIC) {
compressor
.set_dictionary_from_bytes(dict_raw)
.map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
} else {
#[expect(
clippy::cast_possible_truncation,
reason = "intentional: lower 32 bits of xxh3 as internal dict id"
)]
let id = {
let h = dict_key as u32;
h.max(1) };
let dictionary = Dictionary::from_raw_content(id, dict_raw.to_vec())
.map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
compressor
.set_encoder_dictionary(EncoderDictionary::from_dictionary(dictionary))
.map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
}
*state = Some((dict_key, level, compressor));
}
let Some((_, _, compressor)) = state.as_mut() else {
unreachable!("TLS_COMPRESSOR always initialised above");
};
let src_buf = compressor.take_source().map_or_else(
|| data.to_vec(),
|c| {
let mut v = c.into_inner();
v.clear();
v.extend_from_slice(data);
v
},
);
compressor.set_source_size_hint(data.len() as u64);
compressor.set_source(std::io::Cursor::new(src_buf));
compressor.set_drain(Vec::new());
compressor.compress();
let compressed = compressor
.take_drain()
.unwrap_or_else(|| unreachable!("drain is always set by set_drain() above"));
Ok(compressed)
})
}
fn decompress_with_dict(
data: &[u8],
dict: &crate::compression::ZstdDictionary,
capacity: usize,
) -> crate::Result<Vec<u8>> {
use structured_zstd::decoding::FrameDecoder;
thread_local! {
static TLS_DECODER: std::cell::RefCell<Option<(u64, FrameDecoder)>> =
const { std::cell::RefCell::new(None) };
}
let is_raw_content = !dict.raw().starts_with(&DICT_MAGIC);
TLS_DECODER.with(|cell| {
let mut state = cell.borrow_mut();
if !matches!(&*state, Some((id, _)) if *id == dict.id64()) {
let handle = dict.prepared_handle()?;
let mut decoder = FrameDecoder::new();
decoder
.add_dict_handle(handle)
.map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
*state = Some((dict.id64(), decoder));
}
let Some((_, decoder)) = state.as_mut() else {
unreachable!("TLS_DECODER always initialised above");
};
do_decompress_with_dict(decoder, data, dict.id().max(1), capacity, is_raw_content)
})
}
}
impl ZstdProvider {
pub fn decompress_into(data: &[u8], dest: &mut [u8]) -> crate::Result<usize> {
use structured_zstd::decoding::{FrameDecoder, StreamingDecoder};
thread_local! {
static TLS_DECODER: core::cell::RefCell<FrameDecoder> =
core::cell::RefCell::new(FrameDecoder::new());
}
TLS_DECODER.with(|cell| {
let mut decoder = cell.borrow_mut();
let mut stream =
StreamingDecoder::new_with_decoder(std::io::Cursor::new(data), &mut *decoder)
.map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
bounded_read_into(&mut stream, dest)
})
}
}
#[cfg(test)]
#[expect(clippy::expect_used, clippy::indexing_slicing, reason = "test code")]
mod tests;