Skip to main content

cubecl_wgpu/
lib.rs

1#[macro_use]
2extern crate derive_new;
3
4extern crate alloc;
5
6mod backend;
7mod compiler;
8mod compute;
9mod device;
10mod element;
11mod graphics;
12mod runtime;
13
14pub use compiler::base::*;
15pub use compiler::wgsl::WgslCompiler;
16pub use compute::*;
17pub use device::*;
18pub use element::*;
19pub use graphics::*;
20pub use runtime::*;
21
22#[cfg(feature = "spirv")]
23pub use backend::vulkan;
24
25#[cfg(all(feature = "msl", target_os = "macos"))]
26pub use backend::metal;
27
28#[cfg(all(test, not(feature = "spirv"), not(feature = "msl")))]
29#[allow(unexpected_cfgs)]
30mod tests {
31    pub type TestRuntime = crate::WgpuRuntime;
32    use half::f16;
33
34    // Include 64-bit types (i64, u64) for WGSL as wgpu supports them. These don't exist on
35    // native WebGPU however.
36    //
37    // Also include f16, this is an extension but supported by wgpu and WebGPU.
38    cubecl_core::testgen_all!(f32: [f16, f32], i32: [i32, i64], u32: [u32, u64]);
39    cubecl_std::testgen!();
40    cubecl_std::testgen_tensor_identity!([flex32, f32, u32]);
41    cubecl_std::testgen_quantized_view!(f32);
42}
43
44#[cfg(all(test, feature = "spirv"))]
45#[allow(unexpected_cfgs)]
46mod tests_spirv {
47    pub type TestRuntime = crate::WgpuRuntime;
48    use cubecl_core::flex32;
49    use half::f16;
50
51    cubecl_core::testgen_all!(f32: [f16, flex32, f32], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]);
52    cubecl_std::testgen!();
53    cubecl_std::testgen_tensor_identity!([f16, flex32, f32, u32]);
54    cubecl_std::testgen_quantized_view!(f16);
55}
56
57#[cfg(all(test, feature = "msl"))]
58#[allow(unexpected_cfgs)]
59mod tests_msl {
60    pub type TestRuntime = crate::WgpuRuntime;
61    use half::f16;
62
63    cubecl_core::testgen_all!(f32: [f16, f32], i32: [i16, i32], u32: [u16, u32]);
64    cubecl_std::testgen!();
65    cubecl_std::testgen_tensor_identity!([f16, flex32, f32, u32]);
66    cubecl_std::testgen_quantized_view!(f16);
67}