1use glam::{Vec3, Vec4};
4use wgpu::util::DeviceExt;
5
6use crate::point_cloud_render::PointUniforms;
7
8#[repr(C)]
11#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
12#[allow(clippy::pub_underscore_fields)]
13pub struct CurveNetworkUniforms {
14 pub color: [f32; 4],
16 pub radius: f32,
18 pub radius_is_relative: u32,
20 pub render_mode: u32,
22 pub _padding: f32,
24}
25
26impl Default for CurveNetworkUniforms {
27 fn default() -> Self {
28 Self {
29 color: [0.2, 0.5, 0.8, 1.0],
30 radius: 0.005,
31 radius_is_relative: 1,
32 render_mode: 0, _padding: 0.0,
34 }
35 }
36}
37
38pub struct CurveNetworkRenderData {
40 pub node_buffer: wgpu::Buffer,
42 pub node_color_buffer: wgpu::Buffer,
44
45 pub edge_vertex_buffer: wgpu::Buffer,
48 pub edge_color_buffer: wgpu::Buffer,
50
51 pub uniform_buffer: wgpu::Buffer,
53 pub bind_group: wgpu::BindGroup,
55
56 pub num_nodes: u32,
58 pub num_edges: u32,
60
61 pub generated_vertex_buffer: Option<wgpu::Buffer>,
64 pub num_edges_buffer: Option<wgpu::Buffer>,
66 pub compute_bind_group: Option<wgpu::BindGroup>,
68 pub tube_render_bind_group: Option<wgpu::BindGroup>,
70
71 pub node_uniform_buffer: Option<wgpu::Buffer>,
74 pub node_render_bind_group: Option<wgpu::BindGroup>,
76}
77
78impl CurveNetworkRenderData {
79 #[must_use]
89 pub fn new(
90 device: &wgpu::Device,
91 bind_group_layout: &wgpu::BindGroupLayout,
92 camera_buffer: &wgpu::Buffer,
93 node_positions: &[Vec3],
94 edge_tail_inds: &[u32],
95 edge_tip_inds: &[u32],
96 ) -> Self {
97 let num_nodes = node_positions.len() as u32;
98 let num_edges = edge_tail_inds.len() as u32;
99
100 let node_data: Vec<f32> = node_positions
102 .iter()
103 .flat_map(|p| [p.x, p.y, p.z, 1.0])
104 .collect();
105 let node_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
106 label: Some("curve network node positions"),
107 contents: bytemuck::cast_slice(&node_data),
108 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
109 });
110
111 let node_color_data: Vec<f32> = vec![0.0; node_positions.len() * 4];
113 let node_color_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
114 label: Some("curve network node colors"),
115 contents: bytemuck::cast_slice(&node_color_data),
116 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
117 });
118
119 let mut edge_vertex_data: Vec<f32> = Vec::with_capacity(edge_tail_inds.len() * 8);
121 for i in 0..edge_tail_inds.len() {
122 let tail = node_positions[edge_tail_inds[i] as usize];
123 let tip = node_positions[edge_tip_inds[i] as usize];
124 edge_vertex_data.extend_from_slice(&[tail.x, tail.y, tail.z, 1.0]);
125 edge_vertex_data.extend_from_slice(&[tip.x, tip.y, tip.z, 1.0]);
126 }
127 let edge_vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
128 label: Some("curve network edge vertices"),
129 contents: bytemuck::cast_slice(&edge_vertex_data),
130 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
131 });
132
133 let edge_color_data: Vec<f32> = vec![0.0; edge_tail_inds.len() * 4];
135 let edge_color_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
136 label: Some("curve network edge colors"),
137 contents: bytemuck::cast_slice(&edge_color_data),
138 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
139 });
140
141 let uniforms = CurveNetworkUniforms::default();
143 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
144 label: Some("curve network uniforms"),
145 contents: bytemuck::cast_slice(&[uniforms]),
146 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
147 });
148
149 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
158 label: Some("curve network bind group"),
159 layout: bind_group_layout,
160 entries: &[
161 wgpu::BindGroupEntry {
162 binding: 0,
163 resource: camera_buffer.as_entire_binding(),
164 },
165 wgpu::BindGroupEntry {
166 binding: 1,
167 resource: uniform_buffer.as_entire_binding(),
168 },
169 wgpu::BindGroupEntry {
170 binding: 2,
171 resource: node_buffer.as_entire_binding(),
172 },
173 wgpu::BindGroupEntry {
174 binding: 3,
175 resource: node_color_buffer.as_entire_binding(),
176 },
177 wgpu::BindGroupEntry {
178 binding: 4,
179 resource: edge_vertex_buffer.as_entire_binding(),
180 },
181 wgpu::BindGroupEntry {
182 binding: 5,
183 resource: edge_color_buffer.as_entire_binding(),
184 },
185 ],
186 });
187
188 Self {
189 node_buffer,
190 node_color_buffer,
191 edge_vertex_buffer,
192 edge_color_buffer,
193 uniform_buffer,
194 bind_group,
195 num_nodes,
196 num_edges,
197 generated_vertex_buffer: None,
198 num_edges_buffer: None,
199 compute_bind_group: None,
200 tube_render_bind_group: None,
201 node_uniform_buffer: None,
202 node_render_bind_group: None,
203 }
204 }
205
206 pub fn init_tube_resources(
208 &mut self,
209 device: &wgpu::Device,
210 compute_bind_group_layout: &wgpu::BindGroupLayout,
211 render_bind_group_layout: &wgpu::BindGroupLayout,
212 camera_buffer: &wgpu::Buffer,
213 ) {
214 let vertex_buffer_size = (self.num_edges as usize * 36 * 32) as u64;
216 let generated_vertex_buffer = device.create_buffer(&wgpu::BufferDescriptor {
217 label: Some("Curve Network Generated Vertices"),
218 size: vertex_buffer_size.max(32), usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::VERTEX,
220 mapped_at_creation: false,
221 });
222
223 let num_edges_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
225 label: Some("Curve Network Num Edges"),
226 contents: bytemuck::cast_slice(&[self.num_edges]),
227 usage: wgpu::BufferUsages::UNIFORM,
228 });
229
230 let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
232 label: Some("Curve Network Tube Compute Bind Group"),
233 layout: compute_bind_group_layout,
234 entries: &[
235 wgpu::BindGroupEntry {
236 binding: 0,
237 resource: self.edge_vertex_buffer.as_entire_binding(),
238 },
239 wgpu::BindGroupEntry {
240 binding: 1,
241 resource: self.uniform_buffer.as_entire_binding(),
242 },
243 wgpu::BindGroupEntry {
244 binding: 2,
245 resource: generated_vertex_buffer.as_entire_binding(),
246 },
247 wgpu::BindGroupEntry {
248 binding: 3,
249 resource: num_edges_buffer.as_entire_binding(),
250 },
251 ],
252 });
253
254 let tube_render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
256 label: Some("Curve Network Tube Render Bind Group"),
257 layout: render_bind_group_layout,
258 entries: &[
259 wgpu::BindGroupEntry {
260 binding: 0,
261 resource: camera_buffer.as_entire_binding(),
262 },
263 wgpu::BindGroupEntry {
264 binding: 1,
265 resource: self.uniform_buffer.as_entire_binding(),
266 },
267 wgpu::BindGroupEntry {
268 binding: 2,
269 resource: self.edge_vertex_buffer.as_entire_binding(),
270 },
271 wgpu::BindGroupEntry {
272 binding: 3,
273 resource: self.edge_color_buffer.as_entire_binding(),
274 },
275 ],
276 });
277
278 self.generated_vertex_buffer = Some(generated_vertex_buffer);
279 self.num_edges_buffer = Some(num_edges_buffer);
280 self.compute_bind_group = Some(compute_bind_group);
281 self.tube_render_bind_group = Some(tube_render_bind_group);
282 }
283
284 #[must_use]
286 pub fn has_tube_resources(&self) -> bool {
287 self.generated_vertex_buffer.is_some()
288 }
289
290 pub fn init_node_render_resources(
293 &mut self,
294 device: &wgpu::Device,
295 point_bind_group_layout: &wgpu::BindGroupLayout,
296 camera_buffer: &wgpu::Buffer,
297 ) {
298 let uniforms = PointUniforms::default();
300 let node_uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
301 label: Some("Curve Network Node Uniforms"),
302 contents: bytemuck::cast_slice(&[uniforms]),
303 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
304 });
305
306 let node_render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
308 label: Some("Curve Network Node Render Bind Group"),
309 layout: point_bind_group_layout,
310 entries: &[
311 wgpu::BindGroupEntry {
312 binding: 0,
313 resource: camera_buffer.as_entire_binding(),
314 },
315 wgpu::BindGroupEntry {
316 binding: 1,
317 resource: node_uniform_buffer.as_entire_binding(),
318 },
319 wgpu::BindGroupEntry {
320 binding: 2,
321 resource: self.node_buffer.as_entire_binding(),
322 },
323 wgpu::BindGroupEntry {
324 binding: 3,
325 resource: self.node_color_buffer.as_entire_binding(),
326 },
327 ],
328 });
329
330 self.node_uniform_buffer = Some(node_uniform_buffer);
331 self.node_render_bind_group = Some(node_render_bind_group);
332 }
333
334 #[must_use]
336 pub fn has_node_render_resources(&self) -> bool {
337 self.node_render_bind_group.is_some()
338 }
339
340 pub fn update_node_uniforms(&self, queue: &wgpu::Queue, uniforms: &PointUniforms) {
342 if let Some(buffer) = &self.node_uniform_buffer {
343 queue.write_buffer(buffer, 0, bytemuck::cast_slice(&[*uniforms]));
344 }
345 }
346
347 pub fn update_uniforms(&self, queue: &wgpu::Queue, uniforms: &CurveNetworkUniforms) {
349 queue.write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[*uniforms]));
350 }
351
352 pub fn update_node_colors(&self, queue: &wgpu::Queue, colors: &[Vec4]) {
354 let color_data: Vec<f32> = colors.iter().flat_map(glam::Vec4::to_array).collect();
355 queue.write_buffer(
356 &self.node_color_buffer,
357 0,
358 bytemuck::cast_slice(&color_data),
359 );
360 }
361
362 pub fn update_edge_colors(&self, queue: &wgpu::Queue, colors: &[Vec4]) {
364 let color_data: Vec<f32> = colors.iter().flat_map(glam::Vec4::to_array).collect();
365 queue.write_buffer(
366 &self.edge_color_buffer,
367 0,
368 bytemuck::cast_slice(&color_data),
369 );
370 }
371
372 pub fn update_node_positions(&self, queue: &wgpu::Queue, positions: &[Vec3]) {
374 let pos_data: Vec<f32> = positions
375 .iter()
376 .flat_map(|p| [p.x, p.y, p.z, 1.0])
377 .collect();
378 queue.write_buffer(&self.node_buffer, 0, bytemuck::cast_slice(&pos_data));
379 }
380
381 pub fn update_edge_vertices(
383 &self,
384 queue: &wgpu::Queue,
385 node_positions: &[Vec3],
386 edge_tail_inds: &[u32],
387 edge_tip_inds: &[u32],
388 ) {
389 let mut edge_vertex_data: Vec<f32> = Vec::with_capacity(edge_tail_inds.len() * 8);
390 for i in 0..edge_tail_inds.len() {
391 let tail = node_positions[edge_tail_inds[i] as usize];
392 let tip = node_positions[edge_tip_inds[i] as usize];
393 edge_vertex_data.extend_from_slice(&[tail.x, tail.y, tail.z, 1.0]);
394 edge_vertex_data.extend_from_slice(&[tip.x, tip.y, tip.z, 1.0]);
395 }
396 queue.write_buffer(
397 &self.edge_vertex_buffer,
398 0,
399 bytemuck::cast_slice(&edge_vertex_data),
400 );
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_curve_network_uniforms_size() {
410 let size = std::mem::size_of::<CurveNetworkUniforms>();
411
412 assert_eq!(size, 32, "CurveNetworkUniforms should be 32 bytes");
420
421 assert_eq!(size % 16, 0, "CurveNetworkUniforms must be 16-byte aligned");
423 }
424
425 use glam::Vec3;
437
438 fn ray_cylinder_parallel_intersect(
442 ray_origin: Vec3,
443 ray_dir: Vec3,
444 cyl_start: Vec3,
445 cyl_end: Vec3,
446 cyl_radius: f32,
447 ) -> Option<(f32, Vec3)> {
448 let cyl_axis = cyl_end - cyl_start;
449 let cyl_dir = cyl_axis.normalize();
450 let delta = ray_origin - cyl_start;
451 let delta_perp = delta - cyl_dir.dot(delta) * cyl_dir;
452
453 if delta_perp.length_squared() > cyl_radius * cyl_radius {
454 return None;
455 }
456 let ray_dot_cyl = ray_dir.dot(cyl_dir);
457 if ray_dot_cyl.abs() < 1e-8 {
458 return None;
459 }
460 let t_start = (cyl_start - ray_origin).dot(cyl_dir) / ray_dot_cyl;
461 let t_end = (cyl_end - ray_origin).dot(cyl_dir) / ray_dot_cyl;
462 let mut t_cap = t_start.min(t_end);
463 if t_cap < 0.001 {
464 t_cap = t_start.max(t_end);
465 if t_cap < 0.001 {
466 return None;
467 }
468 }
469 Some((t_cap, ray_origin + t_cap * ray_dir))
470 }
471
472 #[test]
475 fn parallel_ray_through_axis_hits_front_cap() {
476 let cyl_start = Vec3::new(0.0, 0.0, 0.0);
477 let cyl_end = Vec3::new(0.0, 0.0, 5.0);
478 let radius = 0.1_f32;
479 let ray_dir = Vec3::new(0.0, 0.0, -1.0);
481 let world_position = Vec3::new(0.0, 0.0, 5.5); let extent = (cyl_end - cyl_start).length() + 2.0 * radius;
483 let ray_origin = world_position - extent * ray_dir;
484
485 let hit = ray_cylinder_parallel_intersect(ray_origin, ray_dir, cyl_start, cyl_end, radius);
486 let (t, p) = hit.expect("parallel ray through axis should hit cylinder cap");
487 assert!(t > 0.001, "t must be positive, got {t}");
488 assert!(
490 (p.z - cyl_end.z).abs() < 1e-4,
491 "expected hit at z={}, got {p:?}",
492 cyl_end.z
493 );
494 }
495
496 #[test]
499 fn parallel_ray_offset_within_radius_hits() {
500 let cyl_start = Vec3::ZERO;
501 let cyl_end = Vec3::new(0.0, 0.0, 5.0);
502 let radius = 0.1_f32;
503 let ray_dir = Vec3::new(0.0, 0.0, -1.0);
504 let world_position = Vec3::new(0.05, 0.0, 5.5);
506 let extent = (cyl_end - cyl_start).length() + 2.0 * radius;
507 let ray_origin = world_position - extent * ray_dir;
508
509 let hit = ray_cylinder_parallel_intersect(ray_origin, ray_dir, cyl_start, cyl_end, radius);
510 assert!(hit.is_some(), "ray within radius should hit cap");
511 }
512
513 #[test]
515 fn parallel_ray_offset_beyond_radius_misses() {
516 let cyl_start = Vec3::ZERO;
517 let cyl_end = Vec3::new(0.0, 0.0, 5.0);
518 let radius = 0.1_f32;
519 let ray_dir = Vec3::new(0.0, 0.0, -1.0);
520 let world_position = Vec3::new(0.5, 0.0, 5.5); let ray_origin = world_position - 10.0 * ray_dir;
522
523 let hit = ray_cylinder_parallel_intersect(ray_origin, ray_dir, cyl_start, cyl_end, radius);
524 assert!(hit.is_none(), "ray outside radius must miss");
525 }
526
527 #[test]
530 fn parallel_ray_reverse_direction_hits_other_cap() {
531 let cyl_start = Vec3::new(0.0, 0.0, 0.0);
532 let cyl_end = Vec3::new(0.0, 0.0, 5.0);
533 let radius = 0.1_f32;
534 let ray_dir = Vec3::new(0.0, 0.0, 1.0); let world_position = Vec3::new(0.0, 0.0, -0.5);
536 let extent = (cyl_end - cyl_start).length() + 2.0 * radius;
537 let ray_origin = world_position - extent * ray_dir;
538
539 let (t, p) =
540 ray_cylinder_parallel_intersect(ray_origin, ray_dir, cyl_start, cyl_end, radius)
541 .expect("reverse-direction parallel ray should hit");
542 assert!(t > 0.001);
543 assert!(
544 (p.z - cyl_start.z).abs() < 1e-4,
545 "expected hit at z={}, got {p:?}",
546 cyl_start.z
547 );
548 }
549}