Skip to main content

arcane_core/renderer/
shader.rs

1use std::collections::HashMap;
2
3use wgpu::util::DeviceExt;
4
5use super::gpu::GpuContext;
6
7/// Maximum number of vec4 uniform slots per custom shader.
8const MAX_PARAM_SLOTS: usize = 16;
9/// Size of uniform buffer in bytes (16 vec4s × 16 bytes each).
10const UNIFORM_BUFFER_SIZE: usize = MAX_PARAM_SLOTS * 16;
11
12/// Extract the vertex shader + shared declarations from sprite.wgsl.
13/// Everything before `@fragment` is the preamble.
14fn shader_preamble() -> &'static str {
15    let wgsl = include_str!("shaders/sprite.wgsl");
16    let idx = wgsl
17        .find("@fragment")
18        .expect("sprite.wgsl must contain @fragment");
19    &wgsl[..idx]
20}
21
22/// Build complete WGSL for a custom shader by combining:
23/// 1. Standard preamble (camera, texture, lighting, vertex shader)
24/// 2. Custom uniform params declaration (group 3)
25/// 3. User's fragment shader code
26fn build_custom_wgsl(user_fragment: &str) -> String {
27    format!(
28        r#"{}
29// Custom shader uniform params (16 vec4 slots = 64 floats)
30struct ShaderParams {{
31    values: array<vec4<f32>, 16>,
32}};
33
34@group(3) @binding(0)
35var<uniform> shader_params: ShaderParams;
36
37{}
38"#,
39        shader_preamble(),
40        user_fragment,
41    )
42}
43
44struct ShaderEntry {
45    pipeline: wgpu::RenderPipeline,
46    uniform_buffer: wgpu::Buffer,
47    uniform_bind_group: wgpu::BindGroup,
48    param_data: [f32; MAX_PARAM_SLOTS * 4],
49    dirty: bool,
50}
51
52/// Manages custom user-defined fragment shaders.
53/// Each shader gets its own render pipeline and uniform buffer.
54pub struct ShaderStore {
55    shaders: HashMap<u32, ShaderEntry>,
56    pipeline_layout: wgpu::PipelineLayout,
57    params_bind_group_layout: wgpu::BindGroupLayout,
58    surface_format: wgpu::TextureFormat,
59}
60
61impl ShaderStore {
62    /// Create a shader store for headless testing.
63    pub fn new_headless(device: &wgpu::Device, format: wgpu::TextureFormat) -> Self {
64        Self::new_internal(device, format)
65    }
66
67    pub fn new(gpu: &GpuContext) -> Self {
68        Self::new_internal(&gpu.device, gpu.config.format)
69    }
70
71    fn new_internal(device: &wgpu::Device, surface_format: wgpu::TextureFormat) -> Self {
72        // Create bind group layouts matching SpritePipeline's groups 0-2
73        let camera_layout =
74            device
75                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
76                    label: Some("shader_camera_layout"),
77                    entries: &[wgpu::BindGroupLayoutEntry {
78                        binding: 0,
79                        visibility: wgpu::ShaderStages::VERTEX,
80                        ty: wgpu::BindingType::Buffer {
81                            ty: wgpu::BufferBindingType::Uniform,
82                            has_dynamic_offset: false,
83                            min_binding_size: None,
84                        },
85                        count: None,
86                    }],
87                });
88
89        let texture_layout =
90            device
91                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
92                    label: Some("shader_texture_layout"),
93                    entries: &[
94                        wgpu::BindGroupLayoutEntry {
95                            binding: 0,
96                            visibility: wgpu::ShaderStages::FRAGMENT,
97                            ty: wgpu::BindingType::Texture {
98                                multisampled: false,
99                                view_dimension: wgpu::TextureViewDimension::D2,
100                                sample_type: wgpu::TextureSampleType::Float { filterable: true },
101                            },
102                            count: None,
103                        },
104                        wgpu::BindGroupLayoutEntry {
105                            binding: 1,
106                            visibility: wgpu::ShaderStages::FRAGMENT,
107                            ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
108                            count: None,
109                        },
110                    ],
111                });
112
113        let lighting_layout =
114            device
115                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
116                    label: Some("shader_lighting_layout"),
117                    entries: &[wgpu::BindGroupLayoutEntry {
118                        binding: 0,
119                        visibility: wgpu::ShaderStages::FRAGMENT,
120                        ty: wgpu::BindingType::Buffer {
121                            ty: wgpu::BufferBindingType::Uniform,
122                            has_dynamic_offset: false,
123                            min_binding_size: None,
124                        },
125                        count: None,
126                    }],
127                });
128
129        // Group 3: custom uniform params
130        let params_bind_group_layout =
131            device
132                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
133                    label: Some("shader_params_layout"),
134                    entries: &[wgpu::BindGroupLayoutEntry {
135                        binding: 0,
136                        visibility: wgpu::ShaderStages::FRAGMENT,
137                        ty: wgpu::BindingType::Buffer {
138                            ty: wgpu::BufferBindingType::Uniform,
139                            has_dynamic_offset: false,
140                            min_binding_size: None,
141                        },
142                        count: None,
143                    }],
144                });
145
146        let pipeline_layout =
147            device
148                .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
149                    label: Some("custom_shader_pipeline_layout"),
150                    bind_group_layouts: &[
151                        &camera_layout,
152                        &texture_layout,
153                        &lighting_layout,
154                        &params_bind_group_layout,
155                    ],
156                    push_constant_ranges: &[],
157                });
158
159        Self {
160            shaders: HashMap::new(),
161            pipeline_layout,
162            params_bind_group_layout,
163            surface_format,
164        }
165    }
166
167    /// Compile a custom shader from user-provided WGSL fragment source.
168    /// The source must contain a `@fragment fn fs_main(in: VertexOutput) -> @location(0) vec4<f32>`.
169    /// Standard declarations (camera, texture, lighting, vertex shader) are prepended automatically.
170    /// Custom uniforms are available as `shader_params.values[0..15]` (vec4 array).
171    pub fn create(&mut self, device: &wgpu::Device, id: u32, _name: &str, source: &str) {
172        let full_wgsl = build_custom_wgsl(source);
173
174        let shader_module = device
175            .create_shader_module(wgpu::ShaderModuleDescriptor {
176                label: Some("custom_shader"),
177                source: wgpu::ShaderSource::Wgsl(full_wgsl.into()),
178            });
179
180        let vertex_layout = wgpu::VertexBufferLayout {
181            array_stride: 16, // QuadVertex: 2×f32 + 2×f32 = 16 bytes
182            step_mode: wgpu::VertexStepMode::Vertex,
183            attributes: &[
184                wgpu::VertexAttribute {
185                    offset: 0,
186                    shader_location: 0,
187                    format: wgpu::VertexFormat::Float32x2,
188                },
189                wgpu::VertexAttribute {
190                    offset: 8,
191                    shader_location: 1,
192                    format: wgpu::VertexFormat::Float32x2,
193                },
194            ],
195        };
196
197        let instance_layout = wgpu::VertexBufferLayout {
198            array_stride: 64, // SpriteInstance: 16 floats × 4 bytes = 64
199            step_mode: wgpu::VertexStepMode::Instance,
200            attributes: &[
201                wgpu::VertexAttribute {
202                    offset: 0,
203                    shader_location: 2,
204                    format: wgpu::VertexFormat::Float32x2,
205                },
206                wgpu::VertexAttribute {
207                    offset: 8,
208                    shader_location: 3,
209                    format: wgpu::VertexFormat::Float32x2,
210                },
211                wgpu::VertexAttribute {
212                    offset: 16,
213                    shader_location: 4,
214                    format: wgpu::VertexFormat::Float32x2,
215                },
216                wgpu::VertexAttribute {
217                    offset: 24,
218                    shader_location: 5,
219                    format: wgpu::VertexFormat::Float32x2,
220                },
221                wgpu::VertexAttribute {
222                    offset: 32,
223                    shader_location: 6,
224                    format: wgpu::VertexFormat::Float32x4,
225                },
226                wgpu::VertexAttribute {
227                    offset: 48,
228                    shader_location: 7,
229                    format: wgpu::VertexFormat::Float32x4,
230                },
231            ],
232        };
233
234        let pipeline =
235            device
236                .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
237                    label: Some("custom_shader_pipeline"),
238                    layout: Some(&self.pipeline_layout),
239                    vertex: wgpu::VertexState {
240                        module: &shader_module,
241                        entry_point: Some("vs_main"),
242                        buffers: &[vertex_layout, instance_layout],
243                        compilation_options: Default::default(),
244                    },
245                    fragment: Some(wgpu::FragmentState {
246                        module: &shader_module,
247                        entry_point: Some("fs_main"),
248                        targets: &[Some(wgpu::ColorTargetState {
249                            format: self.surface_format,
250                            blend: Some(wgpu::BlendState::ALPHA_BLENDING),
251                            write_mask: wgpu::ColorWrites::ALL,
252                        })],
253                        compilation_options: Default::default(),
254                    }),
255                    primitive: wgpu::PrimitiveState {
256                        topology: wgpu::PrimitiveTopology::TriangleList,
257                        strip_index_format: None,
258                        front_face: wgpu::FrontFace::Ccw,
259                        cull_mode: None,
260                        polygon_mode: wgpu::PolygonMode::Fill,
261                        unclipped_depth: false,
262                        conservative: false,
263                    },
264                    depth_stencil: None,
265                    multisample: wgpu::MultisampleState::default(),
266                    multiview: None,
267                    cache: None,
268                });
269
270        // Create uniform buffer (zero-initialized)
271        let uniform_buffer =
272            device
273                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
274                    label: Some("shader_params_buffer"),
275                    contents: &[0u8; UNIFORM_BUFFER_SIZE],
276                    usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
277                });
278
279        let uniform_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
280            label: Some("shader_params_bind_group"),
281            layout: &self.params_bind_group_layout,
282            entries: &[wgpu::BindGroupEntry {
283                binding: 0,
284                resource: uniform_buffer.as_entire_binding(),
285            }],
286        });
287
288        self.shaders.insert(
289            id,
290            ShaderEntry {
291                pipeline,
292                uniform_buffer,
293                uniform_bind_group,
294                param_data: [0.0; MAX_PARAM_SLOTS * 4],
295                dirty: false,
296            },
297        );
298    }
299
300    /// Set a vec4 parameter slot for a shader. Index 0-15.
301    pub fn set_param(&mut self, id: u32, index: u32, x: f32, y: f32, z: f32, w: f32) {
302        if let Some(entry) = self.shaders.get_mut(&id) {
303            let i = (index as usize).min(MAX_PARAM_SLOTS - 1) * 4;
304            entry.param_data[i] = x;
305            entry.param_data[i + 1] = y;
306            entry.param_data[i + 2] = z;
307            entry.param_data[i + 3] = w;
308            entry.dirty = true;
309        }
310    }
311
312    /// Flush dirty uniform buffers to GPU.
313    pub fn flush(&mut self, queue: &wgpu::Queue) {
314        for entry in self.shaders.values_mut() {
315            if entry.dirty {
316                queue.write_buffer(
317                    &entry.uniform_buffer,
318                    0,
319                    bytemuck::cast_slice(&entry.param_data),
320                );
321                entry.dirty = false;
322            }
323        }
324    }
325
326    /// Get the pipeline for a custom shader.
327    pub fn get_pipeline(&self, id: u32) -> Option<&wgpu::RenderPipeline> {
328        self.shaders.get(&id).map(|e| &e.pipeline)
329    }
330
331    /// Get the uniform bind group for a custom shader (group 3).
332    pub fn get_bind_group(&self, id: u32) -> Option<&wgpu::BindGroup> {
333        self.shaders.get(&id).map(|e| &e.uniform_bind_group)
334    }
335}