1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
use crate::{
codegen::Compiler,
compute::{CubeCount, CubeTask},
ir::Elem,
};
use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
pub use cubecl_runtime::channel;
pub use cubecl_runtime::client;
pub use cubecl_runtime::server;
pub use cubecl_runtime::tune;
pub use cubecl_runtime::ExecutionMode;
/// Runtime for the CubeCL.
pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
/// The compiler used to compile the inner representation into tokens.
type Compiler: Compiler;
/// The compute server used to run kernels and perform autotuning.
type Server: ComputeServer<
Kernel = Box<dyn CubeTask>,
DispatchOptions = CubeCount<Self::Server>,
FeatureSet = FeatureSet,
>;
/// The channel used to communicate with the compute server.
type Channel: ComputeChannel<Self::Server>;
/// The device used to retrieve the compute client.
type Device;
/// Retrieve the compute client from the runtime device.
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
/// The runtime name.
fn name() -> &'static str;
/// Return true if global input array lengths should be added to kernel info.
fn require_array_lengths() -> bool {
false
}
}
/// The set of [features](Feature) supported by a [runtime](Runtime).
#[derive(Default)]
pub struct FeatureSet {
set: alloc::collections::BTreeSet<Feature>,
}
impl FeatureSet {
pub fn new(features: &[Feature]) -> Self {
let mut this = Self::default();
for feature in features {
this.register(*feature);
}
this
}
/// Check if the provided [feature](Feature) is supported by the runtime.
pub fn enabled(&self, feature: Feature) -> bool {
self.set.contains(&feature)
}
/// Register a [feature](Feature) supported by the compute server.
///
/// This should only be used by a [runtime](Runtime) when initializing a device.
pub fn register(&mut self, feature: Feature) -> bool {
self.set.insert(feature)
}
}
/// Every feature that can be supported by a [cube runtime](Runtime).
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Feature {
/// The subcube feature enables all basic warp/subgroup operations.
Subcube,
/// The cmma feature enables cooperative matrix-multiply and accumulate operations.
Cmma {
a: Elem,
b: Elem,
c: Elem,
m: u8,
k: u8,
n: u8,
},
}