gloss_geometry/cubecl_normals/
compute_normals.rs

1use burn::tensor::TensorMetadata;
2use burn::tensor::{
3    ops::{FloatTensor, IntTensor},
4    Shape,
5};
6use burn_cubecl::{tensor::CubeTensor, BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
7
8use cubecl::prelude::*; // CubeCL macros and helpers
9
10// ----------------- CubeCL kernels -----------------
11
12#[cube(launch)]
13pub fn face_normals_kernel<F: Float>(verts: &Tensor<F>, faces: &Tensor<i32>, face_normals: &mut Tensor<F>) {
14    // Each thread handles one face
15    let f = ABSOLUTE_POS_X; // X dimension corresponds to face index
16
17    let num_faces = face_normals.shape(0); // assume [num_faces, 3]
18    if f >= num_faces {
19        terminate!();
20    }
21
22    // Get the vertex indices for this face
23    let mut v_indices = Line::<i32>::empty(3u32);
24    for i in 0..3 {
25        let face_idx = f * 3 + i;
26        v_indices[i] = faces[face_idx] as i32; // faces[f, i]
27    }
28
29    // Load vertices
30    let mut v0 = Line::<F>::empty(3u32);
31    let mut v1 = Line::<F>::empty(3u32);
32    let mut v2 = Line::<F>::empty(3u32);
33
34    for c in 0..3 {
35        v0[c] = verts[v_indices[0] as u32 * 3 + c];
36        v1[c] = verts[v_indices[1] as u32 * 3 + c];
37        v2[c] = verts[v_indices[2] as u32 * 3 + c];
38    }
39
40    // Compute edges
41    let mut d1 = Line::<F>::empty(3u32);
42    let mut d2 = Line::<F>::empty(3u32);
43    for c in 0..3 {
44        d1[c] = v1[c] - v0[c];
45        d2[c] = v2[c] - v0[c];
46    }
47
48    // Cross product: normal = d1 x d2
49    let cx: F = d1[1] * d2[2] - d1[2] * d2[1];
50    let cy: F = d1[2] * d2[0] - d1[0] * d2[2];
51    let cz: F = d1[0] * d2[1] - d1[1] * d2[0];
52
53    // Normalize and write result
54    let len = F::sqrt(cx * cx + cy * cy + cz * cz);
55    let eps = F::new(1e-6);
56    let inv = F::new(1.0) / (len + eps);
57
58    face_normals[f * 3] = cx * inv;
59    face_normals[f * 3 + 1] = cy * inv;
60    face_normals[f * 3 + 2] = cz * inv;
61}
62
63#[cube(launch)]
64pub fn vertex_normals_kernel<F: Float>(
65    face_normals: &Tensor<F>,       // flattened [num_faces, 3]
66    row_ptr: &Tensor<i32>,          // [num_vertices + 1]
67    col_idx: &Tensor<i32>,          // [total_incidents]  (face indices)
68    vertex_normals: &mut Tensor<F>, // flattened [num_vertices, 3]
69) {
70    // one thread per vertex
71    let v = ABSOLUTE_POS_X; // vertex index
72
73    let num_vertices = vertex_normals.shape(0);
74    if v >= num_vertices {
75        terminate!();
76    }
77
78    // read CSR range
79    let start_i: i32 = row_ptr[v];
80    let end_i: i32 = row_ptr[v + 1];
81
82    // Accumulators for the normals components from all the incident faces
83    let mut ax: F = F::new(0.0);
84    let mut ay: F = F::new(0.0);
85    let mut az: F = F::new(0.0);
86
87    // If start_i == end_i this vertex has no incident faces -> leave zeros
88    let mut i = start_i;
89    #[allow(clippy::cast_sign_loss)]
90    while i < end_i {
91        // get face index (i32), convert to usize for indexing
92        let face_idx_i: i32 = col_idx[i as u32];
93        let face_idx = face_idx_i as u32;
94        let base = face_idx * 3;
95
96        // accumulate face normal components
97        ax += face_normals[base];
98        ay += face_normals[base + 1];
99        az += face_normals[base + 2];
100
101        i += 1;
102    }
103
104    // normalize accumulated normal (with tiny epsilon)
105    let len = F::sqrt(ax * ax + ay * ay + az * az);
106    let eps = F::new(1e-6);
107    let inv = F::new(1.0) / (len + eps);
108
109    vertex_normals[v * 3] = ax * inv;
110    vertex_normals[v * 3 + 1] = ay * inv;
111    vertex_normals[v * 3 + 2] = az * inv;
112}
113
114//launchers
115pub fn face_normals_launch<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
116    verts: FloatTensor<CubeBackend<R, F, I, BT>>,
117    faces: IntTensor<CubeBackend<R, F, I, BT>>,
118) -> FloatTensor<CubeBackend<R, F, I, BT>> {
119    verts.assert_is_on_same_device(&faces);
120
121    let num_faces = faces.shape().dims::<2>()[0];
122
123    // Build output primitive: shape [num_faces, 3]
124    let shape_out = Shape::from(vec![num_faces, 3usize]);
125    let bytes = shape_out.num_elements() * core::mem::size_of::<F>();
126    let buffer = verts.client.empty(bytes);
127
128    // wrap the buffer Handle into CubeTensor primitive for output.
129    let output = CubeTensor::new_contiguous(verts.client.clone(), verts.device.clone(), shape_out, buffer, F::dtype());
130
131    // Choose cube/workgroup sizes (tune as needed)
132    let cube_dim = CubeDim { x: 256, y: 1, z: 1 }; // e.g. one face per x-thread or tune accordingly
133    #[allow(clippy::cast_possible_truncation)]
134    let cubes_needed_in_x = num_faces.div_ceil(cube_dim.x as usize) as u32;
135    let cube_count = CubeCount::Static(cubes_needed_in_x, 1, 1);
136
137    // Launch the kernel
138    face_normals_kernel::launch::<F, R>(
139        &verts.client,
140        cube_count,
141        cube_dim,
142        verts.as_tensor_arg::<F>(1),
143        faces.as_tensor_arg::<F>(1),
144        output.as_tensor_arg::<F>(1),
145    );
146
147    output
148}
149
150pub fn vertex_normals_launch<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
151    face_normals: FloatTensor<CubeBackend<R, F, I, BT>>, // [num_faces, 3]
152    row_ptr: IntTensor<CubeBackend<R, F, I, BT>>,        // [num_vertices + 1]
153    col_idx: IntTensor<CubeBackend<R, F, I, BT>>,        // [total_incidents]
154    num_vertices: usize,
155) -> FloatTensor<CubeBackend<R, F, I, BT>> {
156    face_normals.assert_is_on_same_device(&row_ptr);
157    face_normals.assert_is_on_same_device(&col_idx);
158
159    // Build output primitive: shape [num_vertices, 3]
160    let shape_out = Shape::from(vec![num_vertices, 3usize]);
161    let bytes = shape_out.num_elements() * core::mem::size_of::<F>();
162    let buffer = face_normals.client.empty(bytes);
163
164    // wrap the buffer Handle into CubeTensor primitive for output.
165    let output = CubeTensor::new_contiguous(face_normals.client.clone(), face_normals.device.clone(), shape_out, buffer, F::dtype());
166
167    // Each thread handles one vertex
168    let cube_dim = CubeDim { x: 256, y: 1, z: 1 };
169    #[allow(clippy::cast_possible_truncation)]
170    let cubes_needed_in_x = num_vertices.div_ceil(cube_dim.x as usize) as u32;
171    let cube_count = CubeCount::Static(cubes_needed_in_x, 1, 1);
172
173    // Launch vertex_normals_kernel
174    vertex_normals_kernel::launch::<F, R>(
175        &face_normals.client,
176        cube_count,
177        cube_dim,
178        face_normals.as_tensor_arg::<F>(1),
179        row_ptr.as_tensor_arg::<I>(1),
180        col_idx.as_tensor_arg::<I>(1),
181        output.as_tensor_arg::<F>(1),
182    );
183
184    output
185}