use naga_oil::compose::{
ComposableModuleDescriptor, Composer, NagaModuleDescriptor, ShaderDefValue, ShaderLanguage,
};
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
pub struct ShaderComposer {
inner: Mutex<Composer>,
}
impl Default for ShaderComposer {
fn default() -> Self {
Self::new()
}
}
impl ShaderComposer {
pub fn new() -> Self {
Self {
inner: Mutex::new(Composer::default()),
}
}
pub fn register_module(
&self,
source: &str,
file_path: &str,
) -> Result<(), Box<naga_oil::compose::ComposerError>> {
let mut composer = self.inner.lock().unwrap();
composer
.add_composable_module(ComposableModuleDescriptor {
source,
file_path,
language: ShaderLanguage::Wgsl,
as_name: None,
additional_imports: &[],
shader_defs: HashMap::new(),
})
.map_err(Box::new)?;
Ok(())
}
pub fn compile_wgsl(
&self,
device: &wgpu::Device,
label: &str,
source: &str,
shader_defs: &[(&str, ShaderDefValue)],
) -> wgpu::ShaderModule {
let defs: HashMap<String, ShaderDefValue> = shader_defs
.iter()
.map(|(name, value)| ((*name).to_string(), *value))
.collect();
let mut composer = self.inner.lock().unwrap();
let module = composer
.make_naga_module(NagaModuleDescriptor {
source,
file_path: label,
shader_type: naga_oil::compose::ShaderType::Wgsl,
shader_defs: defs,
additional_imports: &[],
})
.unwrap_or_else(|error| {
panic!(
"failed to compose shader {label}: {}",
error.emit_to_string(&composer)
)
});
device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Naga(std::borrow::Cow::Owned(module)),
})
}
}
static GLOBAL_COMPOSER: OnceLock<ShaderComposer> = OnceLock::new();
pub fn global() -> &'static ShaderComposer {
GLOBAL_COMPOSER.get_or_init(ShaderComposer::new)
}
pub fn compile_wgsl(device: &wgpu::Device, label: &str, source: &str) -> wgpu::ShaderModule {
global().compile_wgsl(device, label, source, &[])
}
pub fn compile_wgsl_with_defs(
device: &wgpu::Device,
label: &str,
source: &str,
shader_defs: &[(&str, ShaderDefValue)],
) -> wgpu::ShaderModule {
global().compile_wgsl(device, label, source, shader_defs)
}