use super::{Device, Error, Result, ShaderModule};
use std::borrow::Cow;
use std::path::PathBuf;
#[derive(Debug, Clone, Copy)]
pub struct ShaderSource {
pub name: &'static str,
pub spv: &'static [u8],
}
#[derive(Debug)]
pub enum ShaderLoadError {
NotFound(String),
Io {
name: String,
source: std::io::Error,
},
MalformedSpirv {
name: String,
byte_len: usize,
},
}
impl std::fmt::Display for ShaderLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFound(name) => write!(f, "shader not found: {name}"),
Self::Io { name, source } => {
write!(f, "failed to read shader {name} from override directory: {source}")
}
Self::MalformedSpirv { name, byte_len } => write!(
f,
"shader {name} has malformed SPIR-V: {byte_len} bytes is not a multiple of 4",
),
}
}
}
impl std::error::Error for ShaderLoadError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io { source, .. } => Some(source),
_ => None,
}
}
}
impl From<ShaderLoadError> for Error {
fn from(e: ShaderLoadError) -> Self {
Error::ShaderLoad(e)
}
}
#[derive(Default)]
pub struct ShaderRegistry {
embedded: &'static [ShaderSource],
env_override: Option<&'static str>,
}
impl ShaderRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_embedded(mut self, shaders: &'static [ShaderSource]) -> Self {
self.embedded = shaders;
self
}
pub fn with_env_override(mut self, var: &'static str) -> Self {
self.env_override = Some(var);
self
}
pub fn load(&self, name: &str) -> std::result::Result<Cow<'_, [u8]>, ShaderLoadError> {
if let Some(var) = self.env_override {
if let Some(dir) = override_dir(var) {
let path = dir.join(format!("{name}.spv"));
if path.exists() {
return std::fs::read(&path).map(Cow::Owned).map_err(|source| {
ShaderLoadError::Io {
name: name.to_owned(),
source,
}
});
}
}
}
self.embedded
.iter()
.find(|s| s.name == name)
.map(|s| Cow::Borrowed(s.spv))
.ok_or_else(|| ShaderLoadError::NotFound(name.to_owned()))
}
pub fn load_words(&self, name: &str) -> std::result::Result<Vec<u32>, ShaderLoadError> {
let bytes = self.load(name)?;
if bytes.len() % 4 != 0 {
return Err(ShaderLoadError::MalformedSpirv {
name: name.to_owned(),
byte_len: bytes.len(),
});
}
Ok(bytes
.chunks_exact(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect())
}
pub fn load_module(&self, device: &Device, name: &str) -> Result<ShaderModule> {
let bytes = self.load(name)?;
ShaderModule::from_spirv_bytes(device, &bytes)
}
}
fn override_dir(var: &str) -> Option<PathBuf> {
let raw = std::env::var_os(var)?;
let path = PathBuf::from(raw);
path.metadata().ok().filter(|m| m.is_dir()).map(|_| path)
}