gloss_geometry/cubecl_normals/
launch_normals.rs

1// Required imports
2use burn::tensor::{Int, Tensor};
3
4use crate::cubecl::{cube2tensor, tensor2cube, tensor2cube_int};
5use crate::{csr::VertexFaceCSRBurn, cubecl_normals};
6
7use gloss_burn_multibackend::backend::MultiBackend;
8
9pub fn compute_per_vertex_normals_cubecl(
10    verts: Tensor<MultiBackend, 2>,
11    faces: Tensor<MultiBackend, 2, Int>,
12    csr: &VertexFaceCSRBurn<MultiBackend>,
13) -> Tensor<MultiBackend, 2> {
14    let verts_cube = tensor2cube(verts);
15    let faces_cube = tensor2cube_int(faces);
16    let row_ptr_cube = tensor2cube_int(csr.row_ptr.clone());
17    let col_idx_cube = tensor2cube_int(csr.col_idx.clone());
18    let num_vertices = csr.num_vertices;
19
20    let faces_normals_cube = cubecl_normals::compute_normals::face_normals_launch::<
21        cubecl::wgpu::WgpuRuntime,
22        <MultiBackend as burn::prelude::Backend>::FloatElem,
23        <MultiBackend as burn::prelude::Backend>::IntElem,
24        <MultiBackend as burn::prelude::Backend>::BoolElem,
25    >(verts_cube.clone(), faces_cube.clone());
26
27    let vert_normals_cube = cubecl_normals::compute_normals::vertex_normals_launch::<
28        cubecl::wgpu::WgpuRuntime,
29        <MultiBackend as burn::prelude::Backend>::FloatElem,
30        <MultiBackend as burn::prelude::Backend>::IntElem,
31        <MultiBackend as burn::prelude::Backend>::BoolElem,
32    >(faces_normals_cube.clone(), row_ptr_cube.clone(), col_idx_cube.clone(), num_vertices);
33
34    cube2tensor(vert_normals_cube)
35}