use crate::bytemuck_safe::{safe_bytes_of_slice, safe_cast_slice};
use crate::engine::decode::{DecodeFormat, DecodeRules, DecodedRegion};
use crate::runtime::cache::{BufferPool, PooledBuffer};
use crate::runtime::{bg_entry, compile_compute_pipeline};
use vyre::{Error, Result};
use bytemuck::{Pod, Zeroable};
use std::sync::mpsc;
pub(crate) fn dispatch_decode(
format: DecodeFormat,
input: &[u8],
rules: &DecodeRules,
command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<Vec<DecodedRegion>> {
if input.is_empty() {
return Ok(Vec::new());
}
if input.len() > MAX_DECODE_INPUT_BYTES {
return Err(Error::Decode {
message: format!(
"decode input is {} bytes, exceeding {MAX_DECODE_INPUT_BYTES}. Fix: split the input before GPU decode dispatch.",
input.len()
),
});
}
let (device, queue) = crate::runtime::cached_device()?;
let input_len = u32::try_from(input.len()).map_err(|source| Error::Decode {
message: format!("input size {} exceeds u32::MAX: {source}. Fix: split the decode input before GPU dispatch.", input.len()),
})?;
validate_gpu_sizes(device, input_len, input.len())?;
let max_regions = input_len.min(MAX_DECODE_REGIONS);
let region_bytes = usize::try_from(max_regions)
.map_err(|source| Error::Decode {
message: format!("decode max_regions {max_regions} cannot fit usize: {source}. Fix: run on a supported target."),
})?
.checked_mul(size_of::<RegionMeta>())
.ok_or_else(|| Error::Decode {
message: "decode regions buffer size overflow. Fix: split the decode input before GPU dispatch.".to_string(),
})?;
let params = Params {
input_len,
min_run: format.min_run(rules),
max_regions,
output_size: input_len,
};
let pool = BufferPool::global();
let input_bytes = align_storage_bytes(input.len())?;
let input_buf = pool.acquire(
device,
"vyre decode input",
input_bytes,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
)?;
queue.write_buffer(&input_buf, 0, input);
write_zero_padding(queue, &input_buf, input.len(), input_bytes)?;
let regions_buf = zeroed_storage(device, "vyre decode regions", region_bytes.max(16))?;
let output_bytes = input.len().checked_mul(4).ok_or_else(|| Error::Decode {
message:
"decode output buffer size overflow. Fix: split the decode input before GPU dispatch."
.to_string(),
})?;
let output_buf = zeroed_storage(device, "vyre decode output", output_bytes.max(16))?;
let counters_buf = zeroed_storage(device, "vyre decode counters", 16)?;
let params_array = [params];
let params_bytes = safe_bytes_of_slice(¶ms_array);
let params_buf = pool.acquire(
device,
"vyre decode params",
u64::try_from(params_bytes.len()).map_err(|source| Error::Decode {
message: format!("decode params buffer size cannot fit u64: {source}. Fix: run on a supported target."),
})?,
wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
)?;
queue.write_buffer(¶ms_buf, 0, params_bytes);
vyre::ops::registry::gate::verify_certificate(format.op_id()).map_err(|source| Error::Decode {
message: source.to_string(),
})?;
let wgsl = format.wgsl();
let pipeline = compile_compute_pipeline(device, format.label(), &wgsl, "main")?;
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vyre decode bind group"),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
bg_entry(0, &input_buf),
bg_entry(1, ®ions_buf),
bg_entry(2, &output_buf),
bg_entry(3, &counters_buf),
bg_entry(4, ¶ms_buf),
],
});
let mut owned_encoder = command_encoder.is_none().then(|| {
device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("vyre decode encoder"),
})
});
let encoder = if let Some(encoder) = command_encoder {
encoder
} else {
owned_encoder
.as_mut()
.expect("owned encoder must be present when command_encoder is omitted")
};
encoder.clear_buffer(®ions_buf, 0, Some(regions_buf.size()));
encoder.clear_buffer(&output_buf, 0, Some(output_buf.size()));
encoder.clear_buffer(&counters_buf, 0, Some(counters_buf.size()));
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vyre decode pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(input_len.div_ceil(WORKGROUP_SIZE), 1, 1);
}
let counters_readback = readback_buffer(device, encoder, &counters_buf, 16)?;
let regions_readback = readback_buffer(
device,
encoder,
®ions_buf,
u64::try_from(region_bytes.max(16)).map_err(|source| Error::Decode {
message: format!("decode region readback size cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."),
})?,
)?;
let output_readback = readback_buffer(
device,
encoder,
&output_buf,
u64::try_from(output_bytes.max(16)).map_err(|source| Error::Decode {
message: format!("decode output readback size cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."),
})?,
)?;
let Some(owned_encoder) = owned_encoder else {
return Err(Error::Decode {
message: "dispatch_decode was called with an external command encoder, but this API returns decoded readback data that is unavailable until the caller submits that encoder. Fix: call with `None` for immediate submit/readback, or add a deferred decode API that returns readback buffers.".to_string(),
});
};
let submission = queue.submit(Some(owned_encoder.finish()));
let regions = readback_regions(
device,
&counters_readback,
®ions_readback,
&output_readback,
input.len(),
submission,
)?;
pool.release(input_buf);
pool.release(regions_buf);
pool.release(output_buf);
pool.release(counters_buf);
pool.release(params_buf);
pool.release(counters_readback);
pool.release(regions_readback);
pool.release(output_readback);
Ok(regions)
}
pub fn readback_regions(
device: &wgpu::Device,
counters_readback: &wgpu::Buffer,
regions_readback: &wgpu::Buffer,
output_readback: &wgpu::Buffer,
input_len: usize,
submission: wgpu::SubmissionIndex,
) -> Result<Vec<DecodedRegion>> {
let counters = map_readback_u32(device, counters_readback, submission.clone())?;
let region_meta = map_readback_region_meta(device, regions_readback, submission.clone())?;
let output_words = map_readback_u32(device, output_readback, submission)?;
let max_regions = u32::try_from(input_len).map_err(|source| Error::Decode {
message: format!("input length {input_len} cannot fit u32 while bounding readback: {source}. Fix: split the decode input before GPU dispatch."),
})?.min(MAX_DECODE_REGIONS);
let region_count = usize::try_from(counters.first().copied().unwrap_or(0).min(max_regions))
.map_err(|source| Error::Decode {
message: format!("region count cannot fit usize: {source}. Fix: reject this GPU readback on this platform."),
})?;
if region_count
> usize::try_from(MAX_DECODE_REGIONS).map_err(|source| Error::Decode {
message: format!(
"MAX_DECODE_REGIONS cannot fit usize: {source}. Fix: run on a supported target."
),
})?
{
return Err(Error::Decode {
message: format!(
"GPU region count {region_count} exceeds {MAX_DECODE_REGIONS}. Fix: reject this malformed decoder output."
),
});
}
let mut decoded = Vec::with_capacity(region_count);
for meta in region_meta.into_iter().take(region_count) {
if meta.src_len == 0 || meta.dst_len == 0 {
continue;
}
let src_offset = usize::try_from(meta.src_offset).map_err(|source| Error::Decode {
message: format!("source offset {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.src_offset),
})?;
let src_len = usize::try_from(meta.src_len).map_err(|source| Error::Decode {
message: format!("source length {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.src_len),
})?;
let src_end = src_offset.checked_add(src_len).ok_or_else(|| Error::Decode {
message: "GPU readback src_offset+src_len overflow. Fix: reject this malformed decoder output.".to_string(),
})?;
if src_end > input_len {
return Err(Error::Decode {
message: format!(
"GPU source region [{src_offset}, {src_end}) exceeds input length {input_len}. Fix: reject this malformed decoder output and inspect the decode shader."
),
});
}
let dst_start = usize::try_from(meta.dst_offset).map_err(|source| Error::Decode {
message: format!("output offset {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.dst_offset),
})?;
let dst_end = dst_start
.checked_add(usize::try_from(meta.dst_len).map_err(|source| Error::Decode {
message: format!("output length {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.dst_len),
})?)
.ok_or_else(|| Error::Decode {
message: "output region overflow during readback. Fix: reject this malformed decoder output.".to_string(),
})?;
if dst_end > output_words.len() {
return Err(Error::Decode {
message: "shader emitted output beyond allocated readback storage. Fix: reject this malformed decoder output and inspect the decode shader.".to_string(),
});
}
let decoded_bytes = output_words[dst_start..dst_end]
.iter()
.map(|word| {
u8::try_from(*word & 0xff).map_err(|source| Error::Decode {
message: format!("masked output byte could not fit u8: {source}. Fix: report this impossible conversion failure."),
})
})
.collect::<Result<Vec<_>>>()?;
decoded.push(DecodedRegion {
offset: src_offset,
length: src_len,
decoded_bytes,
});
}
Ok(decoded)
}
pub const WORKGROUP_SIZE: u32 = 64;
pub const REGION_META_SIZE: u64 = 16;
pub const MAX_DECODE_INPUT_BYTES: usize = 64 * 1024 * 1024;
pub const MAX_DECODE_REGIONS: u32 = 1_000_000;
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
pub struct Params {
input_len: u32,
min_run: u32,
max_regions: u32,
output_size: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
pub struct RegionMeta {
src_offset: u32,
src_len: u32,
dst_offset: u32,
dst_len: u32,
}
pub fn validate_gpu_sizes(device: &wgpu::Device, input_len: u32, input_size: usize) -> Result<()> {
let limits = device.limits();
let gpu_limit = u64::from(limits.max_storage_buffer_binding_size).min(limits.max_buffer_size);
let regions_bytes = u64::from(input_len.min(MAX_DECODE_REGIONS)) * REGION_META_SIZE;
let output_bytes = u64::from(input_len) * 4;
if regions_bytes > gpu_limit || output_bytes > gpu_limit {
return Err(Error::Gpu {
message: format!(
"input size {input_size} exceeds GPU buffer limit ({regions_bytes} byte regions buffer, {output_bytes} byte output buffer, {gpu_limit} limit). Fix: split the input or run on an adapter with larger storage buffers."
),
});
}
Ok(())
}
fn align_storage_bytes(len: usize) -> Result<u64> {
let aligned = len.max(1).next_multiple_of(4);
u64::try_from(aligned).map_err(|source| Error::Decode {
message: format!(
"decode input buffer size {aligned} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
),
})
}
fn write_zero_padding(
queue: &wgpu::Queue,
buffer: &wgpu::Buffer,
written: usize,
total: u64,
) -> Result<()> {
let written_u64 = u64::try_from(written).map_err(|source| Error::Decode {
message: format!(
"decode written byte count {written} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
),
})?;
if written_u64 >= total {
return Ok(());
}
let padding_len = usize::try_from(total - written_u64).map_err(|source| Error::Decode {
message: format!(
"decode zero padding length cannot fit usize: {source}. Fix: split the decode input before GPU dispatch."
),
})?;
let padding = [0u8; 4];
queue.write_buffer(buffer, written_u64, &padding[..padding_len]);
Ok(())
}
pub fn map_readback_u32(
device: &wgpu::Device,
buffer: &wgpu::Buffer,
submission: wgpu::SubmissionIndex,
) -> Result<Vec<u32>> {
let slice = buffer.slice(..);
let (sender, receiver) = mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
if let Err(send_err) = sender.send(result) {
tracing::warn!(
?send_err,
"decode readback receiver dropped before map_async result delivery"
);
}
});
match device.poll(wgpu::Maintain::wait_for(submission)) {
wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
}
receiver
.recv()
.map_err(|source| Error::Gpu {
message: format!("readback channel closed unexpectedly: {source}. Fix: keep the decode readback receiver alive until map_async completes."),
})?
.map_err(|error| Error::Gpu {
message: format!("map_async failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
})?;
let mapped = slice.get_mapped_range();
let out = safe_cast_slice::<u32>(&mapped)
.map_err(|error| Error::Decode {
message: format!(
"safe cast failed in map_readback_u32: {error}. Fix: ensure the readback buffer is aligned and sized correctly."
),
})?
.to_vec();
drop(mapped);
buffer.unmap();
Ok(out)
}
pub fn map_readback_region_meta(
device: &wgpu::Device,
buffer: &wgpu::Buffer,
submission: wgpu::SubmissionIndex,
) -> Result<Vec<RegionMeta>> {
let words = map_readback_u32(device, buffer, submission)?;
let region_words = usize::try_from(MAX_DECODE_REGIONS)
.map_err(|source| Error::Decode {
message: format!(
"MAX_DECODE_REGIONS cannot fit usize: {source}. Fix: run on a supported target."
),
})?
.checked_mul(4)
.ok_or_else(|| Error::Decode {
message: "decode region word bound overflow. Fix: lower MAX_DECODE_REGIONS."
.to_string(),
})?;
if words.len() > region_words {
return Err(Error::Decode {
message: format!(
"decode region metadata contains {} u32 words, exceeding {region_words}. Fix: split the input before GPU decode dispatch.",
words.len()
),
});
}
let mut regions = Vec::with_capacity(words.len() / 4);
for chunk in words.chunks_exact(4) {
regions.push(RegionMeta {
src_offset: chunk[0],
src_len: chunk[1],
dst_offset: chunk[2],
dst_len: chunk[3],
});
}
Ok(regions)
}
pub fn zeroed_storage(device: &wgpu::Device, label: &str, bytes: usize) -> Result<PooledBuffer> {
BufferPool::global().acquire(
device,
label,
u64::try_from(bytes).map_err(|source| Error::Decode {
message: format!(
"decode zeroed storage size {bytes} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
),
})?,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
)
}
pub fn readback_buffer(
device: &wgpu::Device,
encoder: &mut wgpu::CommandEncoder,
source: &wgpu::Buffer,
size: u64,
) -> Result<PooledBuffer> {
let readback = BufferPool::global().acquire(
device,
"vyre decode readback",
size,
wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
)?;
encoder.copy_buffer_to_buffer(source, 0, &readback, 0, size);
Ok(readback)
}