use crate::ir_inner::model::program::BufferDecl;
use crate::ir_inner::model::types::BufferAccess;
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct OutputSet(Vec<u32>);
impl OutputSet {
#[must_use]
pub fn from_buffers(buffers: &[BufferDecl]) -> Self {
let mut indices = Vec::with_capacity(buffers.len());
for (index, buffer) in buffers.iter().enumerate() {
if buffer.access() == BufferAccess::ReadWrite {
indices.push(index as u32);
}
}
Self(indices)
}
pub fn encode_into(&self, dst: &mut Vec<u8>) -> Result<(), String> {
let count = u64::try_from(self.0.len()).map_err(|err| {
format!("Fix: output-set count cannot fit u64 ({err}); split the Program.")
})?;
dst.reserve(leb_u64_len(count) + self.0.len().saturating_mul(5));
put_leb_u64(dst, count);
for &index in &self.0 {
put_leb_u32(dst, index);
}
Ok(())
}
pub fn encode_from_buffers_into(
buffers: &[BufferDecl],
dst: &mut Vec<u8>,
) -> Result<(), String> {
let count = buffers
.iter()
.filter(|buffer| buffer.access() == BufferAccess::ReadWrite)
.count();
let count_u64 = u64::try_from(count).map_err(|err| {
format!("Fix: output-set count cannot fit u64 ({err}); split the Program.")
})?;
dst.reserve(leb_u64_len(count_u64) + count.saturating_mul(5));
put_leb_u64(dst, count_u64);
for (index, buffer) in buffers.iter().enumerate() {
if buffer.access() != BufferAccess::ReadWrite {
continue;
}
let index = u32::try_from(index).map_err(|err| {
format!("Fix: output-set index cannot fit u32 ({err}); split the Program.")
})?;
put_leb_u32(dst, index);
}
Ok(())
}
#[must_use]
pub fn from_vec(indices: Vec<u32>) -> Self {
Self(indices)
}
pub(crate) fn decode_from(
reader: &mut crate::serial::wire::Reader<'_>,
metadata: &crate::serial::wire::decode::from_wire::DecodedMetadata,
) -> Result<Self, String> {
use crate::serial::wire::decode::from_wire::LebReader;
let count = reader.leb_len(crate::serial::wire::MAX_BUFFERS, "output set count")?;
let mut indices = Vec::with_capacity(count);
for _ in 0..count {
let index = reader.leb_u32()?;
let usize_index = usize::try_from(index).map_err(|err| {
format!(
"InvalidDiscriminant: output-set index {index} cannot fit usize ({err}). Fix: reserialize with Program::to_wire()."
)
})?;
let Some(buffer) = metadata.buffers.get(usize_index) else {
return Err(format!(
"InvalidDiscriminant: output-set index {index} is out of bounds for {} buffers. Fix: reject tampered Program bytes.",
metadata.buffers.len()
));
};
if buffer.access != BufferAccess::ReadWrite {
return Err(format!(
"InvalidDiscriminant: output-set index {index} names non-writable buffer `{}`. Fix: reserialize with Program::to_wire().",
buffer.name
));
}
if indices.last().is_some_and(|previous| *previous >= index) {
return Err(format!(
"InvalidDiscriminant: output-set index {index} is not in strict declaration order. Fix: reserialize with Program::to_wire()."
));
}
indices.push(index);
}
if !matches_canonical_output_indices(&indices, metadata)? {
let expected = canonical_output_indices(metadata)?;
return Err(format!(
"InvalidDiscriminant: output-set {indices:?} does not match writable buffer indices {expected:?}. Fix: reject tampered Program bytes."
));
}
Ok(Self(indices))
}
#[must_use]
pub fn indices(&self) -> &[u32] {
&self.0
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.0.len()
}
#[must_use]
pub fn into_vec(self) -> Vec<u32> {
self.0
}
}
fn matches_canonical_output_indices(
indices: &[u32],
metadata: &crate::serial::wire::decode::from_wire::DecodedMetadata,
) -> Result<bool, String> {
let mut actual = indices.iter().copied();
for (index, buffer) in metadata.buffers.iter().enumerate() {
if buffer.access != BufferAccess::ReadWrite {
continue;
}
let expected = u32::try_from(index).map_err(|err| {
format!(
"InvalidDiscriminant: canonical output-set index cannot fit u32 ({err}). Fix: split the Program."
)
})?;
if actual.next() != Some(expected) {
return Ok(false);
}
}
Ok(actual.next().is_none())
}
fn canonical_output_indices(
metadata: &crate::serial::wire::decode::from_wire::DecodedMetadata,
) -> Result<Vec<u32>, String> {
let mut expected = Vec::with_capacity(metadata.buffers.len());
for (index, buffer) in metadata.buffers.iter().enumerate() {
if buffer.access != BufferAccess::ReadWrite {
continue;
}
expected.push(u32::try_from(index).map_err(|err| {
format!(
"InvalidDiscriminant: canonical output-set index cannot fit u32 ({err}). Fix: split the Program."
)
})?);
}
Ok(expected)
}
impl AsRef<[u32]> for OutputSet {
fn as_ref(&self) -> &[u32] {
&self.0
}
}
fn put_leb_u32(out: &mut Vec<u8>, value: u32) {
put_leb_u64(out, u64::from(value));
}
fn leb_u64_len(mut value: u64) -> usize {
let mut len = 1;
while value >= 0x80 {
value >>= 7;
len += 1;
}
len
}
fn put_leb_u64(out: &mut Vec<u8>, mut value: u64) {
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
out.push(byte);
if value == 0 {
break;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType};
#[test]
fn from_buffers_picks_read_write_only() {
let buffers = vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read_write("b", 1, DataType::U32),
BufferDecl::read("c", 2, DataType::U32),
BufferDecl::read_write("d", 3, DataType::U32),
];
let set = OutputSet::from_buffers(&buffers);
assert_eq!(set.indices(), &[1, 3]);
}
#[test]
fn roundtrip_encode_decode() {
let buffers = vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read_write("b", 1, DataType::U32),
BufferDecl::read_write("c", 2, DataType::U32),
];
let original = OutputSet::from_buffers(&buffers);
let mut encoded = Vec::new();
original.encode_into(&mut encoded).unwrap();
let metadata = crate::serial::wire::decode::from_wire::DecodedMetadata {
buffers: buffers
.iter()
.map(|b| crate::serial::wire::decode::from_wire::DecodedBuffer {
name: b.name().to_string(),
binding: b.binding(),
access: b.access(),
kind: b.kind(),
element: b.element(),
count: b.count(),
is_output: b.is_output(),
pipeline_live_out: b.is_pipeline_live_out(),
output_byte_range: b.output_byte_range(),
hints: b.hints(),
bytes_extraction: false,
})
.collect(),
..Default::default()
};
let mut reader = crate::serial::wire::Reader {
bytes: &encoded,
pos: 0,
depth: 0,
};
let decoded = OutputSet::decode_from(&mut reader, &metadata).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn direct_buffer_encode_matches_materialized_output_set() {
let buffers = vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read_write("b", 1, DataType::U32),
BufferDecl::read("c", 2, DataType::U32),
BufferDecl::read_write("d", 3, DataType::U32),
];
let mut materialized = Vec::new();
OutputSet::from_buffers(&buffers)
.encode_into(&mut materialized)
.unwrap();
let mut direct = Vec::new();
OutputSet::encode_from_buffers_into(&buffers, &mut direct).unwrap();
assert_eq!(direct, materialized);
}
}