cubecl_core/
runtime.rs

1use crate::DeviceId;
2use crate::{codegen::Compiler, compute::CubeTask, ir::Elem};
3use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
4
5pub use cubecl_runtime::channel;
6pub use cubecl_runtime::client;
7pub use cubecl_runtime::server;
8pub use cubecl_runtime::tune;
9
10/// Runtime for the CubeCL.
11pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
12    /// The compiler used to compile the inner representation into tokens.
13    type Compiler: Compiler;
14    /// The compute server used to run kernels and perform autotuning.
15    type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>, Feature = Feature>;
16    /// The channel used to communicate with the compute server.
17    type Channel: ComputeChannel<Self::Server>;
18    /// The device used to retrieve the compute client.
19    type Device: Default + Clone + core::fmt::Debug + Send + Sync;
20
21    /// Fetch the id for the given device.
22    fn device_id(device: &Self::Device) -> DeviceId;
23
24    /// Retrieve the compute client from the runtime device.
25    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
26
27    /// The runtime name on the given device.
28    fn name(client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str;
29
30    /// Return true if global input array lengths should be added to kernel info.
31    fn require_array_lengths() -> bool {
32        false
33    }
34
35    /// Returns the supported line sizes for the current runtime's compiler.
36    fn supported_line_sizes() -> &'static [u8];
37
38    /// Returns all line sizes that are useful to perform IO operation on the given element.
39    fn line_size_elem(elem: &Elem) -> impl Iterator<Item = u8> + Clone {
40        Self::supported_line_sizes()
41            .iter()
42            .filter(|v| **v as usize * elem.size() <= 16)
43            .cloned() // 128 bits
44    }
45
46    /// Returns the maximum cube count on each dimension that can be launched.
47    fn max_cube_count() -> (u32, u32, u32);
48}
49
50/// Every feature that can be supported by a [cube runtime](Runtime).
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
52pub enum Feature {
53    /// The plane feature enables all basic warp/subgroup operations.
54    Plane,
55    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
56    Cmma {
57        a: Elem,
58        b: Elem,
59        c: Elem,
60        m: u8,
61        k: u8,
62        n: u8,
63    },
64    CmmaWarpSize(i32),
65    Type(Elem),
66    /// Features supported for floating point atomics. For integers, all methods are supported as
67    /// long as the type is.
68    AtomicFloat(AtomicFeature),
69    /// The pipeline feature enables pipelined (async) operations
70    Pipeline,
71    /// The barrier feature enables barrier (async) operations
72    Barrier,
73    /// Tensor Memory Accelerator features. Minimum H100/RTX 5000 series for base set
74    Tma(TmaFeature),
75    /// Clustered launches and intra-cluster operations like cluster shared memory
76    CubeCluster,
77    /// Enables to change the line size of containers during kernel execution.
78    DynamicLineSize,
79}
80
81/// Atomic features that may be supported by a [cube runtime](Runtime).
82#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
83pub enum AtomicFeature {
84    LoadStore,
85    Add,
86    MinMax,
87}
88
89/// Atomic features that may be supported by a [cube runtime](Runtime).
90#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
91pub enum TmaFeature {
92    /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
93    Base,
94    /// im2colWide encoding for tensor map.
95    /// TODO: Not yet implemented due to missing `cudarc` support
96    Im2colWide,
97}