gloss_geometry/cubecl_tangents/
compute_tangents.rs

1use burn::tensor::TensorMetadata;
2use burn::tensor::{
3    ops::{FloatTensor, IntTensor},
4    Shape,
5};
6use burn_cubecl::{tensor::CubeTensor, BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
7
8use cubecl::prelude::*;
9
10#[cube(launch)]
11#[allow(clippy::similar_names)]
12pub fn face_tangents_kernel<F: Float>(
13    verts: &Tensor<F>,               // flattened [num_verts, 3]
14    uvs: &Tensor<F>,                 // flattened [num_verts, 2]
15    faces: &Tensor<i32>,             // flattened [num_faces, 3]
16    face_tangents: &mut Tensor<F>,   // flattened [num_faces, 3]
17    face_bitangents: &mut Tensor<F>, // flattened [num_faces, 3]
18) {
19    let fid = ABSOLUTE_POS_X;
20
21    let num_faces = face_tangents.shape(0);
22    if fid >= num_faces {
23        terminate!();
24    }
25
26    // Get vertex indices for this face
27    let mut v_indices = Line::<i32>::empty(3u32);
28    for i in 0..3 {
29        let face_idx = fid * 3 + i;
30        v_indices[i] = faces[face_idx] as i32;
31    }
32
33    // Load vertex positions
34    let mut v0 = Line::<F>::empty(3u32);
35    let mut v1 = Line::<F>::empty(3u32);
36    let mut v2 = Line::<F>::empty(3u32);
37
38    for c in 0..3 {
39        v0[c] = verts[v_indices[0] as u32 * 3 + c];
40        v1[c] = verts[v_indices[1] as u32 * 3 + c];
41        v2[c] = verts[v_indices[2] as u32 * 3 + c];
42    }
43
44    // Load UV coordinates
45    let mut uv0 = Line::<F>::empty(2u32);
46    let mut uv1 = Line::<F>::empty(2u32);
47    let mut uv2 = Line::<F>::empty(2u32);
48
49    for c in 0..2 {
50        uv0[c] = uvs[v_indices[0] as u32 * 2 + c];
51        uv1[c] = uvs[v_indices[1] as u32 * 2 + c];
52        uv2[c] = uvs[v_indices[2] as u32 * 2 + c];
53    }
54
55    // Compute position deltas
56    let mut delta_pos1 = Line::<F>::empty(3u32);
57    let mut delta_pos2 = Line::<F>::empty(3u32);
58    for c in 0..3 {
59        delta_pos1[c] = v1[c] - v0[c];
60        delta_pos2[c] = v2[c] - v0[c];
61    }
62
63    // Compute UV deltas
64    let mut delta_uv1 = Line::<F>::empty(2u32);
65    let mut delta_uv2 = Line::<F>::empty(2u32);
66    for c in 0..2 {
67        delta_uv1[c] = uv1[c] - uv0[c];
68        delta_uv2[c] = uv2[c] - uv0[c];
69    }
70
71    // denominator (for solving tangent/bitangent)
72    let denom = delta_uv1[0] * delta_uv2[1] - delta_uv1[1] * delta_uv2[0];
73    let eps = F::new(1e-6);
74    let r = F::new(1.0) / (denom + eps);
75
76    // tangent = (deltaPos1 * dv2 - deltaPos2 * dv1) * r
77    // bitangent = (deltaPos2 * du1 - deltaPos1 * du2) * r
78    let mut tangent = Line::<F>::empty(3u32);
79    let mut bitangent = Line::<F>::empty(3u32);
80
81    for c in 0..3 {
82        tangent[c] = (delta_pos1[c] * delta_uv2[1] - delta_pos2[c] * delta_uv1[1]) * r;
83        bitangent[c] = (delta_pos2[c] * delta_uv1[0] - delta_pos1[c] * delta_uv2[0]) * r;
84    }
85
86    // Normalize tangent and bitangent (helps stability)
87    let mut tlen_sq = F::new(0.0);
88    let mut blen_sq = F::new(0.0);
89    for c in 0..3 {
90        tlen_sq += tangent[c] * tangent[c];
91        blen_sq += bitangent[c] * bitangent[c];
92    }
93
94    let t_inv = F::new(1.0) / (F::sqrt(tlen_sq) + eps);
95    let b_inv = F::new(1.0) / (F::sqrt(blen_sq) + eps);
96
97    // Write normalized results
98    for c in 0..3 {
99        face_tangents[fid * 3 + c] = tangent[c] * t_inv;
100        face_bitangents[fid * 3 + c] = bitangent[c] * b_inv;
101    }
102}
103
104#[cube(launch)]
105#[allow(clippy::similar_names)]
106pub fn vertex_tangents_kernel<F: Float>(
107    face_tangents: &Tensor<F>,       // flattened [num_faces, 3]
108    face_bitangents: &Tensor<F>,     // flattened [num_faces, 3]
109    row_ptr: &Tensor<i32>,           // [num_vertices + 1]
110    col_idx: &Tensor<i32>,           // [total_incidents] (face indices)
111    normals: &Tensor<F>,             // flattened [num_vertices, 3]
112    vertex_tangents: &mut Tensor<F>, // flattened [num_vertices, 4] (x,y,z,handness)
113) {
114    let v = ABSOLUTE_POS_X;
115
116    let num_vertices = vertex_tangents.shape(0);
117    if v >= num_vertices {
118        terminate!();
119    }
120
121    // read CSR range
122    let start_i: i32 = row_ptr[v];
123    let end_i: i32 = row_ptr[v + 1];
124
125    // accumulator for tangent vector accumulated over all incident faces
126    let mut ax: F = F::new(0.0);
127    let mut ay: F = F::new(0.0);
128    let mut az: F = F::new(0.0);
129
130    //accumulators for bitangent vector accumulated over all incident faces
131    let mut bx: F = F::new(0.0);
132    let mut by: F = F::new(0.0);
133    let mut bz: F = F::new(0.0);
134
135    // accumulate tangents/bitangents touching vertex v
136    let mut i = start_i;
137    while i < end_i {
138        #[allow(clippy::cast_sign_loss)]
139        let face_idx = col_idx[i as u32] as u32;
140        let base = face_idx * 3;
141        ax += face_tangents[base];
142        ay += face_tangents[base + 1];
143        az += face_tangents[base + 2];
144
145        bx += face_bitangents[base];
146        by += face_bitangents[base + 1];
147        bz += face_bitangents[base + 2];
148
149        i += 1;
150    }
151
152    // If accumulated tangent is zero-length, write zeros
153    let eps = F::new(1e-6);
154    let tlen2 = ax * ax + ay * ay + az * az;
155    if tlen2 <= eps * eps {
156        vertex_tangents[v * 4] = F::new(0.0);
157        vertex_tangents[v * 4 + 1] = F::new(0.0);
158        vertex_tangents[v * 4 + 2] = F::new(0.0);
159        vertex_tangents[v * 4 + 3] = F::new(0.0);
160        terminate!();
161    }
162
163    // Gram-Schmidt: make tangent orthogonal to normal
164    let nx = normals[v * 3];
165    let ny = normals[v * 3 + 1];
166    let nz = normals[v * 3 + 2];
167
168    // dot = normal . tangent
169    let dot = nx * ax + ny * ay + nz * az;
170
171    // t = t - n * dot
172    let tx = ax - nx * dot;
173    let ty = ay - ny * dot;
174    let tz = az - nz * dot;
175
176    // normalize t
177    let tlen = F::sqrt(tx * tx + ty * ty + tz * tz);
178    let invt = F::new(1.0) / (tlen + eps);
179    let ntx = tx * invt;
180    let nty = ty * invt;
181    let ntz = tz * invt;
182
183    // compute handedness: sign( (tangent cross bitangent) dot normal )
184    // cross = tangent x bitangent
185    // unclear weather we do the cross with original accumulations or with normalized tangent and accumulated bitangent
186    // for now I just do it with the accumulations ( so unnormalized tangent and unnormalized bitangent )
187    let c_x = ay * bz - az * by;
188    let c_y = az * bx - ax * bz;
189    let c_z = ax * by - ay * bx;
190
191    // The following is with the normalized tangent instead of the accumulated tangent
192    // let c_x = nty * bz - ntz * by;
193    // let c_y = ntz * bx - ntx * bz;
194    // let c_z = ntx * by - nty * bx;
195
196    let handed_dot = c_x * nx + c_y * ny + c_z * nz;
197
198    // handedness sign: +1 or -1
199    let hand = if handed_dot >= F::new(0.0) { F::new(1.0) } else { F::new(-1.0) };
200
201    // write tangent (x,y,z) and handness as w
202    vertex_tangents[v * 4] = ntx;
203    vertex_tangents[v * 4 + 1] = nty;
204    vertex_tangents[v * 4 + 2] = ntz;
205    vertex_tangents[v * 4 + 3] = hand;
206}
207
208//launchers
209#[allow(clippy::type_complexity)]
210pub fn face_tangents_launch<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
211    verts: FloatTensor<CubeBackend<R, F, I, BT>>, // [num_verts, 3]
212    uvs: FloatTensor<CubeBackend<R, F, I, BT>>,   // [num_verts, 2]
213    faces: IntTensor<CubeBackend<R, F, I, BT>>,   // [num_faces, 3]
214) -> (
215    FloatTensor<CubeBackend<R, F, I, BT>>, // face_tangents [num_faces, 3]
216    FloatTensor<CubeBackend<R, F, I, BT>>, // face_bitangents [num_faces, 3]
217) {
218    verts.assert_is_on_same_device(&uvs);
219    verts.assert_is_on_same_device(&faces);
220
221    let num_faces = faces.shape().dims::<2>()[0];
222
223    // Allocate buffers for tangents and bitangents
224    let shape_out = Shape::from(vec![num_faces, 3usize]);
225    let bytes = shape_out.num_elements() * core::mem::size_of::<F>();
226    let buffer_tangent = verts.client.empty(bytes);
227    let buffer_bitangent = verts.client.empty(bytes);
228
229    // Create CubeTensors
230    let face_tangents = CubeTensor::new_contiguous(verts.client.clone(), verts.device.clone(), shape_out.clone(), buffer_tangent, F::dtype());
231    let face_bitangents = CubeTensor::new_contiguous(verts.client.clone(), verts.device.clone(), shape_out, buffer_bitangent, F::dtype());
232
233    // Workgroup/cube config
234    let cube_dim = CubeDim { x: 256, y: 1, z: 1 };
235    #[allow(clippy::cast_possible_truncation)]
236    let cubes_needed_in_x = num_faces.div_ceil(cube_dim.x as usize) as u32;
237    let cube_count = CubeCount::Static(cubes_needed_in_x, 1, 1);
238
239    // Launch kernel
240    face_tangents_kernel::launch::<F, R>(
241        &verts.client,
242        cube_count,
243        cube_dim,
244        verts.as_tensor_arg::<F>(1),
245        uvs.as_tensor_arg::<F>(1),
246        faces.as_tensor_arg::<F>(1),
247        face_tangents.as_tensor_arg::<F>(1),
248        face_bitangents.as_tensor_arg::<F>(1),
249    );
250
251    (face_tangents, face_bitangents)
252}
253
254pub fn vertex_tangents_launch<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
255    face_tangents: FloatTensor<CubeBackend<R, F, I, BT>>,   // [num_faces, 3]
256    face_bitangents: FloatTensor<CubeBackend<R, F, I, BT>>, // [num_faces, 3]
257    row_ptr: IntTensor<CubeBackend<R, F, I, BT>>,           // [num_vertices + 1]
258    col_idx: IntTensor<CubeBackend<R, F, I, BT>>,           // [total_incidents]
259    normals: FloatTensor<CubeBackend<R, F, I, BT>>,         // [num_vertices, 3]
260    num_vertices: usize,
261) -> FloatTensor<CubeBackend<R, F, I, BT>> {
262    // Device checks
263    face_tangents.assert_is_on_same_device(&face_bitangents);
264    face_tangents.assert_is_on_same_device(&row_ptr);
265    face_tangents.assert_is_on_same_device(&col_idx);
266    face_tangents.assert_is_on_same_device(&normals);
267
268    // Output: [num_vertices, 4] (tangent vec3 + handedness scalar)
269    let shape_out = Shape::from(vec![num_vertices, 4usize]);
270    let bytes = shape_out.num_elements() * core::mem::size_of::<F>();
271    let buffer = face_tangents.client.empty(bytes);
272
273    // wrap the buffer Handle into CubeTensor primitive for output.
274    let output = CubeTensor::new_contiguous(face_tangents.client.clone(), face_tangents.device.clone(), shape_out, buffer, F::dtype());
275
276    // Workgroup config: one thread per vertex
277    let cube_dim = CubeDim { x: 256, y: 1, z: 1 };
278    #[allow(clippy::cast_possible_truncation)]
279    let cubes_needed_in_x = num_vertices.div_ceil(cube_dim.x as usize) as u32;
280    let cube_count = CubeCount::Static(cubes_needed_in_x, 1, 1);
281
282    // Launch kernel
283    vertex_tangents_kernel::launch::<F, R>(
284        &face_tangents.client,
285        cube_count,
286        cube_dim,
287        face_tangents.as_tensor_arg::<F>(1),
288        face_bitangents.as_tensor_arg::<F>(1),
289        row_ptr.as_tensor_arg::<I>(1),
290        col_idx.as_tensor_arg::<I>(1),
291        normals.as_tensor_arg::<F>(1),
292        output.as_tensor_arg::<F>(1),
293    );
294
295    output
296}