cubecl_runtime/
server.rs

1use crate::{
2    kernel::KernelMetadata,
3    logging::ServerLogger,
4    memory_management::{
5        MemoryHandle, MemoryUsage,
6        memory_pool::{SliceBinding, SliceHandle},
7    },
8    storage::{BindingResource, ComputeStorage},
9    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
10};
11use alloc::collections::BTreeMap;
12use alloc::string::String;
13use alloc::sync::Arc;
14use alloc::vec::Vec;
15use core::fmt::Debug;
16use cubecl_common::{ExecutionMode, future::DynFut, profile::ProfileDuration};
17use cubecl_ir::Elem;
18
19#[derive(Debug, Clone)]
20/// An error during profiling.
21pub enum ProfileError {
22    /// Unknown error.
23    Unknown(String),
24    /// When no profiling has been registered.
25    NotRegistered,
26}
27
28/// The compute server is responsible for handling resources and computations over resources.
29///
30/// Everything in the server is mutable, therefore it should be solely accessed through the
31/// [compute channel](crate::channel::ComputeChannel) for thread safety.
32pub trait ComputeServer: Send + core::fmt::Debug
33where
34    Self: Sized,
35{
36    /// The kernel type defines the computation algorithms.
37    type Kernel: KernelMetadata;
38    /// Information that can be retrieved for the runtime.
39    type Info: Debug + Send + Sync;
40    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
41    type Storage: ComputeStorage;
42    /// The type of the features supported by the server.
43    type Feature: Ord + Copy + Debug + Send + Sync;
44
45    /// Given bindings, returns the owned resources as bytes.
46    fn read(&mut self, bindings: Vec<Binding>) -> DynFut<Vec<Vec<u8>>>;
47
48    /// Given tensor handles, returns the owned resources as bytes.
49    fn read_tensor(&mut self, bindings: Vec<BindingWithMeta>) -> DynFut<Vec<Vec<u8>>>;
50
51    /// Wait for the completion of every task in the server.
52    fn sync(&mut self) -> DynFut<()>;
53
54    /// Given a resource handle, returns the storage resource.
55    fn get_resource(
56        &mut self,
57        binding: Binding,
58    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
59
60    /// Given a resource as bytes, stores it and returns the memory handle.
61    fn create(&mut self, data: &[u8]) -> Handle;
62
63    /// Given a resource as bytes with `shape`, stores it and returns the tensor handle.
64    /// May or may not be contiguous, depending on what's best for the given runtime. Always use
65    /// strides to index.
66    /// For example, in CUDA, this will allocate a padded tensor where the last dimension is padded
67    /// to the cache lines, so row access is faster.
68    fn create_tensors(
69        &mut self,
70        data: Vec<&[u8]>,
71        shapes: Vec<&[usize]>,
72        elem_sizes: Vec<usize>,
73    ) -> Vec<(Handle, Vec<usize>)>;
74
75    /// Reserves `size` bytes in the storage, and returns a handle over them.
76    fn empty(&mut self, size: usize) -> Handle;
77
78    /// Reserves `shape` bytes in the storage, and returns a handle to it.
79    fn empty_tensors(
80        &mut self,
81        shapes: Vec<&[usize]>,
82        elem_sizes: Vec<usize>,
83    ) -> Vec<(Handle, Vec<usize>)>;
84
85    /// Executes the `kernel` over the given memory `handles`.
86    ///
87    /// Kernels have mutable access to every resource they are given
88    /// and are responsible of determining which should be read or written.
89    ///
90    /// # Safety
91    ///
92    /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
93    unsafe fn execute(
94        &mut self,
95        kernel: Self::Kernel,
96        count: CubeCount,
97        bindings: Bindings,
98        kind: ExecutionMode,
99        logger: Arc<ServerLogger>,
100    );
101
102    /// Flush all outstanding tasks in the server.
103    fn flush(&mut self);
104
105    /// The current memory usage of the server.
106    fn memory_usage(&self) -> MemoryUsage;
107
108    /// Ask the server to release memory that it can release.
109    fn memory_cleanup(&mut self);
110
111    /// Enable collecting timestamps.
112    fn start_profile(&mut self) -> ProfilingToken;
113
114    /// Disable collecting timestamps.
115    fn end_profile(&mut self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError>;
116}
117
118#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
119/// Profiling identification so that the server can support recursive and overlapping profilings.
120pub struct ProfilingToken {
121    /// The token value.
122    pub id: u64,
123}
124
125/// Server handle containing the [memory handle](crate::server::Handle).
126#[derive(new, Debug)]
127pub struct Handle {
128    /// Memory handle.
129    pub memory: SliceHandle,
130    /// Memory offset in bytes.
131    pub offset_start: Option<u64>,
132    /// Memory offset in bytes.
133    pub offset_end: Option<u64>,
134    /// Length of the underlying buffer ignoring offsets
135    size: u64,
136}
137
138impl Handle {
139    /// Add to the current offset in bytes.
140    pub fn offset_start(mut self, offset: u64) -> Self {
141        if let Some(val) = &mut self.offset_start {
142            *val += offset;
143        } else {
144            self.offset_start = Some(offset);
145        }
146
147        self
148    }
149    /// Add to the current offset in bytes.
150    pub fn offset_end(mut self, offset: u64) -> Self {
151        if let Some(val) = &mut self.offset_end {
152            *val += offset;
153        } else {
154            self.offset_end = Some(offset);
155        }
156
157        self
158    }
159
160    /// Get the size of the handle, in bytes, accounting for offsets
161    pub fn size(&self) -> u64 {
162        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
163    }
164}
165
166/// Bindings to execute a kernel.
167#[derive(Debug, Default)]
168pub struct Bindings {
169    /// Buffer bindings
170    pub buffers: Vec<Binding>,
171    /// Packed metadata for tensor bindings (len, shape, stride, etc).
172    /// Ordered by inputs, then outputs, then tensormaps
173    pub metadata: MetadataBinding,
174    /// Scalar bindings
175    pub scalars: BTreeMap<Elem, ScalarBinding>,
176    /// Tensor map bindings
177    pub tensor_maps: Vec<TensorMapBinding>,
178}
179
180impl Bindings {
181    /// Create a new bindings struct
182    pub fn new() -> Self {
183        Self::default()
184    }
185
186    /// Add a buffer binding
187    pub fn with_buffer(mut self, binding: Binding) -> Self {
188        self.buffers.push(binding);
189        self
190    }
191
192    /// Extend the buffers with `bindings`
193    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
194        self.buffers.extend(bindings);
195        self
196    }
197
198    /// Add a scalar parameter
199    pub fn with_scalar(mut self, elem: Elem, length: usize, data: Vec<u64>) -> Self {
200        self.scalars
201            .insert(elem, ScalarBinding::new(elem, length, data));
202        self
203    }
204
205    /// Extend the scalars with `bindings`
206    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
207        self.scalars
208            .extend(bindings.into_iter().map(|binding| (binding.elem, binding)));
209        self
210    }
211
212    /// Set the metadata to `meta`
213    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
214        self.metadata = meta;
215        self
216    }
217
218    /// Extend the tensor maps with `bindings`
219    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
220        self.tensor_maps.extend(bindings);
221        self
222    }
223}
224
225/// Binding of a set of scalars of the same type to execute a kernel.
226#[derive(new, Debug, Default)]
227pub struct MetadataBinding {
228    /// Metadata values
229    pub data: Vec<u32>,
230    /// Length of the static portion (rank, len, buffer_len, shape_offsets, stride_offsets).
231    pub static_len: usize,
232}
233
234/// Binding of a set of scalars of the same type to execute a kernel.
235#[derive(new, Debug)]
236pub struct ScalarBinding {
237    /// Type of the scalars
238    pub elem: Elem,
239    /// Unpadded length of the underlying data
240    pub length: usize,
241    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
242    pub data: Vec<u64>,
243}
244
245impl ScalarBinding {
246    /// Get data as byte slice
247    pub fn data(&self) -> &[u8] {
248        bytemuck::cast_slice(&self.data)
249    }
250}
251
252/// Binding of a [tensor handle](Handle) to execute a kernel.
253#[derive(new, Debug)]
254pub struct Binding {
255    /// Memory binding.
256    pub memory: SliceBinding,
257    /// Memory offset in bytes.
258    pub offset_start: Option<u64>,
259    /// Memory offset in bytes.
260    pub offset_end: Option<u64>,
261}
262
263/// A binding with shape and stride info for non-contiguous reading
264#[derive(new, Debug)]
265pub struct BindingWithMeta {
266    /// Binding for the memory resource
267    pub binding: Binding,
268    /// Shape of the resource
269    pub shape: Vec<usize>,
270    /// Strides of the resource
271    pub strides: Vec<usize>,
272    /// Size of each element in the resource
273    pub elem_size: usize,
274}
275
276/// A tensor map used with TMA ops
277#[derive(new, Debug, Clone)]
278pub struct TensorMapBinding {
279    /// The binding for the backing tensor
280    pub binding: Binding,
281    /// The tensormap metadata
282    pub map: TensorMapMeta,
283}
284
285/// TensorMap metadata for the opaque proxy used in TMA copies
286#[derive(Debug, Clone)]
287pub struct TensorMapMeta {
288    /// Tensormap format (tiled or im2col)
289    pub format: TensorMapFormat,
290    /// Rank of the backing tensor
291    pub rank: usize,
292    /// Shape of the backing tensor
293    pub shape: Vec<usize>,
294    /// Strides of the backing tensor
295    pub strides: Vec<usize>,
296    /// Element stride, usually 1 but may be 2 for complex tensors
297    /// For im2col, this is equivalent to the kernel stride
298    pub elem_stride: Vec<usize>,
299    /// Interleave mode
300    pub interleave: TensorMapInterleave,
301    /// Swizzle mode
302    pub swizzle: TensorMapSwizzle,
303    /// Prefetch settings
304    pub prefetch: TensorMapPrefetch,
305    /// OOB fill value
306    pub oob_fill: OobFill,
307    /// Element type
308    pub elem: Elem,
309}
310
311impl Handle {
312    /// If the tensor handle can be reused inplace.
313    pub fn can_mut(&self) -> bool {
314        self.memory.can_mut()
315    }
316}
317
318impl Handle {
319    /// Convert the [handle](Handle) into a [binding](Binding).
320    pub fn binding(self) -> Binding {
321        Binding {
322            memory: MemoryHandle::binding(self.memory),
323            offset_start: self.offset_start,
324            offset_end: self.offset_end,
325        }
326    }
327
328    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
329    pub fn binding_with_meta(
330        self,
331        shape: Vec<usize>,
332        strides: Vec<usize>,
333        elem_size: usize,
334    ) -> BindingWithMeta {
335        BindingWithMeta {
336            shape,
337            strides,
338            elem_size,
339            binding: self.binding(),
340        }
341    }
342}
343
344impl Clone for Handle {
345    fn clone(&self) -> Self {
346        Self {
347            memory: self.memory.clone(),
348            offset_start: self.offset_start,
349            offset_end: self.offset_end,
350            size: self.size,
351        }
352    }
353}
354
355impl Clone for Binding {
356    fn clone(&self) -> Self {
357        Self {
358            memory: self.memory.clone(),
359            offset_start: self.offset_start,
360            offset_end: self.offset_end,
361        }
362    }
363}
364
365/// Specifieds the number of cubes to be dispatched for a kernel.
366///
367/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
368#[allow(clippy::large_enum_variant)]
369pub enum CubeCount {
370    /// Dispatch a known count of x, y, z cubes.
371    Static(u32, u32, u32),
372    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
373    Dynamic(Binding),
374}
375
376impl CubeCount {
377    /// Create a new static cube count with the given x = y = z = 1.
378    pub fn new_single() -> Self {
379        CubeCount::Static(1, 1, 1)
380    }
381
382    /// Create a new static cube count with the given x, and y = z = 1.
383    pub fn new_1d(x: u32) -> Self {
384        CubeCount::Static(x, 1, 1)
385    }
386
387    /// Create a new static cube count with the given x and y, and z = 1.
388    pub fn new_2d(x: u32, y: u32) -> Self {
389        CubeCount::Static(x, y, 1)
390    }
391
392    /// Create a new static cube count with the given x, y and z.
393    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
394        CubeCount::Static(x, y, z)
395    }
396}
397
398impl Debug for CubeCount {
399    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
400        match self {
401            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
402            CubeCount::Dynamic(_) => f.write_str("binding"),
403        }
404    }
405}
406
407impl Clone for CubeCount {
408    fn clone(&self) -> Self {
409        match self {
410            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
411            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
412        }
413    }
414}