use std::collections::{HashMap};
use std::path::Path;
use wgpu::{Device, ShaderModule, ShaderModuleDescriptor, ShaderSource};
pub fn compile_wgsl(
device: &Device,
path: &Path,
defines: &HashMap<String, bool>,
) -> ShaderModule {
let source = std::fs::read_to_string(path).unwrap_or_else(|e| {
panic!("Failed to read shader file {}: {}", path.display(), e)
});
let mut conditional_stack: Vec<bool> = vec![];
let processed = preprocess_wgsl(path, &source, defines, &mut conditional_stack);
if !conditional_stack.is_empty() {
panic!(
"Unbalanced preprocessing directives in shader {} (or one of its #included files): open #ifdef/#ifndef without matching #endif",
path.display()
);
}
let label_str = path
.to_str()
.unwrap_or_else(|| panic!("Shader path {} is not valid UTF-8", path.display()));
device.create_shader_module(ShaderModuleDescriptor {
label: Some(label_str),
source: ShaderSource::Wgsl(processed.into()),
})
}
fn preprocess_wgsl(
path: &Path,
src: &str,
defines: &HashMap<String, bool>,
stack: &mut Vec<bool>,
) -> String {
let processed_lines: Vec<String> = src
.lines()
.enumerate()
.flat_map(|(i, line)| {
let line_num = i + 1;
let t = line.trim();
if let Some(rest) = t.strip_prefix("#ifdef ") {
let name = rest.trim();
let value = *defines.get(name).expect(&format!(
"{}:{}: Unknown preprocessing define '{}' in #ifdef",
path.display(),
line_num,
name
));
stack.push(value);
return vec![];
}
if let Some(rest) = t.strip_prefix("#ifndef ") {
let name = rest.trim();
let value = *defines.get(name).expect(&format!(
"{}:{}: Unknown preprocessing define '{}' in #ifndef",
path.display(),
line_num,
name
));
stack.push(!value);
return vec![];
}
if t.starts_with("#else") {
let v = stack.last_mut().expect(&format!(
"{}:{}: #else without matching #ifdef/#ifndef",
path.display(),
line_num
));
*v = !*v;
return vec![];
}
if t.starts_with("#endif") {
if stack.pop().is_none() {
panic!(
"{}:{}: #endif without matching #ifdef/#ifndef",
path.display(),
line_num
);
}
return vec![];
}
if stack.iter().any(|&v| !v) {
return vec![];
}
if let Some(rest) = t.strip_prefix("#include \"") {
let p = rest.strip_suffix('"').expect(&format!(
"{}:{}: Malformed #include directive: missing closing quote",
path.display(),
line_num
));
let parent = path.parent().expect("Cannot resolve relative #include: shader path has no parent directory");
let mut inc_path = parent.to_path_buf();
inc_path.push(p);
let inc_source = std::fs::read_to_string(&inc_path).unwrap_or_else(|e| {
panic!(
"{}:{}: Failed to read included shader \"{}\" (resolved to {}): {}",
path.display(),
line_num,
p,
inc_path.display(),
e
);
});
let inc_processed = preprocess_wgsl(&inc_path, &inc_source, defines, stack);
return inc_processed
.lines()
.map(|s| s.to_string())
.collect::<Vec<_>>();
}
vec![line.to_string()]
})
.collect();
processed_lines.join("\n") + "\n"
}