1use std::sync::atomic::{AtomicU32, Ordering};
2
3use crate::*;
4
5#[derive(Debug)]
6pub struct ShaderMap {
7 pub shaders: HashMap<ShaderId, Shader>,
8 pub watched_paths: HashMap<String, ShaderId>,
9}
10
11impl ShaderMap {
12 pub fn new() -> Self {
13 Self { shaders: Default::default(), watched_paths: Default::default() }
14 }
15
16 pub fn get(&self, id: ShaderId) -> Option<&Shader> {
17 self.shaders.get(&id)
18 }
19
20 pub fn insert_shader(&mut self, id: ShaderId, shader: Shader) {
21 self.shaders.insert(id, shader);
22 }
23
24 pub fn exists(&self, id: ShaderId) -> bool {
25 self.shaders.contains_key(&id)
26 }
27}
28
29pub type UniformDefs = HashMap<String, UniformDef>;
30
31#[derive(Clone, Debug)]
32pub struct Shader {
33 pub id: ShaderId,
34 pub name: String,
35 pub source: String,
36 pub uniform_defs: UniformDefs,
37 pub bindings: HashMap<String, u32>,
38}
39
40#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
43pub struct ShaderId(pub u64);
44
45impl std::fmt::Display for ShaderId {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(f, "ShaderId({})", self.0)
48 }
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
52pub struct ShaderInstance {
53 pub id: ShaderId,
54 pub uniforms: HashMap<String, Uniform>,
55}
56
57#[derive(Clone, Debug)]
58pub enum UniformDef {
59 F32(Option<f32>),
60 Custom { default_data: Option<Vec<u8>>, wgsl_decl: String },
61}
62
63impl UniformDef {
64 pub fn to_wgsl(&self) -> &str {
65 match self {
66 UniformDef::F32(_) => "f32",
67 UniformDef::Custom { wgsl_decl, .. } => wgsl_decl,
68 }
69 }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, Hash)]
73pub enum Uniform {
74 F32(OrderedFloat<f32>),
75 Custom(Vec<u8>),
76}
77
78static CURRENT_RENDER_TARGET: AtomicU32 = AtomicU32::new(0);
82
83pub fn use_render_target(id: RenderTargetId) {
84 CURRENT_RENDER_TARGET.store(id.0, Ordering::SeqCst);
85 }
87
88pub fn use_default_render_target() {
89 CURRENT_RENDER_TARGET.store(0, Ordering::SeqCst);
90}
91
92pub fn get_current_render_target() -> RenderTargetId {
93 RenderTargetId(CURRENT_RENDER_TARGET.load(Ordering::SeqCst))
94}
95
96pub fn set_uniform_f32(name: impl Into<String>, value: f32) {
98 set_uniform(name, Uniform::F32(OrderedFloat(value)));
99}
100
101pub fn create_shader(
112 shaders: &mut ShaderMap,
113 name: &str,
114 source: &str,
115 uniform_defs: UniformDefs,
116) -> Result<ShaderId> {
117 let id = gen_shader_id();
118
119 if !source.contains("@vertex") {
120 panic!(
121 "Missing @vertex function in shader passed to `create_shader`.
122
123 Did you forget to call `sprite_shader_from_fragment`?"
124 );
125 }
126
127 if shaders.exists(id) {
128 bail!("Shader with name '{}' already exists", name);
129 }
130
131 let bindings = uniform_defs_to_bindings(&uniform_defs);
132
133 shaders.insert_shader(id, Shader {
134 id,
135 name: format!("{} Shader", name),
136 source: build_shader_source(source, &bindings, &uniform_defs),
137 uniform_defs,
138 bindings,
139 });
140
141 Ok(id)
142}
143
144pub fn uniform_defs_to_bindings(
145 uniform_defs: &UniformDefs,
146) -> HashMap<String, u32> {
147 uniform_defs
148 .iter()
149 .sorted_by_key(|x| x.0)
150 .enumerate()
151 .map(|(i, (name, _))| (name.clone(), i as u32))
152 .collect::<HashMap<String, u32>>()
153}
154
155pub struct ReloadableShaderSource {
158 pub static_source: String,
159 pub path: String,
160}
161
162pub fn build_shader_source(
163 fragment_source: &str,
164 bindings: &HashMap<String, u32>,
165 uniform_defs: &UniformDefs,
166) -> String {
167 let mut uniforms_src = String::new();
168
169 for (name, binding) in bindings.iter() {
170 let typ = uniform_defs.get(name).unwrap();
171
172 uniforms_src.push_str(&format!(
173 "@group(2) @binding({})
174 var<uniform> {}: {};",
175 binding,
176 name,
177 typ.to_wgsl()
178 ));
179 }
180
181 format!("{}\n{}", uniforms_src, fragment_source)
182}
183
184#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
185pub struct RenderTargetId(pub u32);