use super::wgpu;
use super::{
unsupported_execution_model_error, ConformDispatchConfig, ExecutionModel, ExecutionModelKind,
OneShotDispatch,
};
use crate::spec::types::Convention;
pub(crate) trait WgslBackend: Send + Sync {
fn name(&self) -> &str;
fn version(&self) -> &str {
"unspecified"
}
fn verify_runtime_info(&self) -> Result<(), String> {
Ok(())
}
fn dispatch(
&self,
wgsl: &str,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String>;
fn supported_models(&self) -> &[ExecutionModelKind] {
&[ExecutionModelKind::OneShot]
}
fn execute(&self, model: &ExecutionModel) -> Result<Vec<u8>, String> {
if !self.supported_models().contains(&model.kind()) {
return Err(unsupported_execution_model_error(
self.name(),
model.kind_name(),
));
}
match model {
ExecutionModel::OneShot(dispatch) => self.dispatch(
&dispatch.wgsl,
&dispatch.input,
dispatch.output_size,
dispatch.config.clone(),
),
}
}
fn dispatch_batch(
&self,
wgsl: &str,
inputs: &[Vec<u8>],
output_sizes: &[usize],
config: ConformDispatchConfig,
) -> Result<Vec<Vec<u8>>, String> {
if inputs.len() != output_sizes.len() {
return Err(format!(
"batch input/output length mismatch: {} inputs, {} output sizes. Fix: pass one output size for every batch input.",
inputs.len(),
output_sizes.len()
));
}
let mut outputs = Vec::with_capacity(inputs.len());
for (input, output_size) in inputs.iter().zip(output_sizes) {
outputs.push(self.dispatch(wgsl, input, *output_size, config.clone())?);
}
Ok(outputs)
}
fn dispatch_program(
&self,
program: &[u8],
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let program = vyre::ir::Program::from_wire(program).map_err(|err| err.to_string())?;
let wgsl = vyre::lower::wgsl::lower(&program).map_err(|err| {
format!(
"failed to lower serialized IR program to WGSL: {err}. Fix: provide valid vyre IR or override dispatch_program with a native backend lowering."
)
})?;
self.dispatch(&wgsl, input, output_size, config)
}
fn max_convention(&self) -> Convention {
Convention::V1
}
fn max_workgroup_invocations(&self) -> Option<u32> {
None
}
}
#[inline]
pub fn wrap_shader(op_wgsl: &str, config: &ConformDispatchConfig) -> String {
if op_wgsl.contains("@compute") {
return op_wgsl.to_string();
}
let mut shader = String::with_capacity(op_wgsl.len() + 1024);
shader.push_str(op_wgsl);
let workgroup_size = config.workgroup_size;
let lookup_binding = match config.convention {
Convention::V2 { .. } => "\n@group(0) @binding(3)\nvar<storage, read> lookup: Bytes;\n",
Convention::V1 => "",
};
#[allow(clippy::items_after_statements)]
{
use std::fmt::Write;
let _ = write!(
shader,
r"
struct Bytes {{
data: array<u32>,
}};
struct Params {{
// Original byte length of the input buffer before word-padding.
// Shaders must use this to ignore zero-padded trailing bytes.
input_len: u32,
output_len: u32,
_pad0: u32,
_pad1: u32,
}};
@group(0) @binding(0)
var<storage, read> input: Bytes;
@group(0) @binding(1)
var<storage, read_write> output: Bytes;
@group(0) @binding(2)
var<uniform> params: Params;
{lookup_binding}
@compute @workgroup_size({workgroup_size})
fn vyre_conform_main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let index = gid.x;
if (index >= arrayLength(&output.data)) {{
return;
}}
output.data[index] = vyre_op(index, params.input_len);
}}
"
);
}
shader
}
#[cfg(test)]
mod tests {
use super::{wrap_shader, ConformDispatchConfig};
use crate::spec::types::Convention;
fn test_op_wgsl() -> &'static str {
"fn vyre_op(index: u32, input_len: u32) -> u32 { return input.data[index]; }"
}
#[test]
fn wrap_shader_includes_op_fragment() {
let config = ConformDispatchConfig::default();
let shader = wrap_shader(test_op_wgsl(), &config);
assert!(shader.contains("fn vyre_op"), "op fragment missing");
}
#[test]
fn wrap_shader_v1_has_input_output_params() {
let config = ConformDispatchConfig::default();
let shader = wrap_shader(test_op_wgsl(), &config);
assert!(
shader.contains("@group(0) @binding(0)"),
"missing input binding"
);
assert!(
shader.contains("@group(0) @binding(1)"),
"missing output binding"
);
assert!(
shader.contains("@group(0) @binding(2)"),
"missing params binding"
);
}
#[test]
fn wrap_shader_v1_no_lookup() {
let config = ConformDispatchConfig::default();
let shader = wrap_shader(test_op_wgsl(), &config);
assert!(
!shader.contains("@binding(3)"),
"V1 should not have lookup binding"
);
}
#[test]
fn wrap_shader_v2_has_lookup() {
let config = ConformDispatchConfig {
convention: Convention::V2 { lookup_binding: 3 },
..Default::default()
};
let shader = wrap_shader(test_op_wgsl(), &config);
assert!(
shader.contains("@binding(3)"),
"V2 should have lookup binding"
);
assert!(
shader.contains("lookup: Bytes"),
"V2 should have lookup buffer"
);
}
#[test]
fn wrap_shader_embeds_workgroup_size() {
let config = ConformDispatchConfig {
workgroup_size: 128,
..Default::default()
};
let shader = wrap_shader(test_op_wgsl(), &config);
assert!(
shader.contains("@workgroup_size(128)"),
"workgroup size not embedded"
);
}
#[test]
fn wrap_shader_entry_point_name() {
let config = ConformDispatchConfig::default();
let shader = wrap_shader(test_op_wgsl(), &config);
assert!(
shader.contains("fn vyre_conform_main"),
"wrong entry point name"
);
}
#[test]
fn wrap_shader_has_bounds_check() {
let config = ConformDispatchConfig::default();
let shader = wrap_shader(test_op_wgsl(), &config);
assert!(shader.contains("arrayLength"), "missing bounds check");
}
#[test]
fn dispatch_config_default_values() {
let config = ConformDispatchConfig::default();
assert_eq!(config.workgroup_size, 1);
assert_eq!(config.workgroup_count, 1);
assert_eq!(config.convention, Convention::V1);
assert!(config.lookup_data.is_none());
}
}
#[inline]
pub fn require_gpu() -> Result<wgpu::WgpuBackend, String> {
wgpu::WgpuBackend::new().ok_or_else(|| {
"Fix: no discrete or integrated GPU adapter is available for vyre-conform. \
Install a supported GPU/driver stack and run with the `gpu` feature; set \
VYRE_CONFORM_GPU_REQUIRED=1 in CI to make adapter discovery fail before tests run."
.to_string()
})
}
#[cfg(all(test, feature = "gpu"))]
mod gpu_parity;