pub(crate) mod dispatch_kernel;
pub mod formats;
pub(crate) mod uniforms;
use vyre::error::{Error, Result};
pub use self::formats::lz4::{dispatch_lz4, Lz4DispatchArgs};
pub use self::formats::zstd::{dispatch_zstd, ZstdDispatchArgs};
pub use self::uniforms::{Lz4Uniforms, ZstdUniforms};
pub const MAX_DECOMPRESS_OUTPUT_BYTES: usize = 256 * 1024 * 1024;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct DecompressLimits {
pub max_input_size: usize,
pub max_output_size: usize,
}
impl Default for DecompressLimits {
fn default() -> Self {
Self {
max_input_size: 64 * 1024 * 1024,
max_output_size: MAX_DECOMPRESS_OUTPUT_BYTES,
}
}
}
pub fn validate_output_ratio(
format: &str,
compressed_len: usize,
declared_output_len: usize,
max_output_ratio: usize,
) -> Result<()> {
if declared_output_len == 0 {
return Ok(());
}
if compressed_len == 0 {
return Err(Error::Decompress {
message: format!(
"{format} declared {declared_output_len} output bytes from an empty compressed input. Fix: reject empty-input decompression bombs before GPU dispatch."
),
});
}
let max_output = compressed_len
.checked_mul(max_output_ratio)
.ok_or_else(|| Error::Decompress {
message: format!(
"{format} output-ratio limit overflowed for compressed length {compressed_len} and ratio {max_output_ratio}. Fix: split the compressed input before GPU dispatch."
),
})?;
if declared_output_len > max_output {
return Err(Error::Decompress {
message: format!(
"{format} declared {declared_output_len} output bytes for {compressed_len} compressed bytes, exceeding max_output_ratio {max_output_ratio}. Fix: reject the decompression bomb or lower the declared output size."
),
});
}
Ok(())
}
pub(crate) fn validate_backend_capacity(
device: &wgpu::Device,
compressed_len: usize,
descriptor_words: usize,
output_words: usize,
status_words: usize,
uniform_words: usize,
) -> Result<()> {
let limits = device.limits();
let storage_binding_limit = u64::from(limits.max_storage_buffer_binding_size);
for (name, bytes) in [
("compressed input", align_to_copy(u64::try_from(compressed_len).map_err(|source| Error::Gpu {
message: format!("compressed input length cannot fit u64: {source}. Fix: split the compressed input before GPU dispatch."),
})?)),
("descriptor buffer", bytes_for_u32s(descriptor_words)?),
("output buffer", bytes_for_u32s(output_words)?),
("status buffer", bytes_for_u32s(status_words)?),
("uniform buffer", bytes_for_u32s(uniform_words)?),
] {
if bytes > limits.max_buffer_size || bytes > storage_binding_limit {
return Err(Error::Gpu {
message: format!(
"{name} requires {bytes} bytes, exceeding the adapter storage-buffer limit. Fix: split the input or use a GPU with larger storage buffers."
),
});
}
}
Ok(())
}
pub(crate) fn unpack_words_to_bytes(words: &[u32], byte_len: usize) -> Result<Vec<u8>> {
if byte_len > MAX_DECOMPRESS_OUTPUT_BYTES {
return Err(Error::Decompress {
message: format!(
"decompressed output is {byte_len} bytes, exceeding {MAX_DECOMPRESS_OUTPUT_BYTES}. Fix: split the payload or lower declared output size."
),
});
}
let available = words
.len()
.checked_mul(4)
.ok_or_else(|| Error::Decompress {
message: "readback capacity overflowed usize. Fix: split the decompression workload."
.to_string(),
})?;
if byte_len > available {
return Err(Error::Decompress {
message: format!(
"declared decompressed output {byte_len} bytes exceeds readback capacity {available}. Fix: reject this malformed decompression descriptor."
),
});
}
let mut bytes = vec![0_u8; byte_len];
for (index, byte) in bytes.iter_mut().enumerate() {
*byte = ((words[index / 4] >> ((index % 4) * 8)) & 0xff) as u8;
}
Ok(bytes)
}
pub(crate) fn decode_u32s(bytes: &[u8]) -> Result<Vec<u32>> {
if bytes.len() % 4 != 0 {
return Err(Error::Gpu {
message: format!(
"GPU readback length {} is not divisible by 4. Fix: check decompression buffer sizing.",
bytes.len()
),
});
}
Ok(bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect())
}
pub(crate) fn bytes_for_u32s(words: usize) -> Result<u64> {
let bytes = words.checked_mul(4).ok_or_else(|| Error::Gpu {
message: "buffer size overflow. Fix: split the decompression workload before GPU dispatch."
.to_string(),
})?;
u64::try_from(bytes.max(4)).map_err(|source| Error::Gpu {
message: format!("buffer byte length cannot fit u64: {source}. Fix: split the workload."),
})
}
pub(crate) fn words_for_output_bytes(format: &str, bytes: usize) -> Result<usize> {
if bytes > MAX_DECOMPRESS_OUTPUT_BYTES {
return Err(Error::Decompress {
message: format!(
"{format} decompressed output is {bytes} bytes, exceeding {MAX_DECOMPRESS_OUTPUT_BYTES}. Fix: split the compressed stream before GPU dispatch."
),
});
}
Ok(bytes.div_ceil(4).max(1))
}
pub(crate) fn u32_output_len(format: &str, bytes: usize) -> Result<u32> {
u32::try_from(bytes).map_err(|source| Error::Decompress {
message: format!(
"{format} decompressed output length {bytes} cannot fit u32 shader uniforms: {source}. Fix: split the compressed stream before GPU dispatch."
),
})
}
pub(crate) fn align_to_copy(size: u64) -> u64 {
let alignment = wgpu::COPY_BUFFER_ALIGNMENT;
size.div_ceil(alignment).max(1) * alignment
}
pub(crate) fn binding(binding: u32, buffer: &wgpu::Buffer) -> wgpu::BindGroupEntry<'_> {
wgpu::BindGroupEntry {
binding,
resource: buffer.as_entire_binding(),
}
}