Skip to main content

polyscope_render/
volume_grid_render.rs

1//! Volume grid GPU rendering resources for gridcube and isosurface visualization.
2
3use glam::Vec3;
4use wgpu::util::DeviceExt;
5
6use crate::surface_mesh_render::ShadowModelUniforms;
7
8/// Colormap texture resolution (number of samples).
9const COLORMAP_RESOLUTION: u32 = 256;
10
11/// Uniforms for the simple mesh (isosurface) shader.
12/// Layout must match WGSL `SimpleMeshUniforms` exactly.
13#[repr(C)]
14#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
15#[allow(clippy::pub_underscore_fields)]
16pub struct SimpleMeshUniforms {
17    /// Model transform matrix.
18    pub model: [[f32; 4]; 4],
19    /// Base surface color (RGBA).
20    pub base_color: [f32; 4],
21    /// Transparency (0.0 = opaque, 1.0 = fully transparent).
22    pub transparency: f32,
23    /// Slice plane clipping enable: 0 = off, 1 = on.
24    pub slice_planes_enabled: u32,
25    /// Backface policy: 0 = identical, 1 = different, 3 = cull.
26    pub backface_policy: u32,
27    /// Padding to 16-byte alignment.
28    pub _pad: f32,
29}
30
31impl Default for SimpleMeshUniforms {
32    fn default() -> Self {
33        Self {
34            model: [
35                [1.0, 0.0, 0.0, 0.0],
36                [0.0, 1.0, 0.0, 0.0],
37                [0.0, 0.0, 1.0, 0.0],
38                [0.0, 0.0, 0.0, 1.0],
39            ],
40            base_color: [0.047, 0.451, 0.690, 1.0], // default isosurface blue
41            transparency: 0.0,
42            slice_planes_enabled: 1,
43            backface_policy: 0,
44            _pad: 0.0,
45        }
46    }
47}
48
49/// Uniforms for the gridcube shader.
50/// Layout must match WGSL `GridcubeUniforms` exactly.
51#[repr(C)]
52#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
53#[allow(clippy::pub_underscore_fields)]
54pub struct GridcubeUniforms {
55    /// Model transform matrix.
56    pub model: [[f32; 4]; 4],
57    /// Cube size factor (0..1, default 1.0 = full size, 0.5 = half).
58    pub cube_size_factor: f32,
59    /// Scalar data range minimum.
60    pub data_min: f32,
61    /// Scalar data range maximum.
62    pub data_max: f32,
63    /// Transparency (0.0 = opaque, 1.0 = fully transparent).
64    pub transparency: f32,
65    /// Slice plane clipping enable: 0 = off, 1 = on.
66    pub slice_planes_enabled: u32,
67    /// Padding to 16-byte alignment.
68    pub _pad0: f32,
69    pub _pad1: f32,
70    pub _pad2: f32,
71}
72
73impl Default for GridcubeUniforms {
74    fn default() -> Self {
75        Self {
76            model: [
77                [1.0, 0.0, 0.0, 0.0],
78                [0.0, 1.0, 0.0, 0.0],
79                [0.0, 0.0, 1.0, 0.0],
80                [0.0, 0.0, 0.0, 1.0],
81            ],
82            cube_size_factor: 1.0,
83            data_min: 0.0,
84            data_max: 1.0,
85            transparency: 0.0,
86            slice_planes_enabled: 1,
87            _pad0: 0.0,
88            _pad1: 0.0,
89            _pad2: 0.0,
90        }
91    }
92}
93
94/// GPU uniforms for gridcube pick rendering.
95///
96/// Layout must match WGSL `GridcubePickUniforms` exactly.
97#[repr(C)]
98#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
99#[allow(clippy::pub_underscore_fields)]
100pub struct GridcubePickUniforms {
101    /// Model transform matrix.
102    pub model: [[f32; 4]; 4],
103    /// The starting global index for this quantity's elements.
104    pub global_start: u32,
105    /// Cube size factor (0..1).
106    pub cube_size_factor: f32,
107    /// Padding to 16-byte alignment.
108    pub _pad0: f32,
109    pub _pad1: f32,
110}
111
112impl Default for GridcubePickUniforms {
113    fn default() -> Self {
114        Self {
115            model: [
116                [1.0, 0.0, 0.0, 0.0],
117                [0.0, 1.0, 0.0, 0.0],
118                [0.0, 0.0, 1.0, 0.0],
119                [0.0, 0.0, 0.0, 1.0],
120            ],
121            global_start: 0,
122            cube_size_factor: 1.0,
123            _pad0: 0.0,
124            _pad1: 0.0,
125        }
126    }
127}
128
129/// GPU resources for isosurface (simple mesh) visualization.
130pub struct IsosurfaceRenderData {
131    /// Position buffer (storage, vec4 per expanded triangle vertex).
132    pub vertex_buffer: wgpu::Buffer,
133    /// Normal buffer (storage, vec4 per expanded triangle vertex).
134    pub normal_buffer: wgpu::Buffer,
135    /// Uniform buffer.
136    pub uniform_buffer: wgpu::Buffer,
137    /// Bind group (Group 0).
138    pub bind_group: wgpu::BindGroup,
139    /// Number of vertices (expanded triangle vertices, for non-indexed draw).
140    pub num_vertices: u32,
141    /// Shadow pass bind group.
142    pub shadow_bind_group: Option<wgpu::BindGroup>,
143    /// Shadow model uniform buffer.
144    pub shadow_model_buffer: Option<wgpu::Buffer>,
145}
146
147impl IsosurfaceRenderData {
148    /// Creates new isosurface render data from marching cubes output.
149    ///
150    /// Vertices/normals are expanded per-triangle (non-indexed drawing with storage buffers),
151    /// matching the surface mesh pattern.
152    #[must_use]
153    pub fn new(
154        device: &wgpu::Device,
155        bind_group_layout: &wgpu::BindGroupLayout,
156        camera_buffer: &wgpu::Buffer,
157        vertices: &[Vec3],
158        normals: &[Vec3],
159        indices: &[u32],
160    ) -> Self {
161        // Expand to per-triangle-vertex layout (non-indexed)
162        let num_triangles = indices.len() / 3;
163        let num_vertices = (num_triangles * 3) as u32;
164
165        let mut expanded_positions: Vec<f32> = Vec::with_capacity(num_triangles * 3 * 4);
166        let mut expanded_normals: Vec<f32> = Vec::with_capacity(num_triangles * 3 * 4);
167
168        for tri_idx in 0..num_triangles {
169            for v in 0..3 {
170                let vi = indices[tri_idx * 3 + v] as usize;
171                let p = vertices[vi];
172                expanded_positions.extend_from_slice(&[p.x, p.y, p.z, 1.0]);
173                let n = normals[vi];
174                expanded_normals.extend_from_slice(&[n.x, n.y, n.z, 0.0]);
175            }
176        }
177
178        let vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
179            label: Some("isosurface vertices"),
180            contents: bytemuck::cast_slice(&expanded_positions),
181            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
182        });
183
184        let normal_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
185            label: Some("isosurface normals"),
186            contents: bytemuck::cast_slice(&expanded_normals),
187            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
188        });
189
190        let uniforms = SimpleMeshUniforms::default();
191        let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
192            label: Some("isosurface uniforms"),
193            contents: bytemuck::cast_slice(&[uniforms]),
194            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
195        });
196
197        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
198            label: Some("isosurface bind group"),
199            layout: bind_group_layout,
200            entries: &[
201                wgpu::BindGroupEntry {
202                    binding: 0,
203                    resource: camera_buffer.as_entire_binding(),
204                },
205                wgpu::BindGroupEntry {
206                    binding: 1,
207                    resource: uniform_buffer.as_entire_binding(),
208                },
209                wgpu::BindGroupEntry {
210                    binding: 2,
211                    resource: vertex_buffer.as_entire_binding(),
212                },
213                wgpu::BindGroupEntry {
214                    binding: 3,
215                    resource: normal_buffer.as_entire_binding(),
216                },
217            ],
218        });
219
220        Self {
221            vertex_buffer,
222            normal_buffer,
223            uniform_buffer,
224            bind_group,
225            num_vertices,
226            shadow_bind_group: None,
227            shadow_model_buffer: None,
228        }
229    }
230
231    /// Updates the uniform buffer.
232    pub fn update_uniforms(&self, queue: &wgpu::Queue, uniforms: &SimpleMeshUniforms) {
233        queue.write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[*uniforms]));
234    }
235
236    /// Initializes shadow rendering resources.
237    pub fn init_shadow_resources(
238        &mut self,
239        device: &wgpu::Device,
240        shadow_bind_group_layout: &wgpu::BindGroupLayout,
241        light_buffer: &wgpu::Buffer,
242    ) {
243        let shadow_model_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
244            label: Some("isosurface shadow model buffer"),
245            contents: bytemuck::cast_slice(&[ShadowModelUniforms::default()]),
246            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
247        });
248
249        let shadow_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
250            label: Some("isosurface shadow bind group"),
251            layout: shadow_bind_group_layout,
252            entries: &[
253                wgpu::BindGroupEntry {
254                    binding: 0,
255                    resource: light_buffer.as_entire_binding(),
256                },
257                wgpu::BindGroupEntry {
258                    binding: 1,
259                    resource: shadow_model_buffer.as_entire_binding(),
260                },
261                wgpu::BindGroupEntry {
262                    binding: 2,
263                    resource: self.vertex_buffer.as_entire_binding(),
264                },
265            ],
266        });
267
268        self.shadow_model_buffer = Some(shadow_model_buffer);
269        self.shadow_bind_group = Some(shadow_bind_group);
270    }
271
272    /// Returns whether shadow resources have been initialized.
273    #[must_use]
274    pub fn has_shadow_resources(&self) -> bool {
275        self.shadow_bind_group.is_some()
276    }
277
278    /// Updates the shadow model uniform buffer with the current transform.
279    pub fn update_shadow_model(&self, queue: &wgpu::Queue, model_matrix: [[f32; 4]; 4]) {
280        if let Some(buffer) = &self.shadow_model_buffer {
281            let uniforms = ShadowModelUniforms {
282                model: model_matrix,
283            };
284            queue.write_buffer(buffer, 0, bytemuck::cast_slice(&[uniforms]));
285        }
286    }
287}
288
289/// GPU resources for gridcube visualization.
290pub struct GridcubeRenderData {
291    /// Combined buffer: first 36 entries are unit cube template vertices (vec4),
292    /// followed by per-instance data (vec4: xyz=center, w=`half_size`).
293    pub position_buffer: wgpu::Buffer,
294    /// Cube template normals (36 entries, vec4).
295    pub normal_buffer: wgpu::Buffer,
296    /// Per-instance scalar values.
297    pub scalar_buffer: wgpu::Buffer,
298    /// Uniform buffer.
299    pub uniform_buffer: wgpu::Buffer,
300    /// Colormap texture (1D, RGBA).
301    pub colormap_texture: wgpu::Texture,
302    /// Colormap texture view.
303    pub colormap_view: wgpu::TextureView,
304    /// Colormap sampler.
305    pub colormap_sampler: wgpu::Sampler,
306    /// Bind group (Group 0).
307    pub bind_group: wgpu::BindGroup,
308    /// Number of instances (grid nodes/cells).
309    pub num_instances: u32,
310    /// Shadow pass bind group.
311    pub shadow_bind_group: Option<wgpu::BindGroup>,
312    /// Shadow model uniform buffer.
313    pub shadow_model_buffer: Option<wgpu::Buffer>,
314}
315
316/// Generates the 36 vertices and 36 normals for a unit cube ([-0.5, 0.5]^3).
317/// Returns (positions, normals) as vec4 arrays.
318fn generate_unit_cube() -> (Vec<[f32; 4]>, Vec<[f32; 4]>) {
319    // 6 faces, 2 triangles each, 3 vertices each = 36 vertices
320    // Face order: +X, -X, +Y, -Y, +Z, -Z
321    let faces: [([f32; 3], [[f32; 3]; 4]); 6] = [
322        // +X face (normal = +X)
323        (
324            [1.0, 0.0, 0.0],
325            [
326                [0.5, -0.5, -0.5],
327                [0.5, 0.5, -0.5],
328                [0.5, 0.5, 0.5],
329                [0.5, -0.5, 0.5],
330            ],
331        ),
332        // -X face
333        (
334            [-1.0, 0.0, 0.0],
335            [
336                [-0.5, -0.5, 0.5],
337                [-0.5, 0.5, 0.5],
338                [-0.5, 0.5, -0.5],
339                [-0.5, -0.5, -0.5],
340            ],
341        ),
342        // +Y face
343        (
344            [0.0, 1.0, 0.0],
345            [
346                [-0.5, 0.5, -0.5],
347                [-0.5, 0.5, 0.5],
348                [0.5, 0.5, 0.5],
349                [0.5, 0.5, -0.5],
350            ],
351        ),
352        // -Y face
353        (
354            [0.0, -1.0, 0.0],
355            [
356                [-0.5, -0.5, 0.5],
357                [-0.5, -0.5, -0.5],
358                [0.5, -0.5, -0.5],
359                [0.5, -0.5, 0.5],
360            ],
361        ),
362        // +Z face
363        (
364            [0.0, 0.0, 1.0],
365            [
366                [-0.5, -0.5, 0.5],
367                [0.5, -0.5, 0.5],
368                [0.5, 0.5, 0.5],
369                [-0.5, 0.5, 0.5],
370            ],
371        ),
372        // -Z face
373        (
374            [0.0, 0.0, -1.0],
375            [
376                [0.5, -0.5, -0.5],
377                [-0.5, -0.5, -0.5],
378                [-0.5, 0.5, -0.5],
379                [0.5, 0.5, -0.5],
380            ],
381        ),
382    ];
383
384    let mut positions = Vec::with_capacity(36);
385    let mut normals = Vec::with_capacity(36);
386
387    for (normal, verts) in &faces {
388        // Two triangles per face: 0-1-2 and 0-2-3
389        let tri_indices = [[0, 1, 2], [0, 2, 3]];
390        for tri in &tri_indices {
391            for &vi in tri {
392                let v = verts[vi];
393                positions.push([v[0], v[1], v[2], 1.0]);
394                normals.push([normal[0], normal[1], normal[2], 0.0]);
395            }
396        }
397    }
398
399    (positions, normals)
400}
401
402/// Creates a 1D colormap texture from color samples.
403fn create_colormap_texture(
404    device: &wgpu::Device,
405    queue: &wgpu::Queue,
406    colors: &[Vec3],
407) -> (wgpu::Texture, wgpu::TextureView, wgpu::Sampler) {
408    // Sample the colormap at COLORMAP_RESOLUTION points
409    let mut pixel_data: Vec<u8> = Vec::with_capacity(COLORMAP_RESOLUTION as usize * 4);
410    let n = colors.len();
411
412    for i in 0..COLORMAP_RESOLUTION {
413        let t = i as f32 / (COLORMAP_RESOLUTION - 1) as f32;
414        let t_clamped = t.clamp(0.0, 1.0);
415
416        // Linear interpolation (matches ColorMap::sample)
417        let color = if n <= 1 {
418            colors.first().copied().unwrap_or(Vec3::ZERO)
419        } else {
420            let segments = n - 1;
421            let idx = (t_clamped * segments as f32).floor() as usize;
422            let idx = idx.min(segments - 1);
423            let frac = t_clamped * segments as f32 - idx as f32;
424            colors[idx].lerp(colors[idx + 1], frac)
425        };
426
427        pixel_data.push((color.x * 255.0) as u8);
428        pixel_data.push((color.y * 255.0) as u8);
429        pixel_data.push((color.z * 255.0) as u8);
430        pixel_data.push(255); // alpha
431    }
432
433    let texture = device.create_texture(&wgpu::TextureDescriptor {
434        label: Some("colormap texture"),
435        size: wgpu::Extent3d {
436            width: COLORMAP_RESOLUTION,
437            height: 1,
438            depth_or_array_layers: 1,
439        },
440        mip_level_count: 1,
441        sample_count: 1,
442        dimension: wgpu::TextureDimension::D1,
443        format: wgpu::TextureFormat::Rgba8UnormSrgb,
444        usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
445        view_formats: &[],
446    });
447
448    queue.write_texture(
449        wgpu::TexelCopyTextureInfo {
450            texture: &texture,
451            mip_level: 0,
452            origin: wgpu::Origin3d::ZERO,
453            aspect: wgpu::TextureAspect::All,
454        },
455        &pixel_data,
456        wgpu::TexelCopyBufferLayout {
457            offset: 0,
458            bytes_per_row: Some(COLORMAP_RESOLUTION * 4),
459            rows_per_image: None,
460        },
461        wgpu::Extent3d {
462            width: COLORMAP_RESOLUTION,
463            height: 1,
464            depth_or_array_layers: 1,
465        },
466    );
467
468    let view = texture.create_view(&wgpu::TextureViewDescriptor {
469        dimension: Some(wgpu::TextureViewDimension::D1),
470        ..Default::default()
471    });
472
473    let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
474        label: Some("colormap sampler"),
475        address_mode_u: wgpu::AddressMode::ClampToEdge,
476        mag_filter: wgpu::FilterMode::Linear,
477        min_filter: wgpu::FilterMode::Linear,
478        ..Default::default()
479    });
480
481    (texture, view, sampler)
482}
483
484impl GridcubeRenderData {
485    /// Creates new gridcube render data.
486    ///
487    /// # Arguments
488    /// * `centers` - Per-instance cube center positions
489    /// * `half_size` - Half the cube side length (grid spacing / 2)
490    /// * `scalars` - Per-instance scalar values
491    /// * `colormap_colors` - Color samples for the colormap
492    #[must_use]
493    pub fn new(
494        device: &wgpu::Device,
495        queue: &wgpu::Queue,
496        bind_group_layout: &wgpu::BindGroupLayout,
497        camera_buffer: &wgpu::Buffer,
498        centers: &[Vec3],
499        half_size: f32,
500        scalars: &[f32],
501        colormap_colors: &[Vec3],
502    ) -> Self {
503        let num_instances = centers.len() as u32;
504        let (cube_positions, cube_normals) = generate_unit_cube();
505
506        // Build combined position buffer: 36 cube template verts + N instance data entries
507        let mut position_data: Vec<[f32; 4]> = cube_positions;
508        for center in centers {
509            position_data.push([center.x, center.y, center.z, half_size]);
510        }
511
512        let position_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
513            label: Some("gridcube positions"),
514            contents: bytemuck::cast_slice(&position_data),
515            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
516        });
517
518        let normal_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
519            label: Some("gridcube normals"),
520            contents: bytemuck::cast_slice(&cube_normals),
521            usage: wgpu::BufferUsages::STORAGE,
522        });
523
524        let scalar_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
525            label: Some("gridcube scalars"),
526            contents: bytemuck::cast_slice(scalars),
527            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
528        });
529
530        let uniforms = GridcubeUniforms::default();
531        let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
532            label: Some("gridcube uniforms"),
533            contents: bytemuck::cast_slice(&[uniforms]),
534            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
535        });
536
537        let (colormap_texture, colormap_view, colormap_sampler) =
538            create_colormap_texture(device, queue, colormap_colors);
539
540        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
541            label: Some("gridcube bind group"),
542            layout: bind_group_layout,
543            entries: &[
544                wgpu::BindGroupEntry {
545                    binding: 0,
546                    resource: camera_buffer.as_entire_binding(),
547                },
548                wgpu::BindGroupEntry {
549                    binding: 1,
550                    resource: uniform_buffer.as_entire_binding(),
551                },
552                wgpu::BindGroupEntry {
553                    binding: 2,
554                    resource: position_buffer.as_entire_binding(),
555                },
556                wgpu::BindGroupEntry {
557                    binding: 3,
558                    resource: normal_buffer.as_entire_binding(),
559                },
560                wgpu::BindGroupEntry {
561                    binding: 4,
562                    resource: scalar_buffer.as_entire_binding(),
563                },
564                wgpu::BindGroupEntry {
565                    binding: 5,
566                    resource: wgpu::BindingResource::TextureView(&colormap_view),
567                },
568                wgpu::BindGroupEntry {
569                    binding: 6,
570                    resource: wgpu::BindingResource::Sampler(&colormap_sampler),
571                },
572            ],
573        });
574
575        Self {
576            position_buffer,
577            normal_buffer,
578            scalar_buffer,
579            uniform_buffer,
580            colormap_texture,
581            colormap_view,
582            colormap_sampler,
583            bind_group,
584            num_instances,
585            shadow_bind_group: None,
586            shadow_model_buffer: None,
587        }
588    }
589
590    /// Updates the uniform buffer.
591    pub fn update_uniforms(&self, queue: &wgpu::Queue, uniforms: &GridcubeUniforms) {
592        queue.write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[*uniforms]));
593    }
594
595    /// Updates the colormap texture with new colors.
596    pub fn update_colormap(&self, queue: &wgpu::Queue, colors: &[Vec3]) {
597        let mut pixel_data: Vec<u8> = Vec::with_capacity(COLORMAP_RESOLUTION as usize * 4);
598        let n = colors.len();
599
600        for i in 0..COLORMAP_RESOLUTION {
601            let t = i as f32 / (COLORMAP_RESOLUTION - 1) as f32;
602            let t_clamped = t.clamp(0.0, 1.0);
603            let color = if n <= 1 {
604                colors.first().copied().unwrap_or(Vec3::ZERO)
605            } else {
606                let segments = n - 1;
607                let idx = (t_clamped * segments as f32).floor() as usize;
608                let idx = idx.min(segments - 1);
609                let frac = t_clamped * segments as f32 - idx as f32;
610                colors[idx].lerp(colors[idx + 1], frac)
611            };
612
613            pixel_data.push((color.x * 255.0) as u8);
614            pixel_data.push((color.y * 255.0) as u8);
615            pixel_data.push((color.z * 255.0) as u8);
616            pixel_data.push(255);
617        }
618
619        queue.write_texture(
620            wgpu::TexelCopyTextureInfo {
621                texture: &self.colormap_texture,
622                mip_level: 0,
623                origin: wgpu::Origin3d::ZERO,
624                aspect: wgpu::TextureAspect::All,
625            },
626            &pixel_data,
627            wgpu::TexelCopyBufferLayout {
628                offset: 0,
629                bytes_per_row: Some(COLORMAP_RESOLUTION * 4),
630                rows_per_image: None,
631            },
632            wgpu::Extent3d {
633                width: COLORMAP_RESOLUTION,
634                height: 1,
635                depth_or_array_layers: 1,
636            },
637        );
638    }
639
640    /// Returns the total number of vertices to draw (36 per instance).
641    #[must_use]
642    pub fn total_vertices(&self) -> u32 {
643        36 * self.num_instances
644    }
645
646    /// Initializes shadow rendering resources.
647    pub fn init_shadow_resources(
648        &mut self,
649        device: &wgpu::Device,
650        shadow_bind_group_layout: &wgpu::BindGroupLayout,
651        light_buffer: &wgpu::Buffer,
652    ) {
653        let shadow_model_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
654            label: Some("gridcube shadow model buffer"),
655            contents: bytemuck::cast_slice(&[ShadowModelUniforms::default()]),
656            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
657        });
658
659        let shadow_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
660            label: Some("gridcube shadow bind group"),
661            layout: shadow_bind_group_layout,
662            entries: &[
663                wgpu::BindGroupEntry {
664                    binding: 0,
665                    resource: light_buffer.as_entire_binding(),
666                },
667                wgpu::BindGroupEntry {
668                    binding: 1,
669                    resource: shadow_model_buffer.as_entire_binding(),
670                },
671                wgpu::BindGroupEntry {
672                    binding: 2,
673                    resource: self.position_buffer.as_entire_binding(),
674                },
675            ],
676        });
677
678        self.shadow_model_buffer = Some(shadow_model_buffer);
679        self.shadow_bind_group = Some(shadow_bind_group);
680    }
681
682    /// Returns whether shadow resources have been initialized.
683    #[must_use]
684    pub fn has_shadow_resources(&self) -> bool {
685        self.shadow_bind_group.is_some()
686    }
687
688    /// Updates the shadow model uniform buffer.
689    pub fn update_shadow_model(&self, queue: &wgpu::Queue, model_matrix: [[f32; 4]; 4]) {
690        if let Some(buffer) = &self.shadow_model_buffer {
691            let uniforms = ShadowModelUniforms {
692                model: model_matrix,
693            };
694            queue.write_buffer(buffer, 0, bytemuck::cast_slice(&[uniforms]));
695        }
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702
703    #[test]
704    fn test_simple_mesh_uniforms_size() {
705        let size = std::mem::size_of::<SimpleMeshUniforms>();
706        assert_eq!(
707            size % 16,
708            0,
709            "SimpleMeshUniforms size ({size} bytes) must be 16-byte aligned"
710        );
711        // model(64) + base_color(16) + transparency(4) + slice_planes_enabled(4) + backface_policy(4) + pad(4) = 96
712        assert_eq!(
713            size, 96,
714            "SimpleMeshUniforms should be 96 bytes, got {size}"
715        );
716    }
717
718    #[test]
719    fn test_gridcube_pick_uniforms_size() {
720        let size = std::mem::size_of::<GridcubePickUniforms>();
721        assert_eq!(
722            size % 16,
723            0,
724            "GridcubePickUniforms size ({size} bytes) must be 16-byte aligned"
725        );
726        // model(64) + global_start(4) + cube_size_factor(4) + pad0(4) + pad1(4) = 80
727        assert_eq!(
728            size, 80,
729            "GridcubePickUniforms should be 80 bytes, got {size}"
730        );
731    }
732
733    #[test]
734    fn test_gridcube_uniforms_size() {
735        let size = std::mem::size_of::<GridcubeUniforms>();
736        assert_eq!(
737            size % 16,
738            0,
739            "GridcubeUniforms size ({size} bytes) must be 16-byte aligned"
740        );
741        // model(64) + cube_size_factor(4) + data_min(4) + data_max(4) + transparency(4)
742        // + slice_planes_enabled(4) + pad0(4) + pad1(4) + pad2(4) = 96
743        assert_eq!(size, 96, "GridcubeUniforms should be 96 bytes, got {size}");
744    }
745
746    #[test]
747    fn test_unit_cube_generation() {
748        let (positions, normals) = generate_unit_cube();
749        assert_eq!(positions.len(), 36);
750        assert_eq!(normals.len(), 36);
751
752        // All positions should be within [-0.5, 0.5]
753        for p in &positions {
754            assert!(p[0].abs() <= 0.5 + f32::EPSILON);
755            assert!(p[1].abs() <= 0.5 + f32::EPSILON);
756            assert!(p[2].abs() <= 0.5 + f32::EPSILON);
757            assert!((p[3] - 1.0).abs() < f32::EPSILON); // w = 1.0
758        }
759
760        // All normals should be unit length axis-aligned
761        for n in &normals {
762            let len = (n[0] * n[0] + n[1] * n[1] + n[2] * n[2]).sqrt();
763            assert!((len - 1.0).abs() < 0.01);
764            assert!((n[3]).abs() < f32::EPSILON); // w = 0.0
765        }
766    }
767}