1use std::collections::HashMap;
2
3use wgpu::util::DeviceExt;
4
5use super::gpu::GpuContext;
6
7const MAX_PARAM_SLOTS: usize = 16;
9const UNIFORM_BUFFER_SIZE: usize = MAX_PARAM_SLOTS * 16;
11
12fn shader_preamble() -> &'static str {
15 let wgsl = include_str!("shaders/sprite.wgsl");
16 let idx = wgsl
17 .find("@fragment")
18 .expect("sprite.wgsl must contain @fragment");
19 &wgsl[..idx]
20}
21
22fn build_custom_wgsl(user_fragment: &str) -> String {
27 format!(
28 r#"{}
29// Custom shader uniform params (16 vec4 slots = 64 floats)
30struct ShaderParams {{
31 values: array<vec4<f32>, 16>,
32}};
33
34@group(3) @binding(0)
35var<uniform> shader_params: ShaderParams;
36
37{}
38"#,
39 shader_preamble(),
40 user_fragment,
41 )
42}
43
44struct ShaderEntry {
45 pipeline: wgpu::RenderPipeline,
46 uniform_buffer: wgpu::Buffer,
47 uniform_bind_group: wgpu::BindGroup,
48 param_data: [f32; MAX_PARAM_SLOTS * 4],
49 dirty: bool,
50}
51
52pub struct ShaderStore {
55 shaders: HashMap<u32, ShaderEntry>,
56 pipeline_layout: wgpu::PipelineLayout,
57 params_bind_group_layout: wgpu::BindGroupLayout,
58 surface_format: wgpu::TextureFormat,
59}
60
61impl ShaderStore {
62 pub fn new_headless(device: &wgpu::Device, format: wgpu::TextureFormat) -> Self {
64 Self::new_internal(device, format)
65 }
66
67 pub fn new(gpu: &GpuContext) -> Self {
68 Self::new_internal(&gpu.device, gpu.config.format)
69 }
70
71 fn new_internal(device: &wgpu::Device, surface_format: wgpu::TextureFormat) -> Self {
72 let camera_layout =
74 device
75 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
76 label: Some("shader_camera_layout"),
77 entries: &[wgpu::BindGroupLayoutEntry {
78 binding: 0,
79 visibility: wgpu::ShaderStages::VERTEX,
80 ty: wgpu::BindingType::Buffer {
81 ty: wgpu::BufferBindingType::Uniform,
82 has_dynamic_offset: false,
83 min_binding_size: None,
84 },
85 count: None,
86 }],
87 });
88
89 let texture_layout =
90 device
91 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
92 label: Some("shader_texture_layout"),
93 entries: &[
94 wgpu::BindGroupLayoutEntry {
95 binding: 0,
96 visibility: wgpu::ShaderStages::FRAGMENT,
97 ty: wgpu::BindingType::Texture {
98 multisampled: false,
99 view_dimension: wgpu::TextureViewDimension::D2,
100 sample_type: wgpu::TextureSampleType::Float { filterable: true },
101 },
102 count: None,
103 },
104 wgpu::BindGroupLayoutEntry {
105 binding: 1,
106 visibility: wgpu::ShaderStages::FRAGMENT,
107 ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
108 count: None,
109 },
110 ],
111 });
112
113 let lighting_layout =
114 device
115 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
116 label: Some("shader_lighting_layout"),
117 entries: &[wgpu::BindGroupLayoutEntry {
118 binding: 0,
119 visibility: wgpu::ShaderStages::FRAGMENT,
120 ty: wgpu::BindingType::Buffer {
121 ty: wgpu::BufferBindingType::Uniform,
122 has_dynamic_offset: false,
123 min_binding_size: None,
124 },
125 count: None,
126 }],
127 });
128
129 let params_bind_group_layout =
131 device
132 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
133 label: Some("shader_params_layout"),
134 entries: &[wgpu::BindGroupLayoutEntry {
135 binding: 0,
136 visibility: wgpu::ShaderStages::FRAGMENT,
137 ty: wgpu::BindingType::Buffer {
138 ty: wgpu::BufferBindingType::Uniform,
139 has_dynamic_offset: false,
140 min_binding_size: None,
141 },
142 count: None,
143 }],
144 });
145
146 let pipeline_layout =
147 device
148 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
149 label: Some("custom_shader_pipeline_layout"),
150 bind_group_layouts: &[
151 &camera_layout,
152 &texture_layout,
153 &lighting_layout,
154 ¶ms_bind_group_layout,
155 ],
156 push_constant_ranges: &[],
157 });
158
159 Self {
160 shaders: HashMap::new(),
161 pipeline_layout,
162 params_bind_group_layout,
163 surface_format,
164 }
165 }
166
167 pub fn create(&mut self, device: &wgpu::Device, id: u32, _name: &str, source: &str) {
172 let full_wgsl = build_custom_wgsl(source);
173
174 let shader_module = device
175 .create_shader_module(wgpu::ShaderModuleDescriptor {
176 label: Some("custom_shader"),
177 source: wgpu::ShaderSource::Wgsl(full_wgsl.into()),
178 });
179
180 let vertex_layout = wgpu::VertexBufferLayout {
181 array_stride: 16, step_mode: wgpu::VertexStepMode::Vertex,
183 attributes: &[
184 wgpu::VertexAttribute {
185 offset: 0,
186 shader_location: 0,
187 format: wgpu::VertexFormat::Float32x2,
188 },
189 wgpu::VertexAttribute {
190 offset: 8,
191 shader_location: 1,
192 format: wgpu::VertexFormat::Float32x2,
193 },
194 ],
195 };
196
197 let instance_layout = wgpu::VertexBufferLayout {
198 array_stride: 64, step_mode: wgpu::VertexStepMode::Instance,
200 attributes: &[
201 wgpu::VertexAttribute {
202 offset: 0,
203 shader_location: 2,
204 format: wgpu::VertexFormat::Float32x2,
205 },
206 wgpu::VertexAttribute {
207 offset: 8,
208 shader_location: 3,
209 format: wgpu::VertexFormat::Float32x2,
210 },
211 wgpu::VertexAttribute {
212 offset: 16,
213 shader_location: 4,
214 format: wgpu::VertexFormat::Float32x2,
215 },
216 wgpu::VertexAttribute {
217 offset: 24,
218 shader_location: 5,
219 format: wgpu::VertexFormat::Float32x2,
220 },
221 wgpu::VertexAttribute {
222 offset: 32,
223 shader_location: 6,
224 format: wgpu::VertexFormat::Float32x4,
225 },
226 wgpu::VertexAttribute {
227 offset: 48,
228 shader_location: 7,
229 format: wgpu::VertexFormat::Float32x4,
230 },
231 ],
232 };
233
234 let pipeline =
235 device
236 .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
237 label: Some("custom_shader_pipeline"),
238 layout: Some(&self.pipeline_layout),
239 vertex: wgpu::VertexState {
240 module: &shader_module,
241 entry_point: Some("vs_main"),
242 buffers: &[vertex_layout, instance_layout],
243 compilation_options: Default::default(),
244 },
245 fragment: Some(wgpu::FragmentState {
246 module: &shader_module,
247 entry_point: Some("fs_main"),
248 targets: &[Some(wgpu::ColorTargetState {
249 format: self.surface_format,
250 blend: Some(wgpu::BlendState::ALPHA_BLENDING),
251 write_mask: wgpu::ColorWrites::ALL,
252 })],
253 compilation_options: Default::default(),
254 }),
255 primitive: wgpu::PrimitiveState {
256 topology: wgpu::PrimitiveTopology::TriangleList,
257 strip_index_format: None,
258 front_face: wgpu::FrontFace::Ccw,
259 cull_mode: None,
260 polygon_mode: wgpu::PolygonMode::Fill,
261 unclipped_depth: false,
262 conservative: false,
263 },
264 depth_stencil: None,
265 multisample: wgpu::MultisampleState::default(),
266 multiview: None,
267 cache: None,
268 });
269
270 let uniform_buffer =
272 device
273 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
274 label: Some("shader_params_buffer"),
275 contents: &[0u8; UNIFORM_BUFFER_SIZE],
276 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
277 });
278
279 let uniform_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
280 label: Some("shader_params_bind_group"),
281 layout: &self.params_bind_group_layout,
282 entries: &[wgpu::BindGroupEntry {
283 binding: 0,
284 resource: uniform_buffer.as_entire_binding(),
285 }],
286 });
287
288 self.shaders.insert(
289 id,
290 ShaderEntry {
291 pipeline,
292 uniform_buffer,
293 uniform_bind_group,
294 param_data: [0.0; MAX_PARAM_SLOTS * 4],
295 dirty: false,
296 },
297 );
298 }
299
300 pub fn set_param(&mut self, id: u32, index: u32, x: f32, y: f32, z: f32, w: f32) {
302 if let Some(entry) = self.shaders.get_mut(&id) {
303 let i = (index as usize).min(MAX_PARAM_SLOTS - 1) * 4;
304 entry.param_data[i] = x;
305 entry.param_data[i + 1] = y;
306 entry.param_data[i + 2] = z;
307 entry.param_data[i + 3] = w;
308 entry.dirty = true;
309 }
310 }
311
312 pub fn flush(&mut self, queue: &wgpu::Queue) {
314 for entry in self.shaders.values_mut() {
315 if entry.dirty {
316 queue.write_buffer(
317 &entry.uniform_buffer,
318 0,
319 bytemuck::cast_slice(&entry.param_data),
320 );
321 entry.dirty = false;
322 }
323 }
324 }
325
326 pub fn get_pipeline(&self, id: u32) -> Option<&wgpu::RenderPipeline> {
328 self.shaders.get(&id).map(|e| &e.pipeline)
329 }
330
331 pub fn get_bind_group(&self, id: u32) -> Option<&wgpu::BindGroup> {
333 self.shaders.get(&id).map(|e| &e.uniform_bind_group)
334 }
335}