gloss_geometry/cubecl_tangents/
launch_tangents.rs

1// Required imports
2use burn::tensor::{Int, Tensor};
3
4use crate::cubecl::{cube2tensor, tensor2cube, tensor2cube_int};
5use crate::{csr::VertexFaceCSRBurn, cubecl_tangents};
6
7use gloss_burn_multibackend::backend::MultiBackend;
8
9pub fn compute_tangents_cubecl(
10    verts: Tensor<MultiBackend, 2>,
11    faces: Tensor<MultiBackend, 2, Int>,
12    normals: Tensor<MultiBackend, 2>,
13    uv: Tensor<MultiBackend, 2>,
14    csr: &VertexFaceCSRBurn<MultiBackend>,
15) -> Tensor<MultiBackend, 2> {
16    let verts_cube = tensor2cube(verts);
17    let faces_cube = tensor2cube_int(faces);
18    let normals_cube = tensor2cube(normals);
19    let uv_cube = tensor2cube(uv);
20    let row_ptr_cube = tensor2cube_int(csr.row_ptr.clone());
21    let col_idx_cube = tensor2cube_int(csr.col_idx.clone());
22    let num_vertices = csr.num_vertices;
23
24    let (faces_tangents_cube, faces_bitangents_cube) = cubecl_tangents::compute_tangents::face_tangents_launch::<
25        cubecl::wgpu::WgpuRuntime,
26        <MultiBackend as burn::prelude::Backend>::FloatElem,
27        <MultiBackend as burn::prelude::Backend>::IntElem,
28        <MultiBackend as burn::prelude::Backend>::BoolElem,
29    >(verts_cube.clone(), uv_cube.clone(), faces_cube.clone());
30
31    let vert_tangents_cube = cubecl_tangents::compute_tangents::vertex_tangents_launch::<
32        cubecl::wgpu::WgpuRuntime,
33        <MultiBackend as burn::prelude::Backend>::FloatElem,
34        <MultiBackend as burn::prelude::Backend>::IntElem,
35        <MultiBackend as burn::prelude::Backend>::BoolElem,
36    >(
37        faces_tangents_cube.clone(),
38        faces_bitangents_cube.clone(),
39        row_ptr_cube.clone(),
40        col_idx_cube.clone(),
41        normals_cube.clone(),
42        num_vertices,
43    );
44
45    cube2tensor(vert_tangents_cube)
46}