use super::Error;
use naga::back::spv;
use naga::front::{glsl, wgsl};
use naga::valid::{Capabilities, ValidationFlags, Validator};
use naga::{Module, ShaderStage};
#[derive(Debug)]
pub enum NagaError {
Parse(String),
Validation(String),
SpvOut(String),
}
impl std::fmt::Display for NagaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Parse(s) => write!(f, "GLSL parse error: {s}"),
Self::Validation(s) => write!(f, "Naga validation error: {s}"),
Self::SpvOut(s) => write!(f, "SPIR-V emission error: {s}"),
}
}
}
impl std::error::Error for NagaError {}
impl From<NagaError> for Error {
fn from(e: NagaError) -> Self {
Error::NagaCompile(e.to_string())
}
}
pub fn compile_glsl(source: &str, stage: ShaderStage) -> Result<Vec<u32>, NagaError> {
let mut parser = glsl::Frontend::default();
let module: Module = parser
.parse(&glsl::Options::from(stage), source)
.map_err(|errors| NagaError::Parse(format!("{errors:?}")))?;
spirv_from_module(module)
}
pub fn compile_wgsl(source: &str) -> Result<Vec<u32>, NagaError> {
let module =
wgsl::parse_str(source).map_err(|e| NagaError::Parse(e.emit_to_string(source)))?;
spirv_from_module(module)
}
fn spirv_from_module(module: Module) -> Result<Vec<u32>, NagaError> {
let info = Validator::new(ValidationFlags::all(), Capabilities::all())
.validate(&module)
.map_err(|e| NagaError::Validation(format!("{e:?}")))?;
let spv_options = spv::Options {
lang_version: (1, 0),
..spv::Options::default()
};
let words = spv::write_vec(&module, &info, &spv_options, None)
.map_err(|e| NagaError::SpvOut(format!("{e:?}")))?;
Ok(words)
}