use alloc::{borrow::Cow, vec::Vec};
mod tests;
pub(super) const ZSTD_PREFIX: [u8; 8] = [82, 188, 83, 118, 70, 219, 142, 5];
pub(super) fn zstd_decode_if_necessary(
data: &'_ [u8],
max_allowed: usize,
) -> Result<Cow<'_, [u8]>, Error> {
if data.starts_with(&ZSTD_PREFIX) {
Ok(Cow::Owned(zstd_decode(
&data[ZSTD_PREFIX.len()..],
max_allowed,
)?))
} else if data.len() > max_allowed {
Err(Error::TooLarge)
} else {
Ok(Cow::Borrowed(data))
}
}
fn zstd_decode(mut data: &[u8], max_allowed: usize) -> Result<Vec<u8>, Error> {
let mut decoder = ruzstd::decoding::FrameDecoder::new();
decoder.init(&mut data).map_err(|_| Error::InvalidZstd)?;
match decoder.decode_blocks(
&mut data,
ruzstd::decoding::BlockDecodingStrategy::UptoBytes(max_allowed),
) {
Ok(true) => {}
Ok(false) => return Err(Error::TooLarge),
Err(_) => return Err(Error::InvalidZstd),
}
debug_assert!(decoder.is_finished());
let out_buf = decoder.collect().unwrap();
debug_assert!(out_buf.len() <= max_allowed);
Ok(out_buf)
}
#[derive(Debug, derive_more::Display, derive_more::Error, Clone)]
pub enum Error {
InvalidZstd,
TooLarge,
}