Skip to main content

arcane_engine/renderer/
shader.rs

1use std::collections::HashMap;
2
3use wgpu::util::DeviceExt;
4
5use super::gpu::GpuContext;
6
7/// Maximum number of vec4 uniform slots per custom shader.
8const MAX_PARAM_SLOTS: usize = 16;
9/// Size of uniform buffer in bytes (16 vec4s × 16 bytes each).
10const UNIFORM_BUFFER_SIZE: usize = MAX_PARAM_SLOTS * 16;
11
12/// Extract the vertex shader + shared declarations from sprite.wgsl.
13/// Everything before `@fragment` is the preamble.
14fn shader_preamble() -> &'static str {
15    let wgsl = include_str!("shaders/sprite.wgsl");
16    let idx = wgsl
17        .find("@fragment")
18        .expect("sprite.wgsl must contain @fragment");
19    &wgsl[..idx]
20}
21
22/// Build complete WGSL for a custom shader by combining:
23/// 1. Standard preamble (camera, texture, lighting, vertex shader)
24/// 2. Custom uniform params declaration (group 3)
25/// 3. User's fragment shader code
26fn build_custom_wgsl(user_fragment: &str) -> String {
27    format!(
28        r#"{}
29// Custom shader uniform params (16 vec4 slots = 64 floats)
30struct ShaderParams {{
31    values: array<vec4<f32>, 16>,
32}};
33
34@group(3) @binding(0)
35var<uniform> shader_params: ShaderParams;
36
37{}
38"#,
39        shader_preamble(),
40        user_fragment,
41    )
42}
43
44struct ShaderEntry {
45    pipeline: wgpu::RenderPipeline,
46    uniform_buffer: wgpu::Buffer,
47    uniform_bind_group: wgpu::BindGroup,
48    param_data: [f32; MAX_PARAM_SLOTS * 4],
49    dirty: bool,
50}
51
52/// Manages custom user-defined fragment shaders.
53/// Each shader gets its own render pipeline and uniform buffer.
54pub struct ShaderStore {
55    shaders: HashMap<u32, ShaderEntry>,
56    pipeline_layout: wgpu::PipelineLayout,
57    params_bind_group_layout: wgpu::BindGroupLayout,
58    surface_format: wgpu::TextureFormat,
59}
60
61impl ShaderStore {
62    pub fn new(gpu: &GpuContext) -> Self {
63        // Create bind group layouts matching SpritePipeline's groups 0-2
64        let camera_layout =
65            gpu.device
66                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
67                    label: Some("shader_camera_layout"),
68                    entries: &[wgpu::BindGroupLayoutEntry {
69                        binding: 0,
70                        visibility: wgpu::ShaderStages::VERTEX,
71                        ty: wgpu::BindingType::Buffer {
72                            ty: wgpu::BufferBindingType::Uniform,
73                            has_dynamic_offset: false,
74                            min_binding_size: None,
75                        },
76                        count: None,
77                    }],
78                });
79
80        let texture_layout =
81            gpu.device
82                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
83                    label: Some("shader_texture_layout"),
84                    entries: &[
85                        wgpu::BindGroupLayoutEntry {
86                            binding: 0,
87                            visibility: wgpu::ShaderStages::FRAGMENT,
88                            ty: wgpu::BindingType::Texture {
89                                multisampled: false,
90                                view_dimension: wgpu::TextureViewDimension::D2,
91                                sample_type: wgpu::TextureSampleType::Float { filterable: true },
92                            },
93                            count: None,
94                        },
95                        wgpu::BindGroupLayoutEntry {
96                            binding: 1,
97                            visibility: wgpu::ShaderStages::FRAGMENT,
98                            ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
99                            count: None,
100                        },
101                    ],
102                });
103
104        let lighting_layout =
105            gpu.device
106                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
107                    label: Some("shader_lighting_layout"),
108                    entries: &[wgpu::BindGroupLayoutEntry {
109                        binding: 0,
110                        visibility: wgpu::ShaderStages::FRAGMENT,
111                        ty: wgpu::BindingType::Buffer {
112                            ty: wgpu::BufferBindingType::Uniform,
113                            has_dynamic_offset: false,
114                            min_binding_size: None,
115                        },
116                        count: None,
117                    }],
118                });
119
120        // Group 3: custom uniform params
121        let params_bind_group_layout =
122            gpu.device
123                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
124                    label: Some("shader_params_layout"),
125                    entries: &[wgpu::BindGroupLayoutEntry {
126                        binding: 0,
127                        visibility: wgpu::ShaderStages::FRAGMENT,
128                        ty: wgpu::BindingType::Buffer {
129                            ty: wgpu::BufferBindingType::Uniform,
130                            has_dynamic_offset: false,
131                            min_binding_size: None,
132                        },
133                        count: None,
134                    }],
135                });
136
137        let pipeline_layout =
138            gpu.device
139                .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
140                    label: Some("custom_shader_pipeline_layout"),
141                    bind_group_layouts: &[
142                        &camera_layout,
143                        &texture_layout,
144                        &lighting_layout,
145                        &params_bind_group_layout,
146                    ],
147                    push_constant_ranges: &[],
148                });
149
150        Self {
151            shaders: HashMap::new(),
152            pipeline_layout,
153            params_bind_group_layout,
154            surface_format: gpu.config.format,
155        }
156    }
157
158    /// Compile a custom shader from user-provided WGSL fragment source.
159    /// The source must contain a `@fragment fn fs_main(in: VertexOutput) -> @location(0) vec4<f32>`.
160    /// Standard declarations (camera, texture, lighting, vertex shader) are prepended automatically.
161    /// Custom uniforms are available as `shader_params.values[0..15]` (vec4 array).
162    pub fn create(&mut self, gpu: &GpuContext, id: u32, _name: &str, source: &str) {
163        let full_wgsl = build_custom_wgsl(source);
164
165        let shader_module = gpu
166            .device
167            .create_shader_module(wgpu::ShaderModuleDescriptor {
168                label: Some("custom_shader"),
169                source: wgpu::ShaderSource::Wgsl(full_wgsl.into()),
170            });
171
172        let vertex_layout = wgpu::VertexBufferLayout {
173            array_stride: 16, // QuadVertex: 2×f32 + 2×f32 = 16 bytes
174            step_mode: wgpu::VertexStepMode::Vertex,
175            attributes: &[
176                wgpu::VertexAttribute {
177                    offset: 0,
178                    shader_location: 0,
179                    format: wgpu::VertexFormat::Float32x2,
180                },
181                wgpu::VertexAttribute {
182                    offset: 8,
183                    shader_location: 1,
184                    format: wgpu::VertexFormat::Float32x2,
185                },
186            ],
187        };
188
189        let instance_layout = wgpu::VertexBufferLayout {
190            array_stride: 64, // SpriteInstance: 16 floats × 4 bytes = 64
191            step_mode: wgpu::VertexStepMode::Instance,
192            attributes: &[
193                wgpu::VertexAttribute {
194                    offset: 0,
195                    shader_location: 2,
196                    format: wgpu::VertexFormat::Float32x2,
197                },
198                wgpu::VertexAttribute {
199                    offset: 8,
200                    shader_location: 3,
201                    format: wgpu::VertexFormat::Float32x2,
202                },
203                wgpu::VertexAttribute {
204                    offset: 16,
205                    shader_location: 4,
206                    format: wgpu::VertexFormat::Float32x2,
207                },
208                wgpu::VertexAttribute {
209                    offset: 24,
210                    shader_location: 5,
211                    format: wgpu::VertexFormat::Float32x2,
212                },
213                wgpu::VertexAttribute {
214                    offset: 32,
215                    shader_location: 6,
216                    format: wgpu::VertexFormat::Float32x4,
217                },
218                wgpu::VertexAttribute {
219                    offset: 48,
220                    shader_location: 7,
221                    format: wgpu::VertexFormat::Float32x4,
222                },
223            ],
224        };
225
226        let pipeline =
227            gpu.device
228                .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
229                    label: Some("custom_shader_pipeline"),
230                    layout: Some(&self.pipeline_layout),
231                    vertex: wgpu::VertexState {
232                        module: &shader_module,
233                        entry_point: Some("vs_main"),
234                        buffers: &[vertex_layout, instance_layout],
235                        compilation_options: Default::default(),
236                    },
237                    fragment: Some(wgpu::FragmentState {
238                        module: &shader_module,
239                        entry_point: Some("fs_main"),
240                        targets: &[Some(wgpu::ColorTargetState {
241                            format: self.surface_format,
242                            blend: Some(wgpu::BlendState::ALPHA_BLENDING),
243                            write_mask: wgpu::ColorWrites::ALL,
244                        })],
245                        compilation_options: Default::default(),
246                    }),
247                    primitive: wgpu::PrimitiveState {
248                        topology: wgpu::PrimitiveTopology::TriangleList,
249                        strip_index_format: None,
250                        front_face: wgpu::FrontFace::Ccw,
251                        cull_mode: None,
252                        polygon_mode: wgpu::PolygonMode::Fill,
253                        unclipped_depth: false,
254                        conservative: false,
255                    },
256                    depth_stencil: None,
257                    multisample: wgpu::MultisampleState::default(),
258                    multiview: None,
259                    cache: None,
260                });
261
262        // Create uniform buffer (zero-initialized)
263        let uniform_buffer =
264            gpu.device
265                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
266                    label: Some("shader_params_buffer"),
267                    contents: &[0u8; UNIFORM_BUFFER_SIZE],
268                    usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
269                });
270
271        let uniform_bind_group = gpu.device.create_bind_group(&wgpu::BindGroupDescriptor {
272            label: Some("shader_params_bind_group"),
273            layout: &self.params_bind_group_layout,
274            entries: &[wgpu::BindGroupEntry {
275                binding: 0,
276                resource: uniform_buffer.as_entire_binding(),
277            }],
278        });
279
280        self.shaders.insert(
281            id,
282            ShaderEntry {
283                pipeline,
284                uniform_buffer,
285                uniform_bind_group,
286                param_data: [0.0; MAX_PARAM_SLOTS * 4],
287                dirty: false,
288            },
289        );
290    }
291
292    /// Set a vec4 parameter slot for a shader. Index 0-15.
293    pub fn set_param(&mut self, id: u32, index: u32, x: f32, y: f32, z: f32, w: f32) {
294        if let Some(entry) = self.shaders.get_mut(&id) {
295            let i = (index as usize).min(MAX_PARAM_SLOTS - 1) * 4;
296            entry.param_data[i] = x;
297            entry.param_data[i + 1] = y;
298            entry.param_data[i + 2] = z;
299            entry.param_data[i + 3] = w;
300            entry.dirty = true;
301        }
302    }
303
304    /// Flush dirty uniform buffers to GPU.
305    pub fn flush(&mut self, gpu: &GpuContext) {
306        for entry in self.shaders.values_mut() {
307            if entry.dirty {
308                gpu.queue.write_buffer(
309                    &entry.uniform_buffer,
310                    0,
311                    bytemuck::cast_slice(&entry.param_data),
312                );
313                entry.dirty = false;
314            }
315        }
316    }
317
318    /// Get the pipeline for a custom shader.
319    pub fn get_pipeline(&self, id: u32) -> Option<&wgpu::RenderPipeline> {
320        self.shaders.get(&id).map(|e| &e.pipeline)
321    }
322
323    /// Get the uniform bind group for a custom shader (group 3).
324    pub fn get_bind_group(&self, id: u32) -> Option<&wgpu::BindGroup> {
325        self.shaders.get(&id).map(|e| &e.uniform_bind_group)
326    }
327}