use smallvec::SmallVec;
use std::sync::Arc;
use vyre_foundation::ir::{BufferAccess, BufferDecl, MemoryKind, Program};
use crate::BackendError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BindingRole {
Input,
Output,
InputOutput,
Uniform,
Shared,
Persistent,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Binding {
pub name: Arc<str>,
pub binding: u32,
pub buffer_index: usize,
pub role: BindingRole,
pub element_size: usize,
pub preferred_alignment: usize,
pub element_count: u32,
pub static_byte_len: Option<usize>,
pub input_index: Option<usize>,
pub output_index: Option<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BindingPlan {
pub bindings: Vec<Binding>,
pub input_indices: Vec<usize>,
pub output_indices: Vec<usize>,
pub shared_indices: Vec<usize>,
}
#[derive(Clone, Copy)]
enum InputLengths<'a> {
None,
Owned(&'a [Vec<u8>]),
Borrowed(&'a [&'a [u8]]),
Lengths(&'a [usize]),
}
impl InputLengths<'_> {
fn len(self) -> usize {
match self {
Self::None => 0,
Self::Owned(inputs) => inputs.len(),
Self::Borrowed(inputs) => inputs.len(),
Self::Lengths(lengths) => lengths.len(),
}
}
fn get(self, index: usize) -> Option<usize> {
match self {
Self::None => None,
Self::Owned(inputs) => inputs.get(index).map(Vec::len),
Self::Borrowed(inputs) => inputs.get(index).map(|input| input.len()),
Self::Lengths(lengths) => lengths.get(index).copied(),
}
}
}
impl BindingPlan {
pub fn build(program: &Program) -> Result<Self, BackendError> {
Self::build_inner(program, InputLengths::None, false)
}
pub fn from_program(program: &Program, inputs: &[Vec<u8>]) -> Result<Self, BackendError> {
Self::build_inner(program, InputLengths::Owned(inputs), true)
}
pub fn from_borrowed_inputs(program: &Program, inputs: &[&[u8]]) -> Result<Self, BackendError> {
Self::build_inner(program, InputLengths::Borrowed(inputs), true)
}
pub fn from_input_lengths(
program: &Program,
input_lengths: &[usize],
) -> Result<Self, BackendError> {
Self::build_inner(program, InputLengths::Lengths(input_lengths), true)
}
pub fn validate_input_byte_lengths(&self, input_lengths: &[usize]) -> Result<(), BackendError> {
self.validate_input_lengths(InputLengths::Lengths(input_lengths))
}
pub fn validate_inputs(&self, inputs: &[Vec<u8>]) -> Result<(), BackendError> {
self.validate_input_lengths(InputLengths::Owned(inputs))
}
pub fn validate_borrowed_inputs(&self, inputs: &[&[u8]]) -> Result<(), BackendError> {
self.validate_input_lengths(InputLengths::Borrowed(inputs))
}
fn validate_input_lengths(&self, input_lens: InputLengths<'_>) -> Result<(), BackendError> {
if input_lens.len() != self.input_indices.len() {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: dispatch expected {} input buffer(s) from Program declarations but received {}.",
self.input_indices.len(),
input_lens.len()
),
});
}
for binding in &self.bindings {
if let Some(input_index) = binding.input_index {
let byte_len = input_lens.get(input_index).ok_or_else(|| {
BackendError::InvalidProgram {
fix: format!(
"Fix: dispatch input index {input_index} for `{}` was missing after input-count validation.",
binding.name
),
}
})?;
validate_input_len(
binding,
byte_len,
matches!(input_lens, InputLengths::Lengths(_)),
)?;
}
}
Ok(())
}
fn build_inner(
program: &Program,
input_lens: InputLengths<'_>,
validate_inputs_now: bool,
) -> Result<Self, BackendError> {
let mut ordered = SmallVec::<[(usize, &BufferDecl); 16]>::new();
vyre_foundation::allocation::try_reserve_smallvec_to_capacity(
&mut ordered,
program.buffers().len(),
)
.map_err(|error| {
BackendError::InvalidProgram {
fix: format!(
"Fix: binding-plan construction could not reserve {} ordered buffer slot(s): {error}. Split the program buffers or construct a smaller pipeline.",
program.buffers().len()
),
}
})?;
ordered.extend(program.buffers().iter().enumerate());
ordered.sort_by_key(|(_, buffer)| buffer.binding());
let mut bindings = Vec::new();
crate::allocation::try_reserve_vec_to_capacity(&mut bindings, ordered.len()).map_err(
|error| BackendError::InvalidProgram {
fix: format!(
"Fix: binding-plan construction could not reserve {} binding descriptor(s): {error}. Split the program buffers or construct a smaller pipeline.",
ordered.len()
),
},
)?;
let (input_slot_count, output_slot_count, shared_slot_count) =
binding_role_counts(&ordered)?;
let mut input_indices = SmallVec::<[usize; 8]>::new();
let mut output_indices = SmallVec::<[usize; 8]>::new();
let mut shared_indices = SmallVec::<[usize; 4]>::new();
vyre_foundation::allocation::try_reserve_smallvec_to_capacity(
&mut input_indices,
input_slot_count,
)
.map_err(|error| {
BackendError::InvalidProgram {
fix: format!(
"Fix: binding-plan construction could not reserve {input_slot_count} input index slot(s): {error}. Split the program buffers or construct a smaller pipeline."
),
}
})?;
vyre_foundation::allocation::try_reserve_smallvec_to_capacity(
&mut output_indices,
output_slot_count,
)
.map_err(|error| {
BackendError::InvalidProgram {
fix: format!(
"Fix: binding-plan construction could not reserve {output_slot_count} output index slot(s): {error}. Split the program buffers or construct a smaller pipeline."
),
}
})?;
vyre_foundation::allocation::try_reserve_smallvec_to_capacity(
&mut shared_indices,
shared_slot_count,
)
.map_err(|error| {
BackendError::InvalidProgram {
fix: format!(
"Fix: binding-plan construction could not reserve {shared_slot_count} shared index slot(s): {error}. Split the program buffers or construct a smaller pipeline."
),
}
})?;
for (buffer_index, buffer) in ordered {
let role = role_for_buffer(buffer)?;
let consumes_input = matches!(
role,
BindingRole::Input | BindingRole::InputOutput | BindingRole::Uniform
);
let produces_output = matches!(role, BindingRole::Output | BindingRole::InputOutput);
buffer
.element()
.validate_layout()
.map_err(|error| BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` has malformed data-type layout metadata: {error}",
buffer.name()
),
})?;
let element_size = buffer.element().min_bytes();
let static_byte_len = static_byte_len(buffer)?;
let preferred_alignment = preferred_alignment(buffer, element_size)?;
let input_index = if consumes_input {
let index = input_indices.len();
input_indices.push(buffer_index);
Some(index)
} else {
None
};
let output_index = if produces_output || buffer.pipeline_live_out {
let index = output_indices.len();
output_indices.push(buffer_index);
Some(index)
} else {
None
};
if role == BindingRole::Shared {
shared_indices.push(buffer_index);
}
let element_count = if buffer.count() == 0 {
input_index
.and_then(|index| input_lens.get(index))
.and_then(|byte_len| dynamic_element_count_from_bytes(buffer, byte_len))
.unwrap_or(0)
} else {
buffer.count()
};
bindings.push(Binding {
name: Arc::clone(&buffer.name),
binding: buffer.binding(),
buffer_index,
role,
element_size,
preferred_alignment,
element_count,
static_byte_len,
input_index,
output_index,
});
}
let plan = Self {
bindings,
input_indices: input_indices.into_vec(),
output_indices: output_indices.into_vec(),
shared_indices: shared_indices.into_vec(),
};
if validate_inputs_now {
plan.validate_input_lengths(input_lens)?;
}
Ok(plan)
}
}
fn binding_role_counts(
ordered: &SmallVec<[(usize, &BufferDecl); 16]>,
) -> Result<(usize, usize, usize), BackendError> {
ordered
.iter()
.try_fold((0usize, 0usize, 0usize), |(inputs, outputs, shared), (_, buffer)| {
let role = role_for_buffer(buffer)?;
let next_inputs = inputs
.checked_add(usize::from(matches!(
role,
BindingRole::Input | BindingRole::InputOutput | BindingRole::Uniform
)))
.ok_or_else(|| BackendError::InvalidProgram {
fix: "Fix: binding-plan input role count overflowed usize. Split the program buffers before binding-plan construction.".to_string(),
})?;
let next_outputs = outputs
.checked_add(usize::from(
matches!(role, BindingRole::Output | BindingRole::InputOutput)
|| buffer.pipeline_live_out,
))
.ok_or_else(|| BackendError::InvalidProgram {
fix: "Fix: binding-plan output role count overflowed usize. Split the program buffers before binding-plan construction.".to_string(),
})?;
let next_shared = shared
.checked_add(usize::from(role == BindingRole::Shared))
.ok_or_else(|| BackendError::InvalidProgram {
fix: "Fix: binding-plan shared role count overflowed usize. Split the program buffers before binding-plan construction.".to_string(),
})?;
Ok((next_inputs, next_outputs, next_shared))
})
}
fn role_for_buffer(buffer: &BufferDecl) -> Result<BindingRole, BackendError> {
if buffer.kind() == MemoryKind::Shared || buffer.access() == BufferAccess::Workgroup {
return Ok(BindingRole::Shared);
}
if buffer.kind() == MemoryKind::Persistent {
return Ok(BindingRole::Persistent);
}
if buffer.is_output || buffer.pipeline_live_out {
return Ok(BindingRole::Output);
}
match buffer.access() {
BufferAccess::ReadOnly => Ok(BindingRole::Input),
BufferAccess::ReadWrite => Ok(BindingRole::InputOutput),
BufferAccess::WriteOnly => Ok(BindingRole::Output),
BufferAccess::Uniform => Ok(BindingRole::Uniform),
BufferAccess::Workgroup => Ok(BindingRole::Shared),
_ => Err(BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` uses an unknown BufferAccess variant; update vyre-driver binding role mapping.",
buffer.name()
),
}),
}
}
fn preferred_alignment(buffer: &BufferDecl, element_size: usize) -> Result<usize, BackendError> {
let hinted = usize::try_from(buffer.hints().preferred_alignment).map_err(|_| {
BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` preferred_alignment does not fit usize on this target.",
buffer.name()
),
}
})?;
if hinted != 0 && !hinted.is_power_of_two() {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` preferred_alignment={} is not a power of two. Use 0 or a power-of-two byte alignment.",
buffer.name(),
hinted
),
});
}
Ok(hinted.max(element_size.max(1)))
}
fn static_byte_len(buffer: &BufferDecl) -> Result<Option<usize>, BackendError> {
if buffer.count() == 0 {
return Ok(None);
}
let count = usize::try_from(buffer.count()).map_err(|_| BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` element count does not fit usize; split the buffer or reduce element count.",
buffer.name()
),
})?;
buffer
.element()
.packed_size_bytes(count)
.map_err(|error| BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` static byte length could not be computed: {error}",
buffer.name(),
),
})?
.map(Some)
.ok_or_else(|| BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` declares {} elements of a runtime-sized data type; use a byte-addressed buffer contract or a fixed-width element type.",
buffer.name(),
buffer.count()
),
})
}
fn dynamic_element_count_from_bytes(buffer: &BufferDecl, byte_len: usize) -> Option<u32> {
if let Some(bits) = buffer.element().bit_width() {
let total_bits = byte_len.checked_mul(8)?;
return u32::try_from(total_bits / bits).ok();
}
buffer
.element()
.size_bytes()
.and_then(|element_size| byte_len.checked_div(element_size))
.and_then(|count| u32::try_from(count).ok())
}
fn validate_input_len(
binding: &Binding,
input_len: usize,
_strict_static_input_len: bool,
) -> Result<(), BackendError> {
if binding.element_size > 1 && input_len % binding.element_size != 0 {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: input `{}` has {} bytes, which is not aligned to its {}-byte element size.",
binding.name, input_len, binding.element_size
),
});
}
if let Some(expected) = binding.static_byte_len {
if input_len != expected {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: input `{}` expected {expected} bytes from its static buffer declaration but received {} bytes.",
binding.name,
input_len
),
});
}
}
Ok(())
}
#[cfg(test)]
mod exact_length_tests {
use super::*;
use vyre_foundation::ir::DataType;
fn static_u32_input_program(count: u32) -> Program {
Program::wrapped(
vec![BufferDecl::read("input", 0, DataType::U32).with_count(count)],
[1, 1, 1],
Vec::new(),
)
}
#[test]
fn static_input_lengths_are_exact_for_owned_borrowed_and_resident_inputs() {
let program = static_u32_input_program(2);
let short = vec![0u8; 4];
let exact = vec![0u8; 8];
let owned_err = BindingPlan::from_program(&program, &[short.clone()])
.expect_err("owned static input length must be exact");
assert!(owned_err.to_string().contains("expected 8 bytes"));
assert!(BindingPlan::from_program(&program, &[exact]).is_ok());
let borrowed_short = [short.as_slice()];
let borrowed_err = BindingPlan::from_borrowed_inputs(&program, &borrowed_short)
.expect_err("borrowed static input length must be exact");
assert!(borrowed_err.to_string().contains("expected 8 bytes"));
let resident_err = BindingPlan::from_input_lengths(&program, &[4])
.expect_err("resident static input length must be exact");
assert!(resident_err.to_string().contains("expected 8 bytes"));
}
#[test]
fn dynamic_input_length_sets_runtime_element_count() {
let program = static_u32_input_program(0);
let plan = BindingPlan::from_program(&program, &[vec![0u8; 12]])
.expect("Fix: reject bindings without known element width; do not dispatch un-sized dynamic inputs - dynamic input byte length should define element count");
assert_eq!(plan.bindings[0].element_count, 3);
assert_eq!(plan.bindings[0].static_byte_len, None);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BindingSetFingerprint {
pub slots: Vec<(u32, BindingRole, usize)>,
}
impl BindingSetFingerprint {
#[must_use]
pub fn from_plan(plan: &BindingPlan) -> Self {
let mut slots: Vec<(u32, BindingRole, usize)> = plan
.bindings
.iter()
.map(|b| (b.binding, b.role, b.element_size))
.collect();
slots.sort_by_key(|(idx, _, _)| *idx);
Self { slots }
}
}
#[must_use]
pub fn binding_plans_share_layout(a: &BindingPlan, b: &BindingPlan) -> bool {
BindingSetFingerprint::from_plan(a) == BindingSetFingerprint::from_plan(b)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BackendLayoutSlot {
pub group: u32,
pub binding: u32,
pub class: BackendLayoutClass,
pub read_only: bool,
pub element_size: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BackendLayoutClass {
Storage,
Uniform,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BackendLayoutFingerprint {
pub slots: Vec<BackendLayoutSlot>,
}
impl BackendLayoutFingerprint {
#[must_use]
pub fn new(mut slots: Vec<BackendLayoutSlot>) -> Self {
slots.sort_by_key(|slot| (slot.group, slot.binding));
Self { slots }
}
}
#[cfg(test)]
mod n7_tests {
use super::*;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Program};
fn add_one_program() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(16),
BufferDecl::output("out", 1, DataType::U32).with_count(16),
],
[16, 1, 1],
vec![],
)
}
fn add_one_program_different_input_count() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(64),
BufferDecl::output("out", 1, DataType::U32).with_count(64),
],
[16, 1, 1],
vec![],
)
}
fn different_layout_program() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("a", 0, BufferAccess::ReadOnly, DataType::U32).with_count(16),
BufferDecl::storage("b", 1, BufferAccess::ReadOnly, DataType::U32).with_count(16),
BufferDecl::output("out", 2, DataType::U32).with_count(16),
],
[16, 1, 1],
vec![],
)
}
#[test]
fn same_layout_with_different_element_counts_shares_fingerprint() {
let a = BindingPlan::build(&add_one_program()).unwrap();
let b = BindingPlan::build(&add_one_program_different_input_count()).unwrap();
assert!(
binding_plans_share_layout(&a, &b),
"plans with same (binding, role, element_size) tuples must share layout"
);
}
#[test]
fn different_binding_count_does_not_share_layout() {
let a = BindingPlan::build(&add_one_program()).unwrap();
let b = BindingPlan::build(&different_layout_program()).unwrap();
assert!(
!binding_plans_share_layout(&a, &b),
"plans with different binding count must not share layout"
);
}
#[test]
fn fingerprint_is_stable_across_repeated_builds() {
let a = BindingPlan::build(&add_one_program()).unwrap();
let b = BindingPlan::build(&add_one_program()).unwrap();
assert_eq!(
BindingSetFingerprint::from_plan(&a),
BindingSetFingerprint::from_plan(&b),
"repeated build of the same Program must produce identical fingerprints"
);
}
#[test]
fn fingerprint_slots_are_sorted_by_binding_index() {
let plan = BindingPlan::build(&add_one_program()).unwrap();
let fp = BindingSetFingerprint::from_plan(&plan);
let indices: Vec<u32> = fp.slots.iter().map(|(i, _, _)| *i).collect();
assert_eq!(indices, [0, 1], "slots must be sorted by binding index");
}
#[test]
fn backend_layout_fingerprint_sorts_slots() {
let a = BackendLayoutFingerprint::new(vec![
BackendLayoutSlot {
group: 1,
binding: 4,
class: BackendLayoutClass::Storage,
read_only: false,
element_size: 4,
},
BackendLayoutSlot {
group: 0,
binding: 1,
class: BackendLayoutClass::Uniform,
read_only: true,
element_size: 4,
},
]);
let b = BackendLayoutFingerprint::new(vec![
BackendLayoutSlot {
group: 0,
binding: 1,
class: BackendLayoutClass::Uniform,
read_only: true,
element_size: 4,
},
BackendLayoutSlot {
group: 1,
binding: 4,
class: BackendLayoutClass::Storage,
read_only: false,
element_size: 4,
},
]);
assert_eq!(a, b);
}
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::{CacheLocality, DataType, MemoryHints};
#[test]
fn binding_plan_carries_alignment_hints() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32)
.with_count(16)
.with_hints(MemoryHints {
coalesce_axis: Some(0),
preferred_alignment: 64,
cache_locality: CacheLocality::Streaming,
})],
[64, 1, 1],
vec![],
);
let plan = BindingPlan::build(&program).expect("Fix: alignment hint should build");
assert_eq!(plan.bindings[0].preferred_alignment, 64);
}
#[test]
fn binding_plan_rejects_non_power_of_two_alignment_hint() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32)
.with_count(16)
.with_hints(MemoryHints {
coalesce_axis: None,
preferred_alignment: 48,
cache_locality: CacheLocality::Temporal,
})],
[64, 1, 1],
vec![],
);
let err = BindingPlan::build(&program).expect_err("bad alignment must fail");
assert!(format!("{err}").contains("preferred_alignment=48"));
}
#[test]
fn binding_plan_alignment_defaults_to_element_size() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(16)],
[64, 1, 1],
vec![],
);
let plan = BindingPlan::build(&program).expect("Fix: default alignment should build");
assert_eq!(plan.bindings[0].preferred_alignment, 4);
}
#[test]
fn binding_plan_uses_packed_static_byte_len_for_subbyte_elements() {
let program = Program::wrapped(
vec![
BufferDecl::storage("packed_i4", 0, BufferAccess::ReadOnly, DataType::I4)
.with_count(3),
],
[1, 1, 1],
vec![],
);
let plan =
BindingPlan::build(&program).expect("Fix: packed I4 binding layout should build");
assert_eq!(plan.bindings[0].element_size, 1);
assert_eq!(plan.bindings[0].static_byte_len, Some(2));
}
#[test]
fn binding_plan_validates_packed_static_input_lengths() {
let program = Program::wrapped(
vec![
BufferDecl::storage("packed_i4", 0, BufferAccess::ReadOnly, DataType::I4)
.with_count(3),
],
[1, 1, 1],
vec![],
);
let plan = BindingPlan::from_input_lengths(&program, &[2])
.expect("Fix: packed I4 input should accept the exact packed byte count");
plan.validate_input_byte_lengths(&[2])
.expect("Fix: cached packed I4 input length should remain valid");
let error = plan
.validate_input_byte_lengths(&[3])
.expect_err("unpacked byte length must not satisfy packed I4 contract");
assert!(
format!("{error}").contains("expected 2 bytes"),
"Fix: packed byte mismatch must be explicit: {error}"
);
}
#[test]
fn binding_plan_rejects_malformed_data_type_layouts() {
let program = Program::wrapped(
vec![BufferDecl::output(
"bad_vec",
0,
DataType::Vec {
element: Box::new(DataType::U32),
count: 0,
},
)
.with_count(1)],
[1, 1, 1],
vec![],
);
let error = BindingPlan::build(&program)
.expect_err("zero-lane vector layout must not enter binding planning");
assert!(
format!("{error}").contains("Vec count must be > 0"),
"Fix: malformed data-type layout diagnostics must survive binding planning: {error}"
);
}
#[test]
fn binding_plan_validates_cached_resident_input_lengths() {
let program = Program::wrapped(
vec![
BufferDecl::read("in", 0, DataType::U32).with_count(4),
BufferDecl::output("out", 1, DataType::U32).with_count(4),
],
[4, 1, 1],
vec![],
);
let plan = BindingPlan::from_input_lengths(&program, &[16])
.expect("Fix: resident input length should match the declared u32[4] input");
plan.validate_input_byte_lengths(&[16])
.expect("Fix: cached resident plan should accept the same input byte length");
let error = plan
.validate_input_byte_lengths(&[12])
.expect_err("cached resident plan must reject stale pipeline shape reuse");
assert!(
format!("{error}").contains("expected 16 bytes"),
"wrong resident input length must produce an actionable size mismatch: {error}"
);
}
}