cubecl_runtime/
server.rs

1use crate::{
2    memory_management::{
3        memory_pool::{SliceBinding, SliceHandle},
4        MemoryHandle, MemoryUsage,
5    },
6    storage::{BindingResource, ComputeStorage},
7    ExecutionMode,
8};
9use alloc::vec::Vec;
10use core::{fmt::Debug, future::Future};
11use cubecl_common::benchmark::TimestampsResult;
12
13/// The compute server is responsible for handling resources and computations over resources.
14///
15/// Everything in the server is mutable, therefore it should be solely accessed through the
16/// [compute channel](crate::channel::ComputeChannel) for thread safety.
17pub trait ComputeServer: Send + core::fmt::Debug
18where
19    Self: Sized,
20{
21    /// The kernel type defines the computation algorithms.
22    type Kernel: Send;
23    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
24    type Storage: ComputeStorage;
25    /// The type of the features supported by the server.
26    type Feature: Ord + Copy + Debug + Send + Sync;
27
28    /// Given bindings, returns the owned resources as bytes.
29    fn read(
30        &mut self,
31        bindings: Vec<Binding>,
32    ) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static;
33
34    /// Given a resource handle, returns the storage resource.
35    fn get_resource(&mut self, binding: Binding) -> BindingResource<Self>;
36
37    /// Given a resource as bytes, stores it and returns the memory handle.
38    fn create(&mut self, data: &[u8]) -> Handle;
39
40    /// Reserves `size` bytes in the storage, and returns a handle over them.
41    fn empty(&mut self, size: usize) -> Handle;
42
43    /// Executes the `kernel` over the given memory `handles`.
44    ///
45    /// Kernels have mutable access to every resource they are given
46    /// and are responsible of determining which should be read or written.
47    ///
48    /// # Safety
49    ///
50    /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
51    unsafe fn execute(
52        &mut self,
53        kernel: Self::Kernel,
54        count: CubeCount,
55        bindings: Vec<Binding>,
56        kind: ExecutionMode,
57    );
58
59    /// Flush all outstanding tasks in the server.
60    fn flush(&mut self);
61
62    /// Wait for the completion of every task in the server.
63    fn sync(&mut self) -> impl Future<Output = ()> + Send + 'static;
64
65    /// Wait for the completion of every task in the server.
66    ///
67    /// Returns the (approximate) total amount of GPU work done since the last sync.
68    fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + Send + 'static;
69
70    /// The current memory usage of the server.
71    fn memory_usage(&self) -> MemoryUsage;
72
73    /// Enable collecting timestamps.
74    fn enable_timestamps(&mut self);
75
76    /// Disable collecting timestamps.
77    fn disable_timestamps(&mut self);
78}
79
80/// Server handle containing the [memory handle](MemoryManagement::Handle).
81#[derive(new, Debug)]
82pub struct Handle {
83    /// Memory handle.
84    pub memory: SliceHandle,
85    /// Memory offset in bytes.
86    pub offset_start: Option<u64>,
87    /// Memory offset in bytes.
88    pub offset_end: Option<u64>,
89    /// Length of the underlying buffer ignoring offsets
90    size: u64,
91}
92
93impl Handle {
94    /// Add to the current offset in bytes.
95    pub fn offset_start(mut self, offset: u64) -> Self {
96        if let Some(val) = &mut self.offset_start {
97            *val += offset;
98        } else {
99            self.offset_start = Some(offset);
100        }
101
102        self
103    }
104    /// Add to the current offset in bytes.
105    pub fn offset_end(mut self, offset: u64) -> Self {
106        if let Some(val) = &mut self.offset_end {
107            *val += offset;
108        } else {
109            self.offset_end = Some(offset);
110        }
111
112        self
113    }
114
115    /// Get the size of the handle, in bytes, accounting for offsets
116    pub fn size(&self) -> u64 {
117        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
118    }
119}
120
121/// Binding of a [tensor handle](Handle) to execute a kernel.
122#[derive(new, Debug)]
123pub struct Binding {
124    /// Memory binding.
125    pub memory: SliceBinding,
126    /// Memory offset in bytes.
127    pub offset_start: Option<u64>,
128    /// Memory offset in bytes.
129    pub offset_end: Option<u64>,
130}
131
132impl Handle {
133    /// If the tensor handle can be reused inplace.
134    pub fn can_mut(&self) -> bool {
135        self.memory.can_mut()
136    }
137}
138
139impl Handle {
140    /// Convert the [handle](Handle) into a [binding](Binding).
141    pub fn binding(self) -> Binding {
142        Binding {
143            memory: MemoryHandle::binding(self.memory),
144            offset_start: self.offset_start,
145            offset_end: self.offset_end,
146        }
147    }
148}
149
150impl Clone for Handle {
151    fn clone(&self) -> Self {
152        Self {
153            memory: self.memory.clone(),
154            offset_start: self.offset_start,
155            offset_end: self.offset_end,
156            size: self.size,
157        }
158    }
159}
160
161impl Clone for Binding {
162    fn clone(&self) -> Self {
163        Self {
164            memory: self.memory.clone(),
165            offset_start: self.offset_start,
166            offset_end: self.offset_end,
167        }
168    }
169}
170
171/// Specifieds the number of cubes to be dispatched for a kernel.
172///
173/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
174pub enum CubeCount {
175    /// Dispatch a known count of x, y, z cubes.
176    Static(u32, u32, u32),
177    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
178    Dynamic(Binding),
179}
180
181impl CubeCount {
182    /// Create a new static cube count with the given x = y = z = 1.
183    pub fn new_single() -> Self {
184        CubeCount::Static(1, 1, 1)
185    }
186
187    /// Create a new static cube count with the given x, and y = z = 1.
188    pub fn new_1d(x: u32) -> Self {
189        CubeCount::Static(x, 1, 1)
190    }
191
192    /// Create a new static cube count with the given x and y, and z = 1.
193    pub fn new_2d(x: u32, y: u32) -> Self {
194        CubeCount::Static(x, y, 1)
195    }
196
197    /// Create a new static cube count with the given x, y and z.
198    pub fn new_3d(x: u32, y: u32) -> Self {
199        CubeCount::Static(x, y, 1)
200    }
201}
202
203impl Debug for CubeCount {
204    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
205        match self {
206            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
207            CubeCount::Dynamic(_) => f.write_str("binding"),
208        }
209    }
210}
211
212impl Clone for CubeCount {
213    fn clone(&self) -> Self {
214        match self {
215            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
216            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
217        }
218    }
219}