gloss_geometry/cubecl_normals/
compute_normals.rs1use burn::tensor::TensorMetadata;
2use burn::tensor::{
3 ops::{FloatTensor, IntTensor},
4 Shape,
5};
6use burn_cubecl::{tensor::CubeTensor, BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
7
8use cubecl::prelude::*; #[cube(launch)]
13pub fn face_normals_kernel<F: Float>(verts: &Tensor<F>, faces: &Tensor<i32>, face_normals: &mut Tensor<F>) {
14 let f = ABSOLUTE_POS_X; let num_faces = face_normals.shape(0); if f >= num_faces {
19 terminate!();
20 }
21
22 let mut v_indices = Line::<i32>::empty(3u32);
24 for i in 0..3 {
25 let face_idx = f * 3 + i;
26 v_indices[i] = faces[face_idx] as i32; }
28
29 let mut v0 = Line::<F>::empty(3u32);
31 let mut v1 = Line::<F>::empty(3u32);
32 let mut v2 = Line::<F>::empty(3u32);
33
34 for c in 0..3 {
35 v0[c] = verts[v_indices[0] as u32 * 3 + c];
36 v1[c] = verts[v_indices[1] as u32 * 3 + c];
37 v2[c] = verts[v_indices[2] as u32 * 3 + c];
38 }
39
40 let mut d1 = Line::<F>::empty(3u32);
42 let mut d2 = Line::<F>::empty(3u32);
43 for c in 0..3 {
44 d1[c] = v1[c] - v0[c];
45 d2[c] = v2[c] - v0[c];
46 }
47
48 let cx: F = d1[1] * d2[2] - d1[2] * d2[1];
50 let cy: F = d1[2] * d2[0] - d1[0] * d2[2];
51 let cz: F = d1[0] * d2[1] - d1[1] * d2[0];
52
53 let len = F::sqrt(cx * cx + cy * cy + cz * cz);
55 let eps = F::new(1e-6);
56 let inv = F::new(1.0) / (len + eps);
57
58 face_normals[f * 3] = cx * inv;
59 face_normals[f * 3 + 1] = cy * inv;
60 face_normals[f * 3 + 2] = cz * inv;
61}
62
63#[cube(launch)]
64pub fn vertex_normals_kernel<F: Float>(
65 face_normals: &Tensor<F>, row_ptr: &Tensor<i32>, col_idx: &Tensor<i32>, vertex_normals: &mut Tensor<F>, ) {
70 let v = ABSOLUTE_POS_X; let num_vertices = vertex_normals.shape(0);
74 if v >= num_vertices {
75 terminate!();
76 }
77
78 let start_i: i32 = row_ptr[v];
80 let end_i: i32 = row_ptr[v + 1];
81
82 let mut ax: F = F::new(0.0);
84 let mut ay: F = F::new(0.0);
85 let mut az: F = F::new(0.0);
86
87 let mut i = start_i;
89 #[allow(clippy::cast_sign_loss)]
90 while i < end_i {
91 let face_idx_i: i32 = col_idx[i as u32];
93 let face_idx = face_idx_i as u32;
94 let base = face_idx * 3;
95
96 ax += face_normals[base];
98 ay += face_normals[base + 1];
99 az += face_normals[base + 2];
100
101 i += 1;
102 }
103
104 let len = F::sqrt(ax * ax + ay * ay + az * az);
106 let eps = F::new(1e-6);
107 let inv = F::new(1.0) / (len + eps);
108
109 vertex_normals[v * 3] = ax * inv;
110 vertex_normals[v * 3 + 1] = ay * inv;
111 vertex_normals[v * 3 + 2] = az * inv;
112}
113
114pub fn face_normals_launch<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
116 verts: FloatTensor<CubeBackend<R, F, I, BT>>,
117 faces: IntTensor<CubeBackend<R, F, I, BT>>,
118) -> FloatTensor<CubeBackend<R, F, I, BT>> {
119 verts.assert_is_on_same_device(&faces);
120
121 let num_faces = faces.shape().dims::<2>()[0];
122
123 let shape_out = Shape::from(vec![num_faces, 3usize]);
125 let bytes = shape_out.num_elements() * core::mem::size_of::<F>();
126 let buffer = verts.client.empty(bytes);
127
128 let output = CubeTensor::new_contiguous(verts.client.clone(), verts.device.clone(), shape_out, buffer, F::dtype());
130
131 let cube_dim = CubeDim { x: 256, y: 1, z: 1 }; #[allow(clippy::cast_possible_truncation)]
134 let cubes_needed_in_x = num_faces.div_ceil(cube_dim.x as usize) as u32;
135 let cube_count = CubeCount::Static(cubes_needed_in_x, 1, 1);
136
137 face_normals_kernel::launch::<F, R>(
139 &verts.client,
140 cube_count,
141 cube_dim,
142 verts.as_tensor_arg::<F>(1),
143 faces.as_tensor_arg::<F>(1),
144 output.as_tensor_arg::<F>(1),
145 );
146
147 output
148}
149
150pub fn vertex_normals_launch<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
151 face_normals: FloatTensor<CubeBackend<R, F, I, BT>>, row_ptr: IntTensor<CubeBackend<R, F, I, BT>>, col_idx: IntTensor<CubeBackend<R, F, I, BT>>, num_vertices: usize,
155) -> FloatTensor<CubeBackend<R, F, I, BT>> {
156 face_normals.assert_is_on_same_device(&row_ptr);
157 face_normals.assert_is_on_same_device(&col_idx);
158
159 let shape_out = Shape::from(vec![num_vertices, 3usize]);
161 let bytes = shape_out.num_elements() * core::mem::size_of::<F>();
162 let buffer = face_normals.client.empty(bytes);
163
164 let output = CubeTensor::new_contiguous(face_normals.client.clone(), face_normals.device.clone(), shape_out, buffer, F::dtype());
166
167 let cube_dim = CubeDim { x: 256, y: 1, z: 1 };
169 #[allow(clippy::cast_possible_truncation)]
170 let cubes_needed_in_x = num_vertices.div_ceil(cube_dim.x as usize) as u32;
171 let cube_count = CubeCount::Static(cubes_needed_in_x, 1, 1);
172
173 vertex_normals_kernel::launch::<F, R>(
175 &face_normals.client,
176 cube_count,
177 cube_dim,
178 face_normals.as_tensor_arg::<F>(1),
179 row_ptr.as_tensor_arg::<I>(1),
180 col_idx.as_tensor_arg::<I>(1),
181 output.as_tensor_arg::<F>(1),
182 );
183
184 output
185}