use std::{
borrow::Cow,
io::{Read, Write},
};
const ZSTD_PREFIX: [u8; 8] = [82, 188, 83, 118, 70, 219, 142, 5];
pub const CODE_BLOB_BOMB_LIMIT: usize = 50 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum Error {
#[error("Possible compression bomb encountered")]
PossibleBomb,
#[error("Blob had invalid format")]
Invalid,
}
fn read_from_decoder(
decoder: impl Read,
blob_len: usize,
bomb_limit: usize,
) -> Result<Vec<u8>, Error> {
let mut decoder = decoder.take((bomb_limit + 1) as u64);
let mut buf = Vec::with_capacity(blob_len);
decoder.read_to_end(&mut buf).map_err(|_| Error::Invalid)?;
if buf.len() <= bomb_limit {
Ok(buf)
} else {
Err(Error::PossibleBomb)
}
}
fn decompress_zstd(blob: &[u8], bomb_limit: usize) -> Result<Vec<u8>, Error> {
let decoder = zstd::Decoder::new(blob).map_err(|_| Error::Invalid)?;
read_from_decoder(decoder, blob.len(), bomb_limit)
}
pub fn decompress(blob: &[u8], bomb_limit: usize) -> Result<Cow<'_, [u8]>, Error> {
if blob.starts_with(&ZSTD_PREFIX) {
decompress_zstd(&blob[ZSTD_PREFIX.len()..], bomb_limit).map(Into::into)
} else {
Ok(blob.into())
}
}
pub fn compress_weakly(blob: &[u8], bomb_limit: usize) -> Option<Vec<u8>> {
compress_with_level(blob, bomb_limit, 3)
}
pub fn compress_strongly(blob: &[u8], bomb_limit: usize) -> Option<Vec<u8>> {
compress_with_level(blob, bomb_limit, 22)
}
#[deprecated(
note = "Will be removed after June 2026. Use compress_strongly, compress_weakly or compress_with_level instead"
)]
pub fn compress(blob: &[u8], bomb_limit: usize) -> Option<Vec<u8>> {
compress_with_level(blob, bomb_limit, 3)
}
fn compress_with_level(blob: &[u8], bomb_limit: usize, level: i32) -> Option<Vec<u8>> {
if blob.len() > bomb_limit {
return None;
}
let mut buf = ZSTD_PREFIX.to_vec();
{
let mut v = zstd::Encoder::new(&mut buf, level).ok()?.auto_finish();
v.write_all(blob).ok()?;
}
Some(buf)
}
#[cfg(test)]
mod tests {
use super::*;
const BOMB_LIMIT: usize = 10;
#[test]
fn refuse_to_encode_over_limit() {
let mut v = vec![0; BOMB_LIMIT + 1];
assert!(compress_weakly(&v, BOMB_LIMIT).is_none());
assert!(compress_strongly(&v, BOMB_LIMIT).is_none());
let _ = v.pop();
assert!(compress_weakly(&v, BOMB_LIMIT).is_some());
assert!(compress_strongly(&v, BOMB_LIMIT).is_some());
}
#[test]
fn compress_and_decompress() {
let v = vec![0; BOMB_LIMIT];
let compressed_weakly = compress_weakly(&v, BOMB_LIMIT).unwrap();
let compressed_strongly = compress_strongly(&v, BOMB_LIMIT).unwrap();
assert!(compressed_weakly.starts_with(&ZSTD_PREFIX));
assert!(compressed_strongly.starts_with(&ZSTD_PREFIX));
assert_eq!(&decompress(&compressed_weakly, BOMB_LIMIT).unwrap()[..], &v[..]);
assert_eq!(&decompress(&compressed_strongly, BOMB_LIMIT).unwrap()[..], &v[..]);
}
#[test]
fn decompresses_only_when_magic() {
let v = vec![0; BOMB_LIMIT + 1];
assert_eq!(&decompress(&v, BOMB_LIMIT).unwrap()[..], &v[..]);
}
#[test]
fn possible_bomb_fails() {
let encoded_bigger_than_bomb = vec![0; BOMB_LIMIT + 1];
let mut buf = ZSTD_PREFIX.to_vec();
{
let mut v = zstd::Encoder::new(&mut buf, 3).unwrap().auto_finish();
v.write_all(&encoded_bigger_than_bomb[..]).unwrap();
}
assert_eq!(decompress(&buf[..], BOMB_LIMIT).err(), Some(Error::PossibleBomb));
}
}