cubecl_core/
runtime.rs

1use crate::{codegen::Compiler, compute::CubeTask, ir::Elem};
2use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
3
4pub use cubecl_runtime::channel;
5pub use cubecl_runtime::client;
6pub use cubecl_runtime::server;
7pub use cubecl_runtime::tune;
8pub use cubecl_runtime::ExecutionMode;
9
10/// Runtime for the CubeCL.
11pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
12    /// The compiler used to compile the inner representation into tokens.
13    type Compiler: Compiler;
14    /// The compute server used to run kernels and perform autotuning.
15    type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>, Feature = Feature>;
16    /// The channel used to communicate with the compute server.
17    type Channel: ComputeChannel<Self::Server>;
18    /// The device used to retrieve the compute client.
19    type Device: Default + Clone + core::fmt::Debug;
20
21    /// Retrieve the compute client from the runtime device.
22    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
23
24    /// The runtime name.
25    fn name() -> &'static str;
26
27    /// The default extension for the runtime's kernel/shader code.
28    /// Might change based on which compiler is used.
29    fn extension() -> &'static str;
30
31    /// Return true if global input array lengths should be added to kernel info.
32    fn require_array_lengths() -> bool {
33        false
34    }
35
36    /// Returns the supported line sizes for the current runtime's compiler.
37    fn supported_line_sizes() -> &'static [u8];
38
39    /// Returns all line sizes that are useful to perform IO operation on the given element.
40    fn line_size_elem(elem: &Elem) -> impl Iterator<Item = u8> + Clone {
41        Self::supported_line_sizes()
42            .iter()
43            .filter(|v| **v as usize * elem.size() <= 16)
44            .cloned() // 128 bits
45    }
46
47    /// Returns the maximum cube count on each dimension that can be launched.
48    fn max_cube_count() -> (u32, u32, u32);
49}
50
51/// Every feature that can be supported by a [cube runtime](Runtime).
52#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
53pub enum Feature {
54    /// The plane feature enables all basic warp/subgroup operations.
55    Plane,
56    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
57    Cmma {
58        a: Elem,
59        b: Elem,
60        c: Elem,
61        m: u8,
62        k: u8,
63        n: u8,
64    },
65    CmmaWarpSize(i32),
66    Type(Elem),
67    /// Features supported for floating point atomics. For integers, all methods are supported as
68    /// long as the type is.
69    AtomicFloat(AtomicFeature),
70}
71
72// Atomic features that may be supported by a [cube runtime](Runtime).
73#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
74pub enum AtomicFeature {
75    LoadStore,
76    Add,
77    MinMax,
78}