cubecl-runtime 0.10.0-pre.3

Crate that helps creating high performance async runtimes for CubeCL.
Documentation
use alloc::boxed::Box;
use alloc::vec::Vec;
use cubecl_common::device::{Device, DeviceId};
use cubecl_ir::TargetProperties;
use cubecl_zspace::{Shape, Strides};

use crate::{
    client::ComputeClient,
    compiler::{Compiler, CubeTask},
    server::ComputeServer,
};

/// Runtime for the `CubeCL`.
pub trait Runtime: Sized + Send + Sync + 'static + core::fmt::Debug + Clone {
    /// The compiler used to compile the inner representation into tokens.
    type Compiler: Compiler;
    /// The compute server used to run kernels and perform autotuning.
    type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>>;
    /// The device used to retrieve the compute client.
    type Device: Device;

    /// Retrieve the compute client from the runtime device.
    fn client(device: &Self::Device) -> ComputeClient<Self>;

    /// The runtime name on the given device.
    fn name(client: &ComputeClient<Self>) -> &'static str;

    /// Return true if global input array lengths should be added to kernel info.
    fn require_array_lengths() -> bool {
        false
    }

    /// Returns the maximum cube count on each dimension that can be launched.
    fn max_cube_count() -> (u32, u32, u32);

    /// Whether a tensor with `shape` and `strides` can be read as is. If the result is false, the
    /// tensor should be made contiguous before reading.
    fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool;

    /// Returns the properties of the target hardware architecture.
    fn target_properties() -> TargetProperties;

    /// Returns all devices available under the provided type id.
    fn enumerate_devices(
        type_id: u16,
        info: &<Self::Server as ComputeServer>::Info,
    ) -> Vec<DeviceId>;
    /// Returns all devices that can be handled by the runtime.
    fn enumerate_all_devices(info: &<Self::Server as ComputeServer>::Info) -> Vec<DeviceId> {
        Self::enumerate_devices(0, info)
    }
}