cubecl_runtime/
server.rs

1use crate::{
2    client::ComputeClient,
3    compiler::CompilationError,
4    kernel::KernelMetadata,
5    logging::ServerLogger,
6    memory_management::{
7        MemoryAllocationMode, MemoryHandle, MemoryUsage,
8        memory_pool::{SliceBinding, SliceHandle},
9    },
10    runtime::Runtime,
11    storage::{BindingResource, ComputeStorage},
12    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use alloc::collections::BTreeMap;
15use alloc::string::String;
16use alloc::sync::Arc;
17use alloc::vec;
18use alloc::vec::Vec;
19use core::fmt::Debug;
20use cubecl_common::{
21    backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
22    stream_id::StreamId,
23};
24use cubecl_ir::{DeviceProperties, StorageType};
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27
28#[derive(Error, Clone)]
29/// An error during profiling.
30pub enum ProfileError {
31    /// An unknown error happened during profiling
32    #[error(
33        "An unknown error happened during profiling\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
34    )]
35    Unknown {
36        /// The caused of the error
37        reason: String,
38        /// The captured backtrace.
39        backtrace: BackTrace,
40    },
41
42    /// No profiling was registered
43    #[error("No profiling registered\nBacktrace:\n{backtrace}")]
44    NotRegistered {
45        /// The captured backtrace.
46        backtrace: BackTrace,
47    },
48
49    /// A launch error happened during profiling
50    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
51    Launch(#[from] LaunchError),
52
53    /// An execution error happened during profiling
54    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
55    Execution(#[from] ExecutionError),
56}
57
58impl core::fmt::Debug for ProfileError {
59    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60        f.write_fmt(format_args!("{self}"))
61    }
62}
63
64/// Contains many different types that are useful for server implementations and compute clients.
65pub struct ServerUtilities<Server: ComputeServer> {
66    /// The time when `profile-tracy` is activated.
67    #[cfg(feature = "profile-tracy")]
68    pub epoch_time: web_time::Instant,
69    /// The GPU client when `profile-tracy` is activated.
70    #[cfg(feature = "profile-tracy")]
71    pub gpu_client: tracy_client::GpuContext,
72    /// Information shared between all servers.
73    pub properties: DeviceProperties,
74    /// Information specific to the current server.
75    pub info: Server::Info,
76    /// The logger based on global cubecl configs.
77    pub logger: Arc<ServerLogger>,
78}
79
80impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
81where
82    Server: ComputeServer,
83    Server::Info: core::fmt::Debug,
84{
85    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
86        f.debug_struct("ServerUtilities")
87            .field("properties", &self.properties)
88            .field("info", &self.info)
89            .field("logger", &self.logger)
90            .finish()
91    }
92}
93
94impl<S: ComputeServer> ServerUtilities<S> {
95    /// Creates a new server utilities.
96    pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
97        // Start a tracy client if needed.
98        #[cfg(feature = "profile-tracy")]
99        let client = tracy_client::Client::start();
100
101        Self {
102            properties,
103            logger,
104            // Create the GPU client if needed.
105            #[cfg(feature = "profile-tracy")]
106            gpu_client: client
107                .clone()
108                .new_gpu_context(
109                    Some(&format!("{info:?}")),
110                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
111                    tracy_client::GpuContextType::Invalid,
112                    0,   // Timestamps are manually aligned to this epoch so start at 0.
113                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
114                )
115                .unwrap(),
116            #[cfg(feature = "profile-tracy")]
117            epoch_time: web_time::Instant::now(),
118            info,
119        }
120    }
121}
122
123/// Error that can happen when calling [ComputeServer::execute];
124///
125/// # Notes
126///
127/// Not all errors are going to be caught when calling [ComputeServer::execute] only the one that
128/// won't block the compute queue.
129#[derive(Error, Clone)]
130#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
131pub enum LaunchError {
132    /// The given kernel can't be compiled.
133    #[error("A compilation error happened during launch\nCaused by:\n  {0}")]
134    CompilationError(#[from] CompilationError),
135
136    /// The server is out of memory.
137    #[error(
138        "An out-of-memory error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
139    )]
140    OutOfMemory {
141        /// The caused of the memory error.
142        reason: String,
143        /// The backtrace for this error.
144        #[cfg_attr(std_io, serde(skip))]
145        backtrace: BackTrace,
146    },
147
148    /// Unknown launch error.
149    #[error(
150        "An unknown error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
151    )]
152    Unknown {
153        /// The caused of the unknown error.
154        reason: String,
155        /// The backtrace for this error.
156        #[cfg_attr(std_io, serde(skip))]
157        backtrace: BackTrace,
158    },
159
160    /// Can't launch because of an IO Error.
161    #[error("An io error happened during launch\nCaused by:\n  {0}")]
162    IoError(#[from] IoError),
163}
164
165impl core::fmt::Debug for LaunchError {
166    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
167        f.write_fmt(format_args!("{self}"))
168    }
169}
170
171/// Error that can happen asynchronously while executing registered kernels.
172#[derive(Error, Debug, Clone)]
173#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
174pub enum ExecutionError {
175    /// A generic runtime error.
176    #[error("An error happened during execution\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}")]
177    Generic {
178        /// The details of the generic error.
179        reason: String,
180        /// The backtrace for this error.
181        #[cfg_attr(std_io, serde(skip))]
182        backtrace: BackTrace,
183    },
184}
185
186/// The compute server is responsible for handling resources and computations over resources.
187///
188/// Everything in the server is mutable, therefore it should be solely accessed through the
189/// [compute channel](crate::channel::ComputeChannel) for thread safety.
190pub trait ComputeServer:
191    Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
192where
193    Self: Sized,
194{
195    /// The kernel type defines the computation algorithms.
196    type Kernel: KernelMetadata;
197    /// Information that can be retrieved for the runtime.
198    type Info: Debug + Send + Sync;
199    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
200    type Storage: ComputeStorage;
201
202    /// Reserves `size` bytes in the storage, and returns a handle over them.
203    fn create(
204        &mut self,
205        descriptors: Vec<AllocationDescriptor<'_>>,
206        stream_id: StreamId,
207    ) -> Result<Vec<Allocation>, IoError>;
208
209    /// Reserves N [Bytes] of the provided sizes to be used as staging to load data.
210    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
211        Err(IoError::UnsupportedIoOperation {
212            backtrace: BackTrace::capture(),
213        })
214    }
215
216    /// Retrieve the server logger.
217    fn logger(&self) -> Arc<ServerLogger>;
218
219    /// Retrieve the server utilities.
220    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
221
222    /// Utility to create a new buffer and immediately copy contiguous data into it
223    fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
224        let alloc = self
225            .create(
226                vec![AllocationDescriptor::new(
227                    AllocationKind::Contiguous,
228                    &[data.len()],
229                    1,
230                )],
231                stream_id,
232            )?
233            .remove(0);
234        self.write(
235            vec![(
236                CopyDescriptor::new(
237                    alloc.handle.clone().binding(),
238                    &[data.len()],
239                    &alloc.strides,
240                    1,
241                ),
242                Bytes::from_bytes_vec(data.to_vec()),
243            )],
244            stream_id,
245        )?;
246        Ok(alloc.handle)
247    }
248
249    /// Utility to create a new buffer and immediately copy contiguous data into it
250    fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
251        let alloc = self
252            .create(
253                vec![AllocationDescriptor::new(
254                    AllocationKind::Contiguous,
255                    &[data.len()],
256                    1,
257                )],
258                stream_id,
259            )?
260            .remove(0);
261        self.write(
262            vec![(
263                CopyDescriptor::new(
264                    alloc.handle.clone().binding(),
265                    &[data.len()],
266                    &alloc.strides,
267                    1,
268                ),
269                data,
270            )],
271            stream_id,
272        )?;
273        Ok(alloc.handle)
274    }
275
276    /// Given bindings, returns the owned resources as bytes.
277    fn read<'a>(
278        &mut self,
279        descriptors: Vec<CopyDescriptor<'a>>,
280        stream_id: StreamId,
281    ) -> DynFut<Result<Vec<Bytes>, IoError>>;
282
283    /// Writes the specified bytes into the buffers given
284    fn write(
285        &mut self,
286        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
287        stream_id: StreamId,
288    ) -> Result<(), IoError>;
289
290    /// Wait for the completion of every task in the server.
291    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
292
293    /// Given a resource handle, returns the storage resource.
294    fn get_resource(
295        &mut self,
296        binding: Binding,
297        stream_id: StreamId,
298    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
299
300    /// Executes the `kernel` over the given memory `handles`.
301    ///
302    /// Kernels have mutable access to every resource they are given
303    /// and are responsible of determining which should be read or written.
304    ///
305    /// # Safety
306    ///
307    /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
308    unsafe fn launch(
309        &mut self,
310        kernel: Self::Kernel,
311        count: CubeCount,
312        bindings: Bindings,
313        kind: ExecutionMode,
314        stream_id: StreamId,
315    ) -> Result<(), LaunchError>;
316
317    /// Flush all outstanding tasks in the server.
318    fn flush(&mut self, stream_id: StreamId);
319
320    /// The current memory usage of the server.
321    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
322
323    /// Ask the server to release memory that it can release.
324    fn memory_cleanup(&mut self, stream_id: StreamId);
325
326    /// Enable collecting timestamps.
327    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
328
329    /// Disable collecting timestamps.
330    fn end_profile(
331        &mut self,
332        stream_id: StreamId,
333        token: ProfilingToken,
334    ) -> Result<ProfileDuration, ProfileError>;
335
336    /// Update the memory mode of allocation in the server.
337    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
338}
339
340/// Defines functions for optimized data transfer between servers, supporting custom communication
341/// mechanisms such as peer-to-peer communication or specialized implementations.
342pub trait ServerCommunication {
343    /// Indicates whether server-to-server communication is enabled for this implementation.
344    const SERVER_COMM_ENABLED: bool;
345
346    /// Copies data from a source server to a destination server.
347    ///
348    /// # Arguments
349    ///
350    /// * `server_src` - A mutable reference to the source server from which data is copied.
351    /// * `server_dst` - A mutable reference to the destination server receiving the data.
352    /// * `src` - A descriptor specifying the data to be copied, including shape, strides, and binding.
353    /// * `stream_id_src` - The stream ID associated with the source server's operation.
354    /// * `stream_id_dst` - The stream ID associated with the destination server's operation.
355    ///
356    /// # Returns
357    ///
358    /// Returns a `Result` containing an `Allocation` on success, or an `IoError` if the operation fails.
359    ///
360    /// # Panics
361    ///
362    /// Panics if server communication is not enabled (`SERVER_COMM_ENABLED` is `false`) or if the
363    /// trait is incorrectly implemented by the server.
364    #[allow(unused_variables)]
365    fn copy(
366        server_src: &mut Self,
367        server_dst: &mut Self,
368        src: CopyDescriptor<'_>,
369        stream_id_src: StreamId,
370        stream_id_dst: StreamId,
371    ) -> Result<Allocation, IoError> {
372        if !Self::SERVER_COMM_ENABLED {
373            panic!("Server-to-server communication is not supported by this server.");
374        } else {
375            panic!(
376                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
377            );
378        }
379    }
380}
381
382#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
383/// Profiling identification so that the server can support recursive and overlapping profilings.
384pub struct ProfilingToken {
385    /// The token value.
386    pub id: u64,
387}
388
389/// Server handle containing the [memory handle](crate::server::Handle).
390#[derive(new, Debug, PartialEq, Eq)]
391pub struct Handle {
392    /// Memory handle.
393    pub memory: SliceHandle,
394    /// Memory offset in bytes.
395    pub offset_start: Option<u64>,
396    /// Memory offset in bytes.
397    pub offset_end: Option<u64>,
398    /// The stream where the data was created.
399    pub stream: cubecl_common::stream_id::StreamId,
400    /// The stream position when the tensor became available.
401    pub cursor: u64,
402    /// Length of the underlying buffer ignoring offsets
403    size: u64,
404}
405
406/// Type of allocation, either contiguous or optimized (row-aligned when possible)
407#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
408pub enum AllocationKind {
409    /// Contiguous layout, with no padding
410    Contiguous,
411    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
412    /// that support it.
413    Optimized,
414}
415
416/// Descriptor for a new tensor allocation
417#[derive(new, Debug, Clone, Copy)]
418pub struct AllocationDescriptor<'a> {
419    /// Layout for the tensor
420    pub kind: AllocationKind,
421    /// Shape of the tensor
422    pub shape: &'a [usize],
423    /// Size of each element in the tensor (used for conversion of shape to bytes)
424    pub elem_size: usize,
425}
426
427impl<'a> AllocationDescriptor<'a> {
428    /// Create an optimized allocation descriptor
429    pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
430        AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
431    }
432
433    /// Create a contiguous allocation descriptor
434    pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
435        AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
436    }
437}
438
439/// An allocation with associated strides. Strides depend on tensor layout.
440#[derive(new, Debug)]
441pub struct Allocation {
442    /// The handle for the memory resource
443    pub handle: Handle,
444    /// The strides of the tensor
445    pub strides: Vec<usize>,
446}
447
448/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
449/// are able to be caught, so some IO errors will still panic.
450#[derive(Error, Clone)]
451#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
452pub enum IoError {
453    /// Buffer size exceeds the max available
454    #[error("can't allocate buffer of size: {size}\n{backtrace}")]
455    BufferTooBig {
456        /// The size of the buffer in bytes.
457        size: u64,
458        /// The captured backtrace.
459        #[cfg_attr(std_io, serde(skip))]
460        backtrace: BackTrace,
461    },
462
463    /// Strides aren't supported for this copy operation on this runtime
464    #[error("the provided strides are not supported for this operation\n{backtrace}")]
465    UnsupportedStrides {
466        /// The backtrace.
467        #[cfg_attr(std_io, serde(skip))]
468        backtrace: BackTrace,
469    },
470
471    /// Handle wasn't found in the memory pool
472    #[error("couldn't find resource for that handle\n{backtrace}")]
473    InvalidHandle {
474        /// The backtrace.
475        #[cfg_attr(std_io, serde(skip))]
476        backtrace: BackTrace,
477    },
478
479    /// Unknown error happened during execution
480    #[error("Unknown error happened during execution\n{backtrace}")]
481    Unknown {
482        /// Details of the error
483        description: String,
484        /// The backtrace.
485        #[cfg_attr(std_io, serde(skip))]
486        backtrace: BackTrace,
487    },
488
489    /// The current IO operation is not supported
490    #[error("The current IO operation is not supported\n{backtrace}")]
491    UnsupportedIoOperation {
492        /// The backtrace.
493        #[cfg_attr(std_io, serde(skip))]
494        backtrace: BackTrace,
495    },
496
497    /// Can't perform the IO operation because of a runtime error.
498    #[error("Can't perform the IO operation because of a runtime error")]
499    Execution(#[from] ExecutionError),
500}
501
502impl core::fmt::Debug for IoError {
503    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
504        f.write_fmt(format_args!("{self}"))
505    }
506}
507
508impl Handle {
509    /// Add to the current offset in bytes.
510    pub fn offset_start(mut self, offset: u64) -> Self {
511        if let Some(val) = &mut self.offset_start {
512            *val += offset;
513        } else {
514            self.offset_start = Some(offset);
515        }
516
517        self
518    }
519    /// Add to the current offset in bytes.
520    pub fn offset_end(mut self, offset: u64) -> Self {
521        if let Some(val) = &mut self.offset_end {
522            *val += offset;
523        } else {
524            self.offset_end = Some(offset);
525        }
526
527        self
528    }
529
530    /// Get the size of the handle, in bytes, accounting for offsets
531    pub fn size(&self) -> u64 {
532        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
533    }
534}
535
536/// Bindings to execute a kernel.
537#[derive(Debug, Default)]
538pub struct Bindings {
539    /// Buffer bindings
540    pub buffers: Vec<Binding>,
541    /// Packed metadata for tensor bindings (len, shape, stride, etc).
542    /// Ordered by inputs, then outputs, then tensormaps
543    pub metadata: MetadataBinding,
544    /// Scalar bindings
545    pub scalars: BTreeMap<StorageType, ScalarBinding>,
546    /// Tensor map bindings
547    pub tensor_maps: Vec<TensorMapBinding>,
548}
549
550impl Bindings {
551    /// Create a new bindings struct
552    pub fn new() -> Self {
553        Self::default()
554    }
555
556    /// Add a buffer binding
557    pub fn with_buffer(mut self, binding: Binding) -> Self {
558        self.buffers.push(binding);
559        self
560    }
561
562    /// Extend the buffers with `bindings`
563    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
564        self.buffers.extend(bindings);
565        self
566    }
567
568    /// Add a scalar parameter
569    pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
570        self.scalars
571            .insert(ty, ScalarBinding::new(ty, length, data));
572        self
573    }
574
575    /// Extend the scalars with `bindings`
576    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
577        self.scalars
578            .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
579        self
580    }
581
582    /// Set the metadata to `meta`
583    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
584        self.metadata = meta;
585        self
586    }
587
588    /// Extend the tensor maps with `bindings`
589    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
590        self.tensor_maps.extend(bindings);
591        self
592    }
593}
594
595/// Binding of a set of scalars of the same type to execute a kernel.
596#[derive(new, Debug, Default)]
597pub struct MetadataBinding {
598    /// Metadata values
599    pub data: Vec<u64>,
600    /// Length of the static portion (rank, len, buffer_len, shape_offsets, stride_offsets).
601    pub static_len: usize,
602}
603
604/// Binding of a set of scalars of the same type to execute a kernel.
605#[derive(new, Debug, Clone)]
606pub struct ScalarBinding {
607    /// Type of the scalars
608    pub ty: StorageType,
609    /// Unpadded length of the underlying data
610    pub length: usize,
611    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
612    pub data: Vec<u64>,
613}
614
615impl ScalarBinding {
616    /// Get data as byte slice
617    pub fn data(&self) -> &[u8] {
618        bytemuck::cast_slice(&self.data)
619    }
620}
621
622/// Binding of a [tensor handle](Handle) to execute a kernel.
623#[derive(new, Debug)]
624pub struct Binding {
625    /// Memory binding.
626    pub memory: SliceBinding,
627    /// Memory offset in bytes.
628    pub offset_start: Option<u64>,
629    /// Memory offset in bytes.
630    pub offset_end: Option<u64>,
631    /// The stream where the data was created.
632    pub stream: cubecl_common::stream_id::StreamId,
633    /// The stream position when the tensor became available.
634    pub cursor: u64,
635    /// Size in bytes
636    size: u64,
637}
638
639impl Binding {
640    /// Get the size of the handle, in bytes, accounting for offsets
641    pub fn size(&self) -> u64 {
642        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
643    }
644}
645
646/// A binding with shape and stride info for non-contiguous reading
647#[derive(new, Debug, Clone)]
648pub struct CopyDescriptor<'a> {
649    /// Binding for the memory resource
650    pub binding: Binding,
651    /// Shape of the resource
652    pub shape: &'a [usize],
653    /// Strides of the resource
654    pub strides: &'a [usize],
655    /// Size of each element in the resource
656    pub elem_size: usize,
657}
658
659/// A tensor map used with TMA ops
660#[derive(new, Debug, Clone)]
661pub struct TensorMapBinding {
662    /// The binding for the backing tensor
663    pub binding: Binding,
664    /// The tensormap metadata
665    pub map: TensorMapMeta,
666}
667
668/// TensorMap metadata for the opaque proxy used in TMA copies
669#[derive(Debug, Clone)]
670pub struct TensorMapMeta {
671    /// Tensormap format (tiled or im2col)
672    pub format: TensorMapFormat,
673    /// Rank of the backing tensor
674    pub rank: usize,
675    /// Shape of the backing tensor
676    pub shape: Vec<usize>,
677    /// Strides of the backing tensor
678    pub strides: Vec<usize>,
679    /// Element stride, usually 1 but may be 2 for complex tensors
680    /// For im2col, this is equivalent to the kernel stride
681    pub elem_stride: Vec<usize>,
682    /// Interleave mode
683    pub interleave: TensorMapInterleave,
684    /// Swizzle mode
685    pub swizzle: TensorMapSwizzle,
686    /// Prefetch settings
687    pub prefetch: TensorMapPrefetch,
688    /// OOB fill value
689    pub oob_fill: OobFill,
690    /// Storage type
691    pub storage_ty: StorageType,
692}
693
694impl Handle {
695    /// If the tensor handle can be reused inplace.
696    pub fn can_mut(&self) -> bool {
697        self.memory.can_mut() && self.stream == StreamId::current()
698    }
699}
700
701impl Handle {
702    /// Convert the [handle](Handle) into a [binding](Binding).
703    pub fn binding(self) -> Binding {
704        Binding {
705            memory: MemoryHandle::binding(self.memory),
706            offset_start: self.offset_start,
707            offset_end: self.offset_end,
708            size: self.size,
709            stream: self.stream,
710            cursor: self.cursor,
711        }
712    }
713
714    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
715    pub fn copy_descriptor<'a>(
716        &'a self,
717        shape: &'a [usize],
718        strides: &'a [usize],
719        elem_size: usize,
720    ) -> CopyDescriptor<'a> {
721        CopyDescriptor {
722            shape,
723            strides,
724            elem_size,
725            binding: self.clone().binding(),
726        }
727    }
728}
729
730impl Clone for Handle {
731    fn clone(&self) -> Self {
732        Self {
733            memory: self.memory.clone(),
734            offset_start: self.offset_start,
735            offset_end: self.offset_end,
736            size: self.size,
737            stream: self.stream,
738            cursor: self.cursor,
739        }
740    }
741}
742
743impl Clone for Binding {
744    fn clone(&self) -> Self {
745        Self {
746            memory: self.memory.clone(),
747            offset_start: self.offset_start,
748            offset_end: self.offset_end,
749            size: self.size,
750            stream: self.stream,
751            cursor: self.cursor,
752        }
753    }
754}
755
756/// Specifieds the number of cubes to be dispatched for a kernel.
757///
758/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
759#[allow(clippy::large_enum_variant)]
760pub enum CubeCount {
761    /// Dispatch a known count of x, y, z cubes.
762    Static(u32, u32, u32),
763    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
764    Dynamic(Binding),
765}
766
767/// Defines how to select cube count based on the number of cubes required.
768pub enum CubeCountSelection {
769    /// If the number of cubes is the same as required.
770    Exact(CubeCount),
771    /// If the number of cubes isn't the same as required.
772    ///
773    /// This can happened based on hardware limit, requirering the kernel to perform OOB checks.
774    Approx(CubeCount, u32),
775}
776
777impl CubeCountSelection {
778    /// Creates a [CubeCount] while respecting the hardware limits.
779    pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
780        let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
781
782        let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
783        let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
784
785        match num_cubes_actual == num_cubes {
786            true => CubeCountSelection::Exact(cube_count),
787            false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
788        }
789    }
790
791    /// If some cubes will be idle.
792    pub fn has_idle(&self) -> bool {
793        matches!(self, Self::Approx(..))
794    }
795
796    /// Converts into [CubeCount].
797    pub fn cube_count(self) -> CubeCount {
798        match self {
799            CubeCountSelection::Exact(cube_count) => cube_count,
800            CubeCountSelection::Approx(cube_count, _) => cube_count,
801        }
802    }
803}
804
805impl From<CubeCountSelection> for CubeCount {
806    fn from(value: CubeCountSelection) -> Self {
807        value.cube_count()
808    }
809}
810
811impl CubeCount {
812    /// Create a new static cube count with the given x = y = z = 1.
813    pub fn new_single() -> Self {
814        CubeCount::Static(1, 1, 1)
815    }
816
817    /// Create a new static cube count with the given x, and y = z = 1.
818    pub fn new_1d(x: u32) -> Self {
819        CubeCount::Static(x, 1, 1)
820    }
821
822    /// Create a new static cube count with the given x and y, and z = 1.
823    pub fn new_2d(x: u32, y: u32) -> Self {
824        CubeCount::Static(x, y, 1)
825    }
826
827    /// Create a new static cube count with the given x, y and z.
828    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
829        CubeCount::Static(x, y, z)
830    }
831}
832
833impl Debug for CubeCount {
834    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
835        match self {
836            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
837            CubeCount::Dynamic(_) => f.write_str("binding"),
838        }
839    }
840}
841
842impl Clone for CubeCount {
843    fn clone(&self) -> Self {
844        match self {
845            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
846            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
847        }
848    }
849}
850
851#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
852#[allow(missing_docs)]
853/// The number of units across all 3 axis totalling to the number of working units in a cube.
854pub struct CubeDim {
855    /// The number of units in the x axis.
856    pub x: u32,
857    /// The number of units in the y axis.
858    pub y: u32,
859    /// The number of units in the z axis.
860    pub z: u32,
861}
862
863impl CubeDim {
864    /// Creates a new [CubeDim] based on the maximum number of tasks that can be parellalized by units, in other words,
865    /// by the maximum number of working units.
866    ///
867    /// # Notes
868    ///
869    /// For complex problems, you probably want to have your own logic function to create the
870    /// [CubeDim], but for simpler problems such as elemwise-operation, this is a great default.
871    pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
872        let properties = client.properties();
873        let plane_size = properties.hardware.plane_size_max;
874        let plane_count = Self::calculate_plane_count_per_cube(
875            working_units as u32,
876            plane_size,
877            properties.hardware.num_cpu_cores,
878        );
879
880        // Make sure it respects the max units per cube (especially on wasm)
881        let limit = properties.hardware.max_units_per_cube / plane_size;
882
883        Self::new_2d(plane_size, u32::min(limit, plane_count))
884    }
885
886    fn calculate_plane_count_per_cube(
887        working_units: u32,
888        plane_dim: u32,
889        num_cpu_cores: Option<u32>,
890    ) -> u32 {
891        match num_cpu_cores {
892            Some(num_cores) => core::cmp::min(num_cores, working_units),
893            None => {
894                let plane_count_max = core::cmp::max(1, working_units / plane_dim);
895
896                // Ensures `plane_count` is a power of 2.
897                const NUM_PLANE_MAX: u32 = 8u32;
898                const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
899                let plane_count_max_log2 =
900                    core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
901                2u32.pow(plane_count_max_log2)
902            }
903        }
904    }
905
906    /// Create a new cube dim with x = y = z = 1.
907    pub const fn new_single() -> Self {
908        Self { x: 1, y: 1, z: 1 }
909    }
910
911    /// Create a new cube dim with the given x, and y = z = 1.
912    pub const fn new_1d(x: u32) -> Self {
913        Self { x, y: 1, z: 1 }
914    }
915
916    /// Create a new cube dim with the given x and y, and z = 1.
917    pub const fn new_2d(x: u32, y: u32) -> Self {
918        Self { x, y, z: 1 }
919    }
920
921    /// Create a new cube dim with the given x, y and z.
922    /// This is equivalent to the [new](CubeDim::new) function.
923    pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
924        Self { x, y, z }
925    }
926
927    /// Total numbers of units per cube
928    pub const fn num_elems(&self) -> u32 {
929        self.x * self.y * self.z
930    }
931
932    /// Whether this `CubeDim` can fully contain `other`
933    pub const fn can_contain(&self, other: CubeDim) -> bool {
934        self.x >= other.x && self.y >= other.y && self.z >= other.z
935    }
936}
937
938/// The kind of execution to be performed.
939#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
940pub enum ExecutionMode {
941    /// Checked kernels are safe.
942    #[default]
943    Checked,
944    /// Unchecked kernels are unsafe.
945    Unchecked,
946}
947
948fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
949    let max_cube_counts = [max.0, max.1, max.2];
950    let mut num_cubes = [num_cubes, 1, 1];
951    let base = 2;
952
953    let mut reduce_count = |i: usize| {
954        if num_cubes[i] <= max_cube_counts[i] {
955            return true;
956        }
957
958        loop {
959            num_cubes[i] = num_cubes[i].div_ceil(base);
960            num_cubes[i + 1] *= base;
961
962            if num_cubes[i] <= max_cube_counts[i] {
963                return false;
964            }
965        }
966    };
967
968    for i in 0..2 {
969        if reduce_count(i) {
970            break;
971        }
972    }
973
974    num_cubes
975}
976
977#[cfg(test)]
978mod tests {
979    use super::*;
980
981    #[test_log::test]
982    fn safe_num_cubes_even() {
983        let max = (32, 32, 32);
984        let required = 2048;
985
986        let actual = cube_count_spread(&max, required);
987        let expected = [32, 32, 2];
988        assert_eq!(actual, expected);
989    }
990
991    #[test_log::test]
992    fn safe_num_cubes_odd() {
993        let max = (48, 32, 16);
994        let required = 3177;
995
996        let actual = cube_count_spread(&max, required);
997        let expected = [25, 32, 4];
998        assert_eq!(actual, expected);
999    }
1000}