Skip to main content

burn_cuda/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3extern crate alloc;
4
5use burn_cubecl::CubeBackend;
6pub use cubecl::cuda::CudaDevice;
7use cubecl::cuda::CudaRuntime;
8
9#[cfg(not(feature = "fusion"))]
10pub type Cuda<F = f32, I = i32> = CubeBackend<CudaRuntime, F, I, u8>;
11
12#[cfg(feature = "fusion")]
13pub type Cuda<F = f32, I = i32> = burn_fusion::Fusion<CubeBackend<CudaRuntime, F, I, u8>>;
14
15#[cfg(test)]
16mod tests {
17    use super::*;
18    use burn_backend::{Backend, DType, QTensorPrimitive};
19    use burn_cubecl::tensor::CubeTensor;
20
21    #[test]
22    fn should_support_dtypes() {
23        type B = Cuda;
24        let device = Default::default();
25
26        assert!(B::supports_dtype(&device, DType::F32));
27        assert!(B::supports_dtype(&device, DType::Flex32));
28        assert!(B::supports_dtype(&device, DType::F16));
29        assert!(B::supports_dtype(&device, DType::BF16));
30        assert!(B::supports_dtype(&device, DType::I64));
31        assert!(B::supports_dtype(&device, DType::I32));
32        assert!(B::supports_dtype(&device, DType::I16));
33        assert!(B::supports_dtype(&device, DType::I8));
34        assert!(B::supports_dtype(&device, DType::U64));
35        assert!(B::supports_dtype(&device, DType::U32));
36        assert!(B::supports_dtype(&device, DType::U16));
37        assert!(B::supports_dtype(&device, DType::U8));
38        assert!(B::supports_dtype(&device, DType::Bool));
39        assert!(B::supports_dtype(
40            &device,
41            DType::QFloat(CubeTensor::<CudaRuntime>::default_scheme())
42        ));
43
44        // Currently not registered in supported types
45        assert!(!B::supports_dtype(&device, DType::F64));
46    }
47}