use crate::{
codegen::Compiler,
compute::{CubeCount, CubeTask},
ir::Elem,
};
use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
pub use cubecl_runtime::channel;
pub use cubecl_runtime::client;
pub use cubecl_runtime::server;
pub use cubecl_runtime::tune;
pub use cubecl_runtime::ExecutionMode;
pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
type Compiler: Compiler;
type Server: ComputeServer<
Kernel = Box<dyn CubeTask>,
DispatchOptions = CubeCount<Self::Server>,
FeatureSet = FeatureSet,
>;
type Channel: ComputeChannel<Self::Server>;
type Device;
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
fn name() -> &'static str;
fn require_array_lengths() -> bool {
false
}
}
#[derive(Default)]
pub struct FeatureSet {
set: alloc::collections::BTreeSet<Feature>,
}
impl FeatureSet {
pub fn new(features: &[Feature]) -> Self {
let mut this = Self::default();
for feature in features {
this.register(*feature);
}
this
}
pub fn enabled(&self, feature: Feature) -> bool {
self.set.contains(&feature)
}
pub fn register(&mut self, feature: Feature) -> bool {
self.set.insert(feature)
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Feature {
Subcube,
Cmma {
a: Elem,
b: Elem,
c: Elem,
m: u8,
k: u8,
n: u8,
},
}