cubecl_runtime/
server.rs

1use crate::{
2    DeviceProperties,
3    kernel::KernelMetadata,
4    logging::ServerLogger,
5    memory_management::{
6        MemoryAllocationMode, MemoryHandle, MemoryUsage,
7        memory_pool::{SliceBinding, SliceHandle},
8    },
9    storage::{BindingResource, ComputeStorage},
10    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
11};
12use alloc::collections::BTreeMap;
13use alloc::string::String;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::fmt::Debug;
18use cubecl_common::{
19    ExecutionMode, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
20    stream_id::StreamId,
21};
22use cubecl_ir::StorageType;
23use thiserror::Error;
24
25#[derive(Debug, Clone)]
26/// An error during profiling.
27pub enum ProfileError {
28    /// Unknown error.
29    Unknown(String),
30    /// When no profiling has been registered.
31    NotRegistered,
32}
33
34#[derive(Debug)]
35/// Contains many different types that are useful for server implementations and compute clients.
36pub struct ServerUtilities<Server: ComputeServer> {
37    /// The time when `profile-tracy` is activated.
38    #[cfg(feature = "profile-tracy")]
39    pub epoch_time: web_time::Instant,
40    /// The GPU client when `profile-tracy` is activated.
41    #[cfg(feature = "profile-tracy")]
42    pub gpu_client: tracy_client::GpuContext,
43    /// Information shared between all servers.
44    pub properties: DeviceProperties,
45    /// Information specific to the current server.
46    pub info: Server::Info,
47    /// The logger based on global cubecl configs.
48    pub logger: Arc<ServerLogger>,
49}
50
51impl<S: ComputeServer> ServerUtilities<S> {
52    /// Creates a new server utilities.
53    pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
54        // Start a tracy client if needed.
55        #[cfg(feature = "profile-tracy")]
56        let client = tracy_client::Client::start();
57
58        Self {
59            properties,
60            logger,
61            // Create the GPU client if needed.
62            #[cfg(feature = "profile-tracy")]
63            gpu_client: client
64                .clone()
65                .new_gpu_context(
66                    Some(&format!("{info:?}")),
67                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
68                    tracy_client::GpuContextType::Invalid,
69                    0,   // Timestamps are manually aligned to this epoch so start at 0.
70                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
71                )
72                .unwrap(),
73            #[cfg(feature = "profile-tracy")]
74            epoch_time: web_time::Instant::now(),
75            info,
76        }
77    }
78}
79
80/// The compute server is responsible for handling resources and computations over resources.
81///
82/// Everything in the server is mutable, therefore it should be solely accessed through the
83/// [compute channel](crate::channel::ComputeChannel) for thread safety.
84pub trait ComputeServer:
85    Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
86where
87    Self: Sized,
88{
89    /// The kernel type defines the computation algorithms.
90    type Kernel: KernelMetadata;
91    /// Information that can be retrieved for the runtime.
92    type Info: Debug + Send + Sync;
93    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
94    type Storage: ComputeStorage;
95
96    /// Reserves `size` bytes in the storage, and returns a handle over them.
97    fn create(
98        &mut self,
99        descriptors: Vec<AllocationDescriptor<'_>>,
100        stream_id: StreamId,
101    ) -> Result<Vec<Allocation>, IoError>;
102
103    /// Retrieve the server logger.
104    fn logger(&self) -> Arc<ServerLogger>;
105
106    /// Retrieve the server utilities.
107    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
108
109    /// Utility to create a new buffer and immediately copy contiguous data into it
110    fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
111        let alloc = self
112            .create(
113                vec![AllocationDescriptor::new(
114                    AllocationKind::Contiguous,
115                    &[data.len()],
116                    1,
117                )],
118                stream_id,
119            )?
120            .remove(0);
121        self.write(
122            vec![(
123                CopyDescriptor::new(
124                    alloc.handle.clone().binding(),
125                    &[data.len()],
126                    &alloc.strides,
127                    1,
128                ),
129                data,
130            )],
131            stream_id,
132        )?;
133        Ok(alloc.handle)
134    }
135
136    /// Given bindings, returns the owned resources as bytes.
137    fn read<'a>(
138        &mut self,
139        descriptors: Vec<CopyDescriptor<'a>>,
140        stream_id: StreamId,
141    ) -> DynFut<Result<Vec<Bytes>, IoError>>;
142
143    /// Writes the specified bytes into the buffers given
144    fn write(
145        &mut self,
146        descriptors: Vec<(CopyDescriptor<'_>, &[u8])>,
147        stream_id: StreamId,
148    ) -> Result<(), IoError>;
149
150    /// Wait for the completion of every task in the server.
151    fn sync(&mut self, stream_id: StreamId) -> DynFut<()>;
152
153    /// Given a resource handle, returns the storage resource.
154    fn get_resource(
155        &mut self,
156        binding: Binding,
157        stream_id: StreamId,
158    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
159
160    /// Executes the `kernel` over the given memory `handles`.
161    ///
162    /// Kernels have mutable access to every resource they are given
163    /// and are responsible of determining which should be read or written.
164    ///
165    /// # Safety
166    ///
167    /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
168    unsafe fn execute(
169        &mut self,
170        kernel: Self::Kernel,
171        count: CubeCount,
172        bindings: Bindings,
173        kind: ExecutionMode,
174        stream_id: StreamId,
175    );
176
177    /// Flush all outstanding tasks in the server.
178    fn flush(&mut self, stream_id: StreamId);
179
180    /// The current memory usage of the server.
181    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
182
183    /// Ask the server to release memory that it can release.
184    fn memory_cleanup(&mut self, stream_id: StreamId);
185
186    /// Enable collecting timestamps.
187    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
188
189    /// Disable collecting timestamps.
190    fn end_profile(
191        &mut self,
192        stream_id: StreamId,
193        token: ProfilingToken,
194    ) -> Result<ProfileDuration, ProfileError>;
195
196    /// Update the memory mode of allocation in the server.
197    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
198}
199
200/// Defines functions for optimized data transfer between servers, supporting custom communication
201/// mechanisms such as peer-to-peer communication or specialized implementations.
202pub trait ServerCommunication {
203    /// Indicates whether server-to-server communication is enabled for this implementation.
204    const SERVER_COMM_ENABLED: bool;
205
206    /// Copies data from a source server to a destination server.
207    ///
208    /// # Arguments
209    ///
210    /// * `server_src` - A mutable reference to the source server from which data is copied.
211    /// * `server_dst` - A mutable reference to the destination server receiving the data.
212    /// * `src` - A descriptor specifying the data to be copied, including shape, strides, and binding.
213    /// * `stream_id_src` - The stream ID associated with the source server's operation.
214    /// * `stream_id_dst` - The stream ID associated with the destination server's operation.
215    ///
216    /// # Returns
217    ///
218    /// Returns a `Result` containing an `Allocation` on success, or an `IoError` if the operation fails.
219    ///
220    /// # Panics
221    ///
222    /// Panics if server communication is not enabled (`SERVER_COMM_ENABLED` is `false`) or if the
223    /// trait is incorrectly implemented by the server.
224    #[allow(unused_variables)]
225    fn copy(
226        server_src: &mut Self,
227        server_dst: &mut Self,
228        src: CopyDescriptor<'_>,
229        stream_id_src: StreamId,
230        stream_id_dst: StreamId,
231    ) -> Result<Allocation, IoError> {
232        if !Self::SERVER_COMM_ENABLED {
233            panic!("Server-to-server communication is not supported by this server.");
234        } else {
235            panic!(
236                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
237            );
238        }
239    }
240}
241
242#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
243/// Profiling identification so that the server can support recursive and overlapping profilings.
244pub struct ProfilingToken {
245    /// The token value.
246    pub id: u64,
247}
248
249/// Server handle containing the [memory handle](crate::server::Handle).
250#[derive(new, Debug, PartialEq, Eq)]
251pub struct Handle {
252    /// Memory handle.
253    pub memory: SliceHandle,
254    /// Memory offset in bytes.
255    pub offset_start: Option<u64>,
256    /// Memory offset in bytes.
257    pub offset_end: Option<u64>,
258    /// The stream where the data was created.
259    pub stream: cubecl_common::stream_id::StreamId,
260    /// The stream position when the tensor became available.
261    pub cursor: u64,
262    /// Length of the underlying buffer ignoring offsets
263    size: u64,
264}
265
266/// Type of allocation, either contiguous or optimized (row-aligned when possible)
267#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
268pub enum AllocationKind {
269    /// Contiguous layout, with no padding
270    Contiguous,
271    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
272    /// that support it.
273    Optimized,
274}
275
276/// Descriptor for a new tensor allocation
277#[derive(new, Debug, Clone, Copy)]
278pub struct AllocationDescriptor<'a> {
279    /// Layout for the tensor
280    pub kind: AllocationKind,
281    /// Shape of the tensor
282    pub shape: &'a [usize],
283    /// Size of each element in the tensor (used for conversion of shape to bytes)
284    pub elem_size: usize,
285}
286
287impl<'a> AllocationDescriptor<'a> {
288    /// Create an optimized allocation descriptor
289    pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
290        AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
291    }
292
293    /// Create a contiguous allocation descriptor
294    pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
295        AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
296    }
297}
298
299/// An allocation with associated strides. Strides depend on tensor layout.
300#[derive(new, Debug)]
301pub struct Allocation {
302    /// The handle for the memory resource
303    pub handle: Handle,
304    /// The strides of the tensor
305    pub strides: Vec<usize>,
306}
307
308/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
309/// are able to be caught, so some IO errors will still panic.
310#[derive(Debug, Error)]
311pub enum IoError {
312    /// Buffer size exceeds the max available
313    #[error("can't allocate buffer of size")]
314    BufferTooBig(usize),
315    /// Strides aren't supported for this copy operation on this runtime
316    #[error("the provided strides are not supported for this operation")]
317    UnsupportedStrides,
318    /// Handle wasn't found in the memory pool
319    #[error("couldn't find resource for that handle")]
320    InvalidHandle,
321    /// Unknown error happened during execution
322    #[error("Unknown error happened during execution")]
323    Unknown(String),
324}
325
326impl Handle {
327    /// Add to the current offset in bytes.
328    pub fn offset_start(mut self, offset: u64) -> Self {
329        if let Some(val) = &mut self.offset_start {
330            *val += offset;
331        } else {
332            self.offset_start = Some(offset);
333        }
334
335        self
336    }
337    /// Add to the current offset in bytes.
338    pub fn offset_end(mut self, offset: u64) -> Self {
339        if let Some(val) = &mut self.offset_end {
340            *val += offset;
341        } else {
342            self.offset_end = Some(offset);
343        }
344
345        self
346    }
347
348    /// Get the size of the handle, in bytes, accounting for offsets
349    pub fn size(&self) -> u64 {
350        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
351    }
352}
353
354/// Bindings to execute a kernel.
355#[derive(Debug, Default)]
356pub struct Bindings {
357    /// Buffer bindings
358    pub buffers: Vec<Binding>,
359    /// Packed metadata for tensor bindings (len, shape, stride, etc).
360    /// Ordered by inputs, then outputs, then tensormaps
361    pub metadata: MetadataBinding,
362    /// Scalar bindings
363    pub scalars: BTreeMap<StorageType, ScalarBinding>,
364    /// Tensor map bindings
365    pub tensor_maps: Vec<TensorMapBinding>,
366}
367
368impl Bindings {
369    /// Create a new bindings struct
370    pub fn new() -> Self {
371        Self::default()
372    }
373
374    /// Add a buffer binding
375    pub fn with_buffer(mut self, binding: Binding) -> Self {
376        self.buffers.push(binding);
377        self
378    }
379
380    /// Extend the buffers with `bindings`
381    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
382        self.buffers.extend(bindings);
383        self
384    }
385
386    /// Add a scalar parameter
387    pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
388        self.scalars
389            .insert(ty, ScalarBinding::new(ty, length, data));
390        self
391    }
392
393    /// Extend the scalars with `bindings`
394    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
395        self.scalars
396            .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
397        self
398    }
399
400    /// Set the metadata to `meta`
401    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
402        self.metadata = meta;
403        self
404    }
405
406    /// Extend the tensor maps with `bindings`
407    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
408        self.tensor_maps.extend(bindings);
409        self
410    }
411}
412
413/// Binding of a set of scalars of the same type to execute a kernel.
414#[derive(new, Debug, Default)]
415pub struct MetadataBinding {
416    /// Metadata values
417    pub data: Vec<u32>,
418    /// Length of the static portion (rank, len, buffer_len, shape_offsets, stride_offsets).
419    pub static_len: usize,
420}
421
422/// Binding of a set of scalars of the same type to execute a kernel.
423#[derive(new, Debug, Clone)]
424pub struct ScalarBinding {
425    /// Type of the scalars
426    pub ty: StorageType,
427    /// Unpadded length of the underlying data
428    pub length: usize,
429    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
430    pub data: Vec<u64>,
431}
432
433impl ScalarBinding {
434    /// Get data as byte slice
435    pub fn data(&self) -> &[u8] {
436        bytemuck::cast_slice(&self.data)
437    }
438}
439
440/// Binding of a [tensor handle](Handle) to execute a kernel.
441#[derive(new, Debug)]
442pub struct Binding {
443    /// Memory binding.
444    pub memory: SliceBinding,
445    /// Memory offset in bytes.
446    pub offset_start: Option<u64>,
447    /// Memory offset in bytes.
448    pub offset_end: Option<u64>,
449    /// The stream where the data was created.
450    pub stream: cubecl_common::stream_id::StreamId,
451    /// The stream position when the tensor became available.
452    pub cursor: u64,
453    /// Size in bytes
454    size: u64,
455}
456
457impl Binding {
458    /// Get the size of the handle, in bytes, accounting for offsets
459    pub fn size(&self) -> u64 {
460        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
461    }
462}
463
464/// A binding with shape and stride info for non-contiguous reading
465#[derive(new, Debug, Clone)]
466pub struct CopyDescriptor<'a> {
467    /// Binding for the memory resource
468    pub binding: Binding,
469    /// Shape of the resource
470    pub shape: &'a [usize],
471    /// Strides of the resource
472    pub strides: &'a [usize],
473    /// Size of each element in the resource
474    pub elem_size: usize,
475}
476
477/// A tensor map used with TMA ops
478#[derive(new, Debug, Clone)]
479pub struct TensorMapBinding {
480    /// The binding for the backing tensor
481    pub binding: Binding,
482    /// The tensormap metadata
483    pub map: TensorMapMeta,
484}
485
486/// TensorMap metadata for the opaque proxy used in TMA copies
487#[derive(Debug, Clone)]
488pub struct TensorMapMeta {
489    /// Tensormap format (tiled or im2col)
490    pub format: TensorMapFormat,
491    /// Rank of the backing tensor
492    pub rank: usize,
493    /// Shape of the backing tensor
494    pub shape: Vec<usize>,
495    /// Strides of the backing tensor
496    pub strides: Vec<usize>,
497    /// Element stride, usually 1 but may be 2 for complex tensors
498    /// For im2col, this is equivalent to the kernel stride
499    pub elem_stride: Vec<usize>,
500    /// Interleave mode
501    pub interleave: TensorMapInterleave,
502    /// Swizzle mode
503    pub swizzle: TensorMapSwizzle,
504    /// Prefetch settings
505    pub prefetch: TensorMapPrefetch,
506    /// OOB fill value
507    pub oob_fill: OobFill,
508    /// Storage type
509    pub storage_ty: StorageType,
510}
511
512impl Handle {
513    /// If the tensor handle can be reused inplace.
514    pub fn can_mut(&self) -> bool {
515        self.memory.can_mut() && self.stream == StreamId::current()
516    }
517}
518
519impl Handle {
520    /// Convert the [handle](Handle) into a [binding](Binding).
521    pub fn binding(self) -> Binding {
522        Binding {
523            memory: MemoryHandle::binding(self.memory),
524            offset_start: self.offset_start,
525            offset_end: self.offset_end,
526            size: self.size,
527            stream: self.stream,
528            cursor: self.cursor,
529        }
530    }
531
532    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
533    pub fn copy_descriptor<'a>(
534        &'a self,
535        shape: &'a [usize],
536        strides: &'a [usize],
537        elem_size: usize,
538    ) -> CopyDescriptor<'a> {
539        CopyDescriptor {
540            shape,
541            strides,
542            elem_size,
543            binding: self.clone().binding(),
544        }
545    }
546}
547
548impl Clone for Handle {
549    fn clone(&self) -> Self {
550        Self {
551            memory: self.memory.clone(),
552            offset_start: self.offset_start,
553            offset_end: self.offset_end,
554            size: self.size,
555            stream: self.stream,
556            cursor: self.cursor,
557        }
558    }
559}
560
561impl Clone for Binding {
562    fn clone(&self) -> Self {
563        Self {
564            memory: self.memory.clone(),
565            offset_start: self.offset_start,
566            offset_end: self.offset_end,
567            size: self.size,
568            stream: self.stream,
569            cursor: self.cursor,
570        }
571    }
572}
573
574/// Specifieds the number of cubes to be dispatched for a kernel.
575///
576/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
577#[allow(clippy::large_enum_variant)]
578pub enum CubeCount {
579    /// Dispatch a known count of x, y, z cubes.
580    Static(u32, u32, u32),
581    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
582    Dynamic(Binding),
583}
584
585impl CubeCount {
586    /// Create a new static cube count with the given x = y = z = 1.
587    pub fn new_single() -> Self {
588        CubeCount::Static(1, 1, 1)
589    }
590
591    /// Create a new static cube count with the given x, and y = z = 1.
592    pub fn new_1d(x: u32) -> Self {
593        CubeCount::Static(x, 1, 1)
594    }
595
596    /// Create a new static cube count with the given x and y, and z = 1.
597    pub fn new_2d(x: u32, y: u32) -> Self {
598        CubeCount::Static(x, y, 1)
599    }
600
601    /// Create a new static cube count with the given x, y and z.
602    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
603        CubeCount::Static(x, y, z)
604    }
605}
606
607impl Debug for CubeCount {
608    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
609        match self {
610            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
611            CubeCount::Dynamic(_) => f.write_str("binding"),
612        }
613    }
614}
615
616impl Clone for CubeCount {
617    fn clone(&self) -> Self {
618        match self {
619            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
620            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
621        }
622    }
623}