use crate::engine::decompress::dispatch_kernel::{
dispatch_decompression_kernel, DecompressionKernelInputs, DecompressionKernelLabels,
};
pub use crate::engine::decompress::uniforms::Lz4Uniforms;
use vyre::error::{Error, Result};
use vyre::ops::compression::lz4;
pub fn dispatch_lz4(
args: Lz4DispatchArgs<'_>,
command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<(Vec<u8>, Vec<u32>)> {
if !args.descriptor_words.is_empty() {
return Err(Error::Decompress {
message: format!(
"LZ4 dispatch received {} descriptor words, but the current LZ4 IR backend consumes input/output/params/status buffers directly. Fix: pass descriptor fields through uniforms or update the LZ4 IR composition.",
args.descriptor_words.len()
),
});
}
let derived_output_bytes = decoded_len_from_block(args.data)?;
let derived_output_words = super::words_for_output_bytes("LZ4", derived_output_bytes)?;
let derived_output_len = super::u32_output_len("LZ4", derived_output_bytes)?;
let uniforms = Lz4Uniforms {
input_len: u32::try_from(args.data.len()).map_err(|source| Error::Decompress {
message: format!(
"LZ4 input length {} cannot fit u32 shader uniforms: {source}. Fix: split the compressed stream before GPU dispatch.",
args.data.len()
),
})?,
output_len: derived_output_len,
total_output_size: derived_output_len,
};
super::validate_output_ratio(
"LZ4",
args.data.len(),
derived_output_bytes,
lz4::MAX_OUTPUT_RATIO as usize,
)?;
super::validate_backend_capacity(
args.device,
args.data.len(),
0,
derived_output_words,
args.status_words_len,
std::mem::size_of::<Lz4Uniforms>() / std::mem::size_of::<u32>(),
)?;
let program = lz4::Lz4Decompress::program();
let dispatch = dispatch_decompression_kernel(
args.device,
args.queue,
DecompressionKernelInputs {
data: args.data,
uniform_bytes: bytemuck::bytes_of(&uniforms),
output_words_len: derived_output_words,
status_words_len: args.status_words_len,
total_output_size: uniforms.total_output_size,
block_count: args.block_count,
},
DecompressionKernelLabels {
format: "LZ4",
compressed_buffer: "vyre-lz4-compressed",
output_buffer: "vyre-lz4-output",
status_buffer: "vyre-lz4-status",
uniforms_buffer: "vyre-lz4-uniforms",
shader: "vyre-lz4-shader",
bind_group_layout: "vyre-lz4-bind-group-layout",
pipeline_layout: "vyre-lz4-pipeline-layout",
pipeline: "vyre-lz4-pipeline",
bind_group: "vyre-lz4-bind-group",
readback_output: "vyre-lz4-readback-output",
readback_status: "vyre-lz4-readback-status",
command_encoder: "vyre-lz4-command-encoder",
pass: "vyre-lz4-pass",
},
&program,
command_encoder,
)?;
Ok((dispatch.output, dispatch.statuses))
}
pub struct Lz4DispatchArgs<'a> {
pub device: &'a wgpu::Device,
pub queue: &'a wgpu::Queue,
pub data: &'a [u8],
pub descriptor_words: &'a [u32],
pub output_words_len: usize,
pub status_words_len: usize,
pub uniforms: Lz4Uniforms,
pub block_count: u32,
}
pub(crate) fn decoded_len_from_block(data: &[u8]) -> Result<usize> {
let mut cursor = 0usize;
let mut output_len = 0usize;
while cursor < data.len() {
let token = data[cursor];
cursor += 1;
let literal_len = read_extended_len(data, &mut cursor, usize::from(token >> 4), "literal")?;
cursor = cursor.checked_add(literal_len).ok_or_else(|| Error::Decompress {
message: "LZ4 literal cursor overflow. Fix: reject this malformed block before GPU dispatch.".to_string(),
})?;
if cursor > data.len() {
return Err(Error::Decompress {
message: "LZ4 literal length exceeds compressed block bounds. Fix: reject this malformed block before GPU dispatch.".to_string(),
});
}
output_len = output_len.checked_add(literal_len).ok_or_else(|| Error::Decompress {
message: "LZ4 decoded length overflowed while reading literals. Fix: split or reject this compressed block.".to_string(),
})?;
if cursor == data.len() {
break;
}
if data.len().saturating_sub(cursor) < 2 {
return Err(Error::Decompress {
message: "LZ4 match offset is truncated. Fix: reject this malformed block before GPU dispatch.".to_string(),
});
}
let offset = u16::from_le_bytes([data[cursor], data[cursor + 1]]);
cursor += 2;
if offset == 0 || usize::from(offset) > output_len {
return Err(Error::Decompress {
message: format!(
"LZ4 match offset {offset} is invalid for decoded length {output_len}. Fix: reject this malformed block before GPU dispatch."
),
});
}
let match_len = read_extended_len(data, &mut cursor, usize::from(token & 0x0f), "match")?
.checked_add(4)
.ok_or_else(|| Error::Decompress {
message: "LZ4 match length overflow. Fix: reject this malformed block before GPU dispatch.".to_string(),
})?;
output_len = output_len.checked_add(match_len).ok_or_else(|| Error::Decompress {
message: "LZ4 decoded length overflowed while reading match. Fix: split or reject this compressed block.".to_string(),
})?;
}
Ok(output_len)
}
pub fn read_extended_len(
data: &[u8],
cursor: &mut usize,
nibble: usize,
label: &'static str,
) -> Result<usize> {
let mut len = nibble;
if nibble != 15 {
return Ok(len);
}
loop {
let Some(byte) = data.get(*cursor).copied() else {
return Err(Error::Decompress {
message: format!(
"LZ4 {label} length extension is truncated. Fix: reject this malformed block before GPU dispatch."
),
});
};
*cursor += 1;
len = len.checked_add(usize::from(byte)).ok_or_else(|| Error::Decompress {
message: format!(
"LZ4 {label} length extension overflowed. Fix: reject this malformed block before GPU dispatch."
),
})?;
if byte != 255 {
return Ok(len);
}
}
}