1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
use alloc::boxed::Box;
use alloc::vec::Vec;
use cubecl_common::device::{Device, DeviceId};
use cubecl_ir::TargetProperties;
use cubecl_zspace::{Shape, Strides};
use crate::{
client::ComputeClient,
compiler::{Compiler, CubeTask},
server::ComputeServer,
};
/// Runtime for the `CubeCL`.
pub trait Runtime: Sized + Send + Sync + 'static + core::fmt::Debug + Clone {
/// The compiler used to compile the inner representation into tokens.
type Compiler: Compiler;
/// The compute server used to run kernels and perform autotuning.
type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>>;
/// The device used to retrieve the compute client.
type Device: Device;
/// Retrieve the compute client from the runtime device.
fn client(device: &Self::Device) -> ComputeClient<Self>;
/// The runtime name on the given device.
fn name(client: &ComputeClient<Self>) -> &'static str;
/// Return true if global input array lengths should be added to kernel info.
fn require_array_lengths() -> bool {
false
}
/// Returns the maximum cube count on each dimension that can be launched.
fn max_cube_count() -> (u32, u32, u32);
/// Whether a tensor with `shape` and `strides` can be read as is. If the result is false, the
/// tensor should be made contiguous before reading.
fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool;
/// Returns the properties of the target hardware architecture.
fn target_properties() -> TargetProperties;
/// Returns all devices available under the provided type id.
fn enumerate_devices(
type_id: u16,
info: &<Self::Server as ComputeServer>::Info,
) -> Vec<DeviceId>;
/// Returns all devices that can be handled by the runtime.
fn enumerate_all_devices(info: &<Self::Server as ComputeServer>::Info) -> Vec<DeviceId> {
Self::enumerate_devices(0, info)
}
}