use crate::core::{self, Device, Pipeline, Shader, ShaderReflection, ShaderStage};
use crate::ex::errors::{PipelineId, ShaderId, ShaderManagerError};
use crate::ex::PipelineBuilder;
use std::sync::Arc;
struct ShaderEntry {
shader: Shader,
reflection: Option<ShaderReflection>,
name: String,
stage: ShaderStage,
}
struct PipelineEntry {
pipeline: Pipeline,
name: String,
}
pub struct ShaderManager {
device: Arc<Device>,
shaders: Vec<ShaderEntry>,
pipelines: Vec<PipelineEntry>,
}
impl ShaderManager {
pub fn new(device: Arc<Device>) -> Result<Self, ShaderManagerError> {
Ok(Self {
device,
shaders: Vec::new(),
pipelines: Vec::new(),
})
}
pub fn add_shader(
&mut self,
source: &str,
stage: ShaderStage,
name: &str,
) -> Result<ShaderId, ShaderManagerError> {
let shader = Shader::from_glsl(&self.device, source, stage, "main")?;
let reflection = Self::try_reflect_shader(source, stage);
let id = ShaderId(self.shaders.len());
self.shaders.push(ShaderEntry {
shader,
reflection,
name: name.to_string(),
stage,
});
Ok(id)
}
pub fn add_shader_spirv(
&mut self,
spirv: &[u32],
stage: ShaderStage,
name: &str,
) -> Result<ShaderId, ShaderManagerError> {
let shader = Shader::from_spirv(&self.device, spirv, stage, "main")?;
let reflection = ShaderReflection::from_spirv(spirv).ok();
let id = ShaderId(self.shaders.len());
self.shaders.push(ShaderEntry {
shader,
reflection,
name: name.to_string(),
stage,
});
Ok(id)
}
#[inline]
pub fn get_shader(&self, id: ShaderId) -> Result<&Shader, ShaderManagerError> {
self.shaders
.get(id.0)
.map(|entry| &entry.shader)
.ok_or(ShaderManagerError::InvalidShaderId(id))
}
#[inline]
pub fn shader_reflection(
&self,
id: ShaderId,
) -> Result<Option<&ShaderReflection>, ShaderManagerError> {
self.shaders
.get(id.0)
.map(|entry| entry.reflection.as_ref())
.ok_or(ShaderManagerError::InvalidShaderId(id))
}
pub fn shader_name(&self, id: ShaderId) -> Result<&str, ShaderManagerError> {
self.shaders
.get(id.0)
.map(|entry| entry.name.as_str())
.ok_or(ShaderManagerError::InvalidShaderId(id))
}
pub fn shader_stage(&self, id: ShaderId) -> Result<ShaderStage, ShaderManagerError> {
self.shaders
.get(id.0)
.map(|entry| entry.stage)
.ok_or(ShaderManagerError::InvalidShaderId(id))
}
pub fn build_pipeline(
&mut self,
builder: PipelineBuilder,
name: &str,
) -> Result<PipelineId, ShaderManagerError> {
let pipeline = builder.build_graphics(&self.device)?;
let id = PipelineId(self.pipelines.len());
self.pipelines.push(PipelineEntry {
pipeline,
name: name.to_string(),
});
Ok(id)
}
#[inline]
pub fn get_pipeline(&self, id: PipelineId) -> Result<&Pipeline, ShaderManagerError> {
self.pipelines
.get(id.0)
.map(|entry| &entry.pipeline)
.ok_or(ShaderManagerError::InvalidPipelineId(id))
}
#[inline]
pub fn pipeline_name(&self, id: PipelineId) -> Result<&str, ShaderManagerError> {
self.pipelines
.get(id.0)
.map(|entry| entry.name.as_str())
.ok_or(ShaderManagerError::InvalidPipelineId(id))
}
#[inline]
pub fn device(&self) -> Arc<Device> {
Arc::clone(&self.device)
}
#[inline]
pub fn shader_count(&self) -> usize {
self.shaders.len()
}
#[inline]
pub fn pipeline_count(&self) -> usize {
self.pipelines.len()
}
fn try_reflect_shader(source: &str, stage: ShaderStage) -> Option<ShaderReflection> {
match core::ShaderCompiler::compile_glsl(source, stage) {
Ok(spirv) => ShaderReflection::from_spirv(&spirv).ok(),
Err(_) => None,
}
}
}
impl Drop for ShaderManager {
fn drop(&mut self) {
for entry in &mut self.pipelines {
entry.pipeline.destroy(&self.device);
}
self.pipelines.clear();
for entry in &mut self.shaders {
entry.shader.destroy(&self.device);
}
self.shaders.clear();
}
}
impl std::fmt::Debug for ShaderManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShaderManager")
.field("shader_count", &self.shaders.len())
.field("pipeline_count", &self.pipelines.len())
.field(
"shaders",
&self.shaders.iter().map(|s| &s.name).collect::<Vec<_>>(),
)
.field(
"pipelines",
&self.pipelines.iter().map(|p| &p.name).collect::<Vec<_>>(),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ex::RuntimeConfig;
use crate::ex::RuntimeManager;
fn test_config() -> RuntimeConfig {
RuntimeConfig {
enable_validation: false,
..Default::default()
}
}
#[test]
fn test_shader_manager_creation() {
let runtime = RuntimeManager::new(test_config()).unwrap();
let device = runtime.device();
let shaders = ShaderManager::new(device);
assert!(shaders.is_ok());
}
#[test]
fn test_add_shader() {
let runtime = RuntimeManager::new(test_config()).unwrap();
let device = runtime.device();
let mut shaders = ShaderManager::new(device).unwrap();
const SIMPLE_VERT: &str = r#"
#version 450
void main() {
gl_Position = vec4(0.0, 0.0, 0.0, 1.0);
}
"#;
let result = shaders.add_shader(SIMPLE_VERT, ShaderStage::Vertex, "test_vert");
assert!(result.is_ok());
let id = result.unwrap();
assert_eq!(shaders.shader_count(), 1);
assert!(shaders.get_shader(id).is_ok());
assert_eq!(shaders.shader_name(id).unwrap(), "test_vert");
}
#[test]
fn test_invalid_shader_id() {
let runtime = RuntimeManager::new(test_config()).unwrap();
let device = runtime.device();
let shaders = ShaderManager::new(device).unwrap();
let invalid_id = ShaderId(999);
assert!(shaders.get_shader(invalid_id).is_err());
}
#[test]
fn test_multiple_shaders() {
let runtime = RuntimeManager::new(test_config()).unwrap();
let device = runtime.device();
let mut shaders = ShaderManager::new(device).unwrap();
const VERT: &str = "#version 450\nvoid main() { gl_Position = vec4(0.0); }";
const FRAG: &str =
"#version 450\nlayout(location=0) out vec4 color; void main() { color = vec4(1.0); }";
let vert_id = shaders
.add_shader(VERT, ShaderStage::Vertex, "vert")
.unwrap();
let frag_id = shaders
.add_shader(FRAG, ShaderStage::Fragment, "frag")
.unwrap();
assert_eq!(shaders.shader_count(), 2);
assert!(shaders.get_shader(vert_id).is_ok());
assert!(shaders.get_shader(frag_id).is_ok());
assert_eq!(shaders.shader_stage(vert_id).unwrap(), ShaderStage::Vertex);
assert_eq!(
shaders.shader_stage(frag_id).unwrap(),
ShaderStage::Fragment
);
}
}