Skip to main content

cubecl_runtime/
server.rs

1use crate::{
2    client::ComputeClient,
3    compiler::CompilationError,
4    kernel::KernelMetadata,
5    logging::ServerLogger,
6    memory_management::{
7        MemoryAllocationMode, MemoryHandle, MemoryUsage,
8        memory_pool::{SliceBinding, SliceHandle},
9    },
10    runtime::Runtime,
11    storage::{BindingResource, ComputeStorage},
12    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use alloc::collections::BTreeMap;
15#[cfg(feature = "profile-tracy")]
16use alloc::format;
17use alloc::string::String;
18use alloc::sync::Arc;
19use alloc::vec;
20use alloc::vec::Vec;
21use core::fmt::Debug;
22use cubecl_common::{
23    backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
24    stream_id::StreamId,
25};
26use cubecl_ir::{DeviceProperties, StorageType};
27use cubecl_zspace::{Strides, metadata::Metadata};
28use serde::{Deserialize, Serialize};
29use thiserror::Error;
30
31#[derive(Error, Clone)]
32/// An error during profiling.
33pub enum ProfileError {
34    /// An unknown error happened during profiling
35    #[error(
36        "An unknown error happened during profiling\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
37    )]
38    Unknown {
39        /// The caused of the error
40        reason: String,
41        /// The captured backtrace.
42        backtrace: BackTrace,
43    },
44
45    /// No profiling was registered
46    #[error("No profiling registered\nBacktrace:\n{backtrace}")]
47    NotRegistered {
48        /// The captured backtrace.
49        backtrace: BackTrace,
50    },
51
52    /// A launch error happened during profiling
53    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
54    Launch(#[from] LaunchError),
55
56    /// An execution error happened during profiling
57    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
58    Execution(#[from] ExecutionError),
59}
60
61impl core::fmt::Debug for ProfileError {
62    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
63        f.write_fmt(format_args!("{self}"))
64    }
65}
66
67/// Contains many different types that are useful for server implementations and compute clients.
68pub struct ServerUtilities<Server: ComputeServer> {
69    /// The time when `profile-tracy` is activated.
70    #[cfg(feature = "profile-tracy")]
71    pub epoch_time: web_time::Instant,
72    /// The GPU client when `profile-tracy` is activated.
73    #[cfg(feature = "profile-tracy")]
74    pub gpu_client: tracy_client::GpuContext,
75    /// Information shared between all servers.
76    pub properties: DeviceProperties,
77    /// Stable hash of the device properties
78    pub properties_hash: u64,
79    /// Information specific to the current server.
80    pub info: Server::Info,
81    /// The logger based on global cubecl configs.
82    pub logger: Arc<ServerLogger>,
83}
84
85impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
86where
87    Server: ComputeServer,
88    Server::Info: core::fmt::Debug,
89{
90    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
91        f.debug_struct("ServerUtilities")
92            .field("properties", &self.properties)
93            .field("info", &self.info)
94            .field("logger", &self.logger)
95            .finish()
96    }
97}
98
99impl<S: ComputeServer> ServerUtilities<S> {
100    /// Creates a new server utilities.
101    pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
102        // Start a tracy client if needed.
103        #[cfg(feature = "profile-tracy")]
104        let client = tracy_client::Client::start();
105
106        Self {
107            properties_hash: properties.checksum(),
108            properties,
109            logger,
110            // Create the GPU client if needed.
111            #[cfg(feature = "profile-tracy")]
112            gpu_client: client
113                .clone()
114                .new_gpu_context(
115                    Some(&format!("{info:?}")),
116                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
117                    tracy_client::GpuContextType::Invalid,
118                    0,   // Timestamps are manually aligned to this epoch so start at 0.
119                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
120                )
121                .unwrap(),
122            #[cfg(feature = "profile-tracy")]
123            epoch_time: web_time::Instant::now(),
124            info,
125        }
126    }
127}
128
129/// Kernel Launch Errors.
130#[derive(Error, Clone)]
131#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
132pub enum LaunchError {
133    /// The given kernel can't be compiled.
134    #[error("A compilation error happened during launch\nCaused by:\n  {0}")]
135    CompilationError(#[from] CompilationError),
136
137    /// The server is out of memory.
138    #[error(
139        "An out-of-memory error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
140    )]
141    OutOfMemory {
142        /// The caused of the memory error.
143        reason: String,
144        /// The backtrace for this error.
145        #[cfg_attr(std_io, serde(skip))]
146        backtrace: BackTrace,
147    },
148
149    /// Too many resources were requested
150    #[error("Too many resources were requested during launch\n{0}")]
151    TooManyResources(#[from] ResourceLimitError),
152
153    /// Unknown launch error.
154    #[error(
155        "An unknown error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
156    )]
157    Unknown {
158        /// The caused of the unknown error.
159        reason: String,
160        /// The backtrace for this error.
161        #[cfg_attr(std_io, serde(skip))]
162        backtrace: BackTrace,
163    },
164
165    /// Can't launch because of an IO Error.
166    #[error("An io error happened during launch\nCaused by:\n  {0}")]
167    IoError(#[from] IoError),
168}
169
170/// Resource limit errors.
171#[derive(Error, Clone)]
172#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
173pub enum ResourceLimitError {
174    /// Shared memory exceeds maximum
175    #[error(
176        "Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
177    )]
178    SharedMemory {
179        /// Value requested
180        requested: usize,
181        /// Maximum value
182        max: usize,
183        /// The backtrace for this error.
184        #[cfg_attr(std_io, serde(skip))]
185        backtrace: BackTrace,
186    },
187    /// Total units exceeds maximum
188    #[error(
189        "Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
190    )]
191    Units {
192        /// Requested value
193        requested: u32,
194        /// Maximum value
195        max: u32,
196        /// The backtrace for this error.
197        #[cfg_attr(std_io, serde(skip))]
198        backtrace: BackTrace,
199    },
200    /// `CubeDim` exceeds maximum
201    #[error(
202        "Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
203    )]
204    CubeDim {
205        /// Requested value
206        requested: (u32, u32, u32),
207        /// Maximum value
208        max: (u32, u32, u32),
209        /// The backtrace for this error.
210        #[cfg_attr(std_io, serde(skip))]
211        backtrace: BackTrace,
212    },
213}
214
215impl core::fmt::Debug for LaunchError {
216    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
217        f.write_fmt(format_args!("{self}"))
218    }
219}
220
221impl core::fmt::Debug for ResourceLimitError {
222    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
223        f.write_fmt(format_args!("{self}"))
224    }
225}
226
227/// Error that can happen asynchronously while executing registered kernels.
228#[derive(Error, Debug, Clone)]
229#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
230pub enum ExecutionError {
231    /// A generic runtime error.
232    #[error("An error happened during execution\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}")]
233    Generic {
234        /// The details of the generic error.
235        reason: String,
236        /// The backtrace for this error.
237        #[cfg_attr(std_io, serde(skip))]
238        backtrace: BackTrace,
239    },
240}
241
242/// The compute server is responsible for handling resources and computations over resources.
243///
244/// Everything in the server is mutable, therefore it should be solely accessed through the
245/// [`ComputeClient`] for thread safety.
246pub trait ComputeServer:
247    Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
248where
249    Self: Sized,
250{
251    /// The kernel type defines the computation algorithms.
252    type Kernel: KernelMetadata;
253    /// Information that can be retrieved for the runtime.
254    type Info: Debug + Send + Sync;
255    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
256    type Storage: ComputeStorage;
257
258    /// Reserves `size` bytes in the storage, and returns a handle over them.
259    fn create(
260        &mut self,
261        descriptors: Vec<AllocationDescriptor<'_>>,
262        stream_id: StreamId,
263    ) -> Result<Vec<Allocation>, IoError>;
264
265    /// Reserves N [Bytes] of the provided sizes to be used as staging to load data.
266    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
267        Err(IoError::UnsupportedIoOperation {
268            backtrace: BackTrace::capture(),
269        })
270    }
271
272    /// Retrieve the server logger.
273    fn logger(&self) -> Arc<ServerLogger>;
274
275    /// Retrieve the server utilities.
276    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
277
278    /// Utility to create a new buffer and immediately copy contiguous data into it
279    fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
280        let alloc = self
281            .create(
282                vec![AllocationDescriptor::new(
283                    AllocationKind::Contiguous,
284                    &[data.len()],
285                    1,
286                )],
287                stream_id,
288            )?
289            .remove(0);
290        self.write(
291            vec![(
292                CopyDescriptor::new(
293                    alloc.handle.clone().binding(),
294                    &[data.len()],
295                    &alloc.strides,
296                    1,
297                ),
298                Bytes::from_bytes_vec(data.to_vec()),
299            )],
300            stream_id,
301        )?;
302        Ok(alloc.handle)
303    }
304
305    /// Utility to create a new buffer and immediately copy contiguous data into it
306    fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
307        let alloc = self
308            .create(
309                vec![AllocationDescriptor::new(
310                    AllocationKind::Contiguous,
311                    &[data.len()],
312                    1,
313                )],
314                stream_id,
315            )?
316            .remove(0);
317        self.write(
318            vec![(
319                CopyDescriptor::new(
320                    alloc.handle.clone().binding(),
321                    &[data.len()],
322                    &alloc.strides,
323                    1,
324                ),
325                data,
326            )],
327            stream_id,
328        )?;
329        Ok(alloc.handle)
330    }
331
332    /// Given bindings, returns the owned resources as bytes.
333    fn read<'a>(
334        &mut self,
335        descriptors: Vec<CopyDescriptor<'a>>,
336        stream_id: StreamId,
337    ) -> DynFut<Result<Vec<Bytes>, IoError>>;
338
339    /// Writes the specified bytes into the buffers given
340    fn write(
341        &mut self,
342        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
343        stream_id: StreamId,
344    ) -> Result<(), IoError>;
345
346    /// Wait for the completion of every task in the server.
347    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
348
349    /// Given a resource handle, returns the storage resource.
350    fn get_resource(
351        &mut self,
352        binding: Binding,
353        stream_id: StreamId,
354    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
355
356    /// Executes the `kernel` over the given memory `handles`.
357    ///
358    /// Kernels have mutable access to every resource they are given
359    /// and are responsible of determining which should be read or written.
360    ///
361    /// # Safety
362    ///
363    /// When executing with mode [`ExecutionMode::Unchecked`], out-of-bound reads and writes can happen.
364    unsafe fn launch(
365        &mut self,
366        kernel: Self::Kernel,
367        count: CubeCount,
368        bindings: Bindings,
369        kind: ExecutionMode,
370        stream_id: StreamId,
371    ) -> Result<(), LaunchError>;
372
373    /// Flush all outstanding tasks in the server.
374    fn flush(&mut self, stream_id: StreamId);
375
376    /// The current memory usage of the server.
377    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
378
379    /// Ask the server to release memory that it can release.
380    fn memory_cleanup(&mut self, stream_id: StreamId);
381
382    /// Enable collecting timestamps.
383    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
384
385    /// Disable collecting timestamps.
386    fn end_profile(
387        &mut self,
388        stream_id: StreamId,
389        token: ProfilingToken,
390    ) -> Result<ProfileDuration, ProfileError>;
391
392    /// Update the memory mode of allocation in the server.
393    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
394}
395
396/// Defines functions for optimized data transfer between servers, supporting custom communication
397/// mechanisms such as peer-to-peer communication or specialized implementations.
398pub trait ServerCommunication {
399    /// Indicates whether server-to-server communication is enabled for this implementation.
400    const SERVER_COMM_ENABLED: bool;
401
402    /// Copies data from a source server to a destination server.
403    ///
404    /// # Arguments
405    ///
406    /// * `server_src` - A mutable reference to the source server from which data is copied.
407    /// * `server_dst` - A mutable reference to the destination server receiving the data.
408    /// * `src` - A descriptor specifying the data to be copied, including shape, strides, and binding.
409    /// * `stream_id_src` - The stream ID associated with the source server's operation.
410    /// * `stream_id_dst` - The stream ID associated with the destination server's operation.
411    ///
412    /// # Returns
413    ///
414    /// Returns a `Result` containing an `Allocation` on success, or an `IoError` if the operation fails.
415    ///
416    /// # Panics
417    ///
418    /// Panics if server communication is not enabled (`SERVER_COMM_ENABLED` is `false`) or if the
419    /// trait is incorrectly implemented by the server.
420    #[allow(unused_variables)]
421    fn copy(
422        server_src: &mut Self,
423        server_dst: &mut Self,
424        src: CopyDescriptor<'_>,
425        stream_id_src: StreamId,
426        stream_id_dst: StreamId,
427    ) -> Result<Allocation, IoError> {
428        if !Self::SERVER_COMM_ENABLED {
429            panic!("Server-to-server communication is not supported by this server.");
430        } else {
431            panic!(
432                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
433            );
434        }
435    }
436}
437
438#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
439/// Profiling identification so that the server can support recursive and overlapping profilings.
440pub struct ProfilingToken {
441    /// The token value.
442    pub id: u64,
443}
444
445/// Server handle containing the [memory handle](crate::server::Handle).
446#[derive(new, Debug, PartialEq, Eq)]
447pub struct Handle {
448    /// Memory handle.
449    pub memory: SliceHandle,
450    /// Memory offset in bytes.
451    pub offset_start: Option<u64>,
452    /// Memory offset in bytes.
453    pub offset_end: Option<u64>,
454    /// The stream where the data was created.
455    pub stream: cubecl_common::stream_id::StreamId,
456    /// The stream position when the tensor became available.
457    pub cursor: u64,
458    /// Length of the underlying buffer ignoring offsets
459    size: u64,
460}
461
462/// Type of allocation, either contiguous or optimized (row-aligned when possible)
463#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
464pub enum AllocationKind {
465    /// Contiguous layout, with no padding
466    Contiguous,
467    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
468    /// that support it.
469    Optimized,
470}
471
472/// Descriptor for a new tensor allocation
473#[derive(new, Debug, Clone, Copy)]
474pub struct AllocationDescriptor<'a> {
475    /// Layout for the tensor
476    pub kind: AllocationKind,
477    /// Shape of the tensor
478    pub shape: &'a [usize],
479    /// Size of each element in the tensor (used for conversion of shape to bytes)
480    pub elem_size: usize,
481}
482
483impl<'a> AllocationDescriptor<'a> {
484    /// Create an optimized allocation descriptor
485    pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
486        AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
487    }
488
489    /// Create a contiguous allocation descriptor
490    pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
491        AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
492    }
493}
494
495/// An allocation with associated strides. Strides depend on tensor layout.
496#[derive(Debug)]
497pub struct Allocation {
498    /// The handle for the memory resource
499    pub handle: Handle,
500    /// The strides of the tensor
501    pub strides: Strides,
502}
503
504impl Allocation {
505    /// Create a new allocation
506    pub fn new(handle: Handle, strides: impl Into<Strides>) -> Self {
507        Allocation {
508            handle,
509            strides: strides.into(),
510        }
511    }
512}
513
514/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
515/// are able to be caught, so some IO errors will still panic.
516#[derive(Error, Clone)]
517#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
518pub enum IoError {
519    /// Buffer size exceeds the max available
520    #[error("can't allocate buffer of size: {size}\n{backtrace}")]
521    BufferTooBig {
522        /// The size of the buffer in bytes.
523        size: u64,
524        /// The captured backtrace.
525        #[cfg_attr(std_io, serde(skip))]
526        backtrace: BackTrace,
527    },
528
529    /// Strides aren't supported for this copy operation on this runtime
530    #[error("the provided strides are not supported for this operation\n{backtrace}")]
531    UnsupportedStrides {
532        /// The backtrace.
533        #[cfg_attr(std_io, serde(skip))]
534        backtrace: BackTrace,
535    },
536
537    /// Handle wasn't found in the memory pool
538    #[error("couldn't find resource for that handle\n{backtrace}")]
539    InvalidHandle {
540        /// The backtrace.
541        #[cfg_attr(std_io, serde(skip))]
542        backtrace: BackTrace,
543    },
544
545    /// Unknown error happened during execution
546    #[error("Unknown error happened during execution\n{backtrace}")]
547    Unknown {
548        /// Details of the error
549        description: String,
550        /// The backtrace.
551        #[cfg_attr(std_io, serde(skip))]
552        backtrace: BackTrace,
553    },
554
555    /// The current IO operation is not supported
556    #[error("The current IO operation is not supported\n{backtrace}")]
557    UnsupportedIoOperation {
558        /// The backtrace.
559        #[cfg_attr(std_io, serde(skip))]
560        backtrace: BackTrace,
561    },
562
563    /// Can't perform the IO operation because of a runtime error.
564    #[error("Can't perform the IO operation because of a runtime error")]
565    Execution(#[from] ExecutionError),
566}
567
568impl core::fmt::Debug for IoError {
569    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
570        f.write_fmt(format_args!("{self}"))
571    }
572}
573
574impl Handle {
575    /// Add to the current offset in bytes.
576    pub fn offset_start(mut self, offset: u64) -> Self {
577        if let Some(val) = &mut self.offset_start {
578            *val += offset;
579        } else {
580            self.offset_start = Some(offset);
581        }
582
583        self
584    }
585    /// Add to the current offset in bytes.
586    pub fn offset_end(mut self, offset: u64) -> Self {
587        if let Some(val) = &mut self.offset_end {
588            *val += offset;
589        } else {
590            self.offset_end = Some(offset);
591        }
592
593        self
594    }
595
596    /// Get the size of the handle, in bytes, accounting for offsets
597    pub fn size(&self) -> u64 {
598        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
599    }
600}
601
602/// Bindings to execute a kernel.
603#[derive(Debug, Default)]
604pub struct Bindings {
605    /// Buffer bindings
606    pub buffers: Vec<Binding>,
607    /// Packed metadata for tensor bindings (len, shape, stride, etc).
608    /// Ordered by inputs, then outputs, then tensormaps
609    pub metadata: MetadataBinding,
610    /// Scalar bindings
611    pub scalars: BTreeMap<StorageType, ScalarBinding>,
612    /// Tensor map bindings
613    pub tensor_maps: Vec<TensorMapBinding>,
614}
615
616impl Bindings {
617    /// Create a new bindings struct
618    pub fn new() -> Self {
619        Self::default()
620    }
621
622    /// Add a buffer binding
623    pub fn with_buffer(mut self, binding: Binding) -> Self {
624        self.buffers.push(binding);
625        self
626    }
627
628    /// Extend the buffers with `bindings`
629    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
630        self.buffers.extend(bindings);
631        self
632    }
633
634    /// Add a scalar parameter
635    pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
636        self.scalars
637            .insert(ty, ScalarBinding::new(ty, length, data));
638        self
639    }
640
641    /// Extend the scalars with `bindings`
642    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
643        self.scalars
644            .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
645        self
646    }
647
648    /// Set the metadata to `meta`
649    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
650        self.metadata = meta;
651        self
652    }
653
654    /// Extend the tensor maps with `bindings`
655    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
656        self.tensor_maps.extend(bindings);
657        self
658    }
659}
660
661/// Binding of a set of scalars of the same type to execute a kernel.
662#[derive(new, Debug, Default)]
663pub struct MetadataBinding {
664    /// Metadata values
665    pub data: Vec<u64>,
666    /// Length of the static portion (rank, len, `buffer_len`, `shape_offsets`, `stride_offsets`).
667    pub static_len: usize,
668}
669
670/// Binding of a set of scalars of the same type to execute a kernel.
671#[derive(new, Debug, Clone)]
672pub struct ScalarBinding {
673    /// Type of the scalars
674    pub ty: StorageType,
675    /// Unpadded length of the underlying data
676    pub length: usize,
677    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
678    pub data: Vec<u64>,
679}
680
681impl ScalarBinding {
682    /// Get data as byte slice
683    pub fn data(&self) -> &[u8] {
684        bytemuck::cast_slice(&self.data)
685    }
686}
687
688/// Binding of a [tensor handle](Handle) to execute a kernel.
689#[derive(new, Debug)]
690pub struct Binding {
691    /// Memory binding.
692    pub memory: SliceBinding,
693    /// Memory offset in bytes.
694    pub offset_start: Option<u64>,
695    /// Memory offset in bytes.
696    pub offset_end: Option<u64>,
697    /// The stream where the data was created.
698    pub stream: cubecl_common::stream_id::StreamId,
699    /// The stream position when the tensor became available.
700    pub cursor: u64,
701    /// Size in bytes
702    size: u64,
703}
704
705impl Binding {
706    /// Get the size of the handle, in bytes, accounting for offsets
707    pub fn size(&self) -> u64 {
708        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
709    }
710}
711
712/// A binding with shape and stride info for non-contiguous reading
713#[derive(new, Debug, Clone)]
714pub struct CopyDescriptor<'a> {
715    /// Binding for the memory resource
716    pub binding: Binding,
717    /// Shape of the resource
718    pub shape: &'a [usize],
719    /// Strides of the resource
720    pub strides: &'a [usize],
721    /// Size of each element in the resource
722    pub elem_size: usize,
723}
724
725/// A tensor map used with TMA ops
726#[derive(new, Debug, Clone)]
727pub struct TensorMapBinding {
728    /// The binding for the backing tensor
729    pub binding: Binding,
730    /// The tensormap metadata
731    pub map: TensorMapMeta,
732}
733
734/// `TensorMap` metadata for the opaque proxy used in TMA copies
735#[derive(Debug, Clone)]
736pub struct TensorMapMeta {
737    /// Tensormap format (tiled or im2col)
738    pub format: TensorMapFormat,
739    /// Metadata of the backing tensor
740    pub metadata: Metadata,
741    /// Element stride, usually 1 but may be 2 for complex tensors
742    /// For im2col, this is equivalent to the kernel stride
743    pub elem_stride: Strides,
744    /// Interleave mode
745    pub interleave: TensorMapInterleave,
746    /// Swizzle mode
747    pub swizzle: TensorMapSwizzle,
748    /// Prefetch settings
749    pub prefetch: TensorMapPrefetch,
750    /// OOB fill value
751    pub oob_fill: OobFill,
752    /// Storage type
753    pub storage_ty: StorageType,
754}
755
756impl Handle {
757    /// If the tensor handle can be reused inplace.
758    pub fn can_mut(&self) -> bool {
759        self.memory.can_mut() && self.stream == StreamId::current()
760    }
761}
762
763impl Handle {
764    /// Convert the [handle](Handle) into a [binding](Binding).
765    pub fn binding(self) -> Binding {
766        Binding {
767            memory: MemoryHandle::binding(self.memory),
768            offset_start: self.offset_start,
769            offset_end: self.offset_end,
770            size: self.size,
771            stream: self.stream,
772            cursor: self.cursor,
773        }
774    }
775
776    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
777    pub fn copy_descriptor<'a>(
778        &'a self,
779        shape: &'a [usize],
780        strides: &'a [usize],
781        elem_size: usize,
782    ) -> CopyDescriptor<'a> {
783        CopyDescriptor {
784            shape,
785            strides,
786            elem_size,
787            binding: self.clone().binding(),
788        }
789    }
790}
791
792impl Clone for Handle {
793    fn clone(&self) -> Self {
794        Self {
795            memory: self.memory.clone(),
796            offset_start: self.offset_start,
797            offset_end: self.offset_end,
798            size: self.size,
799            stream: self.stream,
800            cursor: self.cursor,
801        }
802    }
803}
804
805impl Clone for Binding {
806    fn clone(&self) -> Self {
807        Self {
808            memory: self.memory.clone(),
809            offset_start: self.offset_start,
810            offset_end: self.offset_end,
811            size: self.size,
812            stream: self.stream,
813            cursor: self.cursor,
814        }
815    }
816}
817
818/// Specifieds the number of cubes to be dispatched for a kernel.
819///
820/// This translates to eg. a grid for CUDA, or to `num_workgroups` for wgsl.
821#[allow(clippy::large_enum_variant)]
822pub enum CubeCount {
823    /// Dispatch a known count of x, y, z cubes.
824    Static(u32, u32, u32),
825    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
826    Dynamic(Binding),
827}
828
829/// Defines how to select cube count based on the number of cubes required.
830pub enum CubeCountSelection {
831    /// If the number of cubes is the same as required.
832    Exact(CubeCount),
833    /// If the number of cubes isn't the same as required.
834    ///
835    /// This can happen based on the hardware limit, requiring the kernel to perform OOB checks.
836    Approx(CubeCount, u32),
837}
838
839impl CubeCountSelection {
840    /// Creates a [`CubeCount`] while respecting the hardware limits.
841    pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
842        let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
843
844        let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
845        let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
846
847        match num_cubes_actual == num_cubes {
848            true => CubeCountSelection::Exact(cube_count),
849            false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
850        }
851    }
852
853    /// If some cubes will be idle.
854    pub fn has_idle(&self) -> bool {
855        matches!(self, Self::Approx(..))
856    }
857
858    /// Converts into [`CubeCount`].
859    pub fn cube_count(self) -> CubeCount {
860        match self {
861            CubeCountSelection::Exact(cube_count) => cube_count,
862            CubeCountSelection::Approx(cube_count, _) => cube_count,
863        }
864    }
865}
866
867impl From<CubeCountSelection> for CubeCount {
868    fn from(value: CubeCountSelection) -> Self {
869        value.cube_count()
870    }
871}
872
873impl CubeCount {
874    /// Create a new static cube count with the given x = y = z = 1.
875    pub fn new_single() -> Self {
876        CubeCount::Static(1, 1, 1)
877    }
878
879    /// Create a new static cube count with the given x, and y = z = 1.
880    pub fn new_1d(x: u32) -> Self {
881        CubeCount::Static(x, 1, 1)
882    }
883
884    /// Create a new static cube count with the given x and y, and z = 1.
885    pub fn new_2d(x: u32, y: u32) -> Self {
886        CubeCount::Static(x, y, 1)
887    }
888
889    /// Create a new static cube count with the given x, y and z.
890    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
891        CubeCount::Static(x, y, z)
892    }
893}
894
895impl Debug for CubeCount {
896    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
897        match self {
898            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
899            CubeCount::Dynamic(_) => f.write_str("binding"),
900        }
901    }
902}
903
904impl Clone for CubeCount {
905    fn clone(&self) -> Self {
906        match self {
907            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
908            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
909        }
910    }
911}
912
913#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
914#[allow(missing_docs)]
915/// The number of units across all 3 axis totalling to the number of working units in a cube.
916pub struct CubeDim {
917    /// The number of units in the x axis.
918    pub x: u32,
919    /// The number of units in the y axis.
920    pub y: u32,
921    /// The number of units in the z axis.
922    pub z: u32,
923}
924
925impl CubeDim {
926    /// Creates a new [`CubeDim`] based on the maximum number of tasks that can be parellalized by units, in other words,
927    /// by the maximum number of working units.
928    ///
929    /// # Notes
930    ///
931    /// For complex problems, you probably want to have your own logic function to create the
932    /// [`CubeDim`], but for simpler problems such as elemwise-operation, this is a great default.
933    pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
934        let properties = client.properties();
935        let plane_size = properties.hardware.plane_size_max;
936        let plane_count = Self::calculate_plane_count_per_cube(
937            working_units as u32,
938            plane_size,
939            properties.hardware.num_cpu_cores,
940        );
941
942        // Make sure it respects the max units per cube (especially on wasm)
943        let limit = properties.hardware.max_units_per_cube / plane_size;
944
945        Self::new_2d(plane_size, u32::min(limit, plane_count))
946    }
947
948    fn calculate_plane_count_per_cube(
949        working_units: u32,
950        plane_dim: u32,
951        num_cpu_cores: Option<u32>,
952    ) -> u32 {
953        match num_cpu_cores {
954            Some(num_cores) => core::cmp::min(num_cores, working_units),
955            None => {
956                let plane_count_max = core::cmp::max(1, working_units / plane_dim);
957
958                // Ensures `plane_count` is a power of 2.
959                const NUM_PLANE_MAX: u32 = 8u32;
960                const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
961                let plane_count_max_log2 =
962                    core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
963                2u32.pow(plane_count_max_log2)
964            }
965        }
966    }
967
968    /// Create a new cube dim with x = y = z = 1.
969    pub const fn new_single() -> Self {
970        Self { x: 1, y: 1, z: 1 }
971    }
972
973    /// Create a new cube dim with the given x, and y = z = 1.
974    pub const fn new_1d(x: u32) -> Self {
975        Self { x, y: 1, z: 1 }
976    }
977
978    /// Create a new cube dim with the given x and y, and z = 1.
979    pub const fn new_2d(x: u32, y: u32) -> Self {
980        Self { x, y, z: 1 }
981    }
982
983    /// Create a new cube dim with the given x, y and z.
984    /// This is equivalent to the [new](CubeDim::new) function.
985    pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
986        Self { x, y, z }
987    }
988
989    /// Total numbers of units per cube
990    pub const fn num_elems(&self) -> u32 {
991        self.x * self.y * self.z
992    }
993
994    /// Whether this `CubeDim` can fully contain `other`
995    pub const fn can_contain(&self, other: CubeDim) -> bool {
996        self.x >= other.x && self.y >= other.y && self.z >= other.z
997    }
998}
999
1000impl From<(u32, u32, u32)> for CubeDim {
1001    fn from(value: (u32, u32, u32)) -> Self {
1002        CubeDim::new_3d(value.0, value.1, value.2)
1003    }
1004}
1005
1006impl From<CubeDim> for (u32, u32, u32) {
1007    fn from(val: CubeDim) -> Self {
1008        (val.x, val.y, val.z)
1009    }
1010}
1011
1012/// The kind of execution to be performed.
1013#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
1014pub enum ExecutionMode {
1015    /// Checked kernels are safe.
1016    #[default]
1017    Checked,
1018    /// Unchecked kernels are unsafe.
1019    Unchecked,
1020}
1021
1022fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
1023    let max_cube_counts = [max.0, max.1, max.2];
1024    let mut num_cubes = [num_cubes, 1, 1];
1025    let base = 2;
1026
1027    let mut reduce_count = |i: usize| {
1028        if num_cubes[i] <= max_cube_counts[i] {
1029            return true;
1030        }
1031
1032        loop {
1033            num_cubes[i] = num_cubes[i].div_ceil(base);
1034            num_cubes[i + 1] *= base;
1035
1036            if num_cubes[i] <= max_cube_counts[i] {
1037                return false;
1038            }
1039        }
1040    };
1041
1042    for i in 0..2 {
1043        if reduce_count(i) {
1044            break;
1045        }
1046    }
1047
1048    num_cubes
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053    use super::*;
1054
1055    #[test_log::test]
1056    fn safe_num_cubes_even() {
1057        let max = (32, 32, 32);
1058        let required = 2048;
1059
1060        let actual = cube_count_spread(&max, required);
1061        let expected = [32, 32, 2];
1062        assert_eq!(actual, expected);
1063    }
1064
1065    #[test_log::test]
1066    fn safe_num_cubes_odd() {
1067        let max = (48, 32, 16);
1068        let required = 3177;
1069
1070        let actual = cube_count_spread(&max, required);
1071        let expected = [25, 32, 4];
1072        assert_eq!(actual, expected);
1073    }
1074}