1#[macro_use]
2extern crate derive_new;
3extern crate alloc;
4
5mod compute;
6mod device;
7mod runtime;
8
9pub use device::*;
10pub use runtime::*;
11
12#[cfg(feature = "ptx-wmma")]
13pub(crate) type WmmaCompiler = cubecl_cpp::cuda::mma::PtxWmmaCompiler;
14
15#[cfg(not(feature = "ptx-wmma"))]
16pub(crate) type WmmaCompiler = cubecl_cpp::cuda::mma::CudaWmmaCompiler;
17
18pub mod install {
19 use std::path::PathBuf;
20
21 pub fn include_path() -> PathBuf {
22 let mut path = cuda_path().expect("
23 CUDA installation not found.
24 Please ensure that CUDA is installed and the CUDA_PATH environment variable is set correctly.
25 Note: Default paths are used for Linux (/usr/local/cuda) and Windows (C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/), which may not be correct.
26 ");
27 path.push("include");
28 path
29 }
30
31 pub fn cccl_include_path() -> PathBuf {
32 let mut path = include_path();
33 path.push("cccl");
34 path
35 }
36
37 pub fn cuda_path() -> Option<PathBuf> {
38 if let Ok(path) = std::env::var("CUDA_PATH") {
39 return Some(PathBuf::from(path));
40 }
41
42 #[cfg(target_os = "linux")]
43 {
44 return if std::fs::exists("/usr/local/cuda").is_ok_and(|exists| exists) {
46 Some(PathBuf::from("/usr/local/cuda"))
47 } else if std::fs::exists("/opt/cuda").is_ok_and(|exists| exists) {
48 Some(PathBuf::from("/opt/cuda"))
49 } else if std::fs::exists("/usr/bin/nvcc").is_ok_and(|exists| exists) {
50 Some(PathBuf::from("/usr"))
52 } else {
53 None
54 };
55 }
56
57 #[cfg(target_os = "windows")]
58 {
59 return Some(PathBuf::from(
60 "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/",
61 ));
62 }
63
64 #[allow(unreachable_code)]
65 None
66 }
67}
68
69#[cfg(test)]
70#[allow(unexpected_cfgs)]
71mod tests {
72 pub type TestRuntime = crate::CudaRuntime;
73
74 pub use half::{bf16, f16};
75
76 cubecl_core::testgen_all!(f32: [f16, bf16, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]);
77
78 cubecl_std::testgen!();
79
80 cubecl_matmul::testgen_matmul_plane_accelerated!();
81 cubecl_matmul::testgen_matmul_plane_mma!();
82 cubecl_matmul::testgen_matmul_plane_vecmat!();
83 cubecl_matmul::testgen_matmul_unit!();
84 cubecl_matmul::testgen_matmul_tma!();
85 cubecl_quant::testgen_quant!();
86
87 cubecl_matmul::testgen_matmul_simple!([f16, bf16, f32]);
89 cubecl_std::testgen_tensor_identity!([f16, bf16, f32, u32]);
90 cubecl_std::testgen_quantized_view!(f16);
91 cubecl_convolution::testgen_conv2d_accelerated!([f16: f16, bf16: bf16, f32: tf32]);
92 cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]);
93 cubecl_random::testgen_random!();
94 cubecl_attention::testgen_attention!();
95 cubecl_reduce::testgen_shared_sum!([f16, bf16, f32, f64]);
96}