cubecl_core/runtime.rs
1use crate::compute::CubeTask;
2use crate::{codegen::Compiler, ir::Elem};
3use cubecl_runtime::id::DeviceId;
4use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
5
6pub use cubecl_runtime::channel;
7pub use cubecl_runtime::client;
8pub use cubecl_runtime::server;
9pub use cubecl_runtime::tune;
10
11/// Runtime for the CubeCL.
12pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
13 /// The compiler used to compile the inner representation into tokens.
14 type Compiler: Compiler;
15 /// The compute server used to run kernels and perform autotuning.
16 type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>, Feature = Feature>;
17 /// The channel used to communicate with the compute server.
18 type Channel: ComputeChannel<Self::Server>;
19 /// The device used to retrieve the compute client.
20 type Device: Default + Clone + core::fmt::Debug + Send + Sync;
21
22 /// Fetch the id for the given device.
23 fn device_id(device: &Self::Device) -> DeviceId;
24
25 /// Retrieve the compute client from the runtime device.
26 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
27
28 /// The runtime name on the given device.
29 fn name(client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str;
30
31 /// Return true if global input array lengths should be added to kernel info.
32 fn require_array_lengths() -> bool {
33 false
34 }
35
36 /// Returns the supported line sizes for the current runtime's compiler.
37 fn supported_line_sizes() -> &'static [u8];
38
39 /// Returns all line sizes that are useful to perform IO operation on the given element.
40 fn line_size_elem(elem: &Elem) -> impl Iterator<Item = u8> + Clone {
41 Self::supported_line_sizes()
42 .iter()
43 .filter(|v| **v as usize * elem.size() <= 16)
44 .cloned() // 128 bits
45 }
46
47 /// Returns the maximum cube count on each dimension that can be launched.
48 fn max_cube_count() -> (u32, u32, u32);
49
50 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool;
51}
52
53/// Every feature that can be supported by a [cube runtime](Runtime).
54#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
55pub enum Feature {
56 /// The plane feature enables all basic warp/subgroup operations.
57 Plane,
58 /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
59 Cmma {
60 a: Elem,
61 b: Elem,
62 c: Elem,
63 m: u8,
64 k: u8,
65 n: u8,
66 },
67 CmmaWarpSize(i32),
68 Type(Elem),
69 /// Features supported for floating point atomics.
70 AtomicFloat(AtomicFeature),
71 /// Features supported for integer atomics.
72 AtomicInt(AtomicFeature),
73 /// Features supported for unsigned integer atomics.
74 AtomicUInt(AtomicFeature),
75 /// The pipeline feature enables pipelined (async) operations
76 Pipeline,
77 /// The barrier feature enables barrier (async) operations
78 Barrier,
79 /// Tensor Memory Accelerator features. Minimum H100/RTX 5000 series for base set
80 Tma(TmaFeature),
81 /// Clustered launches and intra-cluster operations like cluster shared memory
82 CubeCluster,
83 /// Enables to change the line size of containers during kernel execution.
84 DynamicLineSize,
85 /// Enables synchronization within a plane only
86 SyncPlane,
87}
88
89/// Atomic features that may be supported by a [cube runtime](Runtime).
90#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
91pub enum AtomicFeature {
92 LoadStore,
93 Add,
94 MinMax,
95}
96
97/// Atomic features that may be supported by a [cube runtime](Runtime).
98#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
99pub enum TmaFeature {
100 /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
101 Base,
102 /// im2colWide encoding for tensor map.
103 /// TODO: Not yet implemented due to missing `cudarc` support
104 Im2colWide,
105}