1use std::collections::HashMap;
2
3use wgpu::util::DeviceExt;
4
5use super::gpu::GpuContext;
6
7const MAX_PARAM_SLOTS: usize = 16;
10const BUILTIN_SLOTS: usize = 2;
12const UNIFORM_BUFFER_SIZE: usize = MAX_PARAM_SLOTS * 16;
14
15fn shader_preamble() -> &'static str {
18 let wgsl = include_str!("shaders/sprite.wgsl");
19 let idx = wgsl
20 .find("@fragment")
21 .expect("sprite.wgsl must contain @fragment");
22 &wgsl[..idx]
23}
24
25fn build_custom_wgsl(user_fragment: &str) -> String {
30 format!(
31 r#"{}
32// Custom shader params: 2 built-in vec4s + 14 user vec4 slots = 256 bytes
33struct ShaderParams {{
34 time: f32, // elapsed seconds (auto-injected)
35 delta: f32, // frame delta time (auto-injected)
36 resolution: vec2<f32>, // viewport size in logical pixels (auto-injected)
37 mouse: vec2<f32>, // mouse position in screen pixels (auto-injected)
38 _pad: vec2<f32>,
39 values: array<vec4<f32>, 14>, // user-defined uniform slots
40}};
41
42@group(3) @binding(0)
43var<uniform> shader_params: ShaderParams;
44
45{}
46"#,
47 shader_preamble(),
48 user_fragment,
49 )
50}
51
52struct ShaderEntry {
53 pipeline: wgpu::RenderPipeline,
54 uniform_buffer: wgpu::Buffer,
55 uniform_bind_group: wgpu::BindGroup,
56 param_data: [f32; MAX_PARAM_SLOTS * 4],
57 dirty: bool,
58}
59
60pub struct ShaderStore {
63 shaders: HashMap<u32, ShaderEntry>,
64 pipeline_layout: wgpu::PipelineLayout,
65 params_bind_group_layout: wgpu::BindGroupLayout,
66 surface_format: wgpu::TextureFormat,
67}
68
69impl ShaderStore {
70 pub fn new_headless(device: &wgpu::Device, format: wgpu::TextureFormat) -> Self {
72 Self::new_internal(device, format)
73 }
74
75 pub fn new(gpu: &GpuContext) -> Self {
76 Self::new_internal(&gpu.device, gpu.config.format)
77 }
78
79 fn new_internal(device: &wgpu::Device, surface_format: wgpu::TextureFormat) -> Self {
80 let camera_layout =
82 device
83 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
84 label: Some("shader_camera_layout"),
85 entries: &[wgpu::BindGroupLayoutEntry {
86 binding: 0,
87 visibility: wgpu::ShaderStages::VERTEX,
88 ty: wgpu::BindingType::Buffer {
89 ty: wgpu::BufferBindingType::Uniform,
90 has_dynamic_offset: false,
91 min_binding_size: None,
92 },
93 count: None,
94 }],
95 });
96
97 let texture_layout =
98 device
99 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
100 label: Some("shader_texture_layout"),
101 entries: &[
102 wgpu::BindGroupLayoutEntry {
103 binding: 0,
104 visibility: wgpu::ShaderStages::FRAGMENT,
105 ty: wgpu::BindingType::Texture {
106 multisampled: false,
107 view_dimension: wgpu::TextureViewDimension::D2,
108 sample_type: wgpu::TextureSampleType::Float { filterable: true },
109 },
110 count: None,
111 },
112 wgpu::BindGroupLayoutEntry {
113 binding: 1,
114 visibility: wgpu::ShaderStages::FRAGMENT,
115 ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
116 count: None,
117 },
118 ],
119 });
120
121 let lighting_layout =
122 device
123 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
124 label: Some("shader_lighting_layout"),
125 entries: &[wgpu::BindGroupLayoutEntry {
126 binding: 0,
127 visibility: wgpu::ShaderStages::FRAGMENT,
128 ty: wgpu::BindingType::Buffer {
129 ty: wgpu::BufferBindingType::Uniform,
130 has_dynamic_offset: false,
131 min_binding_size: None,
132 },
133 count: None,
134 }],
135 });
136
137 let params_bind_group_layout =
139 device
140 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
141 label: Some("shader_params_layout"),
142 entries: &[wgpu::BindGroupLayoutEntry {
143 binding: 0,
144 visibility: wgpu::ShaderStages::FRAGMENT,
145 ty: wgpu::BindingType::Buffer {
146 ty: wgpu::BufferBindingType::Uniform,
147 has_dynamic_offset: false,
148 min_binding_size: None,
149 },
150 count: None,
151 }],
152 });
153
154 let pipeline_layout =
155 device
156 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
157 label: Some("custom_shader_pipeline_layout"),
158 bind_group_layouts: &[
159 &camera_layout,
160 &texture_layout,
161 &lighting_layout,
162 ¶ms_bind_group_layout,
163 ],
164 push_constant_ranges: &[],
165 });
166
167 Self {
168 shaders: HashMap::new(),
169 pipeline_layout,
170 params_bind_group_layout,
171 surface_format,
172 }
173 }
174
175 pub fn create(&mut self, device: &wgpu::Device, id: u32, _name: &str, source: &str) {
180 let full_wgsl = build_custom_wgsl(source);
181
182 let shader_module = device
183 .create_shader_module(wgpu::ShaderModuleDescriptor {
184 label: Some("custom_shader"),
185 source: wgpu::ShaderSource::Wgsl(full_wgsl.into()),
186 });
187
188 let vertex_layout = wgpu::VertexBufferLayout {
189 array_stride: 16, step_mode: wgpu::VertexStepMode::Vertex,
191 attributes: &[
192 wgpu::VertexAttribute {
193 offset: 0,
194 shader_location: 0,
195 format: wgpu::VertexFormat::Float32x2,
196 },
197 wgpu::VertexAttribute {
198 offset: 8,
199 shader_location: 1,
200 format: wgpu::VertexFormat::Float32x2,
201 },
202 ],
203 };
204
205 let instance_layout = wgpu::VertexBufferLayout {
206 array_stride: 64, step_mode: wgpu::VertexStepMode::Instance,
208 attributes: &[
209 wgpu::VertexAttribute {
210 offset: 0,
211 shader_location: 2,
212 format: wgpu::VertexFormat::Float32x2,
213 },
214 wgpu::VertexAttribute {
215 offset: 8,
216 shader_location: 3,
217 format: wgpu::VertexFormat::Float32x2,
218 },
219 wgpu::VertexAttribute {
220 offset: 16,
221 shader_location: 4,
222 format: wgpu::VertexFormat::Float32x2,
223 },
224 wgpu::VertexAttribute {
225 offset: 24,
226 shader_location: 5,
227 format: wgpu::VertexFormat::Float32x2,
228 },
229 wgpu::VertexAttribute {
230 offset: 32,
231 shader_location: 6,
232 format: wgpu::VertexFormat::Float32x4,
233 },
234 wgpu::VertexAttribute {
235 offset: 48,
236 shader_location: 7,
237 format: wgpu::VertexFormat::Float32x4,
238 },
239 ],
240 };
241
242 let pipeline =
243 device
244 .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
245 label: Some("custom_shader_pipeline"),
246 layout: Some(&self.pipeline_layout),
247 vertex: wgpu::VertexState {
248 module: &shader_module,
249 entry_point: Some("vs_main"),
250 buffers: &[vertex_layout, instance_layout],
251 compilation_options: Default::default(),
252 },
253 fragment: Some(wgpu::FragmentState {
254 module: &shader_module,
255 entry_point: Some("fs_main"),
256 targets: &[Some(wgpu::ColorTargetState {
257 format: self.surface_format,
258 blend: Some(wgpu::BlendState::ALPHA_BLENDING),
259 write_mask: wgpu::ColorWrites::ALL,
260 })],
261 compilation_options: Default::default(),
262 }),
263 primitive: wgpu::PrimitiveState {
264 topology: wgpu::PrimitiveTopology::TriangleList,
265 strip_index_format: None,
266 front_face: wgpu::FrontFace::Ccw,
267 cull_mode: None,
268 polygon_mode: wgpu::PolygonMode::Fill,
269 unclipped_depth: false,
270 conservative: false,
271 },
272 depth_stencil: None,
273 multisample: wgpu::MultisampleState::default(),
274 multiview: None,
275 cache: None,
276 });
277
278 let uniform_buffer =
280 device
281 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
282 label: Some("shader_params_buffer"),
283 contents: &[0u8; UNIFORM_BUFFER_SIZE],
284 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
285 });
286
287 let uniform_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
288 label: Some("shader_params_bind_group"),
289 layout: &self.params_bind_group_layout,
290 entries: &[wgpu::BindGroupEntry {
291 binding: 0,
292 resource: uniform_buffer.as_entire_binding(),
293 }],
294 });
295
296 self.shaders.insert(
297 id,
298 ShaderEntry {
299 pipeline,
300 uniform_buffer,
301 uniform_bind_group,
302 param_data: [0.0; MAX_PARAM_SLOTS * 4],
303 dirty: false,
304 },
305 );
306 }
307
308 pub fn set_param(&mut self, id: u32, index: u32, x: f32, y: f32, z: f32, w: f32) {
311 if let Some(entry) = self.shaders.get_mut(&id) {
312 let offset_index = (index as usize + BUILTIN_SLOTS).min(MAX_PARAM_SLOTS - 1);
313 let i = offset_index * 4;
314 entry.param_data[i] = x;
315 entry.param_data[i + 1] = y;
316 entry.param_data[i + 2] = z;
317 entry.param_data[i + 3] = w;
318 entry.dirty = true;
319 }
320 }
321
322 pub fn flush(
325 &mut self,
326 queue: &wgpu::Queue,
327 time: f32,
328 delta: f32,
329 resolution: [f32; 2],
330 mouse: [f32; 2],
331 ) {
332 for entry in self.shaders.values_mut() {
333 entry.param_data[0] = time;
335 entry.param_data[1] = delta;
336 entry.param_data[2] = resolution[0];
337 entry.param_data[3] = resolution[1];
338 entry.param_data[4] = mouse[0];
339 entry.param_data[5] = mouse[1];
340 entry.param_data[6] = 0.0; entry.param_data[7] = 0.0; queue.write_buffer(
345 &entry.uniform_buffer,
346 0,
347 bytemuck::cast_slice(&entry.param_data),
348 );
349 entry.dirty = false;
350 }
351 }
352
353 pub fn get_pipeline(&self, id: u32) -> Option<&wgpu::RenderPipeline> {
355 self.shaders.get(&id).map(|e| &e.pipeline)
356 }
357
358 pub fn get_bind_group(&self, id: u32) -> Option<&wgpu::BindGroup> {
360 self.shaders.get(&id).map(|e| &e.uniform_bind_group)
361 }
362}
363
364#[cfg(test)]
368fn compute_param_offset(user_index: u32) -> usize {
369 let offset_index = (user_index as usize + BUILTIN_SLOTS).min(MAX_PARAM_SLOTS - 1);
370 offset_index * 4
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_param_offset_slot_0() {
379 assert_eq!(compute_param_offset(0), 8);
381 }
382
383 #[test]
384 fn test_param_offset_slot_13() {
385 assert_eq!(compute_param_offset(13), 60);
387 }
388
389 #[test]
390 fn test_param_offset_slot_max_clamp() {
391 assert_eq!(compute_param_offset(14), 60);
393 assert_eq!(compute_param_offset(100), 60);
394 }
395
396 #[test]
397 fn test_builtin_slots_consistency() {
398 assert_eq!(BUILTIN_SLOTS, 2);
399 assert_eq!(MAX_PARAM_SLOTS, 16);
400 assert_eq!(UNIFORM_BUFFER_SIZE, 256); }
402
403 #[test]
404 fn test_param_data_layout() {
405 let total_floats = MAX_PARAM_SLOTS * 4;
407 assert_eq!(total_floats, 64);
408 assert_eq!(compute_param_offset(0), 8);
410 assert_eq!(compute_param_offset(13) + 3, 63);
412 }
413}