gloss_geometry/
csr.rs

1extern crate nalgebra as na;
2use burn::prelude::Backend;
3use burn::tensor::{Device, Int, Tensor};
4
5/// Compressed-Sparse-Row (CSR) mapping from vertex -> incident faces.
6///
7/// This struct is intended to be used by a GPU "vertex kernel" which,
8/// for each vertex `v`, will read `row_ptr[v]..row_ptr[v+1]` from `col_idx`
9/// to find all faces that reference `v`, then sum those face normals
10/// and normalize to produce a per-vertex normal.
11///
12/// Semantics / invariants:
13/// - `num_vertices == row_ptr.len() - 1`.
14/// - `col_idx.len() == row_ptr[num_vertices]` (total incidence count, usually 3*F).
15/// - Every value `col_idx[i]` is a face id in range `0 .. num_faces`.
16/// - The order of face indices within a vertex is determined by input face order.
17///
18/// Example read pattern for vertex `v`:
19/// ```text
20/// let start = csr.row_ptr[v] as usize;
21/// let end   = csr.row_ptr[v + 1] as usize;
22/// for i in start..end {
23///     let fid = csr.col_idx[i] as usize; // fid in 0..num_faces
24///     // use face_normals[fid]
25/// }
26/// ```
27#[derive(Clone, Debug)]
28pub struct VertexFaceCSR {
29    /// `row_ptr`: length `num_vertices` + 1
30    /// `row_ptr`[v]..`row_ptr`[v+1] (half-open) index into `col_idx`.
31    pub row_ptr: Vec<u32>,
32
33    /// `col_idx`: flattened list of face indices (u32), grouped by vertex.
34    /// Range of values: 0..num_faces-1.
35    pub col_idx: Vec<u32>,
36
37    /// convenience
38    pub num_vertices: usize,
39    pub num_faces: usize,
40}
41
42/// Burn-friendly CSR representation: buffers are Tensors on a chosen backend/device.
43///
44/// These can be uploaded to GPU and used inside `CubeCL` kernels.
45#[derive(Clone, Debug)]
46pub struct VertexFaceCSRBurn<B: Backend> {
47    /// Tensor of length = `num_vertices` + 1
48    /// `row_ptr`[v]..`row_ptr`[v+1] indexes into `col_idx` for vertex v
49    pub row_ptr: Tensor<B, 1, Int>,
50    /// Tensor of length = 3 * `num_faces`
51    /// Flat list of incident face IDs
52    pub col_idx: Tensor<B, 1, Int>,
53    /// Mesh metadata
54    pub num_vertices: usize,
55    pub num_faces: usize,
56}
57
58impl VertexFaceCSR {
59    /// Build a CSR mapping from a (F x 3) faces matrix.
60    ///
61    /// # Arguments
62    /// * `faces` - A `DMatrix<u32>` with shape `(F, 3)`.
63    ///
64    /// Each row is a triangle `(i0, i1, i2)` with vertex indices.
65    pub fn from_faces(faces: &na::DMatrix<u32>) -> Self {
66        assert_eq!(faces.ncols(), 3, "Faces matrix must have exactly 3 columns (triangle vertex indices)");
67
68        let num_faces = faces.nrows();
69
70        // Find maximum vertex index -> num_vertices
71        let max_idx = faces.iter().copied().max().unwrap_or(0);
72        let num_vertices = (max_idx as usize) + 1;
73
74        // 1) Count degrees (number of incident faces) per vertex
75        let mut degree = vec![0usize; num_vertices];
76        for idx in faces.iter() {
77            degree[*idx as usize] += 1;
78        }
79
80        // 2) Build row_ptr (prefix sum of degrees)
81        let mut row_ptr = Vec::with_capacity(num_vertices + 1);
82        row_ptr.push(0);
83        #[allow(clippy::cast_possible_truncation)]
84        for &d in &degree {
85            let last = *row_ptr.last().unwrap();
86            row_ptr.push(last + d as u32);
87        }
88
89        // 3) Allocate col_idx and temporary cursors
90        let total_incidents = *row_ptr.last().unwrap() as usize;
91        let mut col_idx = vec![0u32; total_incidents];
92        let mut cursor = vec![0usize; num_vertices];
93
94        // 4) Fill col_idx: for each face, push its ID into each vertex’s bucket
95        #[allow(clippy::cast_possible_truncation)]
96        for fid in 0..num_faces {
97            let row = faces.row(fid);
98            for j in 0..3 {
99                let v = row[j] as usize;
100                let base = row_ptr[v] as usize;
101                let pos = base + cursor[v];
102                col_idx[pos] = fid as u32;
103                cursor[v] += 1;
104            }
105        }
106
107        VertexFaceCSR {
108            row_ptr,
109            col_idx,
110            num_vertices,
111            num_faces,
112        }
113    }
114
115    /// Get the slice of incident face IDs for vertex v.
116    pub fn incident_faces(&self, v: usize) -> &[u32] {
117        assert!(v < self.num_vertices);
118        let start = self.row_ptr[v] as usize;
119        let end = self.row_ptr[v + 1] as usize;
120        &self.col_idx[start..end]
121    }
122
123    /// Convert CPU CSR to Burn tensor version.
124    ///
125    /// # Arguments
126    /// * `device` - The device where tensors should be allocated (CPU, CUDA, WGPU…).
127    #[allow(clippy::cast_possible_wrap)]
128    pub fn to_burn<B: Backend>(&self, device: &Device<B>) -> VertexFaceCSRBurn<B> {
129        // Convert Vec<u32> to Vec<i32> because Burn's `Int` usually maps to i32
130        let row_ptr_i32: Vec<i32> = self.row_ptr.iter().map(|&x| x as i32).collect();
131        let col_idx_i32: Vec<i32> = self.col_idx.iter().map(|&x| x as i32).collect();
132
133        let row_ptr_tensor: Tensor<B, 1, Int> = Tensor::<B, 1, Int>::from_ints(row_ptr_i32.as_slice(), &device.clone());
134        let col_idx_tensor: Tensor<B, 1, Int> = Tensor::<B, 1, Int>::from_ints(col_idx_i32.as_slice(), &device.clone());
135
136        VertexFaceCSRBurn {
137            row_ptr: row_ptr_tensor,
138            col_idx: col_idx_tensor,
139            num_vertices: self.num_vertices,
140            num_faces: self.num_faces,
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use nalgebra::DMatrix;
149
150    #[test]
151    fn csr_small_example() {
152        // faces: F=3
153        // face 0: (0,1,2)
154        // face 1: (0,2,3)
155        // face 2: (2,4,5)
156        let data: Vec<u32> = vec![0, 1, 2, 0, 2, 3, 2, 4, 5];
157        let faces = DMatrix::from_row_slice(3, 3, &data);
158        let csr = VertexFaceCSR::from_faces(&faces);
159
160        assert_eq!(csr.num_faces, 3);
161        assert_eq!(csr.num_vertices, 6);
162
163        assert_eq!(csr.incident_faces(0), &[0, 1]);
164        assert_eq!(csr.incident_faces(1), &[0]);
165        assert_eq!(csr.incident_faces(2), &[0, 1, 2]);
166        assert_eq!(csr.incident_faces(3), &[1]);
167        assert_eq!(csr.incident_faces(4), &[2]);
168        assert_eq!(csr.incident_faces(5), &[2]);
169    }
170}