1extern crate nalgebra as na;
2use burn::prelude::Backend;
3use burn::tensor::{Device, Int, Tensor};
4
5#[derive(Clone, Debug)]
28pub struct VertexFaceCSR {
29 pub row_ptr: Vec<u32>,
32
33 pub col_idx: Vec<u32>,
36
37 pub num_vertices: usize,
39 pub num_faces: usize,
40}
41
42#[derive(Clone, Debug)]
46pub struct VertexFaceCSRBurn<B: Backend> {
47 pub row_ptr: Tensor<B, 1, Int>,
50 pub col_idx: Tensor<B, 1, Int>,
53 pub num_vertices: usize,
55 pub num_faces: usize,
56}
57
58impl VertexFaceCSR {
59 pub fn from_faces(faces: &na::DMatrix<u32>) -> Self {
66 assert_eq!(faces.ncols(), 3, "Faces matrix must have exactly 3 columns (triangle vertex indices)");
67
68 let num_faces = faces.nrows();
69
70 let max_idx = faces.iter().copied().max().unwrap_or(0);
72 let num_vertices = (max_idx as usize) + 1;
73
74 let mut degree = vec![0usize; num_vertices];
76 for idx in faces.iter() {
77 degree[*idx as usize] += 1;
78 }
79
80 let mut row_ptr = Vec::with_capacity(num_vertices + 1);
82 row_ptr.push(0);
83 #[allow(clippy::cast_possible_truncation)]
84 for &d in °ree {
85 let last = *row_ptr.last().unwrap();
86 row_ptr.push(last + d as u32);
87 }
88
89 let total_incidents = *row_ptr.last().unwrap() as usize;
91 let mut col_idx = vec![0u32; total_incidents];
92 let mut cursor = vec![0usize; num_vertices];
93
94 #[allow(clippy::cast_possible_truncation)]
96 for fid in 0..num_faces {
97 let row = faces.row(fid);
98 for j in 0..3 {
99 let v = row[j] as usize;
100 let base = row_ptr[v] as usize;
101 let pos = base + cursor[v];
102 col_idx[pos] = fid as u32;
103 cursor[v] += 1;
104 }
105 }
106
107 VertexFaceCSR {
108 row_ptr,
109 col_idx,
110 num_vertices,
111 num_faces,
112 }
113 }
114
115 pub fn incident_faces(&self, v: usize) -> &[u32] {
117 assert!(v < self.num_vertices);
118 let start = self.row_ptr[v] as usize;
119 let end = self.row_ptr[v + 1] as usize;
120 &self.col_idx[start..end]
121 }
122
123 #[allow(clippy::cast_possible_wrap)]
128 pub fn to_burn<B: Backend>(&self, device: &Device<B>) -> VertexFaceCSRBurn<B> {
129 let row_ptr_i32: Vec<i32> = self.row_ptr.iter().map(|&x| x as i32).collect();
131 let col_idx_i32: Vec<i32> = self.col_idx.iter().map(|&x| x as i32).collect();
132
133 let row_ptr_tensor: Tensor<B, 1, Int> = Tensor::<B, 1, Int>::from_ints(row_ptr_i32.as_slice(), &device.clone());
134 let col_idx_tensor: Tensor<B, 1, Int> = Tensor::<B, 1, Int>::from_ints(col_idx_i32.as_slice(), &device.clone());
135
136 VertexFaceCSRBurn {
137 row_ptr: row_ptr_tensor,
138 col_idx: col_idx_tensor,
139 num_vertices: self.num_vertices,
140 num_faces: self.num_faces,
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use nalgebra::DMatrix;
149
150 #[test]
151 fn csr_small_example() {
152 let data: Vec<u32> = vec![0, 1, 2, 0, 2, 3, 2, 4, 5];
157 let faces = DMatrix::from_row_slice(3, 3, &data);
158 let csr = VertexFaceCSR::from_faces(&faces);
159
160 assert_eq!(csr.num_faces, 3);
161 assert_eq!(csr.num_vertices, 6);
162
163 assert_eq!(csr.incident_faces(0), &[0, 1]);
164 assert_eq!(csr.incident_faces(1), &[0]);
165 assert_eq!(csr.incident_faces(2), &[0, 1, 2]);
166 assert_eq!(csr.incident_faces(3), &[1]);
167 assert_eq!(csr.incident_faces(4), &[2]);
168 assert_eq!(csr.incident_faces(5), &[2]);
169 }
170}