cubecl_runtime/server.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
use crate::{
memory_management::{
memory_pool::{SliceBinding, SliceHandle},
MemoryHandle, MemoryUsage,
},
storage::{BindingResource, ComputeStorage},
ExecutionMode,
};
use alloc::vec::Vec;
use core::{fmt::Debug, future::Future};
use cubecl_common::benchmark::TimestampsResult;
/// The compute server is responsible for handling resources and computations over resources.
///
/// Everything in the server is mutable, therefore it should be solely accessed through the
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
pub trait ComputeServer: Send + core::fmt::Debug
where
Self: Sized,
{
/// The kernel type defines the computation algorithms.
type Kernel: Send;
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
type Storage: ComputeStorage;
/// The type of the features supported by the server.
type Feature: Ord + Copy + Debug + Send + Sync;
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, binding: Binding) -> impl Future<Output = Vec<u8>> + Send + 'static;
/// Given a resource handle, returns the storage resource.
fn get_resource(&mut self, binding: Binding) -> BindingResource<Self>;
/// Given a resource as bytes, stores it and returns the memory handle.
fn create(&mut self, data: &[u8]) -> Handle;
/// Reserves `size` bytes in the storage, and returns a handle over them.
fn empty(&mut self, size: usize) -> Handle;
/// Executes the `kernel` over the given memory `handles`.
///
/// Kernels have mutable access to every resource they are given
/// and are responsible of determining which should be read or written.
///
/// # Safety
///
/// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
unsafe fn execute(
&mut self,
kernel: Self::Kernel,
count: CubeCount,
bindings: Vec<Binding>,
kind: ExecutionMode,
);
/// Flush all outstanding tasks in the server.
fn flush(&mut self);
/// Wait for the completion of every task in the server.
fn sync(&mut self) -> impl Future<Output = ()> + Send + 'static;
/// Wait for the completion of every task in the server.
///
/// Returns the (approximate) total amount of GPU work done since the last sync.
fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + Send + 'static;
/// The current memory usage of the server.
fn memory_usage(&self) -> MemoryUsage;
/// Enable collecting timestamps.
fn enable_timestamps(&mut self);
/// Disable collecting timestamps.
fn disable_timestamps(&mut self);
}
/// Server handle containing the [memory handle](MemoryManagement::Handle).
#[derive(new, Debug)]
pub struct Handle {
/// Memory handle.
pub memory: SliceHandle,
/// Memory offset in bytes.
pub offset_start: Option<u64>,
/// Memory offset in bytes.
pub offset_end: Option<u64>,
}
impl Handle {
/// Add to the current offset in bytes.
pub fn offset_start(mut self, offset: u64) -> Self {
if let Some(val) = &mut self.offset_start {
*val += offset;
} else {
self.offset_start = Some(offset);
}
self
}
/// Add to the current offset in bytes.
pub fn offset_end(mut self, offset: u64) -> Self {
if let Some(val) = &mut self.offset_end {
*val += offset;
} else {
self.offset_end = Some(offset);
}
self
}
}
/// Binding of a [tensor handle](Handle) to execute a kernel.
#[derive(new, Debug)]
pub struct Binding {
/// Memory binding.
pub memory: SliceBinding,
/// Memory offset in bytes.
pub offset_start: Option<u64>,
/// Memory offset in bytes.
pub offset_end: Option<u64>,
}
impl Handle {
/// If the tensor handle can be reused inplace.
pub fn can_mut(&self) -> bool {
self.memory.can_mut()
}
}
impl Handle {
/// Convert the [handle](Handle) into a [binding](Binding).
pub fn binding(self) -> Binding {
Binding {
memory: MemoryHandle::binding(self.memory),
offset_start: self.offset_start,
offset_end: self.offset_end,
}
}
}
impl Clone for Handle {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
offset_start: self.offset_start,
offset_end: self.offset_end,
}
}
}
impl Clone for Binding {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
offset_start: self.offset_start,
offset_end: self.offset_end,
}
}
}
/// Specifieds the number of cubes to be dispatched for a kernel.
///
/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
pub enum CubeCount {
/// Dispatch a known count of x, y, z cubes.
Static(u32, u32, u32),
/// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
Dynamic(Binding),
}
impl Debug for CubeCount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
CubeCount::Dynamic(_) => f.write_str("binding"),
}
}
}
impl Clone for CubeCount {
fn clone(&self) -> Self {
match self {
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
}
}
}