cubecl_cuda/
lib.rs

1#[macro_use]
2extern crate derive_new;
3extern crate alloc;
4
5mod compute;
6mod device;
7mod runtime;
8
9pub use device::*;
10pub use runtime::*;
11
12#[cfg(feature = "ptx-wmma")]
13pub(crate) type WmmaCompiler = cubecl_cpp::cuda::mma::PtxWmmaCompiler;
14
15#[cfg(not(feature = "ptx-wmma"))]
16pub(crate) type WmmaCompiler = cubecl_cpp::cuda::mma::CudaWmmaCompiler;
17
18pub mod install {
19    use std::path::PathBuf;
20
21    pub fn include_path() -> PathBuf {
22        let mut path = cuda_path().expect("
23        CUDA installation not found.
24        Please ensure that CUDA is installed and the CUDA_PATH environment variable is set correctly.
25        Note: Default paths are used for Linux (/usr/local/cuda) and Windows (C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/), which may not be correct.
26    ");
27        path.push("include");
28        path
29    }
30
31    pub fn cccl_include_path() -> PathBuf {
32        let mut path = include_path();
33        path.push("cccl");
34        path
35    }
36
37    pub fn cuda_path() -> Option<PathBuf> {
38        if let Ok(path) = std::env::var("CUDA_PATH") {
39            return Some(PathBuf::from(path));
40        }
41
42        #[cfg(target_os = "linux")]
43        {
44            // If it is installed as part of the distribution
45            return if std::fs::exists("/usr/local/cuda").is_ok_and(|exists| exists) {
46                Some(PathBuf::from("/usr/local/cuda"))
47            } else if std::fs::exists("/opt/cuda").is_ok_and(|exists| exists) {
48                Some(PathBuf::from("/opt/cuda"))
49            } else if std::fs::exists("/usr/bin/nvcc").is_ok_and(|exists| exists) {
50                // Maybe the compiler was installed within the user path.
51                Some(PathBuf::from("/usr"))
52            } else {
53                None
54            };
55        }
56
57        #[cfg(target_os = "windows")]
58        {
59            return Some(PathBuf::from(
60                "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/",
61            ));
62        }
63
64        #[allow(unreachable_code)]
65        None
66    }
67}
68
69#[cfg(test)]
70#[allow(unexpected_cfgs)]
71mod tests {
72    pub type TestRuntime = crate::CudaRuntime;
73
74    pub use half::{bf16, f16};
75
76    cubecl_core::testgen_all!(f32: [f16, bf16, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]);
77
78    cubecl_std::testgen!();
79
80    cubecl_matmul::testgen_matmul_plane_accelerated!();
81    cubecl_matmul::testgen_matmul_plane_mma!();
82    cubecl_matmul::testgen_matmul_plane_vecmat!();
83    cubecl_matmul::testgen_matmul_unit!();
84    cubecl_matmul::testgen_matmul_tma!();
85    cubecl_quant::testgen_quant!();
86
87    // TODO: re-instate matmul quantized tests
88    cubecl_matmul::testgen_matmul_simple!([f16, bf16, f32]);
89    cubecl_std::testgen_tensor_identity!([f16, bf16, f32, u32]);
90    cubecl_std::testgen_quantized_view!(f16);
91    cubecl_convolution::testgen_conv2d_accelerated!([f16: f16, bf16: bf16, f32: tf32]);
92    cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]);
93    cubecl_random::testgen_random!();
94    cubecl_attention::testgen_attention!();
95    cubecl_reduce::testgen_shared_sum!([f16, bf16, f32, f64]);
96}