1pub mod csr;
2pub mod cubecl_normals;
3pub mod cubecl_tangents;
4pub mod geom;
5
6pub mod cubecl {
7 pub use crate::cubecl_normals::launch_normals::compute_per_vertex_normals_cubecl;
8 pub use crate::cubecl_tangents::launch_tangents::compute_tangents_cubecl;
9
10 use burn::tensor::{Int, Tensor, TensorMetadata, TensorPrimitive};
11 use burn_cubecl::tensor::CubeTensor;
12 use core::panic;
13 use gloss_burn_multibackend::{backend::MultiBackend, tensor::MultiFloatTensor, tensor::MultiIntTensor};
14
15 #[allow(clippy::match_wildcard_for_single_variants)]
16 pub fn tensor2cube<const D: usize>(tensor: Tensor<MultiBackend, D>) -> CubeTensor<cubecl::wgpu::WgpuRuntime> {
17 let prim = tensor.into_primitive();
18 let TensorPrimitive::Float(t) = prim else {
19 panic!("Expected float tensor got {:?}", prim.dtype())
20 };
21 match t {
22 MultiFloatTensor::Wgpu(t) => t,
23 _ => panic!("Expected wgpu tensor got tensor {:?} with type {:?}", t, t.dtype()),
24 }
25 }
26
27 #[allow(clippy::match_wildcard_for_single_variants)]
28 pub fn tensor2cube_int<const D: usize>(tensor: Tensor<MultiBackend, D, Int>) -> CubeTensor<cubecl::wgpu::WgpuRuntime> {
29 let prim = tensor.into_primitive();
30 match prim {
31 MultiIntTensor::Wgpu(t) => t,
32 _ => panic!("Expected wgpu tensor got {:?}", prim.dtype()),
33 }
34 }
35
36 pub fn cube2tensor<const D: usize>(t: CubeTensor<cubecl::wgpu::WgpuRuntime>) -> Tensor<MultiBackend, D> {
37 Tensor::from_primitive(TensorPrimitive::Float(MultiFloatTensor::Wgpu(t)))
38 }
39}