pub mod kernels;
use crate::webgpu::{WebGpuDevice, WebGpuError, WebGpuResult};
#[cfg(feature = "webgpu")]
#[allow(unused_imports)]
use md5;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(feature = "webgpu")]
#[allow(unused_imports)]
use wgpu;
#[derive(Debug, Clone)]
pub enum ShaderSource {
Wgsl(String),
Glsl(String),
}
#[derive(Debug)]
pub struct ShaderModule {
pub module: wgpu::ShaderModule,
pub source: ShaderSource,
pub entry_points: Vec<String>,
pub size_bytes: usize,
pub compilation_info: Option<wgpu::CompilationInfo>,
}
impl ShaderModule {
pub fn new(
device: &WebGpuDevice,
source: ShaderSource,
label: Option<&str>,
) -> WebGpuResult<Self> {
let (wgsl_source, size_bytes) = match &source {
ShaderSource::Wgsl(code) => (code.clone(), code.len()),
ShaderSource::Glsl(_glsl_code) => {
return Err(WebGpuError::UnsupportedFeature(
"GLSL shaders require translation to WGSL. Please convert your GLSL to WGSL or use a translation tool like naga.".to_string(),
));
}
};
let module = device
.device()
.create_shader_module(wgpu::ShaderModuleDescriptor {
label,
source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
});
let entry_points = vec!["main".to_string()];
Ok(Self {
module,
source,
entry_points,
size_bytes,
compilation_info: None, })
}
pub fn wgpu_module(&self) -> &wgpu::ShaderModule {
&self.module
}
pub fn source(&self) -> &ShaderSource {
&self.source
}
pub fn entry_points(&self) -> &[String] {
&self.entry_points
}
pub fn size_bytes(&self) -> usize {
self.size_bytes
}
}
#[derive(Debug)]
pub struct ShaderCache {
cache: RwLock<HashMap<String, Arc<ShaderModule>>>,
}
impl ShaderCache {
pub fn new() -> Self {
Self {
cache: RwLock::new(HashMap::new()),
}
}
pub fn get_or_compile(
&self,
device: &WebGpuDevice,
key: String,
source: ShaderSource,
label: Option<&str>,
) -> WebGpuResult<Arc<ShaderModule>> {
{
let cache = self.cache.read();
if let Some(module) = cache.get(&key) {
return Ok(Arc::clone(module));
}
}
let module = ShaderModule::new(device, source, label)?;
let module_arc = Arc::new(module);
{
let mut cache = self.cache.write();
cache.insert(key, Arc::clone(&module_arc));
}
Ok(module_arc)
}
pub fn clear(&self) {
self.cache.write().clear();
}
pub fn stats(&self) -> (usize, usize) {
let cache = self.cache.read();
let count = cache.len();
let total_bytes = cache.values().map(|module| module.size_bytes()).sum();
(count, total_bytes)
}
}
#[derive(Debug)]
pub struct ShaderCompiler {
cache: Arc<ShaderCache>,
}
impl ShaderCompiler {
pub fn new() -> Self {
Self {
cache: Arc::new(ShaderCache::new()),
}
}
pub fn compile_wgsl(
&self,
device: &WebGpuDevice,
source: &str,
label: Option<&str>,
) -> WebGpuResult<Arc<ShaderModule>> {
let key = format!("wgsl_{:x}", md5::compute(source));
let source = ShaderSource::Wgsl(source.to_string());
self.cache.get_or_compile(device, key, source, label)
}
pub fn cache_stats(&self) -> (usize, usize) {
self.cache.stats()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
}
pub mod layout_helpers {
use super::*;
pub fn create_binary_op_layout(device: &WebGpuDevice) -> wgpu::BindGroupLayout {
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Binary Operation Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
})
}
pub fn create_unary_op_layout(device: &WebGpuDevice) -> wgpu::BindGroupLayout {
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Unary Operation Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
})
}
}