cubecl_cuda/
lib.rs

1#[macro_use]
2extern crate derive_new;
3extern crate alloc;
4
5mod compute;
6mod device;
7mod runtime;
8
9pub use device::*;
10pub use runtime::*;
11
12#[cfg(feature = "ptx-wmma")]
13pub(crate) type WmmaCompiler = cubecl_cpp::cuda::mma::PtxWmmaCompiler;
14#[cfg(not(feature = "ptx-wmma"))]
15pub(crate) type WmmaCompiler = cubecl_cpp::cuda::mma::CudaWmmaCompiler;
16
17#[cfg(test)]
18#[allow(unexpected_cfgs)]
19mod tests {
20    pub type TestRuntime = crate::CudaRuntime;
21
22    pub use half::{bf16, f16};
23
24    cubecl_core::testgen_all!(f32: [f16, bf16, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]);
25
26    cubecl_std::testgen!();
27
28    cubecl_matmul::testgen_matmul_plane_accelerated!();
29    cubecl_matmul::testgen_matmul_unit!();
30    cubecl_matmul::testgen_matmul_tma!();
31    // TODO: re-instate matmul quantized tests
32    cubecl_matmul::testgen_matmul_simple!([f16, bf16, f32]);
33    cubecl_std::testgen_tensor_identity!([f16, bf16, f32, u32]);
34    cubecl_convolution::testgen_conv2d_accelerated!([f16: f16, bf16: bf16, f32: tf32]);
35    cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]);
36    cubecl_random::testgen_random!();
37    cubecl_reduce::testgen_shared_sum!([f16, bf16, f32, f64]);
38}