use smallvec::SmallVec;
use std::sync::Arc;
use vyre_foundation::ir::{BufferAccess, BufferDecl, MemoryKind, Program};
use crate::BackendError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BindingRole {
Input,
Output,
InputOutput,
Uniform,
Shared,
Persistent,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Binding {
pub name: Arc<str>,
pub binding: u32,
pub buffer_index: usize,
pub role: BindingRole,
pub element_size: usize,
pub preferred_alignment: usize,
pub element_count: u32,
pub static_byte_len: Option<usize>,
pub input_index: Option<usize>,
pub output_index: Option<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BindingPlan {
pub bindings: Vec<Binding>,
pub input_indices: Vec<usize>,
pub output_indices: Vec<usize>,
pub shared_indices: Vec<usize>,
}
#[derive(Clone, Copy)]
enum InputLengths<'a> {
None,
Owned(&'a [Vec<u8>]),
Borrowed(&'a [&'a [u8]]),
}
impl InputLengths<'_> {
fn len(self) -> usize {
match self {
Self::None => 0,
Self::Owned(inputs) => inputs.len(),
Self::Borrowed(inputs) => inputs.len(),
}
}
fn get(self, index: usize) -> Option<usize> {
match self {
Self::None => None,
Self::Owned(inputs) => inputs.get(index).map(Vec::len),
Self::Borrowed(inputs) => inputs.get(index).map(|input| input.len()),
}
}
}
impl BindingPlan {
pub fn build(program: &Program) -> Result<Self, BackendError> {
Self::build_inner(program, InputLengths::None, false)
}
pub fn from_program(program: &Program, inputs: &[Vec<u8>]) -> Result<Self, BackendError> {
Self::build_inner(program, InputLengths::Owned(inputs), true)
}
pub fn from_borrowed_inputs(program: &Program, inputs: &[&[u8]]) -> Result<Self, BackendError> {
Self::build_inner(program, InputLengths::Borrowed(inputs), true)
}
pub fn validate_inputs(&self, inputs: &[Vec<u8>]) -> Result<(), BackendError> {
self.validate_input_lengths(InputLengths::Owned(inputs))
}
pub fn validate_borrowed_inputs(&self, inputs: &[&[u8]]) -> Result<(), BackendError> {
self.validate_input_lengths(InputLengths::Borrowed(inputs))
}
fn validate_input_lengths(&self, input_lens: InputLengths<'_>) -> Result<(), BackendError> {
if input_lens.len() != self.input_indices.len() {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: dispatch expected {} input buffer(s) from Program declarations but received {}.",
self.input_indices.len(),
input_lens.len()
),
});
}
for binding in &self.bindings {
if let Some(input_index) = binding.input_index {
let byte_len = input_lens.get(input_index).ok_or_else(|| {
BackendError::InvalidProgram {
fix: format!(
"Fix: dispatch input index {input_index} for `{}` was missing after input-count validation.",
binding.name
),
}
})?;
validate_input_len(binding, byte_len)?;
}
}
Ok(())
}
fn build_inner(
program: &Program,
input_lens: InputLengths<'_>,
validate_inputs_now: bool,
) -> Result<Self, BackendError> {
let mut ordered: SmallVec<[(usize, &BufferDecl); 16]> =
program.buffers().iter().enumerate().collect();
ordered.sort_by_key(|(_, buffer)| buffer.binding());
let mut bindings = Vec::with_capacity(ordered.len());
let mut input_indices = SmallVec::<[usize; 8]>::new();
let mut output_indices = SmallVec::<[usize; 8]>::new();
let mut shared_indices = SmallVec::<[usize; 4]>::new();
for (buffer_index, buffer) in ordered {
let role = role_for_buffer(buffer)?;
let consumes_input = matches!(
role,
BindingRole::Input | BindingRole::InputOutput | BindingRole::Uniform
);
let produces_output = matches!(role, BindingRole::Output | BindingRole::InputOutput);
let element_size = buffer.element().min_bytes();
let static_byte_len = static_byte_len(buffer, element_size)?;
let preferred_alignment = preferred_alignment(buffer, element_size)?;
let input_index = if consumes_input {
let index = input_indices.len();
input_indices.push(buffer_index);
Some(index)
} else {
None
};
let output_index = if produces_output || buffer.pipeline_live_out {
let index = output_indices.len();
output_indices.push(buffer_index);
Some(index)
} else {
None
};
if role == BindingRole::Shared {
shared_indices.push(buffer_index);
}
let element_count = if buffer.count() == 0 {
input_index
.and_then(|index| input_lens.get(index))
.and_then(|byte_len| {
if element_size == 0 {
None
} else {
u32::try_from(byte_len / element_size).ok()
}
})
.unwrap_or(0)
} else {
buffer.count()
};
bindings.push(Binding {
name: Arc::clone(&buffer.name),
binding: buffer.binding(),
buffer_index,
role,
element_size,
preferred_alignment,
element_count,
static_byte_len,
input_index,
output_index,
});
}
let plan = Self {
bindings,
input_indices: input_indices.into_vec(),
output_indices: output_indices.into_vec(),
shared_indices: shared_indices.into_vec(),
};
if validate_inputs_now {
plan.validate_input_lengths(input_lens)?;
}
Ok(plan)
}
}
fn role_for_buffer(buffer: &BufferDecl) -> Result<BindingRole, BackendError> {
if buffer.kind() == MemoryKind::Shared || buffer.access() == BufferAccess::Workgroup {
return Ok(BindingRole::Shared);
}
if buffer.kind() == MemoryKind::Persistent {
return Ok(BindingRole::Persistent);
}
if buffer.is_output {
return Ok(BindingRole::Output);
}
match buffer.access() {
BufferAccess::ReadOnly => Ok(BindingRole::Input),
BufferAccess::ReadWrite => Ok(BindingRole::InputOutput),
BufferAccess::WriteOnly => Ok(BindingRole::Output),
BufferAccess::Uniform => Ok(BindingRole::Uniform),
BufferAccess::Workgroup => Ok(BindingRole::Shared),
_ => Err(BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` uses an unknown BufferAccess variant; update vyre-driver binding role mapping.",
buffer.name()
),
}),
}
}
fn preferred_alignment(buffer: &BufferDecl, element_size: usize) -> Result<usize, BackendError> {
let hinted = usize::try_from(buffer.hints().preferred_alignment).map_err(|_| {
BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` preferred_alignment does not fit usize on this target.",
buffer.name()
),
}
})?;
if hinted != 0 && !hinted.is_power_of_two() {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` preferred_alignment={} is not a power of two. Use 0 or a power-of-two byte alignment.",
buffer.name(),
hinted
),
});
}
Ok(hinted.max(element_size.max(1)))
}
fn static_byte_len(
buffer: &BufferDecl,
element_size: usize,
) -> Result<Option<usize>, BackendError> {
if buffer.count() == 0 {
return Ok(None);
}
if element_size == 0 {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` declares {} elements of a runtime-sized data type; use a byte-addressed buffer contract or a fixed-width element type.",
buffer.name(),
buffer.count()
),
});
}
usize::try_from(buffer.count())
.ok()
.and_then(|count| count.checked_mul(element_size))
.map(Some)
.ok_or_else(|| BackendError::InvalidProgram {
fix: format!(
"Fix: binding `{}` static byte length overflowed usize; split the buffer or reduce element count.",
buffer.name()
),
})
}
fn validate_input_len(binding: &Binding, input_len: usize) -> Result<(), BackendError> {
if binding.element_size > 1 && input_len % binding.element_size != 0 {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: input `{}` has {} bytes, which is not aligned to its {}-byte element size.",
binding.name, input_len, binding.element_size
),
});
}
if let Some(expected) = binding.static_byte_len {
if input_len != expected {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: input `{}` expected {expected} bytes from its static buffer declaration but received {} bytes.",
binding.name,
input_len
),
});
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BindingSetFingerprint {
pub slots: Vec<(u32, BindingRole, usize)>,
}
impl BindingSetFingerprint {
#[must_use]
pub fn from_plan(plan: &BindingPlan) -> Self {
let mut slots: Vec<(u32, BindingRole, usize)> = plan
.bindings
.iter()
.map(|b| (b.binding, b.role, b.element_size))
.collect();
slots.sort_by_key(|(idx, _, _)| *idx);
Self { slots }
}
}
#[must_use]
pub fn binding_plans_share_layout(a: &BindingPlan, b: &BindingPlan) -> bool {
BindingSetFingerprint::from_plan(a) == BindingSetFingerprint::from_plan(b)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BackendLayoutSlot {
pub group: u32,
pub binding: u32,
pub class: BackendLayoutClass,
pub read_only: bool,
pub element_size: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BackendLayoutClass {
Storage,
Uniform,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BackendLayoutFingerprint {
pub slots: Vec<BackendLayoutSlot>,
}
impl BackendLayoutFingerprint {
#[must_use]
pub fn new(mut slots: Vec<BackendLayoutSlot>) -> Self {
slots.sort_by_key(|slot| (slot.group, slot.binding));
Self { slots }
}
}
#[cfg(test)]
mod n7_tests {
use super::*;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Program};
fn add_one_program() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(16),
BufferDecl::output("out", 1, DataType::U32).with_count(16),
],
[16, 1, 1],
vec![],
)
}
fn add_one_program_different_input_count() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(64),
BufferDecl::output("out", 1, DataType::U32).with_count(64),
],
[16, 1, 1],
vec![],
)
}
fn different_layout_program() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("a", 0, BufferAccess::ReadOnly, DataType::U32).with_count(16),
BufferDecl::storage("b", 1, BufferAccess::ReadOnly, DataType::U32).with_count(16),
BufferDecl::output("out", 2, DataType::U32).with_count(16),
],
[16, 1, 1],
vec![],
)
}
#[test]
fn same_layout_with_different_element_counts_shares_fingerprint() {
let a = BindingPlan::build(&add_one_program()).unwrap();
let b = BindingPlan::build(&add_one_program_different_input_count()).unwrap();
assert!(
binding_plans_share_layout(&a, &b),
"plans with same (binding, role, element_size) tuples must share layout"
);
}
#[test]
fn different_binding_count_does_not_share_layout() {
let a = BindingPlan::build(&add_one_program()).unwrap();
let b = BindingPlan::build(&different_layout_program()).unwrap();
assert!(
!binding_plans_share_layout(&a, &b),
"plans with different binding count must not share layout"
);
}
#[test]
fn fingerprint_is_stable_across_repeated_builds() {
let a = BindingPlan::build(&add_one_program()).unwrap();
let b = BindingPlan::build(&add_one_program()).unwrap();
assert_eq!(
BindingSetFingerprint::from_plan(&a),
BindingSetFingerprint::from_plan(&b),
"repeated build of the same Program must produce identical fingerprints"
);
}
#[test]
fn fingerprint_slots_are_sorted_by_binding_index() {
let plan = BindingPlan::build(&add_one_program()).unwrap();
let fp = BindingSetFingerprint::from_plan(&plan);
let indices: Vec<u32> = fp.slots.iter().map(|(i, _, _)| *i).collect();
assert_eq!(indices, [0, 1], "slots must be sorted by binding index");
}
#[test]
fn backend_layout_fingerprint_sorts_slots() {
let a = BackendLayoutFingerprint::new(vec![
BackendLayoutSlot {
group: 1,
binding: 4,
class: BackendLayoutClass::Storage,
read_only: false,
element_size: 4,
},
BackendLayoutSlot {
group: 0,
binding: 1,
class: BackendLayoutClass::Uniform,
read_only: true,
element_size: 4,
},
]);
let b = BackendLayoutFingerprint::new(vec![
BackendLayoutSlot {
group: 0,
binding: 1,
class: BackendLayoutClass::Uniform,
read_only: true,
element_size: 4,
},
BackendLayoutSlot {
group: 1,
binding: 4,
class: BackendLayoutClass::Storage,
read_only: false,
element_size: 4,
},
]);
assert_eq!(a, b);
}
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::{CacheLocality, DataType, MemoryHints};
#[test]
fn binding_plan_carries_alignment_hints() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32)
.with_count(16)
.with_hints(MemoryHints {
coalesce_axis: Some(0),
preferred_alignment: 64,
cache_locality: CacheLocality::Streaming,
})],
[64, 1, 1],
vec![],
);
let plan = BindingPlan::build(&program).expect("alignment hint should build");
assert_eq!(plan.bindings[0].preferred_alignment, 64);
}
#[test]
fn binding_plan_rejects_non_power_of_two_alignment_hint() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32)
.with_count(16)
.with_hints(MemoryHints {
coalesce_axis: None,
preferred_alignment: 48,
cache_locality: CacheLocality::Temporal,
})],
[64, 1, 1],
vec![],
);
let err = BindingPlan::build(&program).expect_err("bad alignment must fail");
assert!(format!("{err}").contains("preferred_alignment=48"));
}
#[test]
fn binding_plan_alignment_defaults_to_element_size() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(16)],
[64, 1, 1],
vec![],
);
let plan = BindingPlan::build(&program).expect("default alignment should build");
assert_eq!(plan.bindings[0].preferred_alignment, 4);
}
}