cubecl_runtime/
server.rs

1use crate::{
2    DeviceProperties,
3    compiler::CompilationError,
4    kernel::KernelMetadata,
5    logging::ServerLogger,
6    memory_management::{
7        MemoryAllocationMode, MemoryHandle, MemoryUsage,
8        memory_pool::{SliceBinding, SliceHandle},
9    },
10    storage::{BindingResource, ComputeStorage},
11    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
12};
13use alloc::collections::BTreeMap;
14use alloc::string::String;
15use alloc::sync::Arc;
16use alloc::vec;
17use alloc::vec::Vec;
18use core::fmt::Debug;
19use cubecl_common::{
20    ExecutionMode, backtrace::BackTrace, bytes::Bytes, device, future::DynFut,
21    profile::ProfileDuration, stream_id::StreamId,
22};
23use cubecl_ir::StorageType;
24use thiserror::Error;
25
26#[derive(Error, Clone)]
27/// An error during profiling.
28pub enum ProfileError {
29    /// An unknown error happened during profiling
30    #[error(
31        "An unknown error happened during profiling\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
32    )]
33    Unknown {
34        /// The caused of the error
35        reason: String,
36        /// The captured backtrace.
37        backtrace: BackTrace,
38    },
39
40    /// No profiling was registered
41    #[error("No profiling registered\nBacktrace:\n{backtrace}")]
42    NotRegistered {
43        /// The captured backtrace.
44        backtrace: BackTrace,
45    },
46
47    /// A launch error happened during profiling
48    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
49    Launch(#[from] LaunchError),
50
51    /// An execution error happened during profiling
52    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
53    Execution(#[from] ExecutionError),
54}
55
56impl core::fmt::Debug for ProfileError {
57    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
58        f.write_fmt(format_args!("{self}"))
59    }
60}
61
62#[derive(Debug)]
63/// Contains many different types that are useful for server implementations and compute clients.
64pub struct ServerUtilities<Server: ComputeServer> {
65    /// The time when `profile-tracy` is activated.
66    #[cfg(feature = "profile-tracy")]
67    pub epoch_time: web_time::Instant,
68    /// The GPU client when `profile-tracy` is activated.
69    #[cfg(feature = "profile-tracy")]
70    pub gpu_client: tracy_client::GpuContext,
71    /// Information shared between all servers.
72    pub properties: DeviceProperties,
73    /// Information specific to the current server.
74    pub info: Server::Info,
75    /// The logger based on global cubecl configs.
76    pub logger: Arc<ServerLogger>,
77}
78
79impl<S: ComputeServer> ServerUtilities<S> {
80    /// Creates a new server utilities.
81    pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
82        // Start a tracy client if needed.
83        #[cfg(feature = "profile-tracy")]
84        let client = tracy_client::Client::start();
85
86        Self {
87            properties,
88            logger,
89            // Create the GPU client if needed.
90            #[cfg(feature = "profile-tracy")]
91            gpu_client: client
92                .clone()
93                .new_gpu_context(
94                    Some(&format!("{info:?}")),
95                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
96                    tracy_client::GpuContextType::Invalid,
97                    0,   // Timestamps are manually aligned to this epoch so start at 0.
98                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
99                )
100                .unwrap(),
101            #[cfg(feature = "profile-tracy")]
102            epoch_time: web_time::Instant::now(),
103            info,
104        }
105    }
106}
107
108/// Error that can happen when calling [ComputeServer::execute];
109///
110/// # Notes
111///
112/// Not all errors are going to be catched when calling [ComputeServer::execute] only the one that
113/// won't block the compute queue.
114#[derive(Error, Clone)]
115#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
116pub enum LaunchError {
117    /// The given kernel can't be compiled.
118    #[error("A compilation error happened during launch\nCaused by:\n  {0}")]
119    CompilationError(#[from] CompilationError),
120
121    /// The server is out of memory.
122    #[error(
123        "An out-of-memory error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
124    )]
125    OutOfMemory {
126        /// The caused of the memory error.
127        reason: String,
128        /// The backtrace for this error.
129        #[cfg_attr(std_io, serde(skip))]
130        backtrace: BackTrace,
131    },
132
133    /// Unknown launch error.
134    #[error(
135        "An unknown error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
136    )]
137    Unknown {
138        /// The caused of the unknown error.
139        reason: String,
140        /// The backtrace for this error.
141        #[cfg_attr(std_io, serde(skip))]
142        backtrace: BackTrace,
143    },
144
145    /// Can't launch because of an IO Error.
146    #[error("An io error happened during launch\nCaused by:\n  {0}")]
147    IoError(#[from] IoError),
148}
149
150impl core::fmt::Debug for LaunchError {
151    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
152        f.write_fmt(format_args!("{self}"))
153    }
154}
155
156/// Error that can happen asynchronously while executing registered kernels.
157#[derive(Error, Debug, Clone)]
158#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
159pub enum ExecutionError {
160    /// A generic runtime error.
161    #[error("An error happened during execution\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}")]
162    Generic {
163        /// The details of the generic error.
164        reason: String,
165        /// The backtrace for this error.
166        #[cfg_attr(std_io, serde(skip))]
167        backtrace: BackTrace,
168    },
169}
170
171/// The compute server is responsible for handling resources and computations over resources.
172///
173/// Everything in the server is mutable, therefore it should be solely accessed through the
174/// [compute channel](crate::channel::ComputeChannel) for thread safety.
175pub trait ComputeServer:
176    Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
177where
178    Self: Sized,
179{
180    /// The kernel type defines the computation algorithms.
181    type Kernel: KernelMetadata;
182    /// Information that can be retrieved for the runtime.
183    type Info: Debug + Send + Sync;
184    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
185    type Storage: ComputeStorage;
186
187    /// Reserves `size` bytes in the storage, and returns a handle over them.
188    fn create(
189        &mut self,
190        descriptors: Vec<AllocationDescriptor<'_>>,
191        stream_id: StreamId,
192    ) -> Result<Vec<Allocation>, IoError>;
193
194    /// Reserves N [Bytes] of the provided sizes to be used as staging to load data.
195    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
196        Err(IoError::UnsupportedIoOperation {
197            backtrace: BackTrace::capture(),
198        })
199    }
200
201    /// Retrieve the server logger.
202    fn logger(&self) -> Arc<ServerLogger>;
203
204    /// Retrieve the server utilities.
205    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
206
207    /// Utility to create a new buffer and immediately copy contiguous data into it
208    fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
209        let alloc = self
210            .create(
211                vec![AllocationDescriptor::new(
212                    AllocationKind::Contiguous,
213                    &[data.len()],
214                    1,
215                )],
216                stream_id,
217            )?
218            .remove(0);
219        self.write(
220            vec![(
221                CopyDescriptor::new(
222                    alloc.handle.clone().binding(),
223                    &[data.len()],
224                    &alloc.strides,
225                    1,
226                ),
227                Bytes::from_bytes_vec(data.to_vec()),
228            )],
229            stream_id,
230        )?;
231        Ok(alloc.handle)
232    }
233
234    /// Utility to create a new buffer and immediately copy contiguous data into it
235    fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
236        let alloc = self
237            .create(
238                vec![AllocationDescriptor::new(
239                    AllocationKind::Contiguous,
240                    &[data.len()],
241                    1,
242                )],
243                stream_id,
244            )?
245            .remove(0);
246        self.write(
247            vec![(
248                CopyDescriptor::new(
249                    alloc.handle.clone().binding(),
250                    &[data.len()],
251                    &alloc.strides,
252                    1,
253                ),
254                data,
255            )],
256            stream_id,
257        )?;
258        Ok(alloc.handle)
259    }
260
261    /// Given bindings, returns the owned resources as bytes.
262    fn read<'a>(
263        &mut self,
264        descriptors: Vec<CopyDescriptor<'a>>,
265        stream_id: StreamId,
266    ) -> DynFut<Result<Vec<Bytes>, IoError>>;
267
268    /// Writes the specified bytes into the buffers given
269    fn write(
270        &mut self,
271        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
272        stream_id: StreamId,
273    ) -> Result<(), IoError>;
274
275    /// Wait for the completion of every task in the server.
276    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
277
278    /// Given a resource handle, returns the storage resource.
279    fn get_resource(
280        &mut self,
281        binding: Binding,
282        stream_id: StreamId,
283    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
284
285    /// Executes the `kernel` over the given memory `handles`.
286    ///
287    /// Kernels have mutable access to every resource they are given
288    /// and are responsible of determining which should be read or written.
289    ///
290    /// # Safety
291    ///
292    /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen.
293    unsafe fn launch(
294        &mut self,
295        kernel: Self::Kernel,
296        count: CubeCount,
297        bindings: Bindings,
298        kind: ExecutionMode,
299        stream_id: StreamId,
300    ) -> Result<(), LaunchError>;
301
302    /// Flush all outstanding tasks in the server.
303    fn flush(&mut self, stream_id: StreamId);
304
305    /// The current memory usage of the server.
306    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
307
308    /// Ask the server to release memory that it can release.
309    fn memory_cleanup(&mut self, stream_id: StreamId);
310
311    /// Enable collecting timestamps.
312    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
313
314    /// Disable collecting timestamps.
315    fn end_profile(
316        &mut self,
317        stream_id: StreamId,
318        token: ProfilingToken,
319    ) -> Result<ProfileDuration, ProfileError>;
320
321    /// Update the memory mode of allocation in the server.
322    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
323}
324
325/// Defines functions for optimized data transfer between servers, supporting custom communication
326/// mechanisms such as peer-to-peer communication or specialized implementations.
327pub trait ServerCommunication {
328    /// Indicates whether server-to-server communication is enabled for this implementation.
329    const SERVER_COMM_ENABLED: bool;
330
331    /// Copies data from a source server to a destination server.
332    ///
333    /// # Arguments
334    ///
335    /// * `server_src` - A mutable reference to the source server from which data is copied.
336    /// * `server_dst` - A mutable reference to the destination server receiving the data.
337    /// * `src` - A descriptor specifying the data to be copied, including shape, strides, and binding.
338    /// * `stream_id_src` - The stream ID associated with the source server's operation.
339    /// * `stream_id_dst` - The stream ID associated with the destination server's operation.
340    ///
341    /// # Returns
342    ///
343    /// Returns a `Result` containing an `Allocation` on success, or an `IoError` if the operation fails.
344    ///
345    /// # Panics
346    ///
347    /// Panics if server communication is not enabled (`SERVER_COMM_ENABLED` is `false`) or if the
348    /// trait is incorrectly implemented by the server.
349    #[allow(unused_variables)]
350    fn copy(
351        server_src: &mut Self,
352        server_dst: &mut Self,
353        src: CopyDescriptor<'_>,
354        stream_id_src: StreamId,
355        stream_id_dst: StreamId,
356    ) -> Result<Allocation, IoError> {
357        if !Self::SERVER_COMM_ENABLED {
358            panic!("Server-to-server communication is not supported by this server.");
359        } else {
360            panic!(
361                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
362            );
363        }
364    }
365}
366
367#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
368/// Profiling identification so that the server can support recursive and overlapping profilings.
369pub struct ProfilingToken {
370    /// The token value.
371    pub id: u64,
372}
373
374/// Server handle containing the [memory handle](crate::server::Handle).
375#[derive(new, Debug, PartialEq, Eq)]
376pub struct Handle {
377    /// Memory handle.
378    pub memory: SliceHandle,
379    /// Memory offset in bytes.
380    pub offset_start: Option<u64>,
381    /// Memory offset in bytes.
382    pub offset_end: Option<u64>,
383    /// The stream where the data was created.
384    pub stream: cubecl_common::stream_id::StreamId,
385    /// The stream position when the tensor became available.
386    pub cursor: u64,
387    /// Length of the underlying buffer ignoring offsets
388    size: u64,
389}
390
391/// Type of allocation, either contiguous or optimized (row-aligned when possible)
392#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
393pub enum AllocationKind {
394    /// Contiguous layout, with no padding
395    Contiguous,
396    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
397    /// that support it.
398    Optimized,
399}
400
401/// Descriptor for a new tensor allocation
402#[derive(new, Debug, Clone, Copy)]
403pub struct AllocationDescriptor<'a> {
404    /// Layout for the tensor
405    pub kind: AllocationKind,
406    /// Shape of the tensor
407    pub shape: &'a [usize],
408    /// Size of each element in the tensor (used for conversion of shape to bytes)
409    pub elem_size: usize,
410}
411
412impl<'a> AllocationDescriptor<'a> {
413    /// Create an optimized allocation descriptor
414    pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
415        AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
416    }
417
418    /// Create a contiguous allocation descriptor
419    pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
420        AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
421    }
422}
423
424/// An allocation with associated strides. Strides depend on tensor layout.
425#[derive(new, Debug)]
426pub struct Allocation {
427    /// The handle for the memory resource
428    pub handle: Handle,
429    /// The strides of the tensor
430    pub strides: Vec<usize>,
431}
432
433/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
434/// are able to be caught, so some IO errors will still panic.
435#[derive(Error, Clone)]
436#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
437pub enum IoError {
438    /// Buffer size exceeds the max available
439    #[error("can't allocate buffer of size: {size}\n{backtrace}")]
440    BufferTooBig {
441        /// The size of the buffer in bytes.
442        size: u64,
443        /// The captured backtrace.
444        #[cfg_attr(std_io, serde(skip))]
445        backtrace: BackTrace,
446    },
447
448    /// Strides aren't supported for this copy operation on this runtime
449    #[error("the provided strides are not supported for this operation\n{backtrace}")]
450    UnsupportedStrides {
451        /// The backtrace.
452        #[cfg_attr(std_io, serde(skip))]
453        backtrace: BackTrace,
454    },
455
456    /// Handle wasn't found in the memory pool
457    #[error("couldn't find resource for that handle\n{backtrace}")]
458    InvalidHandle {
459        /// The backtrace.
460        #[cfg_attr(std_io, serde(skip))]
461        backtrace: BackTrace,
462    },
463
464    /// Unknown error happened during execution
465    #[error("Unknown error happened during execution\n{backtrace}")]
466    Unknown {
467        /// Details of the error
468        description: String,
469        /// The backtrace.
470        #[cfg_attr(std_io, serde(skip))]
471        backtrace: BackTrace,
472    },
473
474    /// The current IO operation is not supported
475    #[error("The current IO operation is not supported\n{backtrace}")]
476    UnsupportedIoOperation {
477        /// The backtrace.
478        #[cfg_attr(std_io, serde(skip))]
479        backtrace: BackTrace,
480    },
481
482    /// Can't perform the IO operation because of a runtime error.
483    #[error("Can't perform the IO operation because of a runtime error")]
484    Execution(#[from] ExecutionError),
485}
486
487impl core::fmt::Debug for IoError {
488    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
489        f.write_fmt(format_args!("{self}"))
490    }
491}
492
493impl Handle {
494    /// Add to the current offset in bytes.
495    pub fn offset_start(mut self, offset: u64) -> Self {
496        if let Some(val) = &mut self.offset_start {
497            *val += offset;
498        } else {
499            self.offset_start = Some(offset);
500        }
501
502        self
503    }
504    /// Add to the current offset in bytes.
505    pub fn offset_end(mut self, offset: u64) -> Self {
506        if let Some(val) = &mut self.offset_end {
507            *val += offset;
508        } else {
509            self.offset_end = Some(offset);
510        }
511
512        self
513    }
514
515    /// Get the size of the handle, in bytes, accounting for offsets
516    pub fn size(&self) -> u64 {
517        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
518    }
519}
520
521/// Bindings to execute a kernel.
522#[derive(Debug, Default)]
523pub struct Bindings {
524    /// Buffer bindings
525    pub buffers: Vec<Binding>,
526    /// Packed metadata for tensor bindings (len, shape, stride, etc).
527    /// Ordered by inputs, then outputs, then tensormaps
528    pub metadata: MetadataBinding,
529    /// Scalar bindings
530    pub scalars: BTreeMap<StorageType, ScalarBinding>,
531    /// Tensor map bindings
532    pub tensor_maps: Vec<TensorMapBinding>,
533}
534
535impl Bindings {
536    /// Create a new bindings struct
537    pub fn new() -> Self {
538        Self::default()
539    }
540
541    /// Add a buffer binding
542    pub fn with_buffer(mut self, binding: Binding) -> Self {
543        self.buffers.push(binding);
544        self
545    }
546
547    /// Extend the buffers with `bindings`
548    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
549        self.buffers.extend(bindings);
550        self
551    }
552
553    /// Add a scalar parameter
554    pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
555        self.scalars
556            .insert(ty, ScalarBinding::new(ty, length, data));
557        self
558    }
559
560    /// Extend the scalars with `bindings`
561    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
562        self.scalars
563            .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
564        self
565    }
566
567    /// Set the metadata to `meta`
568    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
569        self.metadata = meta;
570        self
571    }
572
573    /// Extend the tensor maps with `bindings`
574    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
575        self.tensor_maps.extend(bindings);
576        self
577    }
578}
579
580/// Binding of a set of scalars of the same type to execute a kernel.
581#[derive(new, Debug, Default)]
582pub struct MetadataBinding {
583    /// Metadata values
584    pub data: Vec<u32>,
585    /// Length of the static portion (rank, len, buffer_len, shape_offsets, stride_offsets).
586    pub static_len: usize,
587}
588
589/// Binding of a set of scalars of the same type to execute a kernel.
590#[derive(new, Debug, Clone)]
591pub struct ScalarBinding {
592    /// Type of the scalars
593    pub ty: StorageType,
594    /// Unpadded length of the underlying data
595    pub length: usize,
596    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
597    pub data: Vec<u64>,
598}
599
600impl ScalarBinding {
601    /// Get data as byte slice
602    pub fn data(&self) -> &[u8] {
603        bytemuck::cast_slice(&self.data)
604    }
605}
606
607/// Binding of a [tensor handle](Handle) to execute a kernel.
608#[derive(new, Debug)]
609pub struct Binding {
610    /// Memory binding.
611    pub memory: SliceBinding,
612    /// Memory offset in bytes.
613    pub offset_start: Option<u64>,
614    /// Memory offset in bytes.
615    pub offset_end: Option<u64>,
616    /// The stream where the data was created.
617    pub stream: cubecl_common::stream_id::StreamId,
618    /// The stream position when the tensor became available.
619    pub cursor: u64,
620    /// Size in bytes
621    size: u64,
622}
623
624impl Binding {
625    /// Get the size of the handle, in bytes, accounting for offsets
626    pub fn size(&self) -> u64 {
627        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
628    }
629}
630
631/// A binding with shape and stride info for non-contiguous reading
632#[derive(new, Debug, Clone)]
633pub struct CopyDescriptor<'a> {
634    /// Binding for the memory resource
635    pub binding: Binding,
636    /// Shape of the resource
637    pub shape: &'a [usize],
638    /// Strides of the resource
639    pub strides: &'a [usize],
640    /// Size of each element in the resource
641    pub elem_size: usize,
642}
643
644/// A tensor map used with TMA ops
645#[derive(new, Debug, Clone)]
646pub struct TensorMapBinding {
647    /// The binding for the backing tensor
648    pub binding: Binding,
649    /// The tensormap metadata
650    pub map: TensorMapMeta,
651}
652
653/// TensorMap metadata for the opaque proxy used in TMA copies
654#[derive(Debug, Clone)]
655pub struct TensorMapMeta {
656    /// Tensormap format (tiled or im2col)
657    pub format: TensorMapFormat,
658    /// Rank of the backing tensor
659    pub rank: usize,
660    /// Shape of the backing tensor
661    pub shape: Vec<usize>,
662    /// Strides of the backing tensor
663    pub strides: Vec<usize>,
664    /// Element stride, usually 1 but may be 2 for complex tensors
665    /// For im2col, this is equivalent to the kernel stride
666    pub elem_stride: Vec<usize>,
667    /// Interleave mode
668    pub interleave: TensorMapInterleave,
669    /// Swizzle mode
670    pub swizzle: TensorMapSwizzle,
671    /// Prefetch settings
672    pub prefetch: TensorMapPrefetch,
673    /// OOB fill value
674    pub oob_fill: OobFill,
675    /// Storage type
676    pub storage_ty: StorageType,
677}
678
679impl Handle {
680    /// If the tensor handle can be reused inplace.
681    pub fn can_mut(&self) -> bool {
682        self.memory.can_mut() && self.stream == StreamId::current()
683    }
684}
685
686impl Handle {
687    /// Convert the [handle](Handle) into a [binding](Binding).
688    pub fn binding(self) -> Binding {
689        Binding {
690            memory: MemoryHandle::binding(self.memory),
691            offset_start: self.offset_start,
692            offset_end: self.offset_end,
693            size: self.size,
694            stream: self.stream,
695            cursor: self.cursor,
696        }
697    }
698
699    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
700    pub fn copy_descriptor<'a>(
701        &'a self,
702        shape: &'a [usize],
703        strides: &'a [usize],
704        elem_size: usize,
705    ) -> CopyDescriptor<'a> {
706        CopyDescriptor {
707            shape,
708            strides,
709            elem_size,
710            binding: self.clone().binding(),
711        }
712    }
713}
714
715impl Clone for Handle {
716    fn clone(&self) -> Self {
717        Self {
718            memory: self.memory.clone(),
719            offset_start: self.offset_start,
720            offset_end: self.offset_end,
721            size: self.size,
722            stream: self.stream,
723            cursor: self.cursor,
724        }
725    }
726}
727
728impl Clone for Binding {
729    fn clone(&self) -> Self {
730        Self {
731            memory: self.memory.clone(),
732            offset_start: self.offset_start,
733            offset_end: self.offset_end,
734            size: self.size,
735            stream: self.stream,
736            cursor: self.cursor,
737        }
738    }
739}
740
741/// Specifieds the number of cubes to be dispatched for a kernel.
742///
743/// This translates to eg. a grid for CUDA, or to num_workgroups for wgsl.
744#[allow(clippy::large_enum_variant)]
745pub enum CubeCount {
746    /// Dispatch a known count of x, y, z cubes.
747    Static(u32, u32, u32),
748    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
749    Dynamic(Binding),
750}
751
752impl CubeCount {
753    /// Create a new static cube count with the given x = y = z = 1.
754    pub fn new_single() -> Self {
755        CubeCount::Static(1, 1, 1)
756    }
757
758    /// Create a new static cube count with the given x, and y = z = 1.
759    pub fn new_1d(x: u32) -> Self {
760        CubeCount::Static(x, 1, 1)
761    }
762
763    /// Create a new static cube count with the given x and y, and z = 1.
764    pub fn new_2d(x: u32, y: u32) -> Self {
765        CubeCount::Static(x, y, 1)
766    }
767
768    /// Create a new static cube count with the given x, y and z.
769    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
770        CubeCount::Static(x, y, z)
771    }
772}
773
774impl Debug for CubeCount {
775    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
776        match self {
777            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
778            CubeCount::Dynamic(_) => f.write_str("binding"),
779        }
780    }
781}
782
783impl Clone for CubeCount {
784    fn clone(&self) -> Self {
785        match self {
786            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
787            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
788        }
789    }
790}