1use crate::resource::BindGroupLayoutDescriptor;
2use serde::{Deserialize, Serialize};
3use std::hash::{DefaultHasher, Hash, Hasher};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
6pub enum ShaderStage {
7 Vertex,
8 Fragment,
9 Compute,
10}
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub struct PushConstantRange {
14 pub stages: crate::resource::ShaderStageFlags,
15 pub offset: u32,
16 pub size: u32,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub struct AbstractPipelineLayout {
21 pub bind_groups: Vec<BindGroupLayoutDescriptor>,
22 pub push_constants: Vec<PushConstantRange>,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
26pub enum ShaderPayload {
27 SpirV(Vec<u8>),
28 Wgsl(String),
29 Dxil(Vec<u8>),
30 HlslSource(String),
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
34pub struct ShaderVariantMap {
35 pub spirv: Option<ShaderPayload>,
36 pub wgsl: Option<ShaderPayload>,
37 pub dxil: Option<ShaderPayload>,
38 pub hlsl: Option<ShaderPayload>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
42pub struct ShaderPackage {
43 pub source_hash: u64,
44 pub stage: ShaderStage,
45 pub entry_point: String,
46 pub layout: AbstractPipelineLayout,
47 pub variants: ShaderVariantMap,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
51pub struct GraphicsShaderPackage {
52 pub vertex: ShaderPackage,
53 pub fragment: ShaderPackage,
54 pub layout: AbstractPipelineLayout,
55}
56
57impl ShaderPackage {
58 pub fn spirv_bytes(&self) -> Option<&[u8]> {
59 match self.variants.spirv.as_ref()? {
60 ShaderPayload::SpirV(bytes) => Some(bytes.as_slice()),
61 _ => None,
62 }
63 }
64
65 pub fn wgsl_source(&self) -> Option<&str> {
66 match self.variants.wgsl.as_ref()? {
67 ShaderPayload::Wgsl(source) => Some(source.as_str()),
68 _ => None,
69 }
70 }
71
72 pub fn payload_hash(&self) -> u64 {
73 if self.source_hash != 0 {
74 return self.source_hash;
75 }
76 let bytes = self
77 .spirv_bytes()
78 .or_else(|| self.wgsl_source().map(|s| s.as_bytes()))
79 .unwrap_or(b"");
80 let mut hasher = DefaultHasher::new();
81 bytes.hash(&mut hasher);
82 hasher.finish()
83 }
84
85 pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
86 postcard::to_allocvec(self)
87 }
88
89 pub fn from_bytes(bytes: &[u8]) -> Result<Self, postcard::Error> {
90 postcard::from_bytes(bytes)
91 }
92}
93
94impl GraphicsShaderPackage {
95 pub fn new(vertex: ShaderPackage, fragment: ShaderPackage, layout: AbstractPipelineLayout) -> Self {
96 Self {
97 vertex,
98 fragment,
99 layout,
100 }
101 }
102
103 pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
104 postcard::to_allocvec(self)
105 }
106
107 pub fn from_bytes(bytes: &[u8]) -> Result<Self, postcard::Error> {
108 postcard::from_bytes(bytes)
109 }
110}
111
112impl AbstractPipelineLayout {
113 pub fn layout_signature(&self) -> u8 {
114 self.bind_groups
115 .iter()
116 .fold(0u8, |signature, group| signature | (1 << group.set))
117 }
118}
119
120impl ShaderVariantMap {
121 pub fn select_spirv(&self) -> Option<&[u8]> {
122 match self.spirv.as_ref()? {
123 ShaderPayload::SpirV(bytes) => Some(bytes.as_slice()),
124 _ => None,
125 }
126 }
127
128 pub fn select_wgsl(&self) -> Option<&str> {
129 match self.wgsl.as_ref()? {
130 ShaderPayload::Wgsl(source) => Some(source.as_str()),
131 _ => None,
132 }
133 }
134}