1use std::collections::HashMap;
2
3use wgpu::util::DeviceExt;
4
5use super::gpu::GpuContext;
6
7const MAX_PARAM_SLOTS: usize = 16;
9const UNIFORM_BUFFER_SIZE: usize = MAX_PARAM_SLOTS * 16;
11
12fn shader_preamble() -> &'static str {
15 let wgsl = include_str!("shaders/sprite.wgsl");
16 let idx = wgsl
17 .find("@fragment")
18 .expect("sprite.wgsl must contain @fragment");
19 &wgsl[..idx]
20}
21
22fn build_custom_wgsl(user_fragment: &str) -> String {
27 format!(
28 r#"{}
29// Custom shader uniform params (16 vec4 slots = 64 floats)
30struct ShaderParams {{
31 values: array<vec4<f32>, 16>,
32}};
33
34@group(3) @binding(0)
35var<uniform> shader_params: ShaderParams;
36
37{}
38"#,
39 shader_preamble(),
40 user_fragment,
41 )
42}
43
44struct ShaderEntry {
45 pipeline: wgpu::RenderPipeline,
46 uniform_buffer: wgpu::Buffer,
47 uniform_bind_group: wgpu::BindGroup,
48 param_data: [f32; MAX_PARAM_SLOTS * 4],
49 dirty: bool,
50}
51
52pub struct ShaderStore {
55 shaders: HashMap<u32, ShaderEntry>,
56 pipeline_layout: wgpu::PipelineLayout,
57 params_bind_group_layout: wgpu::BindGroupLayout,
58 surface_format: wgpu::TextureFormat,
59}
60
61impl ShaderStore {
62 pub fn new(gpu: &GpuContext) -> Self {
63 let camera_layout =
65 gpu.device
66 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
67 label: Some("shader_camera_layout"),
68 entries: &[wgpu::BindGroupLayoutEntry {
69 binding: 0,
70 visibility: wgpu::ShaderStages::VERTEX,
71 ty: wgpu::BindingType::Buffer {
72 ty: wgpu::BufferBindingType::Uniform,
73 has_dynamic_offset: false,
74 min_binding_size: None,
75 },
76 count: None,
77 }],
78 });
79
80 let texture_layout =
81 gpu.device
82 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
83 label: Some("shader_texture_layout"),
84 entries: &[
85 wgpu::BindGroupLayoutEntry {
86 binding: 0,
87 visibility: wgpu::ShaderStages::FRAGMENT,
88 ty: wgpu::BindingType::Texture {
89 multisampled: false,
90 view_dimension: wgpu::TextureViewDimension::D2,
91 sample_type: wgpu::TextureSampleType::Float { filterable: true },
92 },
93 count: None,
94 },
95 wgpu::BindGroupLayoutEntry {
96 binding: 1,
97 visibility: wgpu::ShaderStages::FRAGMENT,
98 ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
99 count: None,
100 },
101 ],
102 });
103
104 let lighting_layout =
105 gpu.device
106 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
107 label: Some("shader_lighting_layout"),
108 entries: &[wgpu::BindGroupLayoutEntry {
109 binding: 0,
110 visibility: wgpu::ShaderStages::FRAGMENT,
111 ty: wgpu::BindingType::Buffer {
112 ty: wgpu::BufferBindingType::Uniform,
113 has_dynamic_offset: false,
114 min_binding_size: None,
115 },
116 count: None,
117 }],
118 });
119
120 let params_bind_group_layout =
122 gpu.device
123 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
124 label: Some("shader_params_layout"),
125 entries: &[wgpu::BindGroupLayoutEntry {
126 binding: 0,
127 visibility: wgpu::ShaderStages::FRAGMENT,
128 ty: wgpu::BindingType::Buffer {
129 ty: wgpu::BufferBindingType::Uniform,
130 has_dynamic_offset: false,
131 min_binding_size: None,
132 },
133 count: None,
134 }],
135 });
136
137 let pipeline_layout =
138 gpu.device
139 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
140 label: Some("custom_shader_pipeline_layout"),
141 bind_group_layouts: &[
142 &camera_layout,
143 &texture_layout,
144 &lighting_layout,
145 ¶ms_bind_group_layout,
146 ],
147 push_constant_ranges: &[],
148 });
149
150 Self {
151 shaders: HashMap::new(),
152 pipeline_layout,
153 params_bind_group_layout,
154 surface_format: gpu.config.format,
155 }
156 }
157
158 pub fn create(&mut self, gpu: &GpuContext, id: u32, _name: &str, source: &str) {
163 let full_wgsl = build_custom_wgsl(source);
164
165 let shader_module = gpu
166 .device
167 .create_shader_module(wgpu::ShaderModuleDescriptor {
168 label: Some("custom_shader"),
169 source: wgpu::ShaderSource::Wgsl(full_wgsl.into()),
170 });
171
172 let vertex_layout = wgpu::VertexBufferLayout {
173 array_stride: 16, step_mode: wgpu::VertexStepMode::Vertex,
175 attributes: &[
176 wgpu::VertexAttribute {
177 offset: 0,
178 shader_location: 0,
179 format: wgpu::VertexFormat::Float32x2,
180 },
181 wgpu::VertexAttribute {
182 offset: 8,
183 shader_location: 1,
184 format: wgpu::VertexFormat::Float32x2,
185 },
186 ],
187 };
188
189 let instance_layout = wgpu::VertexBufferLayout {
190 array_stride: 64, step_mode: wgpu::VertexStepMode::Instance,
192 attributes: &[
193 wgpu::VertexAttribute {
194 offset: 0,
195 shader_location: 2,
196 format: wgpu::VertexFormat::Float32x2,
197 },
198 wgpu::VertexAttribute {
199 offset: 8,
200 shader_location: 3,
201 format: wgpu::VertexFormat::Float32x2,
202 },
203 wgpu::VertexAttribute {
204 offset: 16,
205 shader_location: 4,
206 format: wgpu::VertexFormat::Float32x2,
207 },
208 wgpu::VertexAttribute {
209 offset: 24,
210 shader_location: 5,
211 format: wgpu::VertexFormat::Float32x2,
212 },
213 wgpu::VertexAttribute {
214 offset: 32,
215 shader_location: 6,
216 format: wgpu::VertexFormat::Float32x4,
217 },
218 wgpu::VertexAttribute {
219 offset: 48,
220 shader_location: 7,
221 format: wgpu::VertexFormat::Float32x4,
222 },
223 ],
224 };
225
226 let pipeline =
227 gpu.device
228 .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
229 label: Some("custom_shader_pipeline"),
230 layout: Some(&self.pipeline_layout),
231 vertex: wgpu::VertexState {
232 module: &shader_module,
233 entry_point: Some("vs_main"),
234 buffers: &[vertex_layout, instance_layout],
235 compilation_options: Default::default(),
236 },
237 fragment: Some(wgpu::FragmentState {
238 module: &shader_module,
239 entry_point: Some("fs_main"),
240 targets: &[Some(wgpu::ColorTargetState {
241 format: self.surface_format,
242 blend: Some(wgpu::BlendState::ALPHA_BLENDING),
243 write_mask: wgpu::ColorWrites::ALL,
244 })],
245 compilation_options: Default::default(),
246 }),
247 primitive: wgpu::PrimitiveState {
248 topology: wgpu::PrimitiveTopology::TriangleList,
249 strip_index_format: None,
250 front_face: wgpu::FrontFace::Ccw,
251 cull_mode: None,
252 polygon_mode: wgpu::PolygonMode::Fill,
253 unclipped_depth: false,
254 conservative: false,
255 },
256 depth_stencil: None,
257 multisample: wgpu::MultisampleState::default(),
258 multiview: None,
259 cache: None,
260 });
261
262 let uniform_buffer =
264 gpu.device
265 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
266 label: Some("shader_params_buffer"),
267 contents: &[0u8; UNIFORM_BUFFER_SIZE],
268 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
269 });
270
271 let uniform_bind_group = gpu.device.create_bind_group(&wgpu::BindGroupDescriptor {
272 label: Some("shader_params_bind_group"),
273 layout: &self.params_bind_group_layout,
274 entries: &[wgpu::BindGroupEntry {
275 binding: 0,
276 resource: uniform_buffer.as_entire_binding(),
277 }],
278 });
279
280 self.shaders.insert(
281 id,
282 ShaderEntry {
283 pipeline,
284 uniform_buffer,
285 uniform_bind_group,
286 param_data: [0.0; MAX_PARAM_SLOTS * 4],
287 dirty: false,
288 },
289 );
290 }
291
292 pub fn set_param(&mut self, id: u32, index: u32, x: f32, y: f32, z: f32, w: f32) {
294 if let Some(entry) = self.shaders.get_mut(&id) {
295 let i = (index as usize).min(MAX_PARAM_SLOTS - 1) * 4;
296 entry.param_data[i] = x;
297 entry.param_data[i + 1] = y;
298 entry.param_data[i + 2] = z;
299 entry.param_data[i + 3] = w;
300 entry.dirty = true;
301 }
302 }
303
304 pub fn flush(&mut self, gpu: &GpuContext) {
306 for entry in self.shaders.values_mut() {
307 if entry.dirty {
308 gpu.queue.write_buffer(
309 &entry.uniform_buffer,
310 0,
311 bytemuck::cast_slice(&entry.param_data),
312 );
313 entry.dirty = false;
314 }
315 }
316 }
317
318 pub fn get_pipeline(&self, id: u32) -> Option<&wgpu::RenderPipeline> {
320 self.shaders.get(&id).map(|e| &e.pipeline)
321 }
322
323 pub fn get_bind_group(&self, id: u32) -> Option<&wgpu::BindGroup> {
325 self.shaders.get(&id).map(|e| &e.uniform_bind_group)
326 }
327}