Skip to main content

rotex_types/shader/
mod.rs

1use crate::resource::BindGroupLayoutDescriptor;
2use serde::{Deserialize, Serialize};
3use std::hash::{DefaultHasher, Hash, Hasher};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
6pub enum ShaderStage {
7    Vertex,
8    Fragment,
9    Compute,
10}
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub struct PushConstantRange {
14    pub stages: crate::resource::ShaderStageFlags,
15    pub offset: u32,
16    pub size: u32,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub struct AbstractPipelineLayout {
21    pub bind_groups: Vec<BindGroupLayoutDescriptor>,
22    pub push_constants: Vec<PushConstantRange>,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
26pub enum ShaderPayload {
27    SpirV(Vec<u8>),
28    Wgsl(String),
29    Dxil(Vec<u8>),
30    HlslSource(String),
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
34pub struct ShaderVariantMap {
35    pub spirv: Option<ShaderPayload>,
36    pub wgsl: Option<ShaderPayload>,
37    pub dxil: Option<ShaderPayload>,
38    pub hlsl: Option<ShaderPayload>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
42pub struct ShaderPackage {
43    pub source_hash: u64,
44    pub stage: ShaderStage,
45    pub entry_point: String,
46    pub layout: AbstractPipelineLayout,
47    pub variants: ShaderVariantMap,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
51pub struct GraphicsShaderPackage {
52    pub vertex: ShaderPackage,
53    pub fragment: ShaderPackage,
54    pub layout: AbstractPipelineLayout,
55}
56
57impl ShaderPackage {
58    pub fn spirv_bytes(&self) -> Option<&[u8]> {
59        match self.variants.spirv.as_ref()? {
60            ShaderPayload::SpirV(bytes) => Some(bytes.as_slice()),
61            _ => None,
62        }
63    }
64
65    pub fn wgsl_source(&self) -> Option<&str> {
66        match self.variants.wgsl.as_ref()? {
67            ShaderPayload::Wgsl(source) => Some(source.as_str()),
68            _ => None,
69        }
70    }
71
72    pub fn payload_hash(&self) -> u64 {
73        if self.source_hash != 0 {
74            return self.source_hash;
75        }
76        let bytes = self
77            .spirv_bytes()
78            .or_else(|| self.wgsl_source().map(|s| s.as_bytes()))
79            .unwrap_or(b"");
80        let mut hasher = DefaultHasher::new();
81        bytes.hash(&mut hasher);
82        hasher.finish()
83    }
84
85    pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
86        postcard::to_allocvec(self)
87    }
88
89    pub fn from_bytes(bytes: &[u8]) -> Result<Self, postcard::Error> {
90        postcard::from_bytes(bytes)
91    }
92}
93
94impl GraphicsShaderPackage {
95    pub fn new(vertex: ShaderPackage, fragment: ShaderPackage, layout: AbstractPipelineLayout) -> Self {
96        Self {
97            vertex,
98            fragment,
99            layout,
100        }
101    }
102
103    pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
104        postcard::to_allocvec(self)
105    }
106
107    pub fn from_bytes(bytes: &[u8]) -> Result<Self, postcard::Error> {
108        postcard::from_bytes(bytes)
109    }
110}
111
112impl AbstractPipelineLayout {
113    pub fn layout_signature(&self) -> u8 {
114        self.bind_groups
115            .iter()
116            .fold(0u8, |signature, group| signature | (1 << group.set))
117    }
118}
119
120impl ShaderVariantMap {
121    pub fn select_spirv(&self) -> Option<&[u8]> {
122        match self.spirv.as_ref()? {
123            ShaderPayload::SpirV(bytes) => Some(bytes.as_slice()),
124            _ => None,
125        }
126    }
127
128    pub fn select_wgsl(&self) -> Option<&str> {
129        match self.wgsl.as_ref()? {
130            ShaderPayload::Wgsl(source) => Some(source.as_str()),
131            _ => None,
132        }
133    }
134}