cubecl_core/runtime.rs
1use crate::DeviceId;
2use crate::{codegen::Compiler, compute::CubeTask, ir::Elem};
3use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
4
5pub use cubecl_runtime::channel;
6pub use cubecl_runtime::client;
7pub use cubecl_runtime::server;
8pub use cubecl_runtime::tune;
9
10/// Runtime for the CubeCL.
11pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
12 /// The compiler used to compile the inner representation into tokens.
13 type Compiler: Compiler;
14 /// The compute server used to run kernels and perform autotuning.
15 type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>, Feature = Feature>;
16 /// The channel used to communicate with the compute server.
17 type Channel: ComputeChannel<Self::Server>;
18 /// The device used to retrieve the compute client.
19 type Device: Default + Clone + core::fmt::Debug + Send + Sync;
20
21 /// Fetch the id for the given device.
22 fn device_id(device: &Self::Device) -> DeviceId;
23
24 /// Retrieve the compute client from the runtime device.
25 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
26
27 /// The runtime name on the given device.
28 fn name(client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str;
29
30 /// Return true if global input array lengths should be added to kernel info.
31 fn require_array_lengths() -> bool {
32 false
33 }
34
35 /// Returns the supported line sizes for the current runtime's compiler.
36 fn supported_line_sizes() -> &'static [u8];
37
38 /// Returns all line sizes that are useful to perform IO operation on the given element.
39 fn line_size_elem(elem: &Elem) -> impl Iterator<Item = u8> + Clone {
40 Self::supported_line_sizes()
41 .iter()
42 .filter(|v| **v as usize * elem.size() <= 16)
43 .cloned() // 128 bits
44 }
45
46 /// Returns the maximum cube count on each dimension that can be launched.
47 fn max_cube_count() -> (u32, u32, u32);
48}
49
50/// Every feature that can be supported by a [cube runtime](Runtime).
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
52pub enum Feature {
53 /// The plane feature enables all basic warp/subgroup operations.
54 Plane,
55 /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
56 Cmma {
57 a: Elem,
58 b: Elem,
59 c: Elem,
60 m: u8,
61 k: u8,
62 n: u8,
63 },
64 CmmaWarpSize(i32),
65 Type(Elem),
66 /// Features supported for floating point atomics. For integers, all methods are supported as
67 /// long as the type is.
68 AtomicFloat(AtomicFeature),
69 /// The pipeline feature enables pipelined (async) operations
70 Pipeline,
71 /// The barrier feature enables barrier (async) operations
72 Barrier,
73 /// Tensor Memory Accelerator features. Minimum H100/RTX 5000 series for base set
74 Tma(TmaFeature),
75 /// Clustered launches and intra-cluster operations like cluster shared memory
76 CubeCluster,
77 /// Enables to change the line size of containers during kernel execution.
78 DynamicLineSize,
79}
80
81/// Atomic features that may be supported by a [cube runtime](Runtime).
82#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
83pub enum AtomicFeature {
84 LoadStore,
85 Add,
86 MinMax,
87}
88
89/// Atomic features that may be supported by a [cube runtime](Runtime).
90#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
91pub enum TmaFeature {
92 /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
93 Base,
94 /// im2colWide encoding for tensor map.
95 /// TODO: Not yet implemented due to missing `cudarc` support
96 Im2colWide,
97}