cubecl_hip/
lib.rs

1#[allow(unused_imports)]
2#[macro_use]
3extern crate derive_new;
4extern crate alloc;
5
6pub mod compute;
7pub mod device;
8pub mod runtime;
9pub use device::*;
10pub use runtime::HipRuntime;
11
12#[cfg(not(feature = "rocwmma"))]
13pub(crate) type HipWmmaCompiler = cubecl_cpp::hip::mma::WmmaIntrinsicCompiler;
14
15#[cfg(feature = "rocwmma")]
16pub(crate) type HipWmmaCompiler = cubecl_cpp::hip::mma::RocWmmaCompiler;
17
18#[cfg(test)]
19mod tests {
20    use half::{bf16, f16};
21    pub type TestRuntime = crate::HipRuntime;
22
23    cubecl_std::testgen!();
24    cubecl_core::testgen_all!(f32: [f16, f32], i32: [i16, i32], u32: [u16, u32]);
25    cubecl_quant::testgen_quant!();
26
27    #[cfg(feature = "matmul_tests_plane")]
28    cubecl_matmul::testgen_matmul_plane_accelerated!();
29    #[cfg(all(feature = "matmul_tests_plane", feature = "matmul_tests_vecmat"))]
30    cubecl_matmul::testgen_matmul_vecmat_accelerated!();
31    #[cfg(feature = "matmul_tests_simple")]
32    cubecl_matmul::testgen_matmul_simple!([f16, f32]);
33
34    cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]);
35    cubecl_reduce::testgen_shared_sum!([f32]);
36}