gloss_geometry/cubecl_tangents/
compute_tangents.rs1use burn::tensor::TensorMetadata;
2use burn::tensor::{
3 ops::{FloatTensor, IntTensor},
4 Shape,
5};
6use burn_cubecl::{tensor::CubeTensor, BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
7
8use cubecl::prelude::*;
9
10#[cube(launch)]
11#[allow(clippy::similar_names)]
12pub fn face_tangents_kernel<F: Float>(
13 verts: &Tensor<F>, uvs: &Tensor<F>, faces: &Tensor<i32>, face_tangents: &mut Tensor<F>, face_bitangents: &mut Tensor<F>, ) {
19 let fid = ABSOLUTE_POS_X;
20
21 let num_faces = face_tangents.shape(0);
22 if fid >= num_faces {
23 terminate!();
24 }
25
26 let mut v_indices = Line::<i32>::empty(3u32);
28 for i in 0..3 {
29 let face_idx = fid * 3 + i;
30 v_indices[i] = faces[face_idx] as i32;
31 }
32
33 let mut v0 = Line::<F>::empty(3u32);
35 let mut v1 = Line::<F>::empty(3u32);
36 let mut v2 = Line::<F>::empty(3u32);
37
38 for c in 0..3 {
39 v0[c] = verts[v_indices[0] as u32 * 3 + c];
40 v1[c] = verts[v_indices[1] as u32 * 3 + c];
41 v2[c] = verts[v_indices[2] as u32 * 3 + c];
42 }
43
44 let mut uv0 = Line::<F>::empty(2u32);
46 let mut uv1 = Line::<F>::empty(2u32);
47 let mut uv2 = Line::<F>::empty(2u32);
48
49 for c in 0..2 {
50 uv0[c] = uvs[v_indices[0] as u32 * 2 + c];
51 uv1[c] = uvs[v_indices[1] as u32 * 2 + c];
52 uv2[c] = uvs[v_indices[2] as u32 * 2 + c];
53 }
54
55 let mut delta_pos1 = Line::<F>::empty(3u32);
57 let mut delta_pos2 = Line::<F>::empty(3u32);
58 for c in 0..3 {
59 delta_pos1[c] = v1[c] - v0[c];
60 delta_pos2[c] = v2[c] - v0[c];
61 }
62
63 let mut delta_uv1 = Line::<F>::empty(2u32);
65 let mut delta_uv2 = Line::<F>::empty(2u32);
66 for c in 0..2 {
67 delta_uv1[c] = uv1[c] - uv0[c];
68 delta_uv2[c] = uv2[c] - uv0[c];
69 }
70
71 let denom = delta_uv1[0] * delta_uv2[1] - delta_uv1[1] * delta_uv2[0];
73 let eps = F::new(1e-6);
74 let r = F::new(1.0) / (denom + eps);
75
76 let mut tangent = Line::<F>::empty(3u32);
79 let mut bitangent = Line::<F>::empty(3u32);
80
81 for c in 0..3 {
82 tangent[c] = (delta_pos1[c] * delta_uv2[1] - delta_pos2[c] * delta_uv1[1]) * r;
83 bitangent[c] = (delta_pos2[c] * delta_uv1[0] - delta_pos1[c] * delta_uv2[0]) * r;
84 }
85
86 let mut tlen_sq = F::new(0.0);
88 let mut blen_sq = F::new(0.0);
89 for c in 0..3 {
90 tlen_sq += tangent[c] * tangent[c];
91 blen_sq += bitangent[c] * bitangent[c];
92 }
93
94 let t_inv = F::new(1.0) / (F::sqrt(tlen_sq) + eps);
95 let b_inv = F::new(1.0) / (F::sqrt(blen_sq) + eps);
96
97 for c in 0..3 {
99 face_tangents[fid * 3 + c] = tangent[c] * t_inv;
100 face_bitangents[fid * 3 + c] = bitangent[c] * b_inv;
101 }
102}
103
104#[cube(launch)]
105#[allow(clippy::similar_names)]
106pub fn vertex_tangents_kernel<F: Float>(
107 face_tangents: &Tensor<F>, face_bitangents: &Tensor<F>, row_ptr: &Tensor<i32>, col_idx: &Tensor<i32>, normals: &Tensor<F>, vertex_tangents: &mut Tensor<F>, ) {
114 let v = ABSOLUTE_POS_X;
115
116 let num_vertices = vertex_tangents.shape(0);
117 if v >= num_vertices {
118 terminate!();
119 }
120
121 let start_i: i32 = row_ptr[v];
123 let end_i: i32 = row_ptr[v + 1];
124
125 let mut ax: F = F::new(0.0);
127 let mut ay: F = F::new(0.0);
128 let mut az: F = F::new(0.0);
129
130 let mut bx: F = F::new(0.0);
132 let mut by: F = F::new(0.0);
133 let mut bz: F = F::new(0.0);
134
135 let mut i = start_i;
137 while i < end_i {
138 #[allow(clippy::cast_sign_loss)]
139 let face_idx = col_idx[i as u32] as u32;
140 let base = face_idx * 3;
141 ax += face_tangents[base];
142 ay += face_tangents[base + 1];
143 az += face_tangents[base + 2];
144
145 bx += face_bitangents[base];
146 by += face_bitangents[base + 1];
147 bz += face_bitangents[base + 2];
148
149 i += 1;
150 }
151
152 let eps = F::new(1e-6);
154 let tlen2 = ax * ax + ay * ay + az * az;
155 if tlen2 <= eps * eps {
156 vertex_tangents[v * 4] = F::new(0.0);
157 vertex_tangents[v * 4 + 1] = F::new(0.0);
158 vertex_tangents[v * 4 + 2] = F::new(0.0);
159 vertex_tangents[v * 4 + 3] = F::new(0.0);
160 terminate!();
161 }
162
163 let nx = normals[v * 3];
165 let ny = normals[v * 3 + 1];
166 let nz = normals[v * 3 + 2];
167
168 let dot = nx * ax + ny * ay + nz * az;
170
171 let tx = ax - nx * dot;
173 let ty = ay - ny * dot;
174 let tz = az - nz * dot;
175
176 let tlen = F::sqrt(tx * tx + ty * ty + tz * tz);
178 let invt = F::new(1.0) / (tlen + eps);
179 let ntx = tx * invt;
180 let nty = ty * invt;
181 let ntz = tz * invt;
182
183 let c_x = ay * bz - az * by;
188 let c_y = az * bx - ax * bz;
189 let c_z = ax * by - ay * bx;
190
191 let handed_dot = c_x * nx + c_y * ny + c_z * nz;
197
198 let hand = if handed_dot >= F::new(0.0) { F::new(1.0) } else { F::new(-1.0) };
200
201 vertex_tangents[v * 4] = ntx;
203 vertex_tangents[v * 4 + 1] = nty;
204 vertex_tangents[v * 4 + 2] = ntz;
205 vertex_tangents[v * 4 + 3] = hand;
206}
207
208#[allow(clippy::type_complexity)]
210pub fn face_tangents_launch<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
211 verts: FloatTensor<CubeBackend<R, F, I, BT>>, uvs: FloatTensor<CubeBackend<R, F, I, BT>>, faces: IntTensor<CubeBackend<R, F, I, BT>>, ) -> (
215 FloatTensor<CubeBackend<R, F, I, BT>>, FloatTensor<CubeBackend<R, F, I, BT>>, ) {
218 verts.assert_is_on_same_device(&uvs);
219 verts.assert_is_on_same_device(&faces);
220
221 let num_faces = faces.shape().dims::<2>()[0];
222
223 let shape_out = Shape::from(vec![num_faces, 3usize]);
225 let bytes = shape_out.num_elements() * core::mem::size_of::<F>();
226 let buffer_tangent = verts.client.empty(bytes);
227 let buffer_bitangent = verts.client.empty(bytes);
228
229 let face_tangents = CubeTensor::new_contiguous(verts.client.clone(), verts.device.clone(), shape_out.clone(), buffer_tangent, F::dtype());
231 let face_bitangents = CubeTensor::new_contiguous(verts.client.clone(), verts.device.clone(), shape_out, buffer_bitangent, F::dtype());
232
233 let cube_dim = CubeDim { x: 256, y: 1, z: 1 };
235 #[allow(clippy::cast_possible_truncation)]
236 let cubes_needed_in_x = num_faces.div_ceil(cube_dim.x as usize) as u32;
237 let cube_count = CubeCount::Static(cubes_needed_in_x, 1, 1);
238
239 face_tangents_kernel::launch::<F, R>(
241 &verts.client,
242 cube_count,
243 cube_dim,
244 verts.as_tensor_arg::<F>(1),
245 uvs.as_tensor_arg::<F>(1),
246 faces.as_tensor_arg::<F>(1),
247 face_tangents.as_tensor_arg::<F>(1),
248 face_bitangents.as_tensor_arg::<F>(1),
249 );
250
251 (face_tangents, face_bitangents)
252}
253
254pub fn vertex_tangents_launch<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
255 face_tangents: FloatTensor<CubeBackend<R, F, I, BT>>, face_bitangents: FloatTensor<CubeBackend<R, F, I, BT>>, row_ptr: IntTensor<CubeBackend<R, F, I, BT>>, col_idx: IntTensor<CubeBackend<R, F, I, BT>>, normals: FloatTensor<CubeBackend<R, F, I, BT>>, num_vertices: usize,
261) -> FloatTensor<CubeBackend<R, F, I, BT>> {
262 face_tangents.assert_is_on_same_device(&face_bitangents);
264 face_tangents.assert_is_on_same_device(&row_ptr);
265 face_tangents.assert_is_on_same_device(&col_idx);
266 face_tangents.assert_is_on_same_device(&normals);
267
268 let shape_out = Shape::from(vec![num_vertices, 4usize]);
270 let bytes = shape_out.num_elements() * core::mem::size_of::<F>();
271 let buffer = face_tangents.client.empty(bytes);
272
273 let output = CubeTensor::new_contiguous(face_tangents.client.clone(), face_tangents.device.clone(), shape_out, buffer, F::dtype());
275
276 let cube_dim = CubeDim { x: 256, y: 1, z: 1 };
278 #[allow(clippy::cast_possible_truncation)]
279 let cubes_needed_in_x = num_vertices.div_ceil(cube_dim.x as usize) as u32;
280 let cube_count = CubeCount::Static(cubes_needed_in_x, 1, 1);
281
282 vertex_tangents_kernel::launch::<F, R>(
284 &face_tangents.client,
285 cube_count,
286 cube_dim,
287 face_tangents.as_tensor_arg::<F>(1),
288 face_bitangents.as_tensor_arg::<F>(1),
289 row_ptr.as_tensor_arg::<I>(1),
290 col_idx.as_tensor_arg::<I>(1),
291 normals.as_tensor_arg::<F>(1),
292 output.as_tensor_arg::<F>(1),
293 );
294
295 output
296}