use core::fmt;
#[derive(Debug, Clone)]
pub struct WGSLShader {
source: String,
entry_point: String,
workgroup_size: (u32, u32, u32),
}
impl WGSLShader {
pub fn new(source: impl Into<String>) -> Self {
Self {
source: source.into(),
entry_point: "main".to_string(),
workgroup_size: (64, 1, 1), }
}
pub fn with_entry_point(source: impl Into<String>, entry_point: impl Into<String>) -> Self {
Self {
source: source.into(),
entry_point: entry_point.into(),
workgroup_size: (64, 1, 1),
}
}
pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.workgroup_size = (x, y, z);
self
}
pub fn source(&self) -> &str {
&self.source
}
pub fn entry_point(&self) -> &str {
&self.entry_point
}
pub fn workgroup_size(&self) -> (u32, u32, u32) {
self.workgroup_size
}
pub fn validate(&self) -> Result<(), ShaderError> {
if self.source.is_empty() {
return Err(ShaderError::EmptyShader);
}
if !self.source.contains("@compute") {
return Err(ShaderError::MissingComputeAttribute);
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShaderError {
EmptyShader,
MissingComputeAttribute,
SyntaxError(String),
CompilationFailed(String),
}
impl fmt::Display for ShaderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ShaderError::EmptyShader => write!(f, "Shader source is empty"),
ShaderError::MissingComputeAttribute => write!(f, "Missing @compute attribute"),
ShaderError::SyntaxError(msg) => write!(f, "Syntax error: {}", msg),
ShaderError::CompilationFailed(msg) => write!(f, "Compilation failed: {}", msg),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BufferUsage {
Storage,
Uniform,
Staging,
Vertex,
Index,
}
#[derive(Debug, Clone)]
pub struct GPUBuffer {
size: usize,
usage: BufferUsage,
mapped: bool,
label: Option<String>,
}
impl GPUBuffer {
pub fn storage(size: usize) -> Self {
Self {
size,
usage: BufferUsage::Storage,
mapped: false,
label: None,
}
}
pub fn uniform(size: usize) -> Self {
Self {
size,
usage: BufferUsage::Uniform,
mapped: false,
label: None,
}
}
pub fn staging(size: usize) -> Self {
Self {
size,
usage: BufferUsage::Staging,
mapped: true,
label: None,
}
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn size(&self) -> usize {
self.size
}
pub fn usage(&self) -> BufferUsage {
self.usage
}
pub fn is_mapped(&self) -> bool {
self.mapped
}
pub fn label(&self) -> Option<&str> {
self.label.as_deref()
}
}
#[derive(Debug, Clone)]
pub struct ComputePipeline {
shader: WGSLShader,
bind_groups: Vec<BindGroupLayout>,
label: Option<String>,
}
impl ComputePipeline {
pub fn new(shader: WGSLShader) -> Self {
Self {
shader,
bind_groups: Vec::new(),
label: None,
}
}
pub fn with_bind_group(mut self, layout: BindGroupLayout) -> Self {
self.bind_groups.push(layout);
self
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn shader(&self) -> &WGSLShader {
&self.shader
}
pub fn bind_groups(&self) -> &[BindGroupLayout] {
&self.bind_groups
}
pub fn label(&self) -> Option<&str> {
self.label.as_deref()
}
pub fn optimal_dispatch_size(&self, data_size: usize) -> (u32, u32, u32) {
let (wg_x, wg_y, wg_z) = self.shader.workgroup_size();
let workgroup_count = (data_size as u32 + wg_x - 1) / wg_x;
(workgroup_count, wg_y, wg_z)
}
}
#[derive(Debug, Clone)]
pub struct BindGroupEntry {
binding: u32,
resource_type: ResourceType,
visibility: ShaderStage,
}
impl BindGroupEntry {
pub fn new(binding: u32, resource_type: ResourceType, visibility: ShaderStage) -> Self {
Self {
binding,
resource_type,
visibility,
}
}
pub fn binding(&self) -> u32 {
self.binding
}
pub fn resource_type(&self) -> ResourceType {
self.resource_type
}
pub fn visibility(&self) -> ShaderStage {
self.visibility
}
}
#[derive(Debug, Clone)]
pub struct BindGroupLayout {
entries: Vec<BindGroupEntry>,
label: Option<String>,
}
impl BindGroupLayout {
pub fn new() -> Self {
Self {
entries: Vec::new(),
label: None,
}
}
pub fn with_entry(mut self, entry: BindGroupEntry) -> Self {
self.entries.push(entry);
self
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn entries(&self) -> &[BindGroupEntry] {
&self.entries
}
pub fn label(&self) -> Option<&str> {
self.label.as_deref()
}
}
impl Default for BindGroupLayout {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResourceType {
StorageBuffer,
UniformBuffer,
ReadOnlyStorageBuffer,
Texture,
Sampler,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShaderStage {
Vertex,
Fragment,
Compute,
All,
}
#[derive(Debug, Clone, Copy)]
pub struct WorkgroupOptimizer {
max_workgroup_size: u32,
preferred_1d: u32,
preferred_2d: (u32, u32),
}
impl WorkgroupOptimizer {
pub fn new(max_workgroup_size: u32) -> Self {
Self {
max_workgroup_size,
preferred_1d: 256,
preferred_2d: (16, 16),
}
}
pub fn optimize_1d(&self, data_size: usize) -> u32 {
let size = data_size as u32;
if size <= 64 {
64
} else if size <= self.preferred_1d {
self.preferred_1d
} else {
self.max_workgroup_size.min(512)
}
}
pub fn optimize_2d(&self, width: usize, height: usize) -> (u32, u32) {
let (w, h) = (width as u32, height as u32);
if w <= 16 && h <= 16 {
(8, 8)
} else if w <= 32 && h <= 32 {
(16, 16)
} else {
self.preferred_2d
}
}
pub fn optimize_3d(&self, width: usize, height: usize, depth: usize) -> (u32, u32, u32) {
let (w, h, d) = (width as u32, height as u32, depth as u32);
if w <= 8 && h <= 8 && d <= 8 {
(4, 4, 4)
} else {
(8, 8, 4)
}
}
pub fn max_workgroup_size(&self) -> u32 {
self.max_workgroup_size
}
}
impl Default for WorkgroupOptimizer {
fn default() -> Self {
Self::new(256) }
}
#[derive(Debug, Clone)]
pub struct PipelineCache {
cache: Vec<(u64, ComputePipeline)>,
max_size: usize,
}
impl PipelineCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: Vec::new(),
max_size,
}
}
pub fn get_or_create<F>(&mut self, shader: &WGSLShader, create_fn: F) -> &ComputePipeline
where
F: FnOnce(&WGSLShader) -> ComputePipeline,
{
let hash = self.hash_shader(shader);
if let Some(index) = self.cache.iter().position(|(h, _)| *h == hash) {
return &self.cache[index].1;
}
let pipeline = create_fn(shader);
self.cache.push((hash, pipeline));
if self.cache.len() > self.max_size {
self.cache.remove(0);
}
&self
.cache
.last()
.expect("cache should have at least one entry after push")
.1
}
fn hash_shader(&self, shader: &WGSLShader) -> u64 {
let src = shader.source();
let mut hash = src.len() as u64;
for (i, byte) in src.bytes().take(16).enumerate() {
hash = hash
.wrapping_mul(31)
.wrapping_add(byte as u64 * (i as u64 + 1));
}
hash
}
pub fn clear(&mut self) {
self.cache.clear();
}
pub fn size(&self) -> usize {
self.cache.len()
}
pub fn max_size(&self) -> usize {
self.max_size
}
}
pub mod shaders {
use super::*;
pub fn elementwise_add() -> WGSLShader {
WGSLShader::new(
r#"
@group(0) @binding(0) var<storage, read> input_a: array<f32>;
@group(0) @binding(1) var<storage, read> input_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let index = global_id.x;
if (index < arrayLength(&input_a)) {
output[index] = input_a[index] + input_b[index];
}
}
"#,
)
}
pub fn elementwise_mul() -> WGSLShader {
WGSLShader::new(
r#"
@group(0) @binding(0) var<storage, read> input_a: array<f32>;
@group(0) @binding(1) var<storage, read> input_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let index = global_id.x;
if (index < arrayLength(&input_a)) {
output[index] = input_a[index] * input_b[index];
}
}
"#,
)
}
pub fn matrix_mul() -> WGSLShader {
WGSLShader::new(
r#"
struct Dimensions {
m: u32,
n: u32,
k: u32,
}
@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> dims: Dimensions;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.y;
let col = global_id.x;
if (row >= dims.m || col >= dims.n) {
return;
}
var sum = 0.0;
for (var i = 0u; i < dims.k; i++) {
let a_index = row * dims.k + i;
let b_index = i * dims.n + col;
sum += matrix_a[a_index] * matrix_b[b_index];
}
let out_index = row * dims.n + col;
output[out_index] = sum;
}
"#,
)
.with_workgroup_size(16, 16, 1)
}
pub fn reduce_sum() -> WGSLShader {
WGSLShader::new(
r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
var<workgroup> shared_data: array<f32, 256>;
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
) {
let tid = local_id.x;
let index = global_id.x;
// Load data into shared memory
if (index < arrayLength(&input)) {
shared_data[tid] = input[index];
} else {
shared_data[tid] = 0.0;
}
workgroupBarrier();
// Reduce in shared memory
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) {
shared_data[tid] += shared_data[tid + s];
}
workgroupBarrier();
}
// Write result
if (tid == 0u) {
output[global_id.x / 256u] = shared_data[0];
}
}
"#,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wgsl_shader_creation() {
let shader = WGSLShader::new("@compute @workgroup_size(64) fn main() {}");
assert!(shader.source().contains("@compute"));
assert_eq!(shader.entry_point(), "main");
assert_eq!(shader.workgroup_size(), (64, 1, 1));
}
#[test]
fn test_shader_with_entry_point() {
let shader = WGSLShader::with_entry_point("code", "custom_entry");
assert_eq!(shader.entry_point(), "custom_entry");
}
#[test]
fn test_shader_workgroup_size() {
let shader = WGSLShader::new("code").with_workgroup_size(16, 16, 1);
assert_eq!(shader.workgroup_size(), (16, 16, 1));
}
#[test]
fn test_shader_validation() {
let shader = WGSLShader::new("@compute fn main() {}");
assert!(shader.validate().is_ok());
let empty_shader = WGSLShader::new("");
assert_eq!(empty_shader.validate(), Err(ShaderError::EmptyShader));
let invalid_shader = WGSLShader::new("fn main() {}");
assert_eq!(
invalid_shader.validate(),
Err(ShaderError::MissingComputeAttribute)
);
}
#[test]
fn test_gpu_buffer_creation() {
let storage = GPUBuffer::storage(1024);
assert_eq!(storage.size(), 1024);
assert_eq!(storage.usage(), BufferUsage::Storage);
assert!(!storage.is_mapped());
let uniform = GPUBuffer::uniform(256);
assert_eq!(uniform.usage(), BufferUsage::Uniform);
let staging = GPUBuffer::staging(512);
assert_eq!(staging.usage(), BufferUsage::Staging);
assert!(staging.is_mapped());
}
#[test]
fn test_buffer_with_label() {
let buffer = GPUBuffer::storage(1024).with_label("test_buffer");
assert_eq!(buffer.label(), Some("test_buffer"));
}
#[test]
fn test_compute_pipeline() {
let shader = WGSLShader::new("@compute fn main() {}");
let pipeline = ComputePipeline::new(shader.clone());
assert_eq!(pipeline.shader().source(), shader.source());
assert_eq!(pipeline.bind_groups().len(), 0);
}
#[test]
fn test_pipeline_with_label() {
let shader = WGSLShader::new("@compute fn main() {}");
let pipeline = ComputePipeline::new(shader).with_label("test_pipeline");
assert_eq!(pipeline.label(), Some("test_pipeline"));
}
#[test]
fn test_optimal_dispatch_size() {
let shader = WGSLShader::new("code").with_workgroup_size(64, 1, 1);
let pipeline = ComputePipeline::new(shader);
let (x, _, _) = pipeline.optimal_dispatch_size(1000);
assert_eq!(x, 16); }
#[test]
fn test_bind_group_entry() {
let entry = BindGroupEntry::new(0, ResourceType::StorageBuffer, ShaderStage::Compute);
assert_eq!(entry.binding(), 0);
assert_eq!(entry.resource_type(), ResourceType::StorageBuffer);
assert_eq!(entry.visibility(), ShaderStage::Compute);
}
#[test]
fn test_bind_group_layout() {
let entry = BindGroupEntry::new(0, ResourceType::StorageBuffer, ShaderStage::Compute);
let layout = BindGroupLayout::new().with_entry(entry);
assert_eq!(layout.entries().len(), 1);
}
#[test]
fn test_workgroup_optimizer() {
let optimizer = WorkgroupOptimizer::new(512);
assert_eq!(optimizer.optimize_1d(50), 64); assert_eq!(optimizer.optimize_1d(100), 256); assert_eq!(optimizer.optimize_1d(500), 512);
let (w, h) = optimizer.optimize_2d(10, 10);
assert_eq!((w, h), (8, 8));
let (w, h, d) = optimizer.optimize_3d(10, 10, 10);
assert_eq!((w, h, d), (8, 8, 4));
}
#[test]
fn test_default_workgroup_optimizer() {
let optimizer = WorkgroupOptimizer::default();
assert_eq!(optimizer.max_workgroup_size(), 256);
}
#[test]
fn test_pipeline_cache() {
let mut cache = PipelineCache::new(10);
assert_eq!(cache.size(), 0);
let shader = WGSLShader::new("@compute fn main() {}");
let _pipeline = cache.get_or_create(&shader, |s| ComputePipeline::new(s.clone()));
assert_eq!(cache.size(), 1);
cache.clear();
assert_eq!(cache.size(), 0);
}
#[test]
fn test_shader_templates() {
let add_shader = shaders::elementwise_add();
assert!(add_shader.source().contains("input_a"));
assert!(add_shader.validate().is_ok());
let mul_shader = shaders::elementwise_mul();
assert!(mul_shader.source().contains("input_b"));
assert!(mul_shader.validate().is_ok());
let matmul_shader = shaders::matrix_mul();
assert!(matmul_shader.source().contains("matrix_a"));
assert_eq!(matmul_shader.workgroup_size(), (16, 16, 1));
assert!(matmul_shader.validate().is_ok());
let reduce_shader = shaders::reduce_sum();
assert!(reduce_shader.source().contains("shared_data"));
assert!(reduce_shader.validate().is_ok());
}
#[test]
fn test_resource_types() {
let _storage = ResourceType::StorageBuffer;
let _uniform = ResourceType::UniformBuffer;
let _readonly = ResourceType::ReadOnlyStorageBuffer;
let _texture = ResourceType::Texture;
let _sampler = ResourceType::Sampler;
}
#[test]
fn test_shader_stages() {
let _vertex = ShaderStage::Vertex;
let _fragment = ShaderStage::Fragment;
let _compute = ShaderStage::Compute;
let _all = ShaderStage::All;
}
#[test]
fn test_buffer_usage_types() {
let _storage = BufferUsage::Storage;
let _uniform = BufferUsage::Uniform;
let _staging = BufferUsage::Staging;
let _vertex = BufferUsage::Vertex;
let _index = BufferUsage::Index;
}
#[test]
fn test_shader_error_display() {
let err = ShaderError::EmptyShader;
assert_eq!(format!("{}", err), "Shader source is empty");
let err = ShaderError::MissingComputeAttribute;
assert_eq!(format!("{}", err), "Missing @compute attribute");
let err = ShaderError::SyntaxError("test".to_string());
assert_eq!(format!("{}", err), "Syntax error: test");
let err = ShaderError::CompilationFailed("test".to_string());
assert_eq!(format!("{}", err), "Compilation failed: test");
}
#[test]
fn test_pipeline_with_bind_group() {
let shader = WGSLShader::new("@compute fn main() {}");
let layout = BindGroupLayout::new();
let pipeline = ComputePipeline::new(shader).with_bind_group(layout);
assert_eq!(pipeline.bind_groups().len(), 1);
}
#[test]
fn test_bind_group_with_label() {
let layout = BindGroupLayout::new().with_label("test_layout");
assert_eq!(layout.label(), Some("test_layout"));
}
#[test]
fn test_default_bind_group_layout() {
let layout = BindGroupLayout::default();
assert_eq!(layout.entries().len(), 0);
}
#[test]
fn test_cache_max_size() {
let cache = PipelineCache::new(5);
assert_eq!(cache.max_size(), 5);
}
}