cubecl_core/
runtime.rs

1use crate::compute::CubeTask;
2use crate::{codegen::Compiler, ir::Elem};
3use cubecl_runtime::id::DeviceId;
4use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
5
6pub use cubecl_runtime::channel;
7pub use cubecl_runtime::client;
8pub use cubecl_runtime::server;
9pub use cubecl_runtime::tune;
10
11/// Runtime for the CubeCL.
12pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
13    /// The compiler used to compile the inner representation into tokens.
14    type Compiler: Compiler;
15    /// The compute server used to run kernels and perform autotuning.
16    type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>, Feature = Feature>;
17    /// The channel used to communicate with the compute server.
18    type Channel: ComputeChannel<Self::Server>;
19    /// The device used to retrieve the compute client.
20    type Device: Default + Clone + core::fmt::Debug + Send + Sync;
21
22    /// Fetch the id for the given device.
23    fn device_id(device: &Self::Device) -> DeviceId;
24
25    /// Retrieve the compute client from the runtime device.
26    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
27
28    /// The runtime name on the given device.
29    fn name(client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str;
30
31    /// Return true if global input array lengths should be added to kernel info.
32    fn require_array_lengths() -> bool {
33        false
34    }
35
36    /// Returns the supported line sizes for the current runtime's compiler.
37    fn supported_line_sizes() -> &'static [u8];
38
39    /// Returns all line sizes that are useful to perform IO operation on the given element.
40    fn line_size_elem(elem: &Elem) -> impl Iterator<Item = u8> + Clone {
41        Self::supported_line_sizes()
42            .iter()
43            .filter(|v| **v as usize * elem.size() <= 16)
44            .cloned() // 128 bits
45    }
46
47    /// Returns the maximum cube count on each dimension that can be launched.
48    fn max_cube_count() -> (u32, u32, u32);
49
50    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool;
51}
52
53/// Every feature that can be supported by a [cube runtime](Runtime).
54#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
55pub enum Feature {
56    /// The plane feature enables all basic warp/subgroup operations.
57    Plane,
58    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
59    Cmma {
60        a: Elem,
61        b: Elem,
62        c: Elem,
63        m: u8,
64        k: u8,
65        n: u8,
66    },
67    CmmaWarpSize(i32),
68    Type(Elem),
69    /// Features supported for floating point atomics.
70    AtomicFloat(AtomicFeature),
71    /// Features supported for integer atomics.
72    AtomicInt(AtomicFeature),
73    /// Features supported for unsigned integer atomics.
74    AtomicUInt(AtomicFeature),
75    /// The pipeline feature enables pipelined (async) operations
76    Pipeline,
77    /// The barrier feature enables barrier (async) operations
78    Barrier,
79    /// Tensor Memory Accelerator features. Minimum H100/RTX 5000 series for base set
80    Tma(TmaFeature),
81    /// Clustered launches and intra-cluster operations like cluster shared memory
82    CubeCluster,
83    /// Enables to change the line size of containers during kernel execution.
84    DynamicLineSize,
85    /// Enables synchronization within a plane only
86    SyncPlane,
87}
88
89/// Atomic features that may be supported by a [cube runtime](Runtime).
90#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
91pub enum AtomicFeature {
92    LoadStore,
93    Add,
94    MinMax,
95}
96
97/// Atomic features that may be supported by a [cube runtime](Runtime).
98#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
99pub enum TmaFeature {
100    /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
101    Base,
102    /// im2colWide encoding for tensor map.
103    /// TODO: Not yet implemented due to missing `cudarc` support
104    Im2colWide,
105}