Skip to main content

cubecl_runtime/
runtime.rs

1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use cubecl_common::device::{Device, DeviceId};
4use cubecl_ir::TargetProperties;
5use cubecl_zspace::{Shape, Strides};
6
7use crate::{
8    client::ComputeClient,
9    compiler::{Compiler, CubeTask},
10    server::ComputeServer,
11};
12
13/// Runtime for the `CubeCL`.
14pub trait Runtime: Sized + Send + Sync + 'static + core::fmt::Debug + Clone {
15    /// The compiler used to compile the inner representation into tokens.
16    type Compiler: Compiler;
17    /// The compute server used to run kernels and perform autotuning.
18    type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>>;
19    /// The device used to retrieve the compute client.
20    type Device: Device;
21
22    /// Retrieve the compute client from the runtime device.
23    fn client(device: &Self::Device) -> ComputeClient<Self>;
24
25    /// The runtime name on the given device.
26    fn name(client: &ComputeClient<Self>) -> &'static str;
27
28    /// Return true if global input array lengths should be added to kernel info.
29    fn require_array_lengths() -> bool {
30        false
31    }
32
33    /// Returns the maximum cube count on each dimension that can be launched.
34    fn max_cube_count() -> (u32, u32, u32);
35
36    /// Whether a tensor with `shape` and `strides` can be read as is. If the result is false, the
37    /// tensor should be made contiguous before reading.
38    fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool;
39
40    /// Returns the properties of the target hardware architecture.
41    fn target_properties() -> TargetProperties;
42
43    /// Returns all devices available under the provided type id.
44    fn enumerate_devices(
45        type_id: u16,
46        info: &<Self::Server as ComputeServer>::Info,
47    ) -> Vec<DeviceId>;
48    /// Returns all devices that can be handled by the runtime.
49    fn enumerate_all_devices(info: &<Self::Server as ComputeServer>::Info) -> Vec<DeviceId> {
50        Self::enumerate_devices(0, info)
51    }
52}