use crate::engine::decompress::{
align_to_copy, binding, bytes_for_u32s, decode_u32s, unpack_words_to_bytes,
};
use crate::runtime::cache::{BufferPool, PooledBuffer};
use std::sync::mpsc;
use vyre::error::{Error, Result};
use vyre::Program;
pub struct DecompressionKernelInputs<'a> {
pub data: &'a [u8],
pub uniform_bytes: &'a [u8],
pub output_words_len: usize,
pub status_words_len: usize,
pub total_output_size: u32,
pub block_count: u32,
}
pub struct DecompressionKernelLabels {
pub format: &'static str,
pub compressed_buffer: &'static str,
pub output_buffer: &'static str,
pub status_buffer: &'static str,
pub uniforms_buffer: &'static str,
pub shader: &'static str,
pub bind_group_layout: &'static str,
pub pipeline_layout: &'static str,
pub pipeline: &'static str,
pub bind_group: &'static str,
pub readback_output: &'static str,
pub readback_status: &'static str,
pub command_encoder: &'static str,
pub pass: &'static str,
}
pub struct DispatchOutput {
pub output: Vec<u8>,
pub statuses: Vec<u32>,
}
pub fn dispatch_decompression_kernel(
device: &wgpu::Device,
queue: &wgpu::Queue,
inputs: DecompressionKernelInputs<'_>,
labels: DecompressionKernelLabels,
program: &Program,
encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<DispatchOutput> {
vyre::ops::registry::gate::verify_program_certificate(program).map_err(|source| {
Error::Decompress {
message: source.to_string(),
}
})?;
let pool = BufferPool::global();
let compressed_buffer =
create_storage_buffer(device, queue, labels.compressed_buffer, inputs.data)?;
let storage_usage =
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST;
let output_buffer = pool.acquire(
device,
labels.output_buffer,
bytes_for_u32s(inputs.output_words_len)?,
storage_usage,
)?;
let status_buffer = pool.acquire(
device,
labels.status_buffer,
bytes_for_u32s(inputs.status_words_len)?,
storage_usage,
)?;
let uniforms_buffer =
create_storage_buffer(device, queue, labels.uniforms_buffer, inputs.uniform_bytes)?;
let shader_source = vyre::lower::wgsl::lower(program).map_err(|source| Error::Decompress {
message: format!(
"failed to lower {} decompression IR to WGSL: {source}. Fix: repair the {} IR composition before GPU dispatch.",
labels.format, labels.format
),
})?;
let _shader_label = labels.shader;
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(labels.bind_group_layout),
entries: &[
storage_binding(0, true),
storage_binding(1, false),
storage_binding(2, true),
storage_binding(3, false),
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(labels.pipeline_layout),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = crate::runtime::compile_compute_pipeline_with_layout(
device,
labels.pipeline,
&shader_source,
"main",
Some(&pipeline_layout),
)?;
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(labels.bind_group),
layout: &bind_group_layout,
entries: &[
binding(0, &compressed_buffer),
binding(1, &output_buffer),
binding(2, &uniforms_buffer),
binding(3, &status_buffer),
],
});
let output_size = align_to_copy(bytes_for_u32s(inputs.output_words_len)?);
let status_size = align_to_copy(bytes_for_u32s(inputs.status_words_len)?);
let readback_output = readback_buffer(device, labels.readback_output, output_size)?;
let readback_status = readback_buffer(device, labels.readback_status, status_size)?;
let mut owned_encoder: Option<wgpu::CommandEncoder> = encoder.is_none().then(|| {
device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(labels.command_encoder),
})
});
let encoder = if let Some(encoder) = encoder {
encoder
} else {
owned_encoder
.as_mut()
.expect("owned encoder must be present")
};
encoder.clear_buffer(&output_buffer, 0, Some(output_buffer.size()));
encoder.clear_buffer(&status_buffer, 0, Some(status_buffer.size()));
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(labels.pass),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(inputs.block_count, 1, 1);
}
encoder.copy_buffer_to_buffer(&output_buffer, 0, &readback_output, 0, output_size);
encoder.copy_buffer_to_buffer(&status_buffer, 0, &readback_status, 0, status_size);
let Some(owned_encoder) = owned_encoder else {
return Err(Error::Decompress {
message: "dispatch_decompression_kernel was called with an external encoder; use deferred dispatch mode to submit and map externally. Fix: pass `None` for immediate readback behavior.".to_string(),
});
};
let submission = queue.submit(Some(owned_encoder.finish()));
let output_bytes =
read_mapped_buffer(device, &readback_output, output_size, submission.clone())?;
let status_bytes = read_mapped_buffer(device, &readback_status, status_size, submission)?;
pool.release(compressed_buffer);
pool.release(output_buffer);
pool.release(status_buffer);
pool.release(uniforms_buffer);
pool.release(readback_output);
pool.release(readback_status);
let statuses = decode_u32s(&status_bytes)?;
let output_words = decode_u32s(&output_bytes)?;
let total_output_size =
usize::try_from(inputs.total_output_size).map_err(|source| Error::Decompress {
message: format!(
"{} output size {} cannot fit usize: {source}. Fix: reject this decompression descriptor on this platform.",
labels.format, inputs.total_output_size
),
})?;
let output = unpack_words_to_bytes(&output_words, total_output_size)?;
Ok(DispatchOutput { output, statuses })
}
pub fn create_storage_buffer(
device: &wgpu::Device,
queue: &wgpu::Queue,
label: &'static str,
bytes: &[u8],
) -> Result<PooledBuffer> {
let size = align_to_copy(u64::try_from(bytes.len()).map_err(|source| Error::Gpu {
message: format!(
"storage buffer '{label}' length {} cannot fit u64: {source}. Fix: split the upload.",
bytes.len()
),
})?);
let buffer = BufferPool::global().acquire(
device,
label,
size,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
)?;
queue.write_buffer(&buffer, 0, bytes);
write_zero_padding(queue, &buffer, bytes.len(), size)?;
Ok(buffer)
}
pub fn readback_buffer(
device: &wgpu::Device,
label: &'static str,
size: u64,
) -> Result<PooledBuffer> {
BufferPool::global().acquire(
device,
label,
size,
wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
)
}
pub fn read_mapped_buffer(
device: &wgpu::Device,
buffer: &wgpu::Buffer,
size: u64,
submission: wgpu::SubmissionIndex,
) -> Result<Vec<u8>> {
let slice = buffer.slice(0..size);
let (sender, receiver) = mpsc::sync_channel(1);
slice.map_async(wgpu::MapMode::Read, move |result| {
if let Err(send_err) = sender.send(result) {
tracing::warn!(
?send_err,
"decompress readback receiver dropped before map_async result delivery"
);
}
});
match device.poll(wgpu::Maintain::wait_for(submission)) {
wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
}
receiver
.recv()
.map_err(|error| Error::Gpu {
message: format!(
"GPU readback channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."
),
})?
.map_err(|error| Error::Gpu {
message: format!(
"GPU readback map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."
),
})?;
let mapped = slice.get_mapped_range();
let output_len = usize::try_from(size).map_err(|source| Error::Gpu {
message: format!("readback size cannot fit usize: {source}. Fix: split the workload."),
})?;
let mut bytes = vec![0_u8; output_len];
bytes.copy_from_slice(&mapped[..output_len]);
drop(mapped);
buffer.unmap();
Ok(bytes)
}
fn write_zero_padding(
queue: &wgpu::Queue,
buffer: &wgpu::Buffer,
written_len: usize,
total_size: u64,
) -> Result<()> {
let mut offset = u64::try_from(written_len).map_err(|source| Error::Gpu {
message: format!(
"written byte length {written_len} cannot fit u64: {source}. Fix: split the upload."
),
})?;
let zeros = [0u8; 4096];
while offset < total_size {
let remaining = usize::try_from((total_size - offset).min(4096)).map_err(|source| {
Error::Gpu {
message: format!(
"padding byte length cannot fit usize: {source}. Fix: run on a supported target."
),
}
})?;
queue.write_buffer(buffer, offset, &zeros[..remaining]);
offset += u64::try_from(remaining).map_err(|source| Error::Gpu {
message: format!(
"padding chunk length {remaining} cannot fit u64: {source}. Fix: run on a supported target."
),
})?;
}
Ok(())
}
pub fn storage_binding(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}