use alloc::vec::Vec;
use ruzstd::{io::Read, StreamingDecoder};
use super::MAX_DECOMPRESSED_EXTENT_BYTES;
use crate::error::{Error, Result};
const ZSTD_FRAME_MAGIC: u32 = 0xFD2F_B528;
pub(super) fn decode(src: &[u8], dst: &mut Vec<u8>) -> Result<()> {
let mut input = src;
loop {
if input.len() < 4 {
return Ok(());
}
let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
if magic != ZSTD_FRAME_MAGIC {
return Ok(());
}
let mut decoder = StreamingDecoder::new(&mut input).map_err(|_| Error::BadCompression {
algorithm: "comp_zstd",
})?;
let mut buf = [0u8; 4096];
loop {
let n = decoder.read(&mut buf).map_err(|_| Error::BadCompression {
algorithm: "comp_zstd",
})?;
if n == 0 {
break;
}
if dst.len() + n > MAX_DECOMPRESSED_EXTENT_BYTES {
return Err(Error::BadCompression {
algorithm: "comp_zstd",
});
}
dst.extend_from_slice(&buf[..n]);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn non_zstd_prefix_terminates_cleanly() {
let mut dst = Vec::new();
let result = decode(b"\x00\x00\x00\x00", &mut dst);
assert!(result.is_ok());
assert!(dst.is_empty());
}
#[test]
fn malformed_zstd_frame_returns_bad_compression() {
use alloc::vec;
let mut dst = Vec::new();
let mut bogus = vec![0u8; 64];
bogus[0..4].copy_from_slice(&0xFD2F_B528u32.to_le_bytes());
let result = decode(&bogus, &mut dst);
assert!(matches!(
result,
Err(crate::error::Error::BadCompression {
algorithm: "comp_zstd"
})
));
}
}