cubecl_core/runtime.rs
1use crate::{codegen::Compiler, compute::CubeTask, ir::Elem};
2use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
3
4pub use cubecl_runtime::channel;
5pub use cubecl_runtime::client;
6pub use cubecl_runtime::server;
7pub use cubecl_runtime::tune;
8pub use cubecl_runtime::ExecutionMode;
9
10/// Runtime for the CubeCL.
11pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
12 /// The compiler used to compile the inner representation into tokens.
13 type Compiler: Compiler;
14 /// The compute server used to run kernels and perform autotuning.
15 type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>, Feature = Feature>;
16 /// The channel used to communicate with the compute server.
17 type Channel: ComputeChannel<Self::Server>;
18 /// The device used to retrieve the compute client.
19 type Device: Default + Clone + core::fmt::Debug;
20
21 /// Retrieve the compute client from the runtime device.
22 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
23
24 /// The runtime name.
25 fn name() -> &'static str;
26
27 /// The default extension for the runtime's kernel/shader code.
28 /// Might change based on which compiler is used.
29 fn extension() -> &'static str;
30
31 /// Return true if global input array lengths should be added to kernel info.
32 fn require_array_lengths() -> bool {
33 false
34 }
35
36 /// Returns the supported line sizes for the current runtime's compiler.
37 fn supported_line_sizes() -> &'static [u8];
38
39 /// Returns all line sizes that are useful to perform IO operation on the given element.
40 fn line_size_elem(elem: &Elem) -> impl Iterator<Item = u8> + Clone {
41 Self::supported_line_sizes()
42 .iter()
43 .filter(|v| **v as usize * elem.size() <= 16)
44 .cloned() // 128 bits
45 }
46
47 /// Returns the maximum cube count on each dimension that can be launched.
48 fn max_cube_count() -> (u32, u32, u32);
49}
50
51/// Every feature that can be supported by a [cube runtime](Runtime).
52#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
53pub enum Feature {
54 /// The plane feature enables all basic warp/subgroup operations.
55 Plane,
56 /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
57 Cmma {
58 a: Elem,
59 b: Elem,
60 c: Elem,
61 m: u8,
62 k: u8,
63 n: u8,
64 },
65 CmmaWarpSize(i32),
66 Type(Elem),
67 /// Features supported for floating point atomics. For integers, all methods are supported as
68 /// long as the type is.
69 AtomicFloat(AtomicFeature),
70}
71
72// Atomic features that may be supported by a [cube runtime](Runtime).
73#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
74pub enum AtomicFeature {
75 LoadStore,
76 Add,
77 MinMax,
78}