Skip to main content

polyscope_render/
curve_network_render.rs

1//! Curve network GPU rendering resources.
2
3use glam::{Vec3, Vec4};
4use wgpu::util::DeviceExt;
5
6use crate::point_cloud_render::PointUniforms;
7
8/// Uniforms for curve network rendering.
9/// Layout must match WGSL `CurveNetworkUniforms` exactly (32 bytes).
10#[repr(C)]
11#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
12#[allow(clippy::pub_underscore_fields)]
13pub struct CurveNetworkUniforms {
14    /// Base color (RGBA)
15    pub color: [f32; 4],
16    /// Radius for nodes and edges
17    pub radius: f32,
18    /// Whether radius is relative to scene scale (0 = absolute, 1 = relative)
19    pub radius_is_relative: u32,
20    /// Render mode: 0 = line, 1 = tube (cylinder)
21    pub render_mode: u32,
22    /// Padding to 16-byte alignment
23    pub _padding: f32,
24}
25
26impl Default for CurveNetworkUniforms {
27    fn default() -> Self {
28        Self {
29            color: [0.2, 0.5, 0.8, 1.0],
30            radius: 0.005,
31            radius_is_relative: 1,
32            render_mode: 0, // lines by default
33            _padding: 0.0,
34        }
35    }
36}
37
38/// GPU resources for rendering a curve network.
39pub struct CurveNetworkRenderData {
40    /// Node position buffer (storage buffer, vec4 for alignment).
41    pub node_buffer: wgpu::Buffer,
42    /// Node color buffer (storage buffer, vec4).
43    pub node_color_buffer: wgpu::Buffer,
44
45    /// Edge vertex buffer - contains tail and tip positions per edge.
46    /// Layout: [tail0, tip0, tail1, tip1, ...] (vec4 each for alignment)
47    pub edge_vertex_buffer: wgpu::Buffer,
48    /// Edge color buffer (per-edge colors, vec4).
49    pub edge_color_buffer: wgpu::Buffer,
50
51    /// Uniform buffer for curve network settings.
52    pub uniform_buffer: wgpu::Buffer,
53    /// Bind group for this curve network.
54    pub bind_group: wgpu::BindGroup,
55
56    /// Number of nodes.
57    pub num_nodes: u32,
58    /// Number of edges.
59    pub num_edges: u32,
60
61    // Tube rendering resources
62    /// Generated vertex buffer from compute shader (36 vertices per edge).
63    pub generated_vertex_buffer: Option<wgpu::Buffer>,
64    /// Buffer containing `num_edges` as uniform.
65    pub num_edges_buffer: Option<wgpu::Buffer>,
66    /// Bind group for tube compute shader.
67    pub compute_bind_group: Option<wgpu::BindGroup>,
68    /// Bind group for tube render shader.
69    pub tube_render_bind_group: Option<wgpu::BindGroup>,
70
71    // Node sphere rendering resources (for tube mode joint filling)
72    /// Uniform buffer for node sphere rendering (matches `PointUniforms`).
73    pub node_uniform_buffer: Option<wgpu::Buffer>,
74    /// Bind group for node sphere rendering (uses point pipeline).
75    pub node_render_bind_group: Option<wgpu::BindGroup>,
76}
77
78impl CurveNetworkRenderData {
79    /// Creates new render data from curve network geometry.
80    ///
81    /// # Arguments
82    /// * `device` - The wgpu device
83    /// * `bind_group_layout` - The bind group layout for curve networks
84    /// * `camera_buffer` - The camera uniform buffer
85    /// * `node_positions` - Node positions
86    /// * `edge_tail_inds` - Edge start indices
87    /// * `edge_tip_inds` - Edge end indices
88    #[must_use]
89    pub fn new(
90        device: &wgpu::Device,
91        bind_group_layout: &wgpu::BindGroupLayout,
92        camera_buffer: &wgpu::Buffer,
93        node_positions: &[Vec3],
94        edge_tail_inds: &[u32],
95        edge_tip_inds: &[u32],
96    ) -> Self {
97        let num_nodes = node_positions.len() as u32;
98        let num_edges = edge_tail_inds.len() as u32;
99
100        // Create node position buffer (vec4 for alignment)
101        let node_data: Vec<f32> = node_positions
102            .iter()
103            .flat_map(|p| [p.x, p.y, p.z, 1.0])
104            .collect();
105        let node_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
106            label: Some("curve network node positions"),
107            contents: bytemuck::cast_slice(&node_data),
108            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
109        });
110
111        // Create node color buffer (default zero - shader uses base color when zero)
112        let node_color_data: Vec<f32> = vec![0.0; node_positions.len() * 4];
113        let node_color_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
114            label: Some("curve network node colors"),
115            contents: bytemuck::cast_slice(&node_color_data),
116            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
117        });
118
119        // Create edge vertex buffer - 2 vertices per edge (tail, tip)
120        let mut edge_vertex_data: Vec<f32> = Vec::with_capacity(edge_tail_inds.len() * 8);
121        for i in 0..edge_tail_inds.len() {
122            let tail = node_positions[edge_tail_inds[i] as usize];
123            let tip = node_positions[edge_tip_inds[i] as usize];
124            edge_vertex_data.extend_from_slice(&[tail.x, tail.y, tail.z, 1.0]);
125            edge_vertex_data.extend_from_slice(&[tip.x, tip.y, tip.z, 1.0]);
126        }
127        let edge_vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
128            label: Some("curve network edge vertices"),
129            contents: bytemuck::cast_slice(&edge_vertex_data),
130            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
131        });
132
133        // Create edge color buffer (default zero - shader uses base color when zero)
134        let edge_color_data: Vec<f32> = vec![0.0; edge_tail_inds.len() * 4];
135        let edge_color_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
136            label: Some("curve network edge colors"),
137            contents: bytemuck::cast_slice(&edge_color_data),
138            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
139        });
140
141        // Create uniform buffer
142        let uniforms = CurveNetworkUniforms::default();
143        let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
144            label: Some("curve network uniforms"),
145            contents: bytemuck::cast_slice(&[uniforms]),
146            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
147        });
148
149        // Create bind group
150        // Bindings:
151        // 0: camera uniforms (uniform)
152        // 1: curve network uniforms (uniform)
153        // 2: node positions (storage)
154        // 3: node colors (storage)
155        // 4: edge vertices (storage)
156        // 5: edge colors (storage)
157        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
158            label: Some("curve network bind group"),
159            layout: bind_group_layout,
160            entries: &[
161                wgpu::BindGroupEntry {
162                    binding: 0,
163                    resource: camera_buffer.as_entire_binding(),
164                },
165                wgpu::BindGroupEntry {
166                    binding: 1,
167                    resource: uniform_buffer.as_entire_binding(),
168                },
169                wgpu::BindGroupEntry {
170                    binding: 2,
171                    resource: node_buffer.as_entire_binding(),
172                },
173                wgpu::BindGroupEntry {
174                    binding: 3,
175                    resource: node_color_buffer.as_entire_binding(),
176                },
177                wgpu::BindGroupEntry {
178                    binding: 4,
179                    resource: edge_vertex_buffer.as_entire_binding(),
180                },
181                wgpu::BindGroupEntry {
182                    binding: 5,
183                    resource: edge_color_buffer.as_entire_binding(),
184                },
185            ],
186        });
187
188        Self {
189            node_buffer,
190            node_color_buffer,
191            edge_vertex_buffer,
192            edge_color_buffer,
193            uniform_buffer,
194            bind_group,
195            num_nodes,
196            num_edges,
197            generated_vertex_buffer: None,
198            num_edges_buffer: None,
199            compute_bind_group: None,
200            tube_render_bind_group: None,
201            node_uniform_buffer: None,
202            node_render_bind_group: None,
203        }
204    }
205
206    /// Initializes tube rendering resources.
207    pub fn init_tube_resources(
208        &mut self,
209        device: &wgpu::Device,
210        compute_bind_group_layout: &wgpu::BindGroupLayout,
211        render_bind_group_layout: &wgpu::BindGroupLayout,
212        camera_buffer: &wgpu::Buffer,
213    ) {
214        // Create generated vertex buffer (36 vertices per edge, 32 bytes per vertex)
215        let vertex_buffer_size = (self.num_edges as usize * 36 * 32) as u64;
216        let generated_vertex_buffer = device.create_buffer(&wgpu::BufferDescriptor {
217            label: Some("Curve Network Generated Vertices"),
218            size: vertex_buffer_size.max(32), // Minimum size
219            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::VERTEX,
220            mapped_at_creation: false,
221        });
222
223        // Create num_edges uniform buffer
224        let num_edges_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
225            label: Some("Curve Network Num Edges"),
226            contents: bytemuck::cast_slice(&[self.num_edges]),
227            usage: wgpu::BufferUsages::UNIFORM,
228        });
229
230        // Create compute bind group
231        let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
232            label: Some("Curve Network Tube Compute Bind Group"),
233            layout: compute_bind_group_layout,
234            entries: &[
235                wgpu::BindGroupEntry {
236                    binding: 0,
237                    resource: self.edge_vertex_buffer.as_entire_binding(),
238                },
239                wgpu::BindGroupEntry {
240                    binding: 1,
241                    resource: self.uniform_buffer.as_entire_binding(),
242                },
243                wgpu::BindGroupEntry {
244                    binding: 2,
245                    resource: generated_vertex_buffer.as_entire_binding(),
246                },
247                wgpu::BindGroupEntry {
248                    binding: 3,
249                    resource: num_edges_buffer.as_entire_binding(),
250                },
251            ],
252        });
253
254        // Create render bind group
255        let tube_render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
256            label: Some("Curve Network Tube Render Bind Group"),
257            layout: render_bind_group_layout,
258            entries: &[
259                wgpu::BindGroupEntry {
260                    binding: 0,
261                    resource: camera_buffer.as_entire_binding(),
262                },
263                wgpu::BindGroupEntry {
264                    binding: 1,
265                    resource: self.uniform_buffer.as_entire_binding(),
266                },
267                wgpu::BindGroupEntry {
268                    binding: 2,
269                    resource: self.edge_vertex_buffer.as_entire_binding(),
270                },
271                wgpu::BindGroupEntry {
272                    binding: 3,
273                    resource: self.edge_color_buffer.as_entire_binding(),
274                },
275            ],
276        });
277
278        self.generated_vertex_buffer = Some(generated_vertex_buffer);
279        self.num_edges_buffer = Some(num_edges_buffer);
280        self.compute_bind_group = Some(compute_bind_group);
281        self.tube_render_bind_group = Some(tube_render_bind_group);
282    }
283
284    /// Returns whether tube resources are initialized.
285    #[must_use]
286    pub fn has_tube_resources(&self) -> bool {
287        self.generated_vertex_buffer.is_some()
288    }
289
290    /// Initializes node sphere rendering resources for tube mode.
291    /// Uses the point pipeline to render spheres at each node to fill gaps at joints.
292    pub fn init_node_render_resources(
293        &mut self,
294        device: &wgpu::Device,
295        point_bind_group_layout: &wgpu::BindGroupLayout,
296        camera_buffer: &wgpu::Buffer,
297    ) {
298        // Create uniform buffer matching PointUniforms structure
299        let uniforms = PointUniforms::default();
300        let node_uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
301            label: Some("Curve Network Node Uniforms"),
302            contents: bytemuck::cast_slice(&[uniforms]),
303            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
304        });
305
306        // Create bind group matching point pipeline layout
307        let node_render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
308            label: Some("Curve Network Node Render Bind Group"),
309            layout: point_bind_group_layout,
310            entries: &[
311                wgpu::BindGroupEntry {
312                    binding: 0,
313                    resource: camera_buffer.as_entire_binding(),
314                },
315                wgpu::BindGroupEntry {
316                    binding: 1,
317                    resource: node_uniform_buffer.as_entire_binding(),
318                },
319                wgpu::BindGroupEntry {
320                    binding: 2,
321                    resource: self.node_buffer.as_entire_binding(),
322                },
323                wgpu::BindGroupEntry {
324                    binding: 3,
325                    resource: self.node_color_buffer.as_entire_binding(),
326                },
327            ],
328        });
329
330        self.node_uniform_buffer = Some(node_uniform_buffer);
331        self.node_render_bind_group = Some(node_render_bind_group);
332    }
333
334    /// Returns whether node render resources are initialized.
335    #[must_use]
336    pub fn has_node_render_resources(&self) -> bool {
337        self.node_render_bind_group.is_some()
338    }
339
340    /// Updates node sphere uniforms (radius, color, etc.).
341    pub fn update_node_uniforms(&self, queue: &wgpu::Queue, uniforms: &PointUniforms) {
342        if let Some(buffer) = &self.node_uniform_buffer {
343            queue.write_buffer(buffer, 0, bytemuck::cast_slice(&[*uniforms]));
344        }
345    }
346
347    /// Updates the uniform buffer.
348    pub fn update_uniforms(&self, queue: &wgpu::Queue, uniforms: &CurveNetworkUniforms) {
349        queue.write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[*uniforms]));
350    }
351
352    /// Updates node colors.
353    pub fn update_node_colors(&self, queue: &wgpu::Queue, colors: &[Vec4]) {
354        let color_data: Vec<f32> = colors.iter().flat_map(glam::Vec4::to_array).collect();
355        queue.write_buffer(
356            &self.node_color_buffer,
357            0,
358            bytemuck::cast_slice(&color_data),
359        );
360    }
361
362    /// Updates edge colors.
363    pub fn update_edge_colors(&self, queue: &wgpu::Queue, colors: &[Vec4]) {
364        let color_data: Vec<f32> = colors.iter().flat_map(glam::Vec4::to_array).collect();
365        queue.write_buffer(
366            &self.edge_color_buffer,
367            0,
368            bytemuck::cast_slice(&color_data),
369        );
370    }
371
372    /// Updates node positions.
373    pub fn update_node_positions(&self, queue: &wgpu::Queue, positions: &[Vec3]) {
374        let pos_data: Vec<f32> = positions
375            .iter()
376            .flat_map(|p| [p.x, p.y, p.z, 1.0])
377            .collect();
378        queue.write_buffer(&self.node_buffer, 0, bytemuck::cast_slice(&pos_data));
379    }
380
381    /// Updates edge vertices (when node positions change).
382    pub fn update_edge_vertices(
383        &self,
384        queue: &wgpu::Queue,
385        node_positions: &[Vec3],
386        edge_tail_inds: &[u32],
387        edge_tip_inds: &[u32],
388    ) {
389        let mut edge_vertex_data: Vec<f32> = Vec::with_capacity(edge_tail_inds.len() * 8);
390        for i in 0..edge_tail_inds.len() {
391            let tail = node_positions[edge_tail_inds[i] as usize];
392            let tip = node_positions[edge_tip_inds[i] as usize];
393            edge_vertex_data.extend_from_slice(&[tail.x, tail.y, tail.z, 1.0]);
394            edge_vertex_data.extend_from_slice(&[tip.x, tip.y, tip.z, 1.0]);
395        }
396        queue.write_buffer(
397            &self.edge_vertex_buffer,
398            0,
399            bytemuck::cast_slice(&edge_vertex_data),
400        );
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_curve_network_uniforms_size() {
410        let size = std::mem::size_of::<CurveNetworkUniforms>();
411
412        // Should be 32 bytes:
413        // color: 16 bytes ([f32; 4])
414        // radius: 4 bytes (f32)
415        // radius_is_relative: 4 bytes (u32)
416        // render_mode: 4 bytes (u32)
417        // _padding: 4 bytes (f32)
418        // Total: 32 bytes
419        assert_eq!(size, 32, "CurveNetworkUniforms should be 32 bytes");
420
421        // Must be 16-byte aligned for GPU uniform buffers
422        assert_eq!(size % 16, 0, "CurveNetworkUniforms must be 16-byte aligned");
423    }
424}