#[cfg(not(feature = "std"))]
extern crate alloc;
use crate::prelude::*;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
const ZSTD_LEVEL: i32 = 1;
pub(crate) const MIN_COMPRESS_LEN: usize = 64;
#[inline(always)]
pub(crate) fn looks_incompressible(data: &[u8]) -> bool {
if data.len() < 32 {
return false;
}
let (mut lo, mut hi) = (data[0], data[0]);
for &b in &data[1..8] {
lo = lo.min(b);
hi = hi.max(b);
}
if hi.wrapping_sub(lo) > 200 {
return true;
}
let mut bits = [0u64; 4];
for &b in &data[..32] {
bits[(b >> 6) as usize] |= 1u64 << (b & 63);
}
let distinct =
bits[0].count_ones() + bits[1].count_ones() + bits[2].count_ones() + bits[3].count_ones();
distinct >= 28
}
#[cfg(feature = "std")]
struct ZstdState {
cctx: zstd_safe::CCtx<'static>,
dctx: zstd_safe::DCtx<'static>,
scratch: Vec<u8>,
}
#[cfg(feature = "std")]
thread_local! {
static ZSTD_STATE: core::cell::RefCell<ZstdState> = core::cell::RefCell::new(ZstdState {
cctx: zstd_safe::CCtx::create(),
dctx: zstd_safe::DCtx::create(),
scratch: Vec::new(),
});
}
#[cold]
#[inline(never)]
pub(crate) fn compress_and_write(
input: &[u8],
writer: &mut impl crate::io::Write,
) -> Result<Option<usize>> {
let raw_len = input.len();
let raw_hdr_len = flagged_header_len(raw_len, false);
let bound = zstd_safe::compress_bound(raw_len);
#[cfg(feature = "std")]
{
ZSTD_STATE.with(|cell| {
let mut state = cell.borrow_mut();
if state.scratch.capacity() < bound {
let extra = bound - state.scratch.capacity();
state.scratch.reserve(extra);
}
unsafe {
state.scratch.set_len(bound);
}
let ZstdState {
ref mut cctx,
ref mut scratch,
..
} = *state;
let comp_len = cctx
.compress(&mut scratch[..], input, ZSTD_LEVEL)
.map_err(|_| Error::InvalidData)?;
let comp_hdr_len = flagged_header_len(comp_len, true);
if comp_len + comp_hdr_len >= raw_len + raw_hdr_len {
return Ok(None);
}
Ok(Some(write_flagged_raw(writer, &scratch[..comp_len], 1)?))
})
}
#[cfg(not(feature = "std"))]
{
#[allow(clippy::uninit_vec)]
let mut scratch: Vec<u8> = unsafe {
let mut v = Vec::with_capacity(bound);
v.set_len(bound);
v
};
let comp_len = zstd_safe::compress(&mut scratch[..], input, ZSTD_LEVEL)
.map_err(|_| Error::InvalidData)?;
let comp_hdr_len = flagged_header_len(comp_len, true);
if comp_len + comp_hdr_len >= raw_len + raw_hdr_len {
return Ok(None);
}
Ok(Some(write_flagged_raw(writer, &scratch[..comp_len], 1)?))
}
}
#[inline(never)]
pub fn zstd_compress(input: &[u8]) -> Result<Vec<u8>> {
let bound = zstd_safe::compress_bound(input.len());
#[allow(clippy::uninit_vec)]
let mut out: Vec<u8> = unsafe {
let mut v = Vec::with_capacity(bound);
v.set_len(bound);
v
};
#[cfg(feature = "std")]
let written = ZSTD_STATE
.with(|cell| {
cell.borrow_mut()
.cctx
.compress(&mut out[..], input, ZSTD_LEVEL)
})
.map_err(|_| Error::InvalidData)?;
#[cfg(not(feature = "std"))]
let written =
zstd_safe::compress(&mut out[..], input, ZSTD_LEVEL).map_err(|_| Error::InvalidData)?;
out.truncate(written);
Ok(out)
}
#[inline(never)]
pub fn zstd_decompress(compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
#[allow(clippy::uninit_vec)]
let mut out: Vec<u8> = unsafe {
let mut v = Vec::with_capacity(original_len);
v.set_len(original_len);
v
};
#[cfg(feature = "std")]
let written = ZSTD_STATE
.with(|cell| cell.borrow_mut().dctx.decompress(&mut out[..], compressed))
.map_err(|_| Error::InvalidData)?;
#[cfg(not(feature = "std"))]
let written =
zstd_safe::decompress(&mut out[..], compressed).map_err(|_| Error::InvalidData)?;
if written != original_len {
return Err(Error::IncorrectLength);
}
Ok(out)
}
#[inline(always)]
pub fn zstd_content_size(compressed: &[u8]) -> Result<usize> {
match zstd_safe::get_frame_content_size(compressed) {
Ok(Some(n)) => Ok(n as usize),
_ => Err(Error::InvalidData),
}
}
#[inline(always)]
const fn varint_len_usize(mut val: usize) -> usize {
if val <= 127 {
return 1;
}
let mut n = 0usize;
while val != 0 {
n += 1;
val >>= 8;
}
1 + n
}
#[inline(always)]
pub const fn flagged_header_len(payload_len: usize, compressed: bool) -> usize {
let v = (payload_len << 1) | (compressed as usize);
varint_len_usize(v)
}
#[inline(always)]
pub(crate) fn write_flagged_raw(
writer: &mut impl crate::io::Write,
bytes: &[u8],
flag: usize,
) -> Result<usize> {
let raw_len = bytes.len();
let header_val = ((raw_len << 1) | flag) as u64;
let hdr_len = if header_val <= 0x7F {
1usize
} else {
1 + (((64 - header_val.leading_zeros() + 7) >> 3) as usize)
};
let total = hdr_len + raw_len;
writer.reserve(total);
if let Some(dst) = writer.buf_mut()
&& dst.len() >= total
{
unsafe {
let p = dst.as_mut_ptr();
if header_val <= 0x7F {
*p = header_val as u8;
core::ptr::copy_nonoverlapping(bytes.as_ptr(), p.add(1), raw_len);
} else {
let n = hdr_len - 1;
*p = 0x80 | (n as u8);
(p.add(1) as *mut u64).write_unaligned(header_val.to_le());
core::ptr::copy_nonoverlapping(bytes.as_ptr(), p.add(hdr_len), raw_len);
}
}
writer.advance_mut(total);
return Ok(total);
}
let mut out = [0u8; 9];
if header_val <= 0x7F {
out[0] = header_val as u8;
} else {
let n = hdr_len - 1;
out[0] = 0x80 | (n as u8);
let le = header_val.to_le_bytes();
unsafe {
core::ptr::copy_nonoverlapping(le.as_ptr(), out.as_mut_ptr().add(1), n);
}
}
writer.write(&out[..hdr_len])?;
writer.write(bytes)?;
Ok(total)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::VecWriter;
use crate::prelude::Encode;
#[test]
fn wire_format_compress_is_deterministic() {
let payload = vec![0u8; 512];
let mut writer_a = VecWriter::new();
let _ = compress_and_write(&payload, &mut writer_a)
.expect("compress_and_write")
.expect("compression should beat raw for all-zero payload");
let bound = zstd_safe::compress_bound(payload.len());
let mut tmp = vec![0u8; bound];
let written =
zstd_safe::compress(&mut tmp[..], &payload, ZSTD_LEVEL).expect("zstd_safe::compress");
tmp.truncate(written);
let mut writer_b = VecWriter::new();
write_flagged_raw(&mut writer_b, &tmp, 1).unwrap();
assert_eq!(
writer_a.into_inner(),
writer_b.into_inner(),
"thread-local CCtx and fresh-context paths must produce identical wire bytes"
);
}
#[test]
fn wire_format_compressible_round_trip() {
let payload = vec![7u8; 1024];
let mut writer = VecWriter::new();
Encode::encode_ext(&payload, &mut writer, None).unwrap();
let bytes = writer.into_inner();
let decoded: Vec<u8> = crate::decode(&mut crate::io::Cursor::new(&bytes)).unwrap();
assert_eq!(decoded, payload);
}
#[test]
fn wire_format_repeated_encode_stable() {
let payload = b"the quick brown fox jumps over the lazy dog the quick brown fox jumps over the lazy dog the quick brown fox".to_vec();
let mut w1 = VecWriter::new();
Encode::encode_ext(&payload, &mut w1, None).unwrap();
let mut w2 = VecWriter::new();
Encode::encode_ext(&payload, &mut w2, None).unwrap();
assert_eq!(w1.into_inner(), w2.into_inner());
}
}