Skip to main content

arcane_core/renderer/
shader.rs

1use std::collections::HashMap;
2
3use wgpu::util::DeviceExt;
4
5use super::gpu::GpuContext;
6
7/// Maximum number of vec4 uniform slots per custom shader (built-in + user).
8/// Layout: 2 vec4s for built-ins (time, delta, resolution, mouse, padding) + 14 user vec4s.
9const MAX_PARAM_SLOTS: usize = 16;
10/// Number of vec4 slots reserved for built-in uniforms (time, delta, resolution, mouse, pad).
11const BUILTIN_SLOTS: usize = 2;
12/// Size of uniform buffer in bytes (16 vec4s × 16 bytes each = 256 bytes).
13const UNIFORM_BUFFER_SIZE: usize = MAX_PARAM_SLOTS * 16;
14
15/// Extract the vertex shader + shared declarations from sprite.wgsl.
16/// Everything before `@fragment` is the preamble.
17fn shader_preamble() -> &'static str {
18    let wgsl = include_str!("shaders/sprite.wgsl");
19    let idx = wgsl
20        .find("@fragment")
21        .expect("sprite.wgsl must contain @fragment");
22    &wgsl[..idx]
23}
24
25/// Build complete WGSL for a custom shader by combining:
26/// 1. Standard preamble (camera, texture, lighting, vertex shader)
27/// 2. Custom uniform params declaration (group 3)
28/// 3. User's fragment shader code
29fn build_custom_wgsl(user_fragment: &str) -> String {
30    format!(
31        r#"{}
32// Custom shader params: 2 built-in vec4s + 14 user vec4 slots = 256 bytes
33struct ShaderParams {{
34    time: f32,              // elapsed seconds (auto-injected)
35    delta: f32,             // frame delta time (auto-injected)
36    resolution: vec2<f32>,  // viewport size in logical pixels (auto-injected)
37    mouse: vec2<f32>,       // mouse position in screen pixels (auto-injected)
38    _pad: vec2<f32>,
39    values: array<vec4<f32>, 14>,  // user-defined uniform slots
40}};
41
42@group(3) @binding(0)
43var<uniform> shader_params: ShaderParams;
44
45{}
46"#,
47        shader_preamble(),
48        user_fragment,
49    )
50}
51
52struct ShaderEntry {
53    pipeline: wgpu::RenderPipeline,
54    uniform_buffer: wgpu::Buffer,
55    uniform_bind_group: wgpu::BindGroup,
56    param_data: [f32; MAX_PARAM_SLOTS * 4],
57    dirty: bool,
58}
59
60/// Manages custom user-defined fragment shaders.
61/// Each shader gets its own render pipeline and uniform buffer.
62pub struct ShaderStore {
63    shaders: HashMap<u32, ShaderEntry>,
64    pipeline_layout: wgpu::PipelineLayout,
65    params_bind_group_layout: wgpu::BindGroupLayout,
66    surface_format: wgpu::TextureFormat,
67}
68
69impl ShaderStore {
70    /// Create a shader store for headless testing.
71    pub fn new_headless(device: &wgpu::Device, format: wgpu::TextureFormat) -> Self {
72        Self::new_internal(device, format)
73    }
74
75    pub fn new(gpu: &GpuContext) -> Self {
76        Self::new_internal(&gpu.device, gpu.config.format)
77    }
78
79    fn new_internal(device: &wgpu::Device, surface_format: wgpu::TextureFormat) -> Self {
80        // Create bind group layouts matching SpritePipeline's groups 0-2
81        let camera_layout =
82            device
83                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
84                    label: Some("shader_camera_layout"),
85                    entries: &[wgpu::BindGroupLayoutEntry {
86                        binding: 0,
87                        visibility: wgpu::ShaderStages::VERTEX,
88                        ty: wgpu::BindingType::Buffer {
89                            ty: wgpu::BufferBindingType::Uniform,
90                            has_dynamic_offset: false,
91                            min_binding_size: None,
92                        },
93                        count: None,
94                    }],
95                });
96
97        let texture_layout =
98            device
99                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
100                    label: Some("shader_texture_layout"),
101                    entries: &[
102                        wgpu::BindGroupLayoutEntry {
103                            binding: 0,
104                            visibility: wgpu::ShaderStages::FRAGMENT,
105                            ty: wgpu::BindingType::Texture {
106                                multisampled: false,
107                                view_dimension: wgpu::TextureViewDimension::D2,
108                                sample_type: wgpu::TextureSampleType::Float { filterable: true },
109                            },
110                            count: None,
111                        },
112                        wgpu::BindGroupLayoutEntry {
113                            binding: 1,
114                            visibility: wgpu::ShaderStages::FRAGMENT,
115                            ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
116                            count: None,
117                        },
118                    ],
119                });
120
121        let lighting_layout =
122            device
123                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
124                    label: Some("shader_lighting_layout"),
125                    entries: &[wgpu::BindGroupLayoutEntry {
126                        binding: 0,
127                        visibility: wgpu::ShaderStages::FRAGMENT,
128                        ty: wgpu::BindingType::Buffer {
129                            ty: wgpu::BufferBindingType::Uniform,
130                            has_dynamic_offset: false,
131                            min_binding_size: None,
132                        },
133                        count: None,
134                    }],
135                });
136
137        // Group 3: custom uniform params
138        let params_bind_group_layout =
139            device
140                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
141                    label: Some("shader_params_layout"),
142                    entries: &[wgpu::BindGroupLayoutEntry {
143                        binding: 0,
144                        visibility: wgpu::ShaderStages::FRAGMENT,
145                        ty: wgpu::BindingType::Buffer {
146                            ty: wgpu::BufferBindingType::Uniform,
147                            has_dynamic_offset: false,
148                            min_binding_size: None,
149                        },
150                        count: None,
151                    }],
152                });
153
154        let pipeline_layout =
155            device
156                .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
157                    label: Some("custom_shader_pipeline_layout"),
158                    bind_group_layouts: &[
159                        &camera_layout,
160                        &texture_layout,
161                        &lighting_layout,
162                        &params_bind_group_layout,
163                    ],
164                    push_constant_ranges: &[],
165                });
166
167        Self {
168            shaders: HashMap::new(),
169            pipeline_layout,
170            params_bind_group_layout,
171            surface_format,
172        }
173    }
174
175    /// Compile a custom shader from user-provided WGSL fragment source.
176    /// The source must contain a `@fragment fn fs_main(in: VertexOutput) -> @location(0) vec4<f32>`.
177    /// Standard declarations (camera, texture, lighting, vertex shader) are prepended automatically.
178    /// Custom uniforms are available as `shader_params.values[0..15]` (vec4 array).
179    pub fn create(&mut self, device: &wgpu::Device, id: u32, _name: &str, source: &str) {
180        let full_wgsl = build_custom_wgsl(source);
181
182        let shader_module = device
183            .create_shader_module(wgpu::ShaderModuleDescriptor {
184                label: Some("custom_shader"),
185                source: wgpu::ShaderSource::Wgsl(full_wgsl.into()),
186            });
187
188        let vertex_layout = wgpu::VertexBufferLayout {
189            array_stride: 16, // QuadVertex: 2×f32 + 2×f32 = 16 bytes
190            step_mode: wgpu::VertexStepMode::Vertex,
191            attributes: &[
192                wgpu::VertexAttribute {
193                    offset: 0,
194                    shader_location: 0,
195                    format: wgpu::VertexFormat::Float32x2,
196                },
197                wgpu::VertexAttribute {
198                    offset: 8,
199                    shader_location: 1,
200                    format: wgpu::VertexFormat::Float32x2,
201                },
202            ],
203        };
204
205        let instance_layout = wgpu::VertexBufferLayout {
206            array_stride: 64, // SpriteInstance: 16 floats × 4 bytes = 64
207            step_mode: wgpu::VertexStepMode::Instance,
208            attributes: &[
209                wgpu::VertexAttribute {
210                    offset: 0,
211                    shader_location: 2,
212                    format: wgpu::VertexFormat::Float32x2,
213                },
214                wgpu::VertexAttribute {
215                    offset: 8,
216                    shader_location: 3,
217                    format: wgpu::VertexFormat::Float32x2,
218                },
219                wgpu::VertexAttribute {
220                    offset: 16,
221                    shader_location: 4,
222                    format: wgpu::VertexFormat::Float32x2,
223                },
224                wgpu::VertexAttribute {
225                    offset: 24,
226                    shader_location: 5,
227                    format: wgpu::VertexFormat::Float32x2,
228                },
229                wgpu::VertexAttribute {
230                    offset: 32,
231                    shader_location: 6,
232                    format: wgpu::VertexFormat::Float32x4,
233                },
234                wgpu::VertexAttribute {
235                    offset: 48,
236                    shader_location: 7,
237                    format: wgpu::VertexFormat::Float32x4,
238                },
239            ],
240        };
241
242        let pipeline =
243            device
244                .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
245                    label: Some("custom_shader_pipeline"),
246                    layout: Some(&self.pipeline_layout),
247                    vertex: wgpu::VertexState {
248                        module: &shader_module,
249                        entry_point: Some("vs_main"),
250                        buffers: &[vertex_layout, instance_layout],
251                        compilation_options: Default::default(),
252                    },
253                    fragment: Some(wgpu::FragmentState {
254                        module: &shader_module,
255                        entry_point: Some("fs_main"),
256                        targets: &[Some(wgpu::ColorTargetState {
257                            format: self.surface_format,
258                            blend: Some(wgpu::BlendState::ALPHA_BLENDING),
259                            write_mask: wgpu::ColorWrites::ALL,
260                        })],
261                        compilation_options: Default::default(),
262                    }),
263                    primitive: wgpu::PrimitiveState {
264                        topology: wgpu::PrimitiveTopology::TriangleList,
265                        strip_index_format: None,
266                        front_face: wgpu::FrontFace::Ccw,
267                        cull_mode: None,
268                        polygon_mode: wgpu::PolygonMode::Fill,
269                        unclipped_depth: false,
270                        conservative: false,
271                    },
272                    depth_stencil: None,
273                    multisample: wgpu::MultisampleState::default(),
274                    multiview: None,
275                    cache: None,
276                });
277
278        // Create uniform buffer (zero-initialized)
279        let uniform_buffer =
280            device
281                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
282                    label: Some("shader_params_buffer"),
283                    contents: &[0u8; UNIFORM_BUFFER_SIZE],
284                    usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
285                });
286
287        let uniform_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
288            label: Some("shader_params_bind_group"),
289            layout: &self.params_bind_group_layout,
290            entries: &[wgpu::BindGroupEntry {
291                binding: 0,
292                resource: uniform_buffer.as_entire_binding(),
293            }],
294        });
295
296        self.shaders.insert(
297            id,
298            ShaderEntry {
299                pipeline,
300                uniform_buffer,
301                uniform_bind_group,
302                param_data: [0.0; MAX_PARAM_SLOTS * 4],
303                dirty: false,
304            },
305        );
306    }
307
308    /// Set a vec4 user parameter slot for a shader. Index 0-13 maps to WGSL `values[0..13]`.
309    /// Internally offset by BUILTIN_SLOTS so user slot 0 → param_data[8..11].
310    pub fn set_param(&mut self, id: u32, index: u32, x: f32, y: f32, z: f32, w: f32) {
311        if let Some(entry) = self.shaders.get_mut(&id) {
312            let offset_index = (index as usize + BUILTIN_SLOTS).min(MAX_PARAM_SLOTS - 1);
313            let i = offset_index * 4;
314            entry.param_data[i] = x;
315            entry.param_data[i + 1] = y;
316            entry.param_data[i + 2] = z;
317            entry.param_data[i + 3] = w;
318            entry.dirty = true;
319        }
320    }
321
322    /// Flush uniform buffers to GPU, auto-injecting built-in values.
323    /// Built-ins (time, delta, resolution, mouse) are written every frame for all shaders.
324    pub fn flush(
325        &mut self,
326        queue: &wgpu::Queue,
327        time: f32,
328        delta: f32,
329        resolution: [f32; 2],
330        mouse: [f32; 2],
331    ) {
332        for entry in self.shaders.values_mut() {
333            // Always write built-ins (first 8 floats = 2 vec4 slots)
334            entry.param_data[0] = time;
335            entry.param_data[1] = delta;
336            entry.param_data[2] = resolution[0];
337            entry.param_data[3] = resolution[1];
338            entry.param_data[4] = mouse[0];
339            entry.param_data[5] = mouse[1];
340            entry.param_data[6] = 0.0; // _pad.x
341            entry.param_data[7] = 0.0; // _pad.y
342
343            // Always upload — built-ins change every frame
344            queue.write_buffer(
345                &entry.uniform_buffer,
346                0,
347                bytemuck::cast_slice(&entry.param_data),
348            );
349            entry.dirty = false;
350        }
351    }
352
353    /// Get the pipeline for a custom shader.
354    pub fn get_pipeline(&self, id: u32) -> Option<&wgpu::RenderPipeline> {
355        self.shaders.get(&id).map(|e| &e.pipeline)
356    }
357
358    /// Get the uniform bind group for a custom shader (group 3).
359    pub fn get_bind_group(&self, id: u32) -> Option<&wgpu::BindGroup> {
360        self.shaders.get(&id).map(|e| &e.uniform_bind_group)
361    }
362}
363
364/// Compute the param_data array index for a user slot index.
365/// User slot 0 → float index 8 (after 2 built-in vec4s).
366/// Clamped to MAX_PARAM_SLOTS - 1 to prevent out-of-bounds.
367#[cfg(test)]
368fn compute_param_offset(user_index: u32) -> usize {
369    let offset_index = (user_index as usize + BUILTIN_SLOTS).min(MAX_PARAM_SLOTS - 1);
370    offset_index * 4
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_param_offset_slot_0() {
379        // User slot 0 → offset by BUILTIN_SLOTS (2) → vec4 index 2 → float index 8
380        assert_eq!(compute_param_offset(0), 8);
381    }
382
383    #[test]
384    fn test_param_offset_slot_13() {
385        // User slot 13 → vec4 index 15 → float index 60
386        assert_eq!(compute_param_offset(13), 60);
387    }
388
389    #[test]
390    fn test_param_offset_slot_max_clamp() {
391        // User slot 14+ → clamped to MAX_PARAM_SLOTS-1 (15) → float index 60
392        assert_eq!(compute_param_offset(14), 60);
393        assert_eq!(compute_param_offset(100), 60);
394    }
395
396    #[test]
397    fn test_builtin_slots_consistency() {
398        assert_eq!(BUILTIN_SLOTS, 2);
399        assert_eq!(MAX_PARAM_SLOTS, 16);
400        assert_eq!(UNIFORM_BUFFER_SIZE, 256); // 16 * 16 bytes
401    }
402
403    #[test]
404    fn test_param_data_layout() {
405        // Verify the full layout: built-in slots 0-1 (8 floats), user slots 2-15 (56 floats)
406        let total_floats = MAX_PARAM_SLOTS * 4;
407        assert_eq!(total_floats, 64);
408        // First user slot starts at float index 8
409        assert_eq!(compute_param_offset(0), 8);
410        // Last user slot (13) starts at float index 60, ends at 63
411        assert_eq!(compute_param_offset(13) + 3, 63);
412    }
413}