cubecl_runtime/
server.rs

1use crate::{
2    memory_management::{
3        MemoryHandle, MemoryUsage,
4        memory_pool::{SliceBinding, SliceHandle},
5    },
6    storage::{BindingResource, ComputeStorage},
7    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
8};
9use alloc::collections::BTreeMap;
10use alloc::vec::Vec;
11use core::{fmt::Debug, future::Future};
12use cubecl_common::{ExecutionMode, benchmark::ProfileDuration};
13use cubecl_ir::Elem;
14
15/// The compute server is responsible for handling resources and computations over resources.
16///
17/// Everything in the server is mutable, therefore it should be solely accessed through the
18/// [compute channel](crate::channel::ComputeChannel) for thread safety.
19pub trait ComputeServer: Send + core::fmt::Debug
20where
21    Self: Sized,
22{
23    /// The kernel type defines the computation algorithms.
24    type Kernel: Send;
25    /// Information that can be retrieved for the runtime.
26    type Info: Debug + Send + Sync;
27    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
28    type Storage: ComputeStorage;
29    /// The type of the features supported by the server.
30    type Feature: Ord + Copy + Debug + Send + Sync;
31
32    /// Given bindings, returns the owned resources as bytes.
33    fn read(
34        &mut self,
35        bindings: Vec<Binding>,
36    ) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static;
37
38    /// Given tensor handles, returns the owned resources as bytes.
39    fn read_tensor(
40        &mut self,
41        bindings: Vec<BindingWithMeta>,
42    ) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static;
43
44    /// Given a resource handle, returns the storage resource.
45    fn get_resource(
46        &mut self,
47        binding: Binding,
48    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
49
50    /// Given a resource as bytes, stores it and returns the memory handle.
51    fn create(&mut self, data: &[u8]) -> Handle;
52
53    /// Given a resource as bytes with `shape`, stores it and returns the tensor handle.
54    /// May or may not be contiguous, depending on what's best for the given runtime. Always use
55    /// strides to index.
56    /// For example, in CUDA, this will allocate a padded tensor where the last dimension is padded
57    /// to the cache lines, so row access is faster.
58    fn create_tensor(
59        &mut self,
60        data: &[u8],
61        shape: &[usize],
62        elem_size: usize,
63    ) -> (Handle, Vec<usize>);
64
65    /// Reserves `size` bytes in the storage, and returns a handle over them.
66    fn empty(&mut self, size: usize) -> Handle;
67
68    /// Reserves `shape` bytes in the storage, and returns a handle to it.
69    fn empty_tensor(&mut self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>);
70
71    /// Executes the `kernel` over the given memory `handles`.
72    ///
73    /// Kernels have mutable access to every resource they are given
74    /// and are responsible of determining which should be read or written.
75    ///
76    /// # Safety
77    ///
78    /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
79    unsafe fn execute(
80        &mut self,
81        kernel: Self::Kernel,
82        count: CubeCount,
83        bindings: Bindings,
84        kind: ExecutionMode,
85    );
86
87    /// Flush all outstanding tasks in the server.
88    fn flush(&mut self);
89
90    /// Wait for the completion of every task in the server.
91    fn sync(&mut self) -> impl Future<Output = ()> + Send + 'static;
92
93    /// The current memory usage of the server.
94    fn memory_usage(&self) -> MemoryUsage;
95
96    /// Ask the server to release memory that it can release.
97    fn memory_cleanup(&mut self);
98
99    /// Enable collecting timestamps.
100    fn start_profile(&mut self);
101
102    /// Disable collecting timestamps.
103    fn end_profile(&mut self) -> ProfileDuration;
104}
105
106/// Server handle containing the [memory handle](crate::server::Handle).
107#[derive(new, Debug)]
108pub struct Handle {
109    /// Memory handle.
110    pub memory: SliceHandle,
111    /// Memory offset in bytes.
112    pub offset_start: Option<u64>,
113    /// Memory offset in bytes.
114    pub offset_end: Option<u64>,
115    /// Length of the underlying buffer ignoring offsets
116    size: u64,
117}
118
119impl Handle {
120    /// Add to the current offset in bytes.
121    pub fn offset_start(mut self, offset: u64) -> Self {
122        if let Some(val) = &mut self.offset_start {
123            *val += offset;
124        } else {
125            self.offset_start = Some(offset);
126        }
127
128        self
129    }
130    /// Add to the current offset in bytes.
131    pub fn offset_end(mut self, offset: u64) -> Self {
132        if let Some(val) = &mut self.offset_end {
133            *val += offset;
134        } else {
135            self.offset_end = Some(offset);
136        }
137
138        self
139    }
140
141    /// Get the size of the handle, in bytes, accounting for offsets
142    pub fn size(&self) -> u64 {
143        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
144    }
145}
146
147/// Bindings to execute a kernel.
148#[derive(Debug, Default)]
149pub struct Bindings {
150    /// Buffer bindings
151    pub buffers: Vec<Binding>,
152    /// Packed metadata for tensor bindings (len, shape, stride, etc).
153    /// Ordered by inputs, then outputs, then tensormaps
154    pub metadata: MetadataBinding,
155    /// Scalar bindings
156    pub scalars: BTreeMap<Elem, ScalarBinding>,
157    /// Tensor map bindings
158    pub tensor_maps: Vec<TensorMapBinding>,
159}
160
161impl Bindings {
162    /// Create a new bindings struct
163    pub fn new() -> Self {
164        Self::default()
165    }
166
167    /// Add a buffer binding
168    pub fn with_buffer(mut self, binding: Binding) -> Self {
169        self.buffers.push(binding);
170        self
171    }
172
173    /// Extend the buffers with `bindings`
174    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
175        self.buffers.extend(bindings);
176        self
177    }
178
179    /// Add a scalar parameter
180    pub fn with_scalar(mut self, elem: Elem, length: usize, data: Vec<u64>) -> Self {
181        self.scalars
182            .insert(elem, ScalarBinding::new(elem, length, data));
183        self
184    }
185
186    /// Extend the scalars with `bindings`
187    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
188        self.scalars
189            .extend(bindings.into_iter().map(|binding| (binding.elem, binding)));
190        self
191    }
192
193    /// Set the metadata to `meta`
194    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
195        self.metadata = meta;
196        self
197    }
198
199    /// Extend the tensor maps with `bindings`
200    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
201        self.tensor_maps.extend(bindings);
202        self
203    }
204}
205
206/// Binding of a set of scalars of the same type to execute a kernel.
207#[derive(new, Debug, Default)]
208pub struct MetadataBinding {
209    /// Metadata values
210    pub data: Vec<u32>,
211    /// Length of the static portion (rank, len, buffer_len, shape_offsets, stride_offsets).
212    pub static_len: usize,
213}
214
215/// Binding of a set of scalars of the same type to execute a kernel.
216#[derive(new, Debug)]
217pub struct ScalarBinding {
218    /// Type of the scalars
219    pub elem: Elem,
220    /// Unpadded length of the underlying data
221    pub length: usize,
222    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
223    pub data: Vec<u64>,
224}
225
226impl ScalarBinding {
227    /// Get data as byte slice
228    pub fn data(&self) -> &[u8] {
229        bytemuck::cast_slice(&self.data)
230    }
231}
232
233/// Binding of a [tensor handle](Handle) to execute a kernel.
234#[derive(new, Debug)]
235pub struct Binding {
236    /// Memory binding.
237    pub memory: SliceBinding,
238    /// Memory offset in bytes.
239    pub offset_start: Option<u64>,
240    /// Memory offset in bytes.
241    pub offset_end: Option<u64>,
242}
243
244/// A binding with shape and stride info for non-contiguous reading
245#[derive(new, Debug)]
246pub struct BindingWithMeta {
247    /// Binding for the memory resource
248    pub binding: Binding,
249    /// Shape of the resource
250    pub shape: Vec<usize>,
251    /// Strides of the resource
252    pub strides: Vec<usize>,
253    /// Size of each element in the resource
254    pub elem_size: usize,
255}
256
257/// A tensor map used with TMA ops
258#[derive(new, Debug, Clone)]
259pub struct TensorMapBinding {
260    /// The binding for the backing tensor
261    pub binding: Binding,
262    /// The tensormap metadata
263    pub map: TensorMapMeta,
264}
265
266/// TensorMap metadata for the opaque proxy used in TMA copies
267#[derive(Debug, Clone)]
268pub struct TensorMapMeta {
269    /// Tensormap format (tiled or im2col)
270    pub format: TensorMapFormat,
271    /// Rank of the backing tensor
272    pub rank: usize,
273    /// Shape of the backing tensor
274    pub shape: Vec<usize>,
275    /// Strides of the backing tensor
276    pub strides: Vec<usize>,
277    /// Element stride, usually 1 but may be 2 for complex tensors
278    /// For im2col, this is equivalent to the kernel stride
279    pub elem_stride: Vec<usize>,
280    /// Interleave mode
281    pub interleave: TensorMapInterleave,
282    /// Swizzle mode
283    pub swizzle: TensorMapSwizzle,
284    /// Prefetch settings
285    pub prefetch: TensorMapPrefetch,
286    /// OOB fill value
287    pub oob_fill: OobFill,
288    /// Element type
289    pub elem: Elem,
290}
291
292impl Handle {
293    /// If the tensor handle can be reused inplace.
294    pub fn can_mut(&self) -> bool {
295        self.memory.can_mut()
296    }
297}
298
299impl Handle {
300    /// Convert the [handle](Handle) into a [binding](Binding).
301    pub fn binding(self) -> Binding {
302        Binding {
303            memory: MemoryHandle::binding(self.memory),
304            offset_start: self.offset_start,
305            offset_end: self.offset_end,
306        }
307    }
308
309    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
310    pub fn binding_with_meta(
311        self,
312        shape: Vec<usize>,
313        strides: Vec<usize>,
314        elem_size: usize,
315    ) -> BindingWithMeta {
316        BindingWithMeta {
317            shape,
318            strides,
319            elem_size,
320            binding: self.binding(),
321        }
322    }
323}
324
325impl Clone for Handle {
326    fn clone(&self) -> Self {
327        Self {
328            memory: self.memory.clone(),
329            offset_start: self.offset_start,
330            offset_end: self.offset_end,
331            size: self.size,
332        }
333    }
334}
335
336impl Clone for Binding {
337    fn clone(&self) -> Self {
338        Self {
339            memory: self.memory.clone(),
340            offset_start: self.offset_start,
341            offset_end: self.offset_end,
342        }
343    }
344}
345
346/// Specifieds the number of cubes to be dispatched for a kernel.
347///
348/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
349#[allow(clippy::large_enum_variant)]
350pub enum CubeCount {
351    /// Dispatch a known count of x, y, z cubes.
352    Static(u32, u32, u32),
353    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
354    Dynamic(Binding),
355}
356
357impl CubeCount {
358    /// Create a new static cube count with the given x = y = z = 1.
359    pub fn new_single() -> Self {
360        CubeCount::Static(1, 1, 1)
361    }
362
363    /// Create a new static cube count with the given x, and y = z = 1.
364    pub fn new_1d(x: u32) -> Self {
365        CubeCount::Static(x, 1, 1)
366    }
367
368    /// Create a new static cube count with the given x and y, and z = 1.
369    pub fn new_2d(x: u32, y: u32) -> Self {
370        CubeCount::Static(x, y, 1)
371    }
372
373    /// Create a new static cube count with the given x, y and z.
374    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
375        CubeCount::Static(x, y, z)
376    }
377}
378
379impl Debug for CubeCount {
380    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
381        match self {
382            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
383            CubeCount::Dynamic(_) => f.write_str("binding"),
384        }
385    }
386}
387
388impl Clone for CubeCount {
389    fn clone(&self) -> Self {
390        match self {
391            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
392            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
393        }
394    }
395}