cubecl_runtime/
server.rs

1use crate::{
2    DeviceProperties,
3    client::ComputeClient,
4    compiler::CompilationError,
5    kernel::KernelMetadata,
6    logging::ServerLogger,
7    memory_management::{
8        MemoryAllocationMode, MemoryHandle, MemoryUsage,
9        memory_pool::{SliceBinding, SliceHandle},
10    },
11    runtime::Runtime,
12    storage::{BindingResource, ComputeStorage},
13    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
14};
15use alloc::collections::BTreeMap;
16use alloc::string::String;
17use alloc::sync::Arc;
18use alloc::vec;
19use alloc::vec::Vec;
20use core::fmt::Debug;
21use cubecl_common::{
22    backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
23    stream_id::StreamId,
24};
25use cubecl_ir::StorageType;
26use serde::{Deserialize, Serialize};
27use thiserror::Error;
28
29#[derive(Error, Clone)]
30/// An error during profiling.
31pub enum ProfileError {
32    /// An unknown error happened during profiling
33    #[error(
34        "An unknown error happened during profiling\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
35    )]
36    Unknown {
37        /// The caused of the error
38        reason: String,
39        /// The captured backtrace.
40        backtrace: BackTrace,
41    },
42
43    /// No profiling was registered
44    #[error("No profiling registered\nBacktrace:\n{backtrace}")]
45    NotRegistered {
46        /// The captured backtrace.
47        backtrace: BackTrace,
48    },
49
50    /// A launch error happened during profiling
51    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
52    Launch(#[from] LaunchError),
53
54    /// An execution error happened during profiling
55    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
56    Execution(#[from] ExecutionError),
57}
58
59impl core::fmt::Debug for ProfileError {
60    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
61        f.write_fmt(format_args!("{self}"))
62    }
63}
64
65/// Contains many different types that are useful for server implementations and compute clients.
66pub struct ServerUtilities<Server: ComputeServer> {
67    /// The time when `profile-tracy` is activated.
68    #[cfg(feature = "profile-tracy")]
69    pub epoch_time: web_time::Instant,
70    /// The GPU client when `profile-tracy` is activated.
71    #[cfg(feature = "profile-tracy")]
72    pub gpu_client: tracy_client::GpuContext,
73    /// Information shared between all servers.
74    pub properties: DeviceProperties,
75    /// Information specific to the current server.
76    pub info: Server::Info,
77    /// The logger based on global cubecl configs.
78    pub logger: Arc<ServerLogger>,
79}
80
81impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
82where
83    Server: ComputeServer,
84    Server::Info: core::fmt::Debug,
85{
86    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
87        f.debug_struct("ServerUtilities")
88            .field("properties", &self.properties)
89            .field("info", &self.info)
90            .field("logger", &self.logger)
91            .finish()
92    }
93}
94
95impl<S: ComputeServer> ServerUtilities<S> {
96    /// Creates a new server utilities.
97    pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
98        // Start a tracy client if needed.
99        #[cfg(feature = "profile-tracy")]
100        let client = tracy_client::Client::start();
101
102        Self {
103            properties,
104            logger,
105            // Create the GPU client if needed.
106            #[cfg(feature = "profile-tracy")]
107            gpu_client: client
108                .clone()
109                .new_gpu_context(
110                    Some(&format!("{info:?}")),
111                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
112                    tracy_client::GpuContextType::Invalid,
113                    0,   // Timestamps are manually aligned to this epoch so start at 0.
114                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
115                )
116                .unwrap(),
117            #[cfg(feature = "profile-tracy")]
118            epoch_time: web_time::Instant::now(),
119            info,
120        }
121    }
122}
123
124/// Error that can happen when calling [ComputeServer::execute];
125///
126/// # Notes
127///
128/// Not all errors are going to be catched when calling [ComputeServer::execute] only the one that
129/// won't block the compute queue.
130#[derive(Error, Clone)]
131#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
132pub enum LaunchError {
133    /// The given kernel can't be compiled.
134    #[error("A compilation error happened during launch\nCaused by:\n  {0}")]
135    CompilationError(#[from] CompilationError),
136
137    /// The server is out of memory.
138    #[error(
139        "An out-of-memory error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
140    )]
141    OutOfMemory {
142        /// The caused of the memory error.
143        reason: String,
144        /// The backtrace for this error.
145        #[cfg_attr(std_io, serde(skip))]
146        backtrace: BackTrace,
147    },
148
149    /// Unknown launch error.
150    #[error(
151        "An unknown error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
152    )]
153    Unknown {
154        /// The caused of the unknown error.
155        reason: String,
156        /// The backtrace for this error.
157        #[cfg_attr(std_io, serde(skip))]
158        backtrace: BackTrace,
159    },
160
161    /// Can't launch because of an IO Error.
162    #[error("An io error happened during launch\nCaused by:\n  {0}")]
163    IoError(#[from] IoError),
164}
165
166impl core::fmt::Debug for LaunchError {
167    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
168        f.write_fmt(format_args!("{self}"))
169    }
170}
171
172/// Error that can happen asynchronously while executing registered kernels.
173#[derive(Error, Debug, Clone)]
174#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
175pub enum ExecutionError {
176    /// A generic runtime error.
177    #[error("An error happened during execution\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}")]
178    Generic {
179        /// The details of the generic error.
180        reason: String,
181        /// The backtrace for this error.
182        #[cfg_attr(std_io, serde(skip))]
183        backtrace: BackTrace,
184    },
185}
186
187/// The compute server is responsible for handling resources and computations over resources.
188///
189/// Everything in the server is mutable, therefore it should be solely accessed through the
190/// [compute channel](crate::channel::ComputeChannel) for thread safety.
191pub trait ComputeServer:
192    Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
193where
194    Self: Sized,
195{
196    /// The kernel type defines the computation algorithms.
197    type Kernel: KernelMetadata;
198    /// Information that can be retrieved for the runtime.
199    type Info: Debug + Send + Sync;
200    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
201    type Storage: ComputeStorage;
202
203    /// Reserves `size` bytes in the storage, and returns a handle over them.
204    fn create(
205        &mut self,
206        descriptors: Vec<AllocationDescriptor<'_>>,
207        stream_id: StreamId,
208    ) -> Result<Vec<Allocation>, IoError>;
209
210    /// Reserves N [Bytes] of the provided sizes to be used as staging to load data.
211    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
212        Err(IoError::UnsupportedIoOperation {
213            backtrace: BackTrace::capture(),
214        })
215    }
216
217    /// Retrieve the server logger.
218    fn logger(&self) -> Arc<ServerLogger>;
219
220    /// Retrieve the server utilities.
221    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
222
223    /// Utility to create a new buffer and immediately copy contiguous data into it
224    fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
225        let alloc = self
226            .create(
227                vec![AllocationDescriptor::new(
228                    AllocationKind::Contiguous,
229                    &[data.len()],
230                    1,
231                )],
232                stream_id,
233            )?
234            .remove(0);
235        self.write(
236            vec![(
237                CopyDescriptor::new(
238                    alloc.handle.clone().binding(),
239                    &[data.len()],
240                    &alloc.strides,
241                    1,
242                ),
243                Bytes::from_bytes_vec(data.to_vec()),
244            )],
245            stream_id,
246        )?;
247        Ok(alloc.handle)
248    }
249
250    /// Utility to create a new buffer and immediately copy contiguous data into it
251    fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
252        let alloc = self
253            .create(
254                vec![AllocationDescriptor::new(
255                    AllocationKind::Contiguous,
256                    &[data.len()],
257                    1,
258                )],
259                stream_id,
260            )?
261            .remove(0);
262        self.write(
263            vec![(
264                CopyDescriptor::new(
265                    alloc.handle.clone().binding(),
266                    &[data.len()],
267                    &alloc.strides,
268                    1,
269                ),
270                data,
271            )],
272            stream_id,
273        )?;
274        Ok(alloc.handle)
275    }
276
277    /// Given bindings, returns the owned resources as bytes.
278    fn read<'a>(
279        &mut self,
280        descriptors: Vec<CopyDescriptor<'a>>,
281        stream_id: StreamId,
282    ) -> DynFut<Result<Vec<Bytes>, IoError>>;
283
284    /// Writes the specified bytes into the buffers given
285    fn write(
286        &mut self,
287        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
288        stream_id: StreamId,
289    ) -> Result<(), IoError>;
290
291    /// Wait for the completion of every task in the server.
292    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
293
294    /// Given a resource handle, returns the storage resource.
295    fn get_resource(
296        &mut self,
297        binding: Binding,
298        stream_id: StreamId,
299    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
300
301    /// Executes the `kernel` over the given memory `handles`.
302    ///
303    /// Kernels have mutable access to every resource they are given
304    /// and are responsible of determining which should be read or written.
305    ///
306    /// # Safety
307    ///
308    /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
309    unsafe fn launch(
310        &mut self,
311        kernel: Self::Kernel,
312        count: CubeCount,
313        bindings: Bindings,
314        kind: ExecutionMode,
315        stream_id: StreamId,
316    ) -> Result<(), LaunchError>;
317
318    /// Flush all outstanding tasks in the server.
319    fn flush(&mut self, stream_id: StreamId);
320
321    /// The current memory usage of the server.
322    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
323
324    /// Ask the server to release memory that it can release.
325    fn memory_cleanup(&mut self, stream_id: StreamId);
326
327    /// Enable collecting timestamps.
328    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
329
330    /// Disable collecting timestamps.
331    fn end_profile(
332        &mut self,
333        stream_id: StreamId,
334        token: ProfilingToken,
335    ) -> Result<ProfileDuration, ProfileError>;
336
337    /// Update the memory mode of allocation in the server.
338    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
339}
340
341/// Defines functions for optimized data transfer between servers, supporting custom communication
342/// mechanisms such as peer-to-peer communication or specialized implementations.
343pub trait ServerCommunication {
344    /// Indicates whether server-to-server communication is enabled for this implementation.
345    const SERVER_COMM_ENABLED: bool;
346
347    /// Copies data from a source server to a destination server.
348    ///
349    /// # Arguments
350    ///
351    /// * `server_src` - A mutable reference to the source server from which data is copied.
352    /// * `server_dst` - A mutable reference to the destination server receiving the data.
353    /// * `src` - A descriptor specifying the data to be copied, including shape, strides, and binding.
354    /// * `stream_id_src` - The stream ID associated with the source server's operation.
355    /// * `stream_id_dst` - The stream ID associated with the destination server's operation.
356    ///
357    /// # Returns
358    ///
359    /// Returns a `Result` containing an `Allocation` on success, or an `IoError` if the operation fails.
360    ///
361    /// # Panics
362    ///
363    /// Panics if server communication is not enabled (`SERVER_COMM_ENABLED` is `false`) or if the
364    /// trait is incorrectly implemented by the server.
365    #[allow(unused_variables)]
366    fn copy(
367        server_src: &mut Self,
368        server_dst: &mut Self,
369        src: CopyDescriptor<'_>,
370        stream_id_src: StreamId,
371        stream_id_dst: StreamId,
372    ) -> Result<Allocation, IoError> {
373        if !Self::SERVER_COMM_ENABLED {
374            panic!("Server-to-server communication is not supported by this server.");
375        } else {
376            panic!(
377                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
378            );
379        }
380    }
381}
382
383#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
384/// Profiling identification so that the server can support recursive and overlapping profilings.
385pub struct ProfilingToken {
386    /// The token value.
387    pub id: u64,
388}
389
390/// Server handle containing the [memory handle](crate::server::Handle).
391#[derive(new, Debug, PartialEq, Eq)]
392pub struct Handle {
393    /// Memory handle.
394    pub memory: SliceHandle,
395    /// Memory offset in bytes.
396    pub offset_start: Option<u64>,
397    /// Memory offset in bytes.
398    pub offset_end: Option<u64>,
399    /// The stream where the data was created.
400    pub stream: cubecl_common::stream_id::StreamId,
401    /// The stream position when the tensor became available.
402    pub cursor: u64,
403    /// Length of the underlying buffer ignoring offsets
404    size: u64,
405}
406
407/// Type of allocation, either contiguous or optimized (row-aligned when possible)
408#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
409pub enum AllocationKind {
410    /// Contiguous layout, with no padding
411    Contiguous,
412    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
413    /// that support it.
414    Optimized,
415}
416
417/// Descriptor for a new tensor allocation
418#[derive(new, Debug, Clone, Copy)]
419pub struct AllocationDescriptor<'a> {
420    /// Layout for the tensor
421    pub kind: AllocationKind,
422    /// Shape of the tensor
423    pub shape: &'a [usize],
424    /// Size of each element in the tensor (used for conversion of shape to bytes)
425    pub elem_size: usize,
426}
427
428impl<'a> AllocationDescriptor<'a> {
429    /// Create an optimized allocation descriptor
430    pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
431        AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
432    }
433
434    /// Create a contiguous allocation descriptor
435    pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
436        AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
437    }
438}
439
440/// An allocation with associated strides. Strides depend on tensor layout.
441#[derive(new, Debug)]
442pub struct Allocation {
443    /// The handle for the memory resource
444    pub handle: Handle,
445    /// The strides of the tensor
446    pub strides: Vec<usize>,
447}
448
449/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
450/// are able to be caught, so some IO errors will still panic.
451#[derive(Error, Clone)]
452#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
453pub enum IoError {
454    /// Buffer size exceeds the max available
455    #[error("can't allocate buffer of size: {size}\n{backtrace}")]
456    BufferTooBig {
457        /// The size of the buffer in bytes.
458        size: u64,
459        /// The captured backtrace.
460        #[cfg_attr(std_io, serde(skip))]
461        backtrace: BackTrace,
462    },
463
464    /// Strides aren't supported for this copy operation on this runtime
465    #[error("the provided strides are not supported for this operation\n{backtrace}")]
466    UnsupportedStrides {
467        /// The backtrace.
468        #[cfg_attr(std_io, serde(skip))]
469        backtrace: BackTrace,
470    },
471
472    /// Handle wasn't found in the memory pool
473    #[error("couldn't find resource for that handle\n{backtrace}")]
474    InvalidHandle {
475        /// The backtrace.
476        #[cfg_attr(std_io, serde(skip))]
477        backtrace: BackTrace,
478    },
479
480    /// Unknown error happened during execution
481    #[error("Unknown error happened during execution\n{backtrace}")]
482    Unknown {
483        /// Details of the error
484        description: String,
485        /// The backtrace.
486        #[cfg_attr(std_io, serde(skip))]
487        backtrace: BackTrace,
488    },
489
490    /// The current IO operation is not supported
491    #[error("The current IO operation is not supported\n{backtrace}")]
492    UnsupportedIoOperation {
493        /// The backtrace.
494        #[cfg_attr(std_io, serde(skip))]
495        backtrace: BackTrace,
496    },
497
498    /// Can't perform the IO operation because of a runtime error.
499    #[error("Can't perform the IO operation because of a runtime error")]
500    Execution(#[from] ExecutionError),
501}
502
503impl core::fmt::Debug for IoError {
504    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
505        f.write_fmt(format_args!("{self}"))
506    }
507}
508
509impl Handle {
510    /// Add to the current offset in bytes.
511    pub fn offset_start(mut self, offset: u64) -> Self {
512        if let Some(val) = &mut self.offset_start {
513            *val += offset;
514        } else {
515            self.offset_start = Some(offset);
516        }
517
518        self
519    }
520    /// Add to the current offset in bytes.
521    pub fn offset_end(mut self, offset: u64) -> Self {
522        if let Some(val) = &mut self.offset_end {
523            *val += offset;
524        } else {
525            self.offset_end = Some(offset);
526        }
527
528        self
529    }
530
531    /// Get the size of the handle, in bytes, accounting for offsets
532    pub fn size(&self) -> u64 {
533        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
534    }
535}
536
537/// Bindings to execute a kernel.
538#[derive(Debug, Default)]
539pub struct Bindings {
540    /// Buffer bindings
541    pub buffers: Vec<Binding>,
542    /// Packed metadata for tensor bindings (len, shape, stride, etc).
543    /// Ordered by inputs, then outputs, then tensormaps
544    pub metadata: MetadataBinding,
545    /// Scalar bindings
546    pub scalars: BTreeMap<StorageType, ScalarBinding>,
547    /// Tensor map bindings
548    pub tensor_maps: Vec<TensorMapBinding>,
549}
550
551impl Bindings {
552    /// Create a new bindings struct
553    pub fn new() -> Self {
554        Self::default()
555    }
556
557    /// Add a buffer binding
558    pub fn with_buffer(mut self, binding: Binding) -> Self {
559        self.buffers.push(binding);
560        self
561    }
562
563    /// Extend the buffers with `bindings`
564    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
565        self.buffers.extend(bindings);
566        self
567    }
568
569    /// Add a scalar parameter
570    pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
571        self.scalars
572            .insert(ty, ScalarBinding::new(ty, length, data));
573        self
574    }
575
576    /// Extend the scalars with `bindings`
577    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
578        self.scalars
579            .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
580        self
581    }
582
583    /// Set the metadata to `meta`
584    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
585        self.metadata = meta;
586        self
587    }
588
589    /// Extend the tensor maps with `bindings`
590    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
591        self.tensor_maps.extend(bindings);
592        self
593    }
594}
595
596/// Binding of a set of scalars of the same type to execute a kernel.
597#[derive(new, Debug, Default)]
598pub struct MetadataBinding {
599    /// Metadata values
600    pub data: Vec<u32>,
601    /// Length of the static portion (rank, len, buffer_len, shape_offsets, stride_offsets).
602    pub static_len: usize,
603}
604
605/// Binding of a set of scalars of the same type to execute a kernel.
606#[derive(new, Debug, Clone)]
607pub struct ScalarBinding {
608    /// Type of the scalars
609    pub ty: StorageType,
610    /// Unpadded length of the underlying data
611    pub length: usize,
612    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
613    pub data: Vec<u64>,
614}
615
616impl ScalarBinding {
617    /// Get data as byte slice
618    pub fn data(&self) -> &[u8] {
619        bytemuck::cast_slice(&self.data)
620    }
621}
622
623/// Binding of a [tensor handle](Handle) to execute a kernel.
624#[derive(new, Debug)]
625pub struct Binding {
626    /// Memory binding.
627    pub memory: SliceBinding,
628    /// Memory offset in bytes.
629    pub offset_start: Option<u64>,
630    /// Memory offset in bytes.
631    pub offset_end: Option<u64>,
632    /// The stream where the data was created.
633    pub stream: cubecl_common::stream_id::StreamId,
634    /// The stream position when the tensor became available.
635    pub cursor: u64,
636    /// Size in bytes
637    size: u64,
638}
639
640impl Binding {
641    /// Get the size of the handle, in bytes, accounting for offsets
642    pub fn size(&self) -> u64 {
643        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
644    }
645}
646
647/// A binding with shape and stride info for non-contiguous reading
648#[derive(new, Debug, Clone)]
649pub struct CopyDescriptor<'a> {
650    /// Binding for the memory resource
651    pub binding: Binding,
652    /// Shape of the resource
653    pub shape: &'a [usize],
654    /// Strides of the resource
655    pub strides: &'a [usize],
656    /// Size of each element in the resource
657    pub elem_size: usize,
658}
659
660/// A tensor map used with TMA ops
661#[derive(new, Debug, Clone)]
662pub struct TensorMapBinding {
663    /// The binding for the backing tensor
664    pub binding: Binding,
665    /// The tensormap metadata
666    pub map: TensorMapMeta,
667}
668
669/// TensorMap metadata for the opaque proxy used in TMA copies
670#[derive(Debug, Clone)]
671pub struct TensorMapMeta {
672    /// Tensormap format (tiled or im2col)
673    pub format: TensorMapFormat,
674    /// Rank of the backing tensor
675    pub rank: usize,
676    /// Shape of the backing tensor
677    pub shape: Vec<usize>,
678    /// Strides of the backing tensor
679    pub strides: Vec<usize>,
680    /// Element stride, usually 1 but may be 2 for complex tensors
681    /// For im2col, this is equivalent to the kernel stride
682    pub elem_stride: Vec<usize>,
683    /// Interleave mode
684    pub interleave: TensorMapInterleave,
685    /// Swizzle mode
686    pub swizzle: TensorMapSwizzle,
687    /// Prefetch settings
688    pub prefetch: TensorMapPrefetch,
689    /// OOB fill value
690    pub oob_fill: OobFill,
691    /// Storage type
692    pub storage_ty: StorageType,
693}
694
695impl Handle {
696    /// If the tensor handle can be reused inplace.
697    pub fn can_mut(&self) -> bool {
698        self.memory.can_mut() && self.stream == StreamId::current()
699    }
700}
701
702impl Handle {
703    /// Convert the [handle](Handle) into a [binding](Binding).
704    pub fn binding(self) -> Binding {
705        Binding {
706            memory: MemoryHandle::binding(self.memory),
707            offset_start: self.offset_start,
708            offset_end: self.offset_end,
709            size: self.size,
710            stream: self.stream,
711            cursor: self.cursor,
712        }
713    }
714
715    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
716    pub fn copy_descriptor<'a>(
717        &'a self,
718        shape: &'a [usize],
719        strides: &'a [usize],
720        elem_size: usize,
721    ) -> CopyDescriptor<'a> {
722        CopyDescriptor {
723            shape,
724            strides,
725            elem_size,
726            binding: self.clone().binding(),
727        }
728    }
729}
730
731impl Clone for Handle {
732    fn clone(&self) -> Self {
733        Self {
734            memory: self.memory.clone(),
735            offset_start: self.offset_start,
736            offset_end: self.offset_end,
737            size: self.size,
738            stream: self.stream,
739            cursor: self.cursor,
740        }
741    }
742}
743
744impl Clone for Binding {
745    fn clone(&self) -> Self {
746        Self {
747            memory: self.memory.clone(),
748            offset_start: self.offset_start,
749            offset_end: self.offset_end,
750            size: self.size,
751            stream: self.stream,
752            cursor: self.cursor,
753        }
754    }
755}
756
757/// Specifieds the number of cubes to be dispatched for a kernel.
758///
759/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
760#[allow(clippy::large_enum_variant)]
761pub enum CubeCount {
762    /// Dispatch a known count of x, y, z cubes.
763    Static(u32, u32, u32),
764    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
765    Dynamic(Binding),
766}
767
768/// Defines how to select cube count based on the number of cubes required.
769pub enum CubeCountSelection {
770    /// If the number of cubes is the same as required.
771    Exact(CubeCount),
772    /// If the number of cubes isn't the same as required.
773    ///
774    /// This can happened based on hardware limit, requirering the kernel to perform OOB checks.
775    Approx(CubeCount, u32),
776}
777
778impl CubeCountSelection {
779    /// Creates a [CubeCount] while respecting the hardware limits.
780    pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
781        let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
782
783        let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
784        let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
785
786        match num_cubes_actual == num_cubes {
787            true => CubeCountSelection::Exact(cube_count),
788            false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
789        }
790    }
791
792    /// If some cubes will be idle.
793    pub fn has_idle(&self) -> bool {
794        matches!(self, Self::Approx(..))
795    }
796
797    /// Converts into [CubeCount].
798    pub fn cube_count(self) -> CubeCount {
799        match self {
800            CubeCountSelection::Exact(cube_count) => cube_count,
801            CubeCountSelection::Approx(cube_count, _) => cube_count,
802        }
803    }
804}
805
806impl From<CubeCountSelection> for CubeCount {
807    fn from(value: CubeCountSelection) -> Self {
808        value.cube_count()
809    }
810}
811
812impl CubeCount {
813    /// Create a new static cube count with the given x = y = z = 1.
814    pub fn new_single() -> Self {
815        CubeCount::Static(1, 1, 1)
816    }
817
818    /// Create a new static cube count with the given x, and y = z = 1.
819    pub fn new_1d(x: u32) -> Self {
820        CubeCount::Static(x, 1, 1)
821    }
822
823    /// Create a new static cube count with the given x and y, and z = 1.
824    pub fn new_2d(x: u32, y: u32) -> Self {
825        CubeCount::Static(x, y, 1)
826    }
827
828    /// Create a new static cube count with the given x, y and z.
829    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
830        CubeCount::Static(x, y, z)
831    }
832}
833
834impl Debug for CubeCount {
835    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
836        match self {
837            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
838            CubeCount::Dynamic(_) => f.write_str("binding"),
839        }
840    }
841}
842
843impl Clone for CubeCount {
844    fn clone(&self) -> Self {
845        match self {
846            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
847            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
848        }
849    }
850}
851
852#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
853#[allow(missing_docs)]
854/// The number of units across all 3 axis totalling to the number of working units in a cube.
855pub struct CubeDim {
856    /// The number of units in the x axis.
857    pub x: u32,
858    /// The number of units in the y axis.
859    pub y: u32,
860    /// The number of units in the z axis.
861    pub z: u32,
862}
863
864impl CubeDim {
865    /// Creates a new [CubeDim] based on the maximum number of tasks that can be parellalized by units, in other words,
866    /// by the maximum number of working units.
867    ///
868    /// # Notes
869    ///
870    /// For complex problems, you probably want to have your own logic function to create the
871    /// [CubeDim], but for simpler problems such as elemwise-operation, this is a great default.
872    pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
873        let properties = client.properties();
874        let plane_size = properties.hardware.plane_size_max;
875        let plane_count = Self::calculate_plane_count_per_cube(
876            working_units as u32,
877            plane_size,
878            properties.hardware.num_cpu_cores,
879        );
880
881        Self::new_2d(plane_size, plane_count)
882    }
883
884    fn calculate_plane_count_per_cube(
885        working_units: u32,
886        plane_dim: u32,
887        num_cpu_cores: Option<u32>,
888    ) -> u32 {
889        match num_cpu_cores {
890            Some(num_cores) => core::cmp::min(num_cores, working_units),
891            None => {
892                let plane_count_max = core::cmp::max(1, working_units / plane_dim);
893
894                // Ensures `plane_count` is a power of 2.
895                const NUM_PLANE_MAX: u32 = 8u32;
896                const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
897                let plane_count_max_log2 =
898                    core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
899                2u32.pow(plane_count_max_log2)
900            }
901        }
902    }
903
904    /// Create a new cube dim with x = y = z = 1.
905    pub const fn new_single() -> Self {
906        Self { x: 1, y: 1, z: 1 }
907    }
908
909    /// Create a new cube dim with the given x, and y = z = 1.
910    pub const fn new_1d(x: u32) -> Self {
911        Self { x, y: 1, z: 1 }
912    }
913
914    /// Create a new cube dim with the given x and y, and z = 1.
915    pub const fn new_2d(x: u32, y: u32) -> Self {
916        Self { x, y, z: 1 }
917    }
918
919    /// Create a new cube dim with the given x, y and z.
920    /// This is equivalent to the [new](CubeDim::new) function.
921    pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
922        Self { x, y, z }
923    }
924
925    /// Total numbers of units per cube
926    pub const fn num_elems(&self) -> u32 {
927        self.x * self.y * self.z
928    }
929
930    /// Whether this `CubeDim` can fully contain `other`
931    pub const fn can_contain(&self, other: CubeDim) -> bool {
932        self.x >= other.x && self.y >= other.y && self.z >= other.z
933    }
934}
935
936/// The kind of execution to be performed.
937#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
938pub enum ExecutionMode {
939    /// Checked kernels are safe.
940    #[default]
941    Checked,
942    /// Unchecked kernels are unsafe.
943    Unchecked,
944}
945
946fn cube_count_spread(max: &CubeCount, num_cubes: u32) -> [u32; 3] {
947    let max_cube_counts = match max {
948        CubeCount::Static(x, y, z) => [*x, *y, *z],
949        CubeCount::Dynamic(_) => panic!("No static max cube count"),
950    };
951    let mut num_cubes = [num_cubes, 1, 1];
952    let base = 2;
953
954    let mut reduce_count = |i: usize| {
955        if num_cubes[i] <= max_cube_counts[i] {
956            return true;
957        }
958
959        loop {
960            num_cubes[i] = num_cubes[i].div_ceil(base);
961            num_cubes[i + 1] *= base;
962
963            if num_cubes[i] <= max_cube_counts[i] {
964                return false;
965            }
966        }
967    };
968
969    for i in 0..2 {
970        if reduce_count(i) {
971            break;
972        }
973    }
974
975    num_cubes
976}
977
978#[cfg(test)]
979mod tests {
980    use super::*;
981
982    #[test]
983    fn safe_num_cubes_even() {
984        let max = CubeCount::Static(32, 32, 32);
985        let required = 2048;
986
987        let actual = cube_count_spread(&max, required);
988        let expected = [32, 32, 2];
989        assert_eq!(actual, expected);
990    }
991
992    #[test]
993    fn safe_num_cubes_odd() {
994        let max = CubeCount::Static(48, 32, 16);
995        let required = 3177;
996
997        let actual = cube_count_spread(&max, required);
998        let expected = [25, 32, 4];
999        assert_eq!(actual, expected);
1000    }
1001}