cubecl_core/
runtime.rs

1use crate::codegen::Compiler;
2use crate::compute::CubeTask;
3use cubecl_common::device::Device;
4use cubecl_ir::{StorageType, TargetProperties};
5use cubecl_runtime::{client::ComputeClient, server::ComputeServer};
6
7pub use cubecl_runtime::client;
8pub use cubecl_runtime::server;
9pub use cubecl_runtime::tune;
10
11/// Max width of loads. May want to make this a property in the future, since Nvidia seems have some
12/// support for 256-bit loads on Blackwell.
13const LOAD_WIDTH: usize = 128;
14
15/// Runtime for the CubeCL.
16pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
17    /// The compiler used to compile the inner representation into tokens.
18    type Compiler: Compiler;
19    /// The compute server used to run kernels and perform autotuning.
20    type Server: ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>>;
21    /// The device used to retrieve the compute client.
22    type Device: Device;
23
24    /// Retrieve the compute client from the runtime device.
25    fn client(device: &Self::Device) -> ComputeClient<Self::Server>;
26
27    /// The runtime name on the given device.
28    fn name(client: &ComputeClient<Self::Server>) -> &'static str;
29
30    /// Return true if global input array lengths should be added to kernel info.
31    fn require_array_lengths() -> bool {
32        false
33    }
34
35    /// Returns the supported line sizes for the current runtime's compiler.
36    fn supported_line_sizes() -> &'static [u8];
37
38    /// The maximum line size that can be used for global buffer bindings.
39    fn max_global_line_size() -> u8 {
40        u8::MAX
41    }
42
43    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
44    fn io_optimized_line_sizes(elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
45        let max = (LOAD_WIDTH / elem.size_bits()) as u8;
46        let supported = Self::supported_line_sizes();
47        supported.iter().filter(move |v| **v <= max).cloned()
48    }
49
50    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
51    /// Ignores native support, and allows all line sizes. This means the returned size may be
52    /// unrolled, and may not support dynamic indexing.
53    fn io_optimized_line_sizes_unchecked(size: usize) -> impl Iterator<Item = u8> + Clone {
54        let size_bits = size * 8;
55        let max = LOAD_WIDTH / size_bits;
56        let max = usize::min(Self::max_global_line_size() as usize, max);
57
58        // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
59        let num_candidates = f32::log2(max as f32) as u32 + 1;
60
61        (0..num_candidates).map(|i| 2u8.pow(i)).rev()
62    }
63
64    /// Returns the maximum cube count on each dimension that can be launched.
65    fn max_cube_count() -> (u32, u32, u32);
66
67    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool;
68
69    /// Returns the properties of the target hardware architecture.
70    fn target_properties() -> TargetProperties;
71}