use glam::{Vec3, Vec4};
use wgpu::util::DeviceExt;
use crate::point_cloud_render::PointUniforms;
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
#[allow(clippy::pub_underscore_fields)]
pub struct CurveNetworkUniforms {
pub color: [f32; 4],
pub radius: f32,
pub radius_is_relative: u32,
pub render_mode: u32,
pub _padding: f32,
}
impl Default for CurveNetworkUniforms {
fn default() -> Self {
Self {
color: [0.2, 0.5, 0.8, 1.0],
radius: 0.005,
radius_is_relative: 1,
render_mode: 0, _padding: 0.0,
}
}
}
pub struct CurveNetworkRenderData {
pub node_buffer: wgpu::Buffer,
pub node_color_buffer: wgpu::Buffer,
pub edge_vertex_buffer: wgpu::Buffer,
pub edge_color_buffer: wgpu::Buffer,
pub uniform_buffer: wgpu::Buffer,
pub bind_group: wgpu::BindGroup,
pub num_nodes: u32,
pub num_edges: u32,
pub generated_vertex_buffer: Option<wgpu::Buffer>,
pub num_edges_buffer: Option<wgpu::Buffer>,
pub compute_bind_group: Option<wgpu::BindGroup>,
pub tube_render_bind_group: Option<wgpu::BindGroup>,
pub node_uniform_buffer: Option<wgpu::Buffer>,
pub node_render_bind_group: Option<wgpu::BindGroup>,
}
impl CurveNetworkRenderData {
#[must_use]
pub fn new(
device: &wgpu::Device,
bind_group_layout: &wgpu::BindGroupLayout,
camera_buffer: &wgpu::Buffer,
node_positions: &[Vec3],
edge_tail_inds: &[u32],
edge_tip_inds: &[u32],
) -> Self {
let num_nodes = node_positions.len() as u32;
let num_edges = edge_tail_inds.len() as u32;
let node_data: Vec<f32> = node_positions
.iter()
.flat_map(|p| [p.x, p.y, p.z, 1.0])
.collect();
let node_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("curve network node positions"),
contents: bytemuck::cast_slice(&node_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let node_color_data: Vec<f32> = vec![0.0; node_positions.len() * 4];
let node_color_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("curve network node colors"),
contents: bytemuck::cast_slice(&node_color_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let mut edge_vertex_data: Vec<f32> = Vec::with_capacity(edge_tail_inds.len() * 8);
for i in 0..edge_tail_inds.len() {
let tail = node_positions[edge_tail_inds[i] as usize];
let tip = node_positions[edge_tip_inds[i] as usize];
edge_vertex_data.extend_from_slice(&[tail.x, tail.y, tail.z, 1.0]);
edge_vertex_data.extend_from_slice(&[tip.x, tip.y, tip.z, 1.0]);
}
let edge_vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("curve network edge vertices"),
contents: bytemuck::cast_slice(&edge_vertex_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let edge_color_data: Vec<f32> = vec![0.0; edge_tail_inds.len() * 4];
let edge_color_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("curve network edge colors"),
contents: bytemuck::cast_slice(&edge_color_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let uniforms = CurveNetworkUniforms::default();
let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("curve network uniforms"),
contents: bytemuck::cast_slice(&[uniforms]),
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("curve network bind group"),
layout: bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: camera_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: node_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: node_color_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: edge_vertex_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: edge_color_buffer.as_entire_binding(),
},
],
});
Self {
node_buffer,
node_color_buffer,
edge_vertex_buffer,
edge_color_buffer,
uniform_buffer,
bind_group,
num_nodes,
num_edges,
generated_vertex_buffer: None,
num_edges_buffer: None,
compute_bind_group: None,
tube_render_bind_group: None,
node_uniform_buffer: None,
node_render_bind_group: None,
}
}
pub fn init_tube_resources(
&mut self,
device: &wgpu::Device,
compute_bind_group_layout: &wgpu::BindGroupLayout,
render_bind_group_layout: &wgpu::BindGroupLayout,
camera_buffer: &wgpu::Buffer,
) {
let vertex_buffer_size = (self.num_edges as usize * 36 * 32) as u64;
let generated_vertex_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Curve Network Generated Vertices"),
size: vertex_buffer_size.max(32), usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::VERTEX,
mapped_at_creation: false,
});
let num_edges_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Curve Network Num Edges"),
contents: bytemuck::cast_slice(&[self.num_edges]),
usage: wgpu::BufferUsages::UNIFORM,
});
let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Curve Network Tube Compute Bind Group"),
layout: compute_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.edge_vertex_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: generated_vertex_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: num_edges_buffer.as_entire_binding(),
},
],
});
let tube_render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Curve Network Tube Render Bind Group"),
layout: render_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: camera_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: self.edge_vertex_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: self.edge_color_buffer.as_entire_binding(),
},
],
});
self.generated_vertex_buffer = Some(generated_vertex_buffer);
self.num_edges_buffer = Some(num_edges_buffer);
self.compute_bind_group = Some(compute_bind_group);
self.tube_render_bind_group = Some(tube_render_bind_group);
}
#[must_use]
pub fn has_tube_resources(&self) -> bool {
self.generated_vertex_buffer.is_some()
}
pub fn init_node_render_resources(
&mut self,
device: &wgpu::Device,
point_bind_group_layout: &wgpu::BindGroupLayout,
camera_buffer: &wgpu::Buffer,
) {
let uniforms = PointUniforms::default();
let node_uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Curve Network Node Uniforms"),
contents: bytemuck::cast_slice(&[uniforms]),
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
});
let node_render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Curve Network Node Render Bind Group"),
layout: point_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: camera_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: node_uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: self.node_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: self.node_color_buffer.as_entire_binding(),
},
],
});
self.node_uniform_buffer = Some(node_uniform_buffer);
self.node_render_bind_group = Some(node_render_bind_group);
}
#[must_use]
pub fn has_node_render_resources(&self) -> bool {
self.node_render_bind_group.is_some()
}
pub fn update_node_uniforms(&self, queue: &wgpu::Queue, uniforms: &PointUniforms) {
if let Some(buffer) = &self.node_uniform_buffer {
queue.write_buffer(buffer, 0, bytemuck::cast_slice(&[*uniforms]));
}
}
pub fn update_uniforms(&self, queue: &wgpu::Queue, uniforms: &CurveNetworkUniforms) {
queue.write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[*uniforms]));
}
pub fn update_node_colors(&self, queue: &wgpu::Queue, colors: &[Vec4]) {
let color_data: Vec<f32> = colors.iter().flat_map(glam::Vec4::to_array).collect();
queue.write_buffer(
&self.node_color_buffer,
0,
bytemuck::cast_slice(&color_data),
);
}
pub fn update_edge_colors(&self, queue: &wgpu::Queue, colors: &[Vec4]) {
let color_data: Vec<f32> = colors.iter().flat_map(glam::Vec4::to_array).collect();
queue.write_buffer(
&self.edge_color_buffer,
0,
bytemuck::cast_slice(&color_data),
);
}
pub fn update_node_positions(&self, queue: &wgpu::Queue, positions: &[Vec3]) {
let pos_data: Vec<f32> = positions
.iter()
.flat_map(|p| [p.x, p.y, p.z, 1.0])
.collect();
queue.write_buffer(&self.node_buffer, 0, bytemuck::cast_slice(&pos_data));
}
pub fn update_edge_vertices(
&self,
queue: &wgpu::Queue,
node_positions: &[Vec3],
edge_tail_inds: &[u32],
edge_tip_inds: &[u32],
) {
let mut edge_vertex_data: Vec<f32> = Vec::with_capacity(edge_tail_inds.len() * 8);
for i in 0..edge_tail_inds.len() {
let tail = node_positions[edge_tail_inds[i] as usize];
let tip = node_positions[edge_tip_inds[i] as usize];
edge_vertex_data.extend_from_slice(&[tail.x, tail.y, tail.z, 1.0]);
edge_vertex_data.extend_from_slice(&[tip.x, tip.y, tip.z, 1.0]);
}
queue.write_buffer(
&self.edge_vertex_buffer,
0,
bytemuck::cast_slice(&edge_vertex_data),
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_curve_network_uniforms_size() {
let size = std::mem::size_of::<CurveNetworkUniforms>();
assert_eq!(size, 32, "CurveNetworkUniforms should be 32 bytes");
assert_eq!(size % 16, 0, "CurveNetworkUniforms must be 16-byte aligned");
}
use glam::Vec3;
fn ray_cylinder_parallel_intersect(
ray_origin: Vec3,
ray_dir: Vec3,
cyl_start: Vec3,
cyl_end: Vec3,
cyl_radius: f32,
) -> Option<(f32, Vec3)> {
let cyl_axis = cyl_end - cyl_start;
let cyl_dir = cyl_axis.normalize();
let delta = ray_origin - cyl_start;
let delta_perp = delta - cyl_dir.dot(delta) * cyl_dir;
if delta_perp.length_squared() > cyl_radius * cyl_radius {
return None;
}
let ray_dot_cyl = ray_dir.dot(cyl_dir);
if ray_dot_cyl.abs() < 1e-8 {
return None;
}
let t_start = (cyl_start - ray_origin).dot(cyl_dir) / ray_dot_cyl;
let t_end = (cyl_end - ray_origin).dot(cyl_dir) / ray_dot_cyl;
let mut t_cap = t_start.min(t_end);
if t_cap < 0.001 {
t_cap = t_start.max(t_end);
if t_cap < 0.001 {
return None;
}
}
Some((t_cap, ray_origin + t_cap * ray_dir))
}
#[test]
fn parallel_ray_through_axis_hits_front_cap() {
let cyl_start = Vec3::new(0.0, 0.0, 0.0);
let cyl_end = Vec3::new(0.0, 0.0, 5.0);
let radius = 0.1_f32;
let ray_dir = Vec3::new(0.0, 0.0, -1.0);
let world_position = Vec3::new(0.0, 0.0, 5.5); let extent = (cyl_end - cyl_start).length() + 2.0 * radius;
let ray_origin = world_position - extent * ray_dir;
let hit = ray_cylinder_parallel_intersect(ray_origin, ray_dir, cyl_start, cyl_end, radius);
let (t, p) = hit.expect("parallel ray through axis should hit cylinder cap");
assert!(t > 0.001, "t must be positive, got {t}");
assert!(
(p.z - cyl_end.z).abs() < 1e-4,
"expected hit at z={}, got {p:?}",
cyl_end.z
);
}
#[test]
fn parallel_ray_offset_within_radius_hits() {
let cyl_start = Vec3::ZERO;
let cyl_end = Vec3::new(0.0, 0.0, 5.0);
let radius = 0.1_f32;
let ray_dir = Vec3::new(0.0, 0.0, -1.0);
let world_position = Vec3::new(0.05, 0.0, 5.5);
let extent = (cyl_end - cyl_start).length() + 2.0 * radius;
let ray_origin = world_position - extent * ray_dir;
let hit = ray_cylinder_parallel_intersect(ray_origin, ray_dir, cyl_start, cyl_end, radius);
assert!(hit.is_some(), "ray within radius should hit cap");
}
#[test]
fn parallel_ray_offset_beyond_radius_misses() {
let cyl_start = Vec3::ZERO;
let cyl_end = Vec3::new(0.0, 0.0, 5.0);
let radius = 0.1_f32;
let ray_dir = Vec3::new(0.0, 0.0, -1.0);
let world_position = Vec3::new(0.5, 0.0, 5.5); let ray_origin = world_position - 10.0 * ray_dir;
let hit = ray_cylinder_parallel_intersect(ray_origin, ray_dir, cyl_start, cyl_end, radius);
assert!(hit.is_none(), "ray outside radius must miss");
}
#[test]
fn parallel_ray_reverse_direction_hits_other_cap() {
let cyl_start = Vec3::new(0.0, 0.0, 0.0);
let cyl_end = Vec3::new(0.0, 0.0, 5.0);
let radius = 0.1_f32;
let ray_dir = Vec3::new(0.0, 0.0, 1.0); let world_position = Vec3::new(0.0, 0.0, -0.5);
let extent = (cyl_end - cyl_start).length() + 2.0 * radius;
let ray_origin = world_position - extent * ray_dir;
let (t, p) =
ray_cylinder_parallel_intersect(ray_origin, ray_dir, cyl_start, cyl_end, radius)
.expect("reverse-direction parallel ray should hit");
assert!(t > 0.001);
assert!(
(p.z - cyl_start.z).abs() < 1e-4,
"expected hit at z={}, got {p:?}",
cyl_start.z
);
}
}