use crate::spec::types::OpSpec;
use bindings::{check_binding_0, check_binding_1, check_binding_2, check_binding_3};
use parse::{extract_bindings, parse_wgsl, resolve_wgsl};
use prototype::check_vyre_op_prototype;
#[inline]
pub fn enforce_signature(op: &OpSpec) -> Result<(), String> {
let wgsl = resolve_wgsl(op)?;
let module = parse_wgsl(&wgsl, op)?;
let bindings = extract_bindings(&module)?;
check_binding_0(&module, &bindings, op)?;
check_binding_1(&module, &bindings, op)?;
check_binding_2(&module, &bindings, op)?;
check_binding_3(&module, &bindings, op)?;
check_vyre_op_prototype(&module, op)?;
Ok(())
}
mod bindings {
use crate::spec::types::{Convention, DataType, OpSpec};
use naga::{AddressSpace, StorageAccess, TypeInner};
use crate::enforce::enforcers::signature_match::buffer_element::{
buffer_element, check_array_stride, check_variable_buffer_element, variable_input_type,
};
use crate::enforce::enforcers::signature_match::parse::BindingInfo;
use crate::enforce::enforcers::signature_match::types::{
access_desc, format_inputs, type_name,
};
pub(super) fn check_binding_0(
module: &naga::Module,
bindings: &[(u32, BindingInfo)],
op: &OpSpec,
) -> Result<(), String> {
let Some(info) = bindings.iter().find(|(b, _)| *b == 0).map(|(_, i)| i) else {
return Err(
"Fix: shader is missing @group(0) @binding(0). Add a read-only storage buffer for input."
.to_string(),
);
};
match info.space {
AddressSpace::Storage { access } if access == StorageAccess::LOAD => {}
AddressSpace::Storage { access } => {
return Err(format!(
"Fix: binding 0 must be `var<storage, read>`. It is currently `var<storage, {}>`. Change the access mode to `read`.",
access_desc(access)
));
}
_ => {
return Err(
"Fix: binding 0 must be a `storage` buffer. Change it to `var<storage, read>`."
.to_string(),
);
}
}
let element = buffer_element(module, info.ty)?;
let element_size = element.byte_size();
if let Some(variable) = variable_input_type(&op.signature.inputs) {
check_variable_buffer_element(variable, &element, "binding 0")?;
check_array_stride(module, info.ty, variable, "binding 0")?;
return Ok(());
}
let min_bytes = op.signature.min_input_bytes();
if element_size > min_bytes {
return Err(format!(
"Fix: declared input signature {} (minimum {} bytes) but shader binding 0 has element size {} bytes. Either change signature.inputs to a larger type or rewrite the shader binding to use smaller elements.",
format_inputs(&op.signature.inputs),
min_bytes,
element_size
));
}
if min_bytes % element_size != 0 {
return Err(format!(
"Fix: declared input signature {} ({} bytes) but shader binding 0 element size {} does not evenly divide the input size. Either change signature.inputs or rewrite the shader to use a compatible element type.",
format_inputs(&op.signature.inputs),
min_bytes,
element_size
));
}
Ok(())
}
pub(super) fn check_binding_1(
module: &naga::Module,
bindings: &[(u32, BindingInfo)],
op: &OpSpec,
) -> Result<(), String> {
let Some(info) = bindings.iter().find(|(b, _)| *b == 1).map(|(_, i)| i) else {
return Err(
"Fix: shader is missing @group(0) @binding(1). Add a read_write storage buffer for output."
.to_string(),
);
};
match info.space {
AddressSpace::Storage { access }
if access.contains(StorageAccess::LOAD)
&& access.contains(StorageAccess::STORE) => {}
AddressSpace::Storage { access } => {
return Err(format!(
"Fix: binding 1 must be `var<storage, read_write>`. It is currently `var<storage, {}>`. Change the access mode to `read_write`.",
access_desc(access)
));
}
_ => {
return Err(
"Fix: binding 1 must be a `storage, read_write` buffer. Change it to `var<storage, read_write>`."
.to_string(),
);
}
}
let element = buffer_element(module, info.ty)?;
let element_size = element.byte_size();
match &op.signature.output {
DataType::Bytes | DataType::Array { .. } => {
check_variable_buffer_element(&op.signature.output, &element, "binding 1")?;
check_array_stride(module, info.ty, &op.signature.output, "binding 1")?;
}
output => {
let out_bytes = output.min_bytes();
if element_size != out_bytes {
let declared = format!("{}", op.signature.output);
return Err(format!(
"Fix: declared {} output ({} bytes) but shader binding 1 element size is {} bytes. Either change signature.output to match the shader or rewrite the shader binding to produce {}.",
declared,
out_bytes,
element_size,
declared
));
}
}
}
Ok(())
}
pub(super) fn check_binding_2(
module: &naga::Module,
bindings: &[(u32, BindingInfo)],
_op: &OpSpec,
) -> Result<(), String> {
let Some(info) = bindings.iter().find(|(b, _)| *b == 2).map(|(_, i)| i) else {
return Err(
"Fix: shader is missing @group(0) @binding(2). Add a uniform Params struct with input_len and output_len fields."
.to_string(),
);
};
if !matches!(info.space, AddressSpace::Uniform) {
let ty_name = type_name(module, info.ty);
return Err(format!(
"Fix: binding 2 must be a `uniform` struct (Params). It is currently `var<{:?}> {}`. Change it to `var<uniform>`.",
info.space, ty_name
));
}
let TypeInner::Struct { members, .. } = &module.types[info.ty].inner else {
let ty_name = type_name(module, info.ty);
return Err(format!(
"Fix: binding 2 must be a uniform struct with input_len/output_len fields, but it is type '{}'. Declare it as `struct Params {{ input_len: u32, output_len: u32, ... }}`.",
ty_name
));
};
let check_field = |idx: usize, expected: &str| -> Result<(), String> {
if idx >= members.len() {
return Err(format!(
"Fix: binding 2 uniform struct is missing field '{}'. Add `{}: u32` to the struct.",
expected, expected
));
}
let member = &members[idx];
let actual = member.name.as_deref().unwrap_or("");
if actual != expected {
return Err(format!(
"Fix: binding 2 uniform struct field {} is named '{}' but must be '{}'. Rename it to `{}: u32`.",
idx, actual, expected, expected
));
}
match &module.types[member.ty].inner {
TypeInner::Scalar(naga::Scalar {
kind: naga::ScalarKind::Uint,
width: 4,
}) => Ok(()),
_ => Err(format!(
"Fix: binding 2 uniform struct field '{}' must be u32. Change its type to u32.",
expected
)),
}
};
check_field(0, "input_len")?;
check_field(1, "output_len")?;
Ok(())
}
pub(super) fn check_binding_3(
_module: &naga::Module,
bindings: &[(u32, BindingInfo)],
op: &OpSpec,
) -> Result<(), String> {
let needs_binding_3 = matches!(op.convention, Convention::V2 { .. });
let has_binding_3 = bindings.iter().any(|(b, _)| *b == 3);
if !needs_binding_3 {
if has_binding_3 {
return Err(
"Fix: binding 3 is present but convention is V1. Either remove binding 3 or change convention to V2."
.to_string(),
);
}
return Ok(());
}
let Some(info) = bindings.iter().find(|(b, _)| *b == 3).map(|(_, i)| i) else {
return Err(
"Fix: convention is V2 but shader is missing @group(0) @binding(3). Add a read-only lookup storage buffer."
.to_string(),
);
};
match info.space {
AddressSpace::Storage { access } if access == StorageAccess::LOAD => Ok(()),
AddressSpace::Storage { access } => Err(format!(
"Fix: binding 3 must be `var<storage, read>`. It is currently `var<storage, {}>`. Change the access mode to `read`.",
access_desc(access)
)),
_ => Err(
"Fix: binding 3 must be a `storage, read` buffer. Change it to `var<storage, read>`."
.to_string(),
),
}
}
}
mod buffer_element {
use crate::spec::types::DataType;
use naga::TypeInner;
#[derive(Debug, Clone, Copy)]
pub(super) enum BufferElement {
Scalar(naga::Scalar),
Vector {
size: naga::VectorSize,
scalar: naga::Scalar,
},
}
impl BufferElement {
pub(super) fn byte_size(self) -> usize {
match self {
Self::Scalar(scalar) => scalar.width as usize,
Self::Vector { size, scalar } => (size as usize) * (scalar.width as usize),
}
}
pub(super) fn is_scalar_u32(self) -> bool {
matches!(
self,
Self::Scalar(naga::Scalar {
kind: naga::ScalarKind::Uint,
width: 4,
})
)
}
pub(super) fn describe(self) -> String {
match self {
Self::Scalar(naga::Scalar {
kind: naga::ScalarKind::Uint,
width: 4,
}) => "u32 element (4-byte storage word)".to_string(),
Self::Scalar(naga::Scalar {
kind: naga::ScalarKind::Sint,
width: 4,
}) => "i32 element (4 bytes)".to_string(),
Self::Scalar(naga::Scalar {
kind: naga::ScalarKind::Float,
width: 4,
}) => "f32 element (4 bytes)".to_string(),
Self::Scalar(scalar) => {
format!("{:?} scalar element ({} bytes)", scalar.kind, scalar.width)
}
Self::Vector { size, scalar } => format!(
"{size:?} vector of {:?} lanes ({} bytes)",
scalar.kind,
self.byte_size()
),
}
}
}
pub(super) fn buffer_element(
module: &naga::Module,
ty: naga::Handle<naga::Type>,
) -> Result<BufferElement, String> {
match &module.types[ty].inner {
TypeInner::Struct { members, .. } if members.len() == 1 => {
buffer_element(module, members[0].ty)
}
TypeInner::Struct { .. } => {
let name = super::types::type_name(module, ty);
Err(format!(
"Fix: buffer type '{}' is a struct but is not a single-member wrapper around an array. Either use array<u32> or a recognized Bytes wrapper.",
name
))
}
TypeInner::Array { base, .. } => buffer_element(module, *base),
TypeInner::Scalar(scalar) => Ok(BufferElement::Scalar(*scalar)),
TypeInner::Vector { size, scalar } => Ok(BufferElement::Vector {
size: *size,
scalar: *scalar,
}),
other => {
let name = super::types::type_name(module, ty);
Err(format!(
"Fix: buffer type '{}' ({:?}) is not a supported storage buffer element type. Use u32, vec2<u32>, vec4<u32>, or array<T>.",
name, other
))
}
}
}
pub(super) fn array_stride(
module: &naga::Module,
ty: naga::Handle<naga::Type>,
) -> Result<Option<usize>, String> {
match &module.types[ty].inner {
TypeInner::Struct { members, .. } if members.len() == 1 => {
array_stride(module, members[0].ty)
}
TypeInner::Array { stride, .. } => Ok(Some(*stride as usize)),
TypeInner::Struct { .. } => {
let name = super::types::type_name(module, ty);
Err(format!(
"Fix: buffer type '{}' is a struct but is not a single-member wrapper around an array. Either use array<u32> or a recognized Bytes wrapper.",
name
))
}
_ => Ok(None),
}
}
pub(super) fn check_array_stride(
module: &naga::Module,
ty: naga::Handle<naga::Type>,
declared: &DataType,
binding: &str,
) -> Result<(), String> {
let DataType::Array { element_size } = declared else {
return Ok(());
};
let Some(stride) = array_stride(module, ty)? else {
return Err(format!(
"Fix: declared array<{element_size}B> requires {binding} to be an array storage buffer with a declared stride. Use `array<T>` whose stride matches the Array element_size."
));
};
if stride == *element_size {
Ok(())
} else {
Err(format!(
"Fix: declared array<{element_size}B> requires {binding} array stride {element_size} bytes, but the shader stride is {stride}. Change the WGSL storage element or the declared Array element_size."
))
}
}
pub(super) fn check_variable_buffer_element(
declared: &DataType,
element: &BufferElement,
binding: &str,
) -> Result<(), String> {
match declared {
DataType::Bytes => {
if element.is_scalar_u32() {
Ok(())
} else {
Err(format!(
"Fix: declared bytes signature requires {binding} to be a scalar u32 storage buffer with logical element_size=1 byte, but the shader uses {}. Use `array<u32>` or the canonical `Bytes {{ data: array<u32> }}` wrapper.",
element.describe()
))
}
}
DataType::Array { element_size } => {
if element.byte_size() == *element_size {
Ok(())
} else {
Err(format!(
"Fix: declared array<{element_size}B> signature requires {binding} element size {element_size} bytes, but the shader uses {}. Change the WGSL storage element or the declared Array element_size so they agree.",
element.describe()
))
}
}
_ => Ok(()),
}
}
pub(super) fn variable_input_type(inputs: &[DataType]) -> Option<&DataType> {
inputs
.iter()
.find(|ty| matches!(ty, DataType::Bytes | DataType::Array { .. }))
}
}
mod parse {
use crate::pipeline::backend::{wrap_shader, ConformDispatchConfig};
use crate::spec::types::{BufferInitPolicy, OpSpec};
use naga::AddressSpace;
#[derive(Debug, Clone)]
pub(super) struct BindingInfo {
pub(super) space: AddressSpace,
pub(super) ty: naga::Handle<naga::Type>,
}
pub(super) fn resolve_wgsl(op: &OpSpec) -> Result<String, String> {
let raw = (op.wgsl_fn)();
let has_bindings =
raw.contains("@group(0) @binding(") || raw.contains("@group(0)@binding(");
if has_bindings {
Ok(raw)
} else {
let config = ConformDispatchConfig {
workgroup_size: 1,
workgroup_count: 1,
convention: op.convention,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
};
Ok(wrap_shader(&raw, &config))
}
}
pub(super) fn parse_wgsl(source: &str, op: &OpSpec) -> Result<naga::Module, String> {
naga::front::wgsl::parse_str(source).map_err(|err| {
format!(
"WGSL parse failed for {}: {err}. Fix: provide syntactically valid WGSL.",
op.id
)
})
}
pub(super) fn extract_bindings(
module: &naga::Module,
) -> Result<Vec<(u32, BindingInfo)>, String> {
let mut out = Vec::new();
for (_, global) in module.global_variables.iter() {
let Some(binding) = &global.binding else {
continue;
};
if binding.group != 0 {
continue;
}
out.push((
binding.binding,
BindingInfo {
space: global.space,
ty: global.ty,
},
));
}
out.sort_by_key(|(b, _)| *b);
Ok(out)
}
}
mod prototype {
use crate::spec::types::OpSpec;
use crate::enforce::enforcers::signature_match::types::{
is_u32_type, return_type_matches_signature, wgsl_type_desc,
};
pub(super) fn check_vyre_op_prototype(
module: &naga::Module,
op: &OpSpec,
) -> Result<(), String> {
let Some((_, function)) = module
.functions
.iter()
.find(|(_, function)| function.name.as_deref() == Some("vyre_op"))
else {
return Err(format!(
"Fix: shader for {} must declare `fn vyre_op(index: u32, input_len: u32) -> ...`. Add the vyre_op function with the canonical two-parameter prototype.",
op.id
));
};
if function.arguments.len() != 2 {
return Err(format!(
"Fix: vyre_op for {} has {} parameters, expected exactly 2: `index: u32` and `input_len: u32`.",
op.id,
function.arguments.len()
));
}
for (idx, arg) in function.arguments.iter().enumerate() {
if !is_u32_type(module, arg.ty) {
let name = arg.name.as_deref().unwrap_or("unnamed");
return Err(format!(
"Fix: vyre_op parameter {idx} (`{name}`) for {} must be u32. Use `fn vyre_op(index: u32, input_len: u32) -> ...`.",
op.id
));
}
}
let Some(result) = &function.result else {
return Err(format!(
"Fix: vyre_op for {} must return {}. Add an explicit compatible return type.",
op.id, op.signature.output
));
};
if !return_type_matches_signature(module, result.ty, &op.signature.output) {
return Err(format!(
"Fix: vyre_op for {} returns {}, but the declared output signature is {}. Change the WGSL return type or the OpSignature so they agree.",
op.id,
wgsl_type_desc(module, result.ty),
op.signature.output
));
}
Ok(())
}
}
mod types {
use crate::spec::types::DataType;
use naga::{StorageAccess, TypeInner};
pub(super) fn type_name(module: &naga::Module, ty: naga::Handle<naga::Type>) -> String {
module.types[ty]
.name
.as_deref()
.unwrap_or("unnamed")
.to_string()
}
pub(super) fn type_byte_size(
module: &naga::Module,
ty: naga::Handle<naga::Type>,
) -> Option<usize> {
match &module.types[ty].inner {
TypeInner::Scalar(scalar) => Some(scalar.width as usize),
TypeInner::Vector { size, scalar } => Some((*size as usize) * (scalar.width as usize)),
_ => None,
}
}
pub(super) fn is_u32_type(module: &naga::Module, ty: naga::Handle<naga::Type>) -> bool {
matches!(
module.types[ty].inner,
TypeInner::Scalar(naga::Scalar {
kind: naga::ScalarKind::Uint,
width: 4,
})
)
}
pub(super) fn return_type_matches_variable_output(
module: &naga::Module,
ty: naga::Handle<naga::Type>,
output: &DataType,
) -> bool {
match output {
DataType::Bytes => is_u32_type(module, ty),
DataType::Array { element_size } => type_byte_size(module, ty) == Some(*element_size),
_ => false,
}
}
pub(super) fn return_type_matches_signature(
module: &naga::Module,
ty: naga::Handle<naga::Type>,
output: &DataType,
) -> bool {
match output {
DataType::U32 => is_u32_type(module, ty),
DataType::Bytes | DataType::Array { .. } => {
return_type_matches_variable_output(module, ty, output)
}
DataType::I32 => matches!(
module.types[ty].inner,
TypeInner::Scalar(naga::Scalar {
kind: naga::ScalarKind::Sint,
width: 4,
})
),
DataType::F32 => matches!(
module.types[ty].inner,
TypeInner::Scalar(naga::Scalar {
kind: naga::ScalarKind::Float,
width: 4,
})
),
DataType::Vec2U32 => matches!(
module.types[ty].inner,
TypeInner::Vector {
size: naga::VectorSize::Bi,
scalar: naga::Scalar {
kind: naga::ScalarKind::Uint,
width: 4,
},
}
),
DataType::Vec4U32 => matches!(
module.types[ty].inner,
TypeInner::Vector {
size: naga::VectorSize::Quad,
scalar: naga::Scalar {
kind: naga::ScalarKind::Uint,
width: 4,
},
}
),
DataType::U64
| DataType::Bool
| DataType::F16
| DataType::BF16
| DataType::F64
| DataType::Tensor => false,
}
}
pub(super) fn wgsl_type_desc(module: &naga::Module, ty: naga::Handle<naga::Type>) -> String {
match &module.types[ty].inner {
TypeInner::Scalar(naga::Scalar {
kind: naga::ScalarKind::Uint,
width: 4,
}) => "u32".to_string(),
TypeInner::Scalar(naga::Scalar {
kind: naga::ScalarKind::Sint,
width: 4,
}) => "i32".to_string(),
TypeInner::Scalar(naga::Scalar {
kind: naga::ScalarKind::Float,
width: 4,
}) => "f32".to_string(),
other => format!("{other:?}"),
}
}
pub(super) fn access_desc(access: StorageAccess) -> &'static str {
if access.contains(StorageAccess::LOAD) && access.contains(StorageAccess::STORE) {
"read_write"
} else if access.contains(StorageAccess::LOAD) {
"read"
} else if access.contains(StorageAccess::STORE) {
"write"
} else {
"none"
}
}
pub(super) fn format_inputs(inputs: &[DataType]) -> String {
if inputs.is_empty() {
return "()".to_string();
}
inputs
.iter()
.map(|dt| format!("{}", dt))
.collect::<Vec<_>>()
.join(", ")
}
}
pub struct SignatureMatchEnforcer;
impl crate::enforce::EnforceGate for SignatureMatchEnforcer {
fn id(&self) -> &'static str {
"signature_match"
}
fn name(&self) -> &'static str {
"signature_match"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let messages = ctx
.specs
.iter()
.filter_map(|spec| {
enforce_signature(spec)
.err()
.map(|err| format!("signature({}): {err}", spec.id))
})
.collect::<Vec<_>>();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: SignatureMatchEnforcer = SignatureMatchEnforcer;