#![cfg(feature = "gpu")]
use trueno::backends::gpu::wgpu;
pub const FUSED_SWIGLU_WGSL: &str = r"
@group(0) @binding(0) var<storage, read> gate: array<f32>;
@group(0) @binding(1) var<storage, read> up: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
struct Dims {
n: u32,
}
@group(0) @binding(3) var<uniform> dims: Dims;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= dims.n) {
return;
}
let g = gate[i];
let sigmoid_g = 1.0 / (1.0 + exp(-g));
output[i] = g * sigmoid_g * up[i];
}
";
pub struct FusedSwigluWgpuKernel {
pub pipeline: wgpu::ComputePipeline,
pub bind_group_layout: wgpu::BindGroupLayout,
}
impl FusedSwigluWgpuKernel {
pub fn new(device: &wgpu::Device) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("fused_swiglu_wgpu"),
source: wgpu::ShaderSource::Wgsl(FUSED_SWIGLU_WGSL.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("fused_swiglu_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("fused_swiglu_pl"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("fused_swiglu_pipe"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Self {
pipeline,
bind_group_layout,
}
}
}
pub const WORKGROUP_SIZE: u32 = 256;
#[cfg(test)]
mod shader_source_tests {
use super::*;
#[test]
fn wgsl_source_has_compute_entry_point() {
assert!(
FUSED_SWIGLU_WGSL.contains("@compute @workgroup_size(256)"),
"WGSL must declare compute entry with workgroup_size(256)"
);
assert!(
FUSED_SWIGLU_WGSL.contains("fn main("),
"WGSL must declare main() entry function"
);
}
#[test]
fn wgsl_source_declares_all_four_bindings() {
assert!(
FUSED_SWIGLU_WGSL.contains("@binding(0) var<storage, read> gate"),
"binding(0) must be the gate input"
);
assert!(
FUSED_SWIGLU_WGSL.contains("@binding(1) var<storage, read> up"),
"binding(1) must be the up input"
);
assert!(
FUSED_SWIGLU_WGSL.contains("@binding(2) var<storage, read_write> output"),
"binding(2) must be the read_write output"
);
assert!(
FUSED_SWIGLU_WGSL.contains("@binding(3) var<uniform> dims"),
"binding(3) must be the uniform dims block"
);
}
#[test]
fn wgsl_source_uses_exp_for_sigmoid() {
assert!(
FUSED_SWIGLU_WGSL.contains("exp(-g)"),
"WGSL must use exp(-g) for sigmoid (matches denominator of silu formula)"
);
assert!(
FUSED_SWIGLU_WGSL.contains("g * sigmoid_g * up[i]"),
"WGSL must compute fused silu(gate) * up in one statement"
);
}
#[test]
fn wgsl_source_has_bounds_check() {
assert!(
FUSED_SWIGLU_WGSL.contains("if (i >= dims.n)"),
"WGSL must guard against out-of-range thread IDs (workgroup ceiling)"
);
}
#[test]
fn workgroup_size_constant_matches_wgsl() {
assert_eq!(
WORKGROUP_SIZE, 256,
"WORKGROUP_SIZE Rust constant must match the WGSL @workgroup_size literal"
);
let needle = format!("@workgroup_size({WORKGROUP_SIZE})");
assert!(
FUSED_SWIGLU_WGSL.contains(&needle),
"WGSL must contain @workgroup_size({WORKGROUP_SIZE})"
);
}
}