cubecl_runtime/
server.rs

1use crate::{
2    DeviceProperties,
3    kernel::KernelMetadata,
4    logging::ServerLogger,
5    memory_management::{
6        MemoryAllocationMode, MemoryHandle, MemoryUsage,
7        memory_pool::{SliceBinding, SliceHandle},
8    },
9    storage::{BindingResource, ComputeStorage},
10    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
11};
12use alloc::collections::BTreeMap;
13use alloc::string::String;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::fmt::Debug;
18use cubecl_common::{
19    ExecutionMode, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
20    stream_id::StreamId,
21};
22use cubecl_ir::StorageType;
23use thiserror::Error;
24
25#[derive(Debug, Clone)]
26/// An error during profiling.
27pub enum ProfileError {
28    /// Unknown error.
29    Unknown(String),
30    /// When no profiling has been registered.
31    NotRegistered,
32}
33
34#[derive(Debug)]
35/// Contains many different types that are useful for server implementations and compute clients.
36pub struct ServerUtilities<Server: ComputeServer> {
37    /// The time when `profile-tracy` is activated.
38    #[cfg(feature = "profile-tracy")]
39    pub epoch_time: web_time::Instant,
40    /// The GPU client when `profile-tracy` is activated.
41    #[cfg(feature = "profile-tracy")]
42    pub gpu_client: tracy_client::GpuContext,
43    /// Information shared between all servers.
44    pub properties: DeviceProperties,
45    /// Information specific to the current server.
46    pub info: Server::Info,
47    /// The logger based on global cubecl configs.
48    pub logger: Arc<ServerLogger>,
49}
50
51impl<S: ComputeServer> ServerUtilities<S> {
52    /// Creates a new server utilities.
53    pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
54        // Start a tracy client if needed.
55        #[cfg(feature = "profile-tracy")]
56        let client = tracy_client::Client::start();
57
58        Self {
59            properties,
60            logger,
61            // Create the GPU client if needed.
62            #[cfg(feature = "profile-tracy")]
63            gpu_client: client
64                .clone()
65                .new_gpu_context(
66                    Some(&format!("{info:?}")),
67                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
68                    tracy_client::GpuContextType::Invalid,
69                    0,   // Timestamps are manually aligned to this epoch so start at 0.
70                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
71                )
72                .unwrap(),
73            #[cfg(feature = "profile-tracy")]
74            epoch_time: web_time::Instant::now(),
75            info,
76        }
77    }
78}
79
80/// The compute server is responsible for handling resources and computations over resources.
81///
82/// Everything in the server is mutable, therefore it should be solely accessed through the
83/// [compute channel](crate::channel::ComputeChannel) for thread safety.
84pub trait ComputeServer:
85    Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
86where
87    Self: Sized,
88{
89    /// The kernel type defines the computation algorithms.
90    type Kernel: KernelMetadata;
91    /// Information that can be retrieved for the runtime.
92    type Info: Debug + Send + Sync;
93    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
94    type Storage: ComputeStorage;
95
96    /// Reserves `size` bytes in the storage, and returns a handle over them.
97    fn create(
98        &mut self,
99        descriptors: Vec<AllocationDescriptor<'_>>,
100        stream_id: StreamId,
101    ) -> Result<Vec<Allocation>, IoError>;
102
103    /// Reserves N [Bytes] of the provided sizes to be used as staging to load data.
104    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
105        Err(IoError::UnsupportedIoOperation)
106    }
107
108    /// Retrieve the server logger.
109    fn logger(&self) -> Arc<ServerLogger>;
110
111    /// Retrieve the server utilities.
112    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
113
114    /// Utility to create a new buffer and immediately copy contiguous data into it
115    fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
116        let alloc = self
117            .create(
118                vec![AllocationDescriptor::new(
119                    AllocationKind::Contiguous,
120                    &[data.len()],
121                    1,
122                )],
123                stream_id,
124            )?
125            .remove(0);
126        self.write(
127            vec![(
128                CopyDescriptor::new(
129                    alloc.handle.clone().binding(),
130                    &[data.len()],
131                    &alloc.strides,
132                    1,
133                ),
134                Bytes::from_bytes_vec(data.to_vec()),
135            )],
136            stream_id,
137        )?;
138        Ok(alloc.handle)
139    }
140
141    /// Utility to create a new buffer and immediately copy contiguous data into it
142    fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
143        let alloc = self
144            .create(
145                vec![AllocationDescriptor::new(
146                    AllocationKind::Contiguous,
147                    &[data.len()],
148                    1,
149                )],
150                stream_id,
151            )?
152            .remove(0);
153        self.write(
154            vec![(
155                CopyDescriptor::new(
156                    alloc.handle.clone().binding(),
157                    &[data.len()],
158                    &alloc.strides,
159                    1,
160                ),
161                data,
162            )],
163            stream_id,
164        )?;
165        Ok(alloc.handle)
166    }
167
168    /// Given bindings, returns the owned resources as bytes.
169    fn read<'a>(
170        &mut self,
171        descriptors: Vec<CopyDescriptor<'a>>,
172        stream_id: StreamId,
173    ) -> DynFut<Result<Vec<Bytes>, IoError>>;
174
175    /// Writes the specified bytes into the buffers given
176    fn write(
177        &mut self,
178        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
179        stream_id: StreamId,
180    ) -> Result<(), IoError>;
181
182    /// Wait for the completion of every task in the server.
183    fn sync(&mut self, stream_id: StreamId) -> DynFut<()>;
184
185    /// Given a resource handle, returns the storage resource.
186    fn get_resource(
187        &mut self,
188        binding: Binding,
189        stream_id: StreamId,
190    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
191
192    /// Executes the `kernel` over the given memory `handles`.
193    ///
194    /// Kernels have mutable access to every resource they are given
195    /// and are responsible of determining which should be read or written.
196    ///
197    /// # Safety
198    ///
199    /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
200    unsafe fn execute(
201        &mut self,
202        kernel: Self::Kernel,
203        count: CubeCount,
204        bindings: Bindings,
205        kind: ExecutionMode,
206        stream_id: StreamId,
207    );
208
209    /// Flush all outstanding tasks in the server.
210    fn flush(&mut self, stream_id: StreamId);
211
212    /// The current memory usage of the server.
213    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
214
215    /// Ask the server to release memory that it can release.
216    fn memory_cleanup(&mut self, stream_id: StreamId);
217
218    /// Enable collecting timestamps.
219    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
220
221    /// Disable collecting timestamps.
222    fn end_profile(
223        &mut self,
224        stream_id: StreamId,
225        token: ProfilingToken,
226    ) -> Result<ProfileDuration, ProfileError>;
227
228    /// Update the memory mode of allocation in the server.
229    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
230}
231
232/// Defines functions for optimized data transfer between servers, supporting custom communication
233/// mechanisms such as peer-to-peer communication or specialized implementations.
234pub trait ServerCommunication {
235    /// Indicates whether server-to-server communication is enabled for this implementation.
236    const SERVER_COMM_ENABLED: bool;
237
238    /// Copies data from a source server to a destination server.
239    ///
240    /// # Arguments
241    ///
242    /// * `server_src` - A mutable reference to the source server from which data is copied.
243    /// * `server_dst` - A mutable reference to the destination server receiving the data.
244    /// * `src` - A descriptor specifying the data to be copied, including shape, strides, and binding.
245    /// * `stream_id_src` - The stream ID associated with the source server's operation.
246    /// * `stream_id_dst` - The stream ID associated with the destination server's operation.
247    ///
248    /// # Returns
249    ///
250    /// Returns a `Result` containing an `Allocation` on success, or an `IoError` if the operation fails.
251    ///
252    /// # Panics
253    ///
254    /// Panics if server communication is not enabled (`SERVER_COMM_ENABLED` is `false`) or if the
255    /// trait is incorrectly implemented by the server.
256    #[allow(unused_variables)]
257    fn copy(
258        server_src: &mut Self,
259        server_dst: &mut Self,
260        src: CopyDescriptor<'_>,
261        stream_id_src: StreamId,
262        stream_id_dst: StreamId,
263    ) -> Result<Allocation, IoError> {
264        if !Self::SERVER_COMM_ENABLED {
265            panic!("Server-to-server communication is not supported by this server.");
266        } else {
267            panic!(
268                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
269            );
270        }
271    }
272}
273
274#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
275/// Profiling identification so that the server can support recursive and overlapping profilings.
276pub struct ProfilingToken {
277    /// The token value.
278    pub id: u64,
279}
280
281/// Server handle containing the [memory handle](crate::server::Handle).
282#[derive(new, Debug, PartialEq, Eq)]
283pub struct Handle {
284    /// Memory handle.
285    pub memory: SliceHandle,
286    /// Memory offset in bytes.
287    pub offset_start: Option<u64>,
288    /// Memory offset in bytes.
289    pub offset_end: Option<u64>,
290    /// The stream where the data was created.
291    pub stream: cubecl_common::stream_id::StreamId,
292    /// The stream position when the tensor became available.
293    pub cursor: u64,
294    /// Length of the underlying buffer ignoring offsets
295    size: u64,
296}
297
298/// Type of allocation, either contiguous or optimized (row-aligned when possible)
299#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
300pub enum AllocationKind {
301    /// Contiguous layout, with no padding
302    Contiguous,
303    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
304    /// that support it.
305    Optimized,
306}
307
308/// Descriptor for a new tensor allocation
309#[derive(new, Debug, Clone, Copy)]
310pub struct AllocationDescriptor<'a> {
311    /// Layout for the tensor
312    pub kind: AllocationKind,
313    /// Shape of the tensor
314    pub shape: &'a [usize],
315    /// Size of each element in the tensor (used for conversion of shape to bytes)
316    pub elem_size: usize,
317}
318
319impl<'a> AllocationDescriptor<'a> {
320    /// Create an optimized allocation descriptor
321    pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
322        AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
323    }
324
325    /// Create a contiguous allocation descriptor
326    pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
327        AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
328    }
329}
330
331/// An allocation with associated strides. Strides depend on tensor layout.
332#[derive(new, Debug)]
333pub struct Allocation {
334    /// The handle for the memory resource
335    pub handle: Handle,
336    /// The strides of the tensor
337    pub strides: Vec<usize>,
338}
339
340/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
341/// are able to be caught, so some IO errors will still panic.
342#[derive(Debug, Error)]
343pub enum IoError {
344    /// Buffer size exceeds the max available
345    #[error("can't allocate buffer of size")]
346    BufferTooBig(usize),
347    /// Strides aren't supported for this copy operation on this runtime
348    #[error("the provided strides are not supported for this operation")]
349    UnsupportedStrides,
350    /// Handle wasn't found in the memory pool
351    #[error("couldn't find resource for that handle")]
352    InvalidHandle,
353    /// Unknown error happened during execution
354    #[error("Unknown error happened during execution")]
355    Unknown(String),
356    /// The current IO operation is not supported
357    #[error("The current IO operation is not supported")]
358    UnsupportedIoOperation,
359}
360
361impl Handle {
362    /// Add to the current offset in bytes.
363    pub fn offset_start(mut self, offset: u64) -> Self {
364        if let Some(val) = &mut self.offset_start {
365            *val += offset;
366        } else {
367            self.offset_start = Some(offset);
368        }
369
370        self
371    }
372    /// Add to the current offset in bytes.
373    pub fn offset_end(mut self, offset: u64) -> Self {
374        if let Some(val) = &mut self.offset_end {
375            *val += offset;
376        } else {
377            self.offset_end = Some(offset);
378        }
379
380        self
381    }
382
383    /// Get the size of the handle, in bytes, accounting for offsets
384    pub fn size(&self) -> u64 {
385        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
386    }
387}
388
389/// Bindings to execute a kernel.
390#[derive(Debug, Default)]
391pub struct Bindings {
392    /// Buffer bindings
393    pub buffers: Vec<Binding>,
394    /// Packed metadata for tensor bindings (len, shape, stride, etc).
395    /// Ordered by inputs, then outputs, then tensormaps
396    pub metadata: MetadataBinding,
397    /// Scalar bindings
398    pub scalars: BTreeMap<StorageType, ScalarBinding>,
399    /// Tensor map bindings
400    pub tensor_maps: Vec<TensorMapBinding>,
401}
402
403impl Bindings {
404    /// Create a new bindings struct
405    pub fn new() -> Self {
406        Self::default()
407    }
408
409    /// Add a buffer binding
410    pub fn with_buffer(mut self, binding: Binding) -> Self {
411        self.buffers.push(binding);
412        self
413    }
414
415    /// Extend the buffers with `bindings`
416    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
417        self.buffers.extend(bindings);
418        self
419    }
420
421    /// Add a scalar parameter
422    pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
423        self.scalars
424            .insert(ty, ScalarBinding::new(ty, length, data));
425        self
426    }
427
428    /// Extend the scalars with `bindings`
429    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
430        self.scalars
431            .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
432        self
433    }
434
435    /// Set the metadata to `meta`
436    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
437        self.metadata = meta;
438        self
439    }
440
441    /// Extend the tensor maps with `bindings`
442    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
443        self.tensor_maps.extend(bindings);
444        self
445    }
446}
447
448/// Binding of a set of scalars of the same type to execute a kernel.
449#[derive(new, Debug, Default)]
450pub struct MetadataBinding {
451    /// Metadata values
452    pub data: Vec<u32>,
453    /// Length of the static portion (rank, len, buffer_len, shape_offsets, stride_offsets).
454    pub static_len: usize,
455}
456
457/// Binding of a set of scalars of the same type to execute a kernel.
458#[derive(new, Debug, Clone)]
459pub struct ScalarBinding {
460    /// Type of the scalars
461    pub ty: StorageType,
462    /// Unpadded length of the underlying data
463    pub length: usize,
464    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
465    pub data: Vec<u64>,
466}
467
468impl ScalarBinding {
469    /// Get data as byte slice
470    pub fn data(&self) -> &[u8] {
471        bytemuck::cast_slice(&self.data)
472    }
473}
474
475/// Binding of a [tensor handle](Handle) to execute a kernel.
476#[derive(new, Debug)]
477pub struct Binding {
478    /// Memory binding.
479    pub memory: SliceBinding,
480    /// Memory offset in bytes.
481    pub offset_start: Option<u64>,
482    /// Memory offset in bytes.
483    pub offset_end: Option<u64>,
484    /// The stream where the data was created.
485    pub stream: cubecl_common::stream_id::StreamId,
486    /// The stream position when the tensor became available.
487    pub cursor: u64,
488    /// Size in bytes
489    size: u64,
490}
491
492impl Binding {
493    /// Get the size of the handle, in bytes, accounting for offsets
494    pub fn size(&self) -> u64 {
495        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
496    }
497}
498
499/// A binding with shape and stride info for non-contiguous reading
500#[derive(new, Debug, Clone)]
501pub struct CopyDescriptor<'a> {
502    /// Binding for the memory resource
503    pub binding: Binding,
504    /// Shape of the resource
505    pub shape: &'a [usize],
506    /// Strides of the resource
507    pub strides: &'a [usize],
508    /// Size of each element in the resource
509    pub elem_size: usize,
510}
511
512/// A tensor map used with TMA ops
513#[derive(new, Debug, Clone)]
514pub struct TensorMapBinding {
515    /// The binding for the backing tensor
516    pub binding: Binding,
517    /// The tensormap metadata
518    pub map: TensorMapMeta,
519}
520
521/// TensorMap metadata for the opaque proxy used in TMA copies
522#[derive(Debug, Clone)]
523pub struct TensorMapMeta {
524    /// Tensormap format (tiled or im2col)
525    pub format: TensorMapFormat,
526    /// Rank of the backing tensor
527    pub rank: usize,
528    /// Shape of the backing tensor
529    pub shape: Vec<usize>,
530    /// Strides of the backing tensor
531    pub strides: Vec<usize>,
532    /// Element stride, usually 1 but may be 2 for complex tensors
533    /// For im2col, this is equivalent to the kernel stride
534    pub elem_stride: Vec<usize>,
535    /// Interleave mode
536    pub interleave: TensorMapInterleave,
537    /// Swizzle mode
538    pub swizzle: TensorMapSwizzle,
539    /// Prefetch settings
540    pub prefetch: TensorMapPrefetch,
541    /// OOB fill value
542    pub oob_fill: OobFill,
543    /// Storage type
544    pub storage_ty: StorageType,
545}
546
547impl Handle {
548    /// If the tensor handle can be reused inplace.
549    pub fn can_mut(&self) -> bool {
550        self.memory.can_mut() && self.stream == StreamId::current()
551    }
552}
553
554impl Handle {
555    /// Convert the [handle](Handle) into a [binding](Binding).
556    pub fn binding(self) -> Binding {
557        Binding {
558            memory: MemoryHandle::binding(self.memory),
559            offset_start: self.offset_start,
560            offset_end: self.offset_end,
561            size: self.size,
562            stream: self.stream,
563            cursor: self.cursor,
564        }
565    }
566
567    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
568    pub fn copy_descriptor<'a>(
569        &'a self,
570        shape: &'a [usize],
571        strides: &'a [usize],
572        elem_size: usize,
573    ) -> CopyDescriptor<'a> {
574        CopyDescriptor {
575            shape,
576            strides,
577            elem_size,
578            binding: self.clone().binding(),
579        }
580    }
581}
582
583impl Clone for Handle {
584    fn clone(&self) -> Self {
585        Self {
586            memory: self.memory.clone(),
587            offset_start: self.offset_start,
588            offset_end: self.offset_end,
589            size: self.size,
590            stream: self.stream,
591            cursor: self.cursor,
592        }
593    }
594}
595
596impl Clone for Binding {
597    fn clone(&self) -> Self {
598        Self {
599            memory: self.memory.clone(),
600            offset_start: self.offset_start,
601            offset_end: self.offset_end,
602            size: self.size,
603            stream: self.stream,
604            cursor: self.cursor,
605        }
606    }
607}
608
609/// Specifieds the number of cubes to be dispatched for a kernel.
610///
611/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
612#[allow(clippy::large_enum_variant)]
613pub enum CubeCount {
614    /// Dispatch a known count of x, y, z cubes.
615    Static(u32, u32, u32),
616    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
617    Dynamic(Binding),
618}
619
620impl CubeCount {
621    /// Create a new static cube count with the given x = y = z = 1.
622    pub fn new_single() -> Self {
623        CubeCount::Static(1, 1, 1)
624    }
625
626    /// Create a new static cube count with the given x, and y = z = 1.
627    pub fn new_1d(x: u32) -> Self {
628        CubeCount::Static(x, 1, 1)
629    }
630
631    /// Create a new static cube count with the given x and y, and z = 1.
632    pub fn new_2d(x: u32, y: u32) -> Self {
633        CubeCount::Static(x, y, 1)
634    }
635
636    /// Create a new static cube count with the given x, y and z.
637    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
638        CubeCount::Static(x, y, z)
639    }
640}
641
642impl Debug for CubeCount {
643    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
644        match self {
645            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
646            CubeCount::Dynamic(_) => f.write_str("binding"),
647        }
648    }
649}
650
651impl Clone for CubeCount {
652    fn clone(&self) -> Self {
653        match self {
654            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
655            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
656        }
657    }
658}