use crate::engine::decompress::dispatch_kernel::{
dispatch_decompression_kernel, DecompressionKernelInputs, DecompressionKernelLabels,
};
pub use crate::engine::decompress::uniforms::ZstdUniforms;
use vyre::error::{Error, Result};
use vyre::ops::compression::zstd;
pub fn dispatch_zstd(
args: ZstdDispatchArgs<'_>,
command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<(Vec<u8>, Vec<u32>)> {
if !args.descriptor_words.is_empty() {
return Err(Error::Decompress {
message: format!(
"zstd dispatch received {} descriptor words, but the current zstd IR backend consumes input/output/params/status buffers directly. Fix: pass descriptor fields through uniforms or update the zstd IR composition.",
args.descriptor_words.len()
),
});
}
let derived_output_bytes = decoded_len_from_blocks(args.data)?;
let derived_output_words = super::words_for_output_bytes("zstd", derived_output_bytes)?;
let derived_output_len = super::u32_output_len("zstd", derived_output_bytes)?;
let uniforms = ZstdUniforms {
input_len: u32::try_from(args.data.len()).map_err(|source| Error::Decompress {
message: format!(
"zstd input length {} cannot fit u32 shader uniforms: {source}. Fix: split the compressed stream before GPU dispatch.",
args.data.len()
),
})?,
output_len: derived_output_len,
total_output_size: derived_output_len,
};
super::validate_output_ratio(
"zstd",
args.data.len(),
derived_output_bytes,
zstd::MAX_OUTPUT_RATIO as usize,
)?;
super::validate_backend_capacity(
args.device,
args.data.len(),
0,
derived_output_words,
args.status_words_len,
std::mem::size_of::<ZstdUniforms>() / std::mem::size_of::<u32>(),
)?;
let program = zstd::ZstdDecompress::program();
let dispatch = dispatch_decompression_kernel(
args.device,
args.queue,
DecompressionKernelInputs {
data: args.data,
uniform_bytes: bytemuck::bytes_of(&uniforms),
output_words_len: derived_output_words,
status_words_len: args.status_words_len,
total_output_size: uniforms.total_output_size,
block_count: args.block_count,
},
DecompressionKernelLabels {
format: "zstd",
compressed_buffer: "vyre-zstd-compressed",
output_buffer: "vyre-zstd-output",
status_buffer: "vyre-zstd-status",
uniforms_buffer: "vyre-zstd-uniforms",
shader: "vyre-zstd-shader",
bind_group_layout: "vyre-zstd-bind-group-layout",
pipeline_layout: "vyre-zstd-pipeline-layout",
pipeline: "vyre-zstd-pipeline",
bind_group: "vyre-zstd-bind-group",
readback_output: "vyre-zstd-readback-output",
readback_status: "vyre-zstd-readback-status",
command_encoder: "vyre-zstd-command-encoder",
pass: "vyre-zstd-pass",
},
&program,
command_encoder,
)?;
Ok((dispatch.output, dispatch.statuses))
}
pub struct ZstdDispatchArgs<'a> {
pub device: &'a wgpu::Device,
pub queue: &'a wgpu::Queue,
pub data: &'a [u8],
pub descriptor_words: &'a [u32],
pub output_words_len: usize,
pub status_words_len: usize,
pub uniforms: ZstdUniforms,
pub block_count: u32,
}
pub(crate) fn decoded_len_from_blocks(data: &[u8]) -> Result<usize> {
let mut cursor = 0usize;
let mut output_len = 0usize;
while cursor < data.len() {
if data.len().saturating_sub(cursor) < 3 {
return Err(Error::Decompress {
message: "zstd block header is truncated. Fix: reject this malformed stream before GPU dispatch.".to_string(),
});
}
let header = u32::from(data[cursor])
| (u32::from(data[cursor + 1]) << 8)
| (u32::from(data[cursor + 2]) << 16);
cursor += 3;
let last = (header & 1) != 0;
let block_type = (header >> 1) & 0b11;
let block_size = usize::try_from(header >> 3).map_err(|source| Error::Decompress {
message: format!(
"zstd block size cannot fit usize: {source}. Fix: reject this malformed stream before GPU dispatch."
),
})?;
match block_type {
0 => {
cursor = cursor.checked_add(block_size).ok_or_else(|| Error::Decompress {
message: "zstd raw block cursor overflow. Fix: reject this malformed stream before GPU dispatch.".to_string(),
})?;
if cursor > data.len() {
return Err(Error::Decompress {
message: "zstd raw block length exceeds input bounds. Fix: reject this malformed stream before GPU dispatch.".to_string(),
});
}
output_len = output_len.checked_add(block_size).ok_or_else(|| Error::Decompress {
message: "zstd decoded length overflowed while reading a raw block. Fix: split or reject this stream.".to_string(),
})?;
}
1 => {
if cursor >= data.len() {
return Err(Error::Decompress {
message: "zstd RLE block is missing its repeated byte. Fix: reject this malformed stream before GPU dispatch.".to_string(),
});
}
cursor += 1;
output_len = output_len.checked_add(block_size).ok_or_else(|| Error::Decompress {
message: "zstd decoded length overflowed while reading an RLE block. Fix: split or reject this stream.".to_string(),
})?;
}
2 => {
return Err(Error::Decompress {
message: "zstd compressed block output length is not derivable from the block header. Fix: dispatch only raw/RLE blocks or provide a trusted frame parser that supplies verified sizes.".to_string(),
});
}
_ => {
return Err(Error::Decompress {
message: "zstd reserved block type encountered. Fix: reject this malformed stream before GPU dispatch.".to_string(),
});
}
}
if last {
if cursor != data.len() {
return Err(Error::Decompress {
message: "zstd stream has trailing bytes after the last block. Fix: reject this malformed stream before GPU dispatch.".to_string(),
});
}
return Ok(output_len);
}
}
Ok(output_len)
}