use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Copy)]
pub enum Language {
WGSL,
GLSL,
SPIRV,
}
#[derive(Clone)]
pub struct Metadata {
pub version: String,
pub author: String,
pub description: String,
pub stage: naga::ShaderStage,
pub entry_point: String,
pub source_hash: u64,
pub language: Language,
}
#[derive(Clone)]
pub struct IR {
pub module: naga::Module,
pub metadata: Metadata,
}
pub fn compile_ir(source: &str, source_type: naga::ShaderStage, lang: Language) -> Result<IR, String> {
let mut hasher = DefaultHasher::new();
source.hash(&mut hasher);
let source_hash = hasher.finish();
let module = match lang {
Language::WGSL => {
naga::front::wgsl::parse_str(source)
.map_err(|e| format!("WGSL parse error: {:?}", e))?
}
Language::GLSL => {
let mut parser = naga::front::glsl::Frontend::default();
let options = naga::front::glsl::Options {
stage: source_type,
defines: Default::default(),
};
parser.parse(&options, source)
.map_err(|e| format!("GLSL parse error: {:?}", e))?
}
Language::SPIRV => {
return Err("Use compile_spirv() for SPIR-V input".to_string());
}
};
let entry_point = module
.entry_points
.iter()
.find(|ep| ep.stage == source_type)
.map(|ep| ep.name.clone())
.unwrap_or_default();
Ok(IR {
module,
metadata: Metadata {
version: String::from("1.0"),
author: String::new(),
description: String::new(),
stage: source_type,
entry_point,
source_hash,
language: lang,
},
})
}
pub fn compile_spirv(spirv_bytes: &[u8]) -> Result<IR, String> {
let mut hasher = DefaultHasher::new();
spirv_bytes.hash(&mut hasher);
let source_hash = hasher.finish();
let module = naga::front::spv::parse_u8_slice(
spirv_bytes,
&naga::front::spv::Options::default(),
).map_err(|e| format!("SPIR-V parse error: {:?}", e))?;
let (stage, entry_point) = module
.entry_points
.first()
.map(|ep| (ep.stage, ep.name.clone()))
.unwrap_or((naga::ShaderStage::Vertex, String::new()));
Ok(IR {
module,
metadata: Metadata {
version: String::from("1.0"),
author: String::new(),
description: String::new(),
stage,
entry_point,
source_hash,
language: Language::SPIRV,
},
})
}