Skip to main content

cubecl_runtime/
server.rs

1use crate::{
2    client::ComputeClient,
3    compiler::CompilationError,
4    kernel::KernelMetadata,
5    logging::ServerLogger,
6    memory_management::{
7        MemoryAllocationMode, MemoryHandle, MemoryUsage,
8        memory_pool::{SliceBinding, SliceHandle},
9    },
10    runtime::Runtime,
11    storage::{BindingResource, ComputeStorage},
12    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use alloc::collections::BTreeMap;
15use alloc::string::String;
16use alloc::sync::Arc;
17use alloc::vec;
18use alloc::vec::Vec;
19use core::fmt::Debug;
20use cubecl_common::{
21    backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
22    stream_id::StreamId,
23};
24use cubecl_ir::{DeviceProperties, StorageType};
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27
28#[derive(Error, Clone)]
29/// An error during profiling.
30pub enum ProfileError {
31    /// An unknown error happened during profiling
32    #[error(
33        "An unknown error happened during profiling\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
34    )]
35    Unknown {
36        /// The caused of the error
37        reason: String,
38        /// The captured backtrace.
39        backtrace: BackTrace,
40    },
41
42    /// No profiling was registered
43    #[error("No profiling registered\nBacktrace:\n{backtrace}")]
44    NotRegistered {
45        /// The captured backtrace.
46        backtrace: BackTrace,
47    },
48
49    /// A launch error happened during profiling
50    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
51    Launch(#[from] LaunchError),
52
53    /// An execution error happened during profiling
54    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
55    Execution(#[from] ExecutionError),
56}
57
58impl core::fmt::Debug for ProfileError {
59    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60        f.write_fmt(format_args!("{self}"))
61    }
62}
63
64/// Contains many different types that are useful for server implementations and compute clients.
65pub struct ServerUtilities<Server: ComputeServer> {
66    /// The time when `profile-tracy` is activated.
67    #[cfg(feature = "profile-tracy")]
68    pub epoch_time: web_time::Instant,
69    /// The GPU client when `profile-tracy` is activated.
70    #[cfg(feature = "profile-tracy")]
71    pub gpu_client: tracy_client::GpuContext,
72    /// Information shared between all servers.
73    pub properties: DeviceProperties,
74    /// Stable hash of the device properties
75    pub properties_hash: u64,
76    /// Information specific to the current server.
77    pub info: Server::Info,
78    /// The logger based on global cubecl configs.
79    pub logger: Arc<ServerLogger>,
80}
81
82impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
83where
84    Server: ComputeServer,
85    Server::Info: core::fmt::Debug,
86{
87    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
88        f.debug_struct("ServerUtilities")
89            .field("properties", &self.properties)
90            .field("info", &self.info)
91            .field("logger", &self.logger)
92            .finish()
93    }
94}
95
96impl<S: ComputeServer> ServerUtilities<S> {
97    /// Creates a new server utilities.
98    pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
99        // Start a tracy client if needed.
100        #[cfg(feature = "profile-tracy")]
101        let client = tracy_client::Client::start();
102
103        Self {
104            properties_hash: properties.checksum(),
105            properties,
106            logger,
107            // Create the GPU client if needed.
108            #[cfg(feature = "profile-tracy")]
109            gpu_client: client
110                .clone()
111                .new_gpu_context(
112                    Some(&format!("{info:?}")),
113                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
114                    tracy_client::GpuContextType::Invalid,
115                    0,   // Timestamps are manually aligned to this epoch so start at 0.
116                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
117                )
118                .unwrap(),
119            #[cfg(feature = "profile-tracy")]
120            epoch_time: web_time::Instant::now(),
121            info,
122        }
123    }
124}
125
126/// Kernel Launch Errors.
127#[derive(Error, Clone)]
128#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
129pub enum LaunchError {
130    /// The given kernel can't be compiled.
131    #[error("A compilation error happened during launch\nCaused by:\n  {0}")]
132    CompilationError(#[from] CompilationError),
133
134    /// The server is out of memory.
135    #[error(
136        "An out-of-memory error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
137    )]
138    OutOfMemory {
139        /// The caused of the memory error.
140        reason: String,
141        /// The backtrace for this error.
142        #[cfg_attr(std_io, serde(skip))]
143        backtrace: BackTrace,
144    },
145
146    /// Too many resources were requested
147    #[error("Too many resources were requested during launch\n{0}")]
148    TooManyResources(#[from] ResourceLimitError),
149
150    /// Unknown launch error.
151    #[error(
152        "An unknown error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
153    )]
154    Unknown {
155        /// The caused of the unknown error.
156        reason: String,
157        /// The backtrace for this error.
158        #[cfg_attr(std_io, serde(skip))]
159        backtrace: BackTrace,
160    },
161
162    /// Can't launch because of an IO Error.
163    #[error("An io error happened during launch\nCaused by:\n  {0}")]
164    IoError(#[from] IoError),
165}
166
167/// Resource limit errors.
168#[derive(Error, Clone)]
169#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
170pub enum ResourceLimitError {
171    /// Shared memory exceeds maximum
172    #[error(
173        "Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
174    )]
175    SharedMemory {
176        /// Value requested
177        requested: usize,
178        /// Maximum value
179        max: usize,
180        /// The backtrace for this error.
181        #[cfg_attr(std_io, serde(skip))]
182        backtrace: BackTrace,
183    },
184    /// Total units exceeds maximum
185    #[error(
186        "Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
187    )]
188    Units {
189        /// Requested value
190        requested: u32,
191        /// Maximum value
192        max: u32,
193        /// The backtrace for this error.
194        #[cfg_attr(std_io, serde(skip))]
195        backtrace: BackTrace,
196    },
197    /// `CubeDim` exceeds maximum
198    #[error(
199        "Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
200    )]
201    CubeDim {
202        /// Requested value
203        requested: (u32, u32, u32),
204        /// Maximum value
205        max: (u32, u32, u32),
206        /// The backtrace for this error.
207        #[cfg_attr(std_io, serde(skip))]
208        backtrace: BackTrace,
209    },
210}
211
212impl core::fmt::Debug for LaunchError {
213    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
214        f.write_fmt(format_args!("{self}"))
215    }
216}
217
218impl core::fmt::Debug for ResourceLimitError {
219    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
220        f.write_fmt(format_args!("{self}"))
221    }
222}
223
224/// Error that can happen asynchronously while executing registered kernels.
225#[derive(Error, Debug, Clone)]
226#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
227pub enum ExecutionError {
228    /// A generic runtime error.
229    #[error("An error happened during execution\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}")]
230    Generic {
231        /// The details of the generic error.
232        reason: String,
233        /// The backtrace for this error.
234        #[cfg_attr(std_io, serde(skip))]
235        backtrace: BackTrace,
236    },
237}
238
239/// The compute server is responsible for handling resources and computations over resources.
240///
241/// Everything in the server is mutable, therefore it should be solely accessed through the
242/// [`ComputeClient`] for thread safety.
243pub trait ComputeServer:
244    Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
245where
246    Self: Sized,
247{
248    /// The kernel type defines the computation algorithms.
249    type Kernel: KernelMetadata;
250    /// Information that can be retrieved for the runtime.
251    type Info: Debug + Send + Sync;
252    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
253    type Storage: ComputeStorage;
254
255    /// Reserves `size` bytes in the storage, and returns a handle over them.
256    fn create(
257        &mut self,
258        descriptors: Vec<AllocationDescriptor<'_>>,
259        stream_id: StreamId,
260    ) -> Result<Vec<Allocation>, IoError>;
261
262    /// Reserves N [Bytes] of the provided sizes to be used as staging to load data.
263    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
264        Err(IoError::UnsupportedIoOperation {
265            backtrace: BackTrace::capture(),
266        })
267    }
268
269    /// Retrieve the server logger.
270    fn logger(&self) -> Arc<ServerLogger>;
271
272    /// Retrieve the server utilities.
273    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
274
275    /// Utility to create a new buffer and immediately copy contiguous data into it
276    fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
277        let alloc = self
278            .create(
279                vec![AllocationDescriptor::new(
280                    AllocationKind::Contiguous,
281                    &[data.len()],
282                    1,
283                )],
284                stream_id,
285            )?
286            .remove(0);
287        self.write(
288            vec![(
289                CopyDescriptor::new(
290                    alloc.handle.clone().binding(),
291                    &[data.len()],
292                    &alloc.strides,
293                    1,
294                ),
295                Bytes::from_bytes_vec(data.to_vec()),
296            )],
297            stream_id,
298        )?;
299        Ok(alloc.handle)
300    }
301
302    /// Utility to create a new buffer and immediately copy contiguous data into it
303    fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
304        let alloc = self
305            .create(
306                vec![AllocationDescriptor::new(
307                    AllocationKind::Contiguous,
308                    &[data.len()],
309                    1,
310                )],
311                stream_id,
312            )?
313            .remove(0);
314        self.write(
315            vec![(
316                CopyDescriptor::new(
317                    alloc.handle.clone().binding(),
318                    &[data.len()],
319                    &alloc.strides,
320                    1,
321                ),
322                data,
323            )],
324            stream_id,
325        )?;
326        Ok(alloc.handle)
327    }
328
329    /// Given bindings, returns the owned resources as bytes.
330    fn read<'a>(
331        &mut self,
332        descriptors: Vec<CopyDescriptor<'a>>,
333        stream_id: StreamId,
334    ) -> DynFut<Result<Vec<Bytes>, IoError>>;
335
336    /// Writes the specified bytes into the buffers given
337    fn write(
338        &mut self,
339        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
340        stream_id: StreamId,
341    ) -> Result<(), IoError>;
342
343    /// Wait for the completion of every task in the server.
344    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
345
346    /// Given a resource handle, returns the storage resource.
347    fn get_resource(
348        &mut self,
349        binding: Binding,
350        stream_id: StreamId,
351    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
352
353    /// Executes the `kernel` over the given memory `handles`.
354    ///
355    /// Kernels have mutable access to every resource they are given
356    /// and are responsible of determining which should be read or written.
357    ///
358    /// # Safety
359    ///
360    /// When executing with mode [`ExecutionMode::Unchecked`], out-of-bound reads and writes can happen.
361    unsafe fn launch(
362        &mut self,
363        kernel: Self::Kernel,
364        count: CubeCount,
365        bindings: Bindings,
366        kind: ExecutionMode,
367        stream_id: StreamId,
368    ) -> Result<(), LaunchError>;
369
370    /// Flush all outstanding tasks in the server.
371    fn flush(&mut self, stream_id: StreamId);
372
373    /// The current memory usage of the server.
374    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
375
376    /// Ask the server to release memory that it can release.
377    fn memory_cleanup(&mut self, stream_id: StreamId);
378
379    /// Enable collecting timestamps.
380    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
381
382    /// Disable collecting timestamps.
383    fn end_profile(
384        &mut self,
385        stream_id: StreamId,
386        token: ProfilingToken,
387    ) -> Result<ProfileDuration, ProfileError>;
388
389    /// Update the memory mode of allocation in the server.
390    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
391}
392
393/// Defines functions for optimized data transfer between servers, supporting custom communication
394/// mechanisms such as peer-to-peer communication or specialized implementations.
395pub trait ServerCommunication {
396    /// Indicates whether server-to-server communication is enabled for this implementation.
397    const SERVER_COMM_ENABLED: bool;
398
399    /// Copies data from a source server to a destination server.
400    ///
401    /// # Arguments
402    ///
403    /// * `server_src` - A mutable reference to the source server from which data is copied.
404    /// * `server_dst` - A mutable reference to the destination server receiving the data.
405    /// * `src` - A descriptor specifying the data to be copied, including shape, strides, and binding.
406    /// * `stream_id_src` - The stream ID associated with the source server's operation.
407    /// * `stream_id_dst` - The stream ID associated with the destination server's operation.
408    ///
409    /// # Returns
410    ///
411    /// Returns a `Result` containing an `Allocation` on success, or an `IoError` if the operation fails.
412    ///
413    /// # Panics
414    ///
415    /// Panics if server communication is not enabled (`SERVER_COMM_ENABLED` is `false`) or if the
416    /// trait is incorrectly implemented by the server.
417    #[allow(unused_variables)]
418    fn copy(
419        server_src: &mut Self,
420        server_dst: &mut Self,
421        src: CopyDescriptor<'_>,
422        stream_id_src: StreamId,
423        stream_id_dst: StreamId,
424    ) -> Result<Allocation, IoError> {
425        if !Self::SERVER_COMM_ENABLED {
426            panic!("Server-to-server communication is not supported by this server.");
427        } else {
428            panic!(
429                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
430            );
431        }
432    }
433}
434
435#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
436/// Profiling identification so that the server can support recursive and overlapping profilings.
437pub struct ProfilingToken {
438    /// The token value.
439    pub id: u64,
440}
441
442/// Server handle containing the [memory handle](crate::server::Handle).
443#[derive(new, Debug, PartialEq, Eq)]
444pub struct Handle {
445    /// Memory handle.
446    pub memory: SliceHandle,
447    /// Memory offset in bytes.
448    pub offset_start: Option<u64>,
449    /// Memory offset in bytes.
450    pub offset_end: Option<u64>,
451    /// The stream where the data was created.
452    pub stream: cubecl_common::stream_id::StreamId,
453    /// The stream position when the tensor became available.
454    pub cursor: u64,
455    /// Length of the underlying buffer ignoring offsets
456    size: u64,
457}
458
459/// Type of allocation, either contiguous or optimized (row-aligned when possible)
460#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
461pub enum AllocationKind {
462    /// Contiguous layout, with no padding
463    Contiguous,
464    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
465    /// that support it.
466    Optimized,
467}
468
469/// Descriptor for a new tensor allocation
470#[derive(new, Debug, Clone, Copy)]
471pub struct AllocationDescriptor<'a> {
472    /// Layout for the tensor
473    pub kind: AllocationKind,
474    /// Shape of the tensor
475    pub shape: &'a [usize],
476    /// Size of each element in the tensor (used for conversion of shape to bytes)
477    pub elem_size: usize,
478}
479
480impl<'a> AllocationDescriptor<'a> {
481    /// Create an optimized allocation descriptor
482    pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
483        AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
484    }
485
486    /// Create a contiguous allocation descriptor
487    pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
488        AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
489    }
490}
491
492/// An allocation with associated strides. Strides depend on tensor layout.
493#[derive(new, Debug)]
494pub struct Allocation {
495    /// The handle for the memory resource
496    pub handle: Handle,
497    /// The strides of the tensor
498    pub strides: Vec<usize>,
499}
500
501/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
502/// are able to be caught, so some IO errors will still panic.
503#[derive(Error, Clone)]
504#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
505pub enum IoError {
506    /// Buffer size exceeds the max available
507    #[error("can't allocate buffer of size: {size}\n{backtrace}")]
508    BufferTooBig {
509        /// The size of the buffer in bytes.
510        size: u64,
511        /// The captured backtrace.
512        #[cfg_attr(std_io, serde(skip))]
513        backtrace: BackTrace,
514    },
515
516    /// Strides aren't supported for this copy operation on this runtime
517    #[error("the provided strides are not supported for this operation\n{backtrace}")]
518    UnsupportedStrides {
519        /// The backtrace.
520        #[cfg_attr(std_io, serde(skip))]
521        backtrace: BackTrace,
522    },
523
524    /// Handle wasn't found in the memory pool
525    #[error("couldn't find resource for that handle\n{backtrace}")]
526    InvalidHandle {
527        /// The backtrace.
528        #[cfg_attr(std_io, serde(skip))]
529        backtrace: BackTrace,
530    },
531
532    /// Unknown error happened during execution
533    #[error("Unknown error happened during execution\n{backtrace}")]
534    Unknown {
535        /// Details of the error
536        description: String,
537        /// The backtrace.
538        #[cfg_attr(std_io, serde(skip))]
539        backtrace: BackTrace,
540    },
541
542    /// The current IO operation is not supported
543    #[error("The current IO operation is not supported\n{backtrace}")]
544    UnsupportedIoOperation {
545        /// The backtrace.
546        #[cfg_attr(std_io, serde(skip))]
547        backtrace: BackTrace,
548    },
549
550    /// Can't perform the IO operation because of a runtime error.
551    #[error("Can't perform the IO operation because of a runtime error")]
552    Execution(#[from] ExecutionError),
553}
554
555impl core::fmt::Debug for IoError {
556    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
557        f.write_fmt(format_args!("{self}"))
558    }
559}
560
561impl Handle {
562    /// Add to the current offset in bytes.
563    pub fn offset_start(mut self, offset: u64) -> Self {
564        if let Some(val) = &mut self.offset_start {
565            *val += offset;
566        } else {
567            self.offset_start = Some(offset);
568        }
569
570        self
571    }
572    /// Add to the current offset in bytes.
573    pub fn offset_end(mut self, offset: u64) -> Self {
574        if let Some(val) = &mut self.offset_end {
575            *val += offset;
576        } else {
577            self.offset_end = Some(offset);
578        }
579
580        self
581    }
582
583    /// Get the size of the handle, in bytes, accounting for offsets
584    pub fn size(&self) -> u64 {
585        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
586    }
587}
588
589/// Bindings to execute a kernel.
590#[derive(Debug, Default)]
591pub struct Bindings {
592    /// Buffer bindings
593    pub buffers: Vec<Binding>,
594    /// Packed metadata for tensor bindings (len, shape, stride, etc).
595    /// Ordered by inputs, then outputs, then tensormaps
596    pub metadata: MetadataBinding,
597    /// Scalar bindings
598    pub scalars: BTreeMap<StorageType, ScalarBinding>,
599    /// Tensor map bindings
600    pub tensor_maps: Vec<TensorMapBinding>,
601}
602
603impl Bindings {
604    /// Create a new bindings struct
605    pub fn new() -> Self {
606        Self::default()
607    }
608
609    /// Add a buffer binding
610    pub fn with_buffer(mut self, binding: Binding) -> Self {
611        self.buffers.push(binding);
612        self
613    }
614
615    /// Extend the buffers with `bindings`
616    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
617        self.buffers.extend(bindings);
618        self
619    }
620
621    /// Add a scalar parameter
622    pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
623        self.scalars
624            .insert(ty, ScalarBinding::new(ty, length, data));
625        self
626    }
627
628    /// Extend the scalars with `bindings`
629    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
630        self.scalars
631            .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
632        self
633    }
634
635    /// Set the metadata to `meta`
636    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
637        self.metadata = meta;
638        self
639    }
640
641    /// Extend the tensor maps with `bindings`
642    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
643        self.tensor_maps.extend(bindings);
644        self
645    }
646}
647
648/// Binding of a set of scalars of the same type to execute a kernel.
649#[derive(new, Debug, Default)]
650pub struct MetadataBinding {
651    /// Metadata values
652    pub data: Vec<u64>,
653    /// Length of the static portion (rank, len, `buffer_len`, `shape_offsets`, `stride_offsets`).
654    pub static_len: usize,
655}
656
657/// Binding of a set of scalars of the same type to execute a kernel.
658#[derive(new, Debug, Clone)]
659pub struct ScalarBinding {
660    /// Type of the scalars
661    pub ty: StorageType,
662    /// Unpadded length of the underlying data
663    pub length: usize,
664    /// Type-erased data of the scalars. Padded and represented by u64 to prevent misalignment.
665    pub data: Vec<u64>,
666}
667
668impl ScalarBinding {
669    /// Get data as byte slice
670    pub fn data(&self) -> &[u8] {
671        bytemuck::cast_slice(&self.data)
672    }
673}
674
675/// Binding of a [tensor handle](Handle) to execute a kernel.
676#[derive(new, Debug)]
677pub struct Binding {
678    /// Memory binding.
679    pub memory: SliceBinding,
680    /// Memory offset in bytes.
681    pub offset_start: Option<u64>,
682    /// Memory offset in bytes.
683    pub offset_end: Option<u64>,
684    /// The stream where the data was created.
685    pub stream: cubecl_common::stream_id::StreamId,
686    /// The stream position when the tensor became available.
687    pub cursor: u64,
688    /// Size in bytes
689    size: u64,
690}
691
692impl Binding {
693    /// Get the size of the handle, in bytes, accounting for offsets
694    pub fn size(&self) -> u64 {
695        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
696    }
697}
698
699/// A binding with shape and stride info for non-contiguous reading
700#[derive(new, Debug, Clone)]
701pub struct CopyDescriptor<'a> {
702    /// Binding for the memory resource
703    pub binding: Binding,
704    /// Shape of the resource
705    pub shape: &'a [usize],
706    /// Strides of the resource
707    pub strides: &'a [usize],
708    /// Size of each element in the resource
709    pub elem_size: usize,
710}
711
712/// A tensor map used with TMA ops
713#[derive(new, Debug, Clone)]
714pub struct TensorMapBinding {
715    /// The binding for the backing tensor
716    pub binding: Binding,
717    /// The tensormap metadata
718    pub map: TensorMapMeta,
719}
720
721/// `TensorMap` metadata for the opaque proxy used in TMA copies
722#[derive(Debug, Clone)]
723pub struct TensorMapMeta {
724    /// Tensormap format (tiled or im2col)
725    pub format: TensorMapFormat,
726    /// Rank of the backing tensor
727    pub rank: usize,
728    /// Shape of the backing tensor
729    pub shape: Vec<usize>,
730    /// Strides of the backing tensor
731    pub strides: Vec<usize>,
732    /// Element stride, usually 1 but may be 2 for complex tensors
733    /// For im2col, this is equivalent to the kernel stride
734    pub elem_stride: Vec<usize>,
735    /// Interleave mode
736    pub interleave: TensorMapInterleave,
737    /// Swizzle mode
738    pub swizzle: TensorMapSwizzle,
739    /// Prefetch settings
740    pub prefetch: TensorMapPrefetch,
741    /// OOB fill value
742    pub oob_fill: OobFill,
743    /// Storage type
744    pub storage_ty: StorageType,
745}
746
747impl Handle {
748    /// If the tensor handle can be reused inplace.
749    pub fn can_mut(&self) -> bool {
750        self.memory.can_mut() && self.stream == StreamId::current()
751    }
752}
753
754impl Handle {
755    /// Convert the [handle](Handle) into a [binding](Binding).
756    pub fn binding(self) -> Binding {
757        Binding {
758            memory: MemoryHandle::binding(self.memory),
759            offset_start: self.offset_start,
760            offset_end: self.offset_end,
761            size: self.size,
762            stream: self.stream,
763            cursor: self.cursor,
764        }
765    }
766
767    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
768    pub fn copy_descriptor<'a>(
769        &'a self,
770        shape: &'a [usize],
771        strides: &'a [usize],
772        elem_size: usize,
773    ) -> CopyDescriptor<'a> {
774        CopyDescriptor {
775            shape,
776            strides,
777            elem_size,
778            binding: self.clone().binding(),
779        }
780    }
781}
782
783impl Clone for Handle {
784    fn clone(&self) -> Self {
785        Self {
786            memory: self.memory.clone(),
787            offset_start: self.offset_start,
788            offset_end: self.offset_end,
789            size: self.size,
790            stream: self.stream,
791            cursor: self.cursor,
792        }
793    }
794}
795
796impl Clone for Binding {
797    fn clone(&self) -> Self {
798        Self {
799            memory: self.memory.clone(),
800            offset_start: self.offset_start,
801            offset_end: self.offset_end,
802            size: self.size,
803            stream: self.stream,
804            cursor: self.cursor,
805        }
806    }
807}
808
809/// Specifieds the number of cubes to be dispatched for a kernel.
810///
811/// This translates to eg. a grid for CUDA, or to `num_workgroups` for wgsl.
812#[allow(clippy::large_enum_variant)]
813pub enum CubeCount {
814    /// Dispatch a known count of x, y, z cubes.
815    Static(u32, u32, u32),
816    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
817    Dynamic(Binding),
818}
819
820/// Defines how to select cube count based on the number of cubes required.
821pub enum CubeCountSelection {
822    /// If the number of cubes is the same as required.
823    Exact(CubeCount),
824    /// If the number of cubes isn't the same as required.
825    ///
826    /// This can happen based on the hardware limit, requiring the kernel to perform OOB checks.
827    Approx(CubeCount, u32),
828}
829
830impl CubeCountSelection {
831    /// Creates a [`CubeCount`] while respecting the hardware limits.
832    pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
833        let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
834
835        let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
836        let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
837
838        match num_cubes_actual == num_cubes {
839            true => CubeCountSelection::Exact(cube_count),
840            false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
841        }
842    }
843
844    /// If some cubes will be idle.
845    pub fn has_idle(&self) -> bool {
846        matches!(self, Self::Approx(..))
847    }
848
849    /// Converts into [`CubeCount`].
850    pub fn cube_count(self) -> CubeCount {
851        match self {
852            CubeCountSelection::Exact(cube_count) => cube_count,
853            CubeCountSelection::Approx(cube_count, _) => cube_count,
854        }
855    }
856}
857
858impl From<CubeCountSelection> for CubeCount {
859    fn from(value: CubeCountSelection) -> Self {
860        value.cube_count()
861    }
862}
863
864impl CubeCount {
865    /// Create a new static cube count with the given x = y = z = 1.
866    pub fn new_single() -> Self {
867        CubeCount::Static(1, 1, 1)
868    }
869
870    /// Create a new static cube count with the given x, and y = z = 1.
871    pub fn new_1d(x: u32) -> Self {
872        CubeCount::Static(x, 1, 1)
873    }
874
875    /// Create a new static cube count with the given x and y, and z = 1.
876    pub fn new_2d(x: u32, y: u32) -> Self {
877        CubeCount::Static(x, y, 1)
878    }
879
880    /// Create a new static cube count with the given x, y and z.
881    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
882        CubeCount::Static(x, y, z)
883    }
884}
885
886impl Debug for CubeCount {
887    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
888        match self {
889            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
890            CubeCount::Dynamic(_) => f.write_str("binding"),
891        }
892    }
893}
894
895impl Clone for CubeCount {
896    fn clone(&self) -> Self {
897        match self {
898            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
899            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
900        }
901    }
902}
903
904#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
905#[allow(missing_docs)]
906/// The number of units across all 3 axis totalling to the number of working units in a cube.
907pub struct CubeDim {
908    /// The number of units in the x axis.
909    pub x: u32,
910    /// The number of units in the y axis.
911    pub y: u32,
912    /// The number of units in the z axis.
913    pub z: u32,
914}
915
916impl CubeDim {
917    /// Creates a new [`CubeDim`] based on the maximum number of tasks that can be parellalized by units, in other words,
918    /// by the maximum number of working units.
919    ///
920    /// # Notes
921    ///
922    /// For complex problems, you probably want to have your own logic function to create the
923    /// [`CubeDim`], but for simpler problems such as elemwise-operation, this is a great default.
924    pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
925        let properties = client.properties();
926        let plane_size = properties.hardware.plane_size_max;
927        let plane_count = Self::calculate_plane_count_per_cube(
928            working_units as u32,
929            plane_size,
930            properties.hardware.num_cpu_cores,
931        );
932
933        // Make sure it respects the max units per cube (especially on wasm)
934        let limit = properties.hardware.max_units_per_cube / plane_size;
935
936        Self::new_2d(plane_size, u32::min(limit, plane_count))
937    }
938
939    fn calculate_plane_count_per_cube(
940        working_units: u32,
941        plane_dim: u32,
942        num_cpu_cores: Option<u32>,
943    ) -> u32 {
944        match num_cpu_cores {
945            Some(num_cores) => core::cmp::min(num_cores, working_units),
946            None => {
947                let plane_count_max = core::cmp::max(1, working_units / plane_dim);
948
949                // Ensures `plane_count` is a power of 2.
950                const NUM_PLANE_MAX: u32 = 8u32;
951                const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
952                let plane_count_max_log2 =
953                    core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
954                2u32.pow(plane_count_max_log2)
955            }
956        }
957    }
958
959    /// Create a new cube dim with x = y = z = 1.
960    pub const fn new_single() -> Self {
961        Self { x: 1, y: 1, z: 1 }
962    }
963
964    /// Create a new cube dim with the given x, and y = z = 1.
965    pub const fn new_1d(x: u32) -> Self {
966        Self { x, y: 1, z: 1 }
967    }
968
969    /// Create a new cube dim with the given x and y, and z = 1.
970    pub const fn new_2d(x: u32, y: u32) -> Self {
971        Self { x, y, z: 1 }
972    }
973
974    /// Create a new cube dim with the given x, y and z.
975    /// This is equivalent to the [new](CubeDim::new) function.
976    pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
977        Self { x, y, z }
978    }
979
980    /// Total numbers of units per cube
981    pub const fn num_elems(&self) -> u32 {
982        self.x * self.y * self.z
983    }
984
985    /// Whether this `CubeDim` can fully contain `other`
986    pub const fn can_contain(&self, other: CubeDim) -> bool {
987        self.x >= other.x && self.y >= other.y && self.z >= other.z
988    }
989}
990
991impl From<(u32, u32, u32)> for CubeDim {
992    fn from(value: (u32, u32, u32)) -> Self {
993        CubeDim::new_3d(value.0, value.1, value.2)
994    }
995}
996
997impl From<CubeDim> for (u32, u32, u32) {
998    fn from(val: CubeDim) -> Self {
999        (val.x, val.y, val.z)
1000    }
1001}
1002
1003/// The kind of execution to be performed.
1004#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
1005pub enum ExecutionMode {
1006    /// Checked kernels are safe.
1007    #[default]
1008    Checked,
1009    /// Unchecked kernels are unsafe.
1010    Unchecked,
1011}
1012
1013fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
1014    let max_cube_counts = [max.0, max.1, max.2];
1015    let mut num_cubes = [num_cubes, 1, 1];
1016    let base = 2;
1017
1018    let mut reduce_count = |i: usize| {
1019        if num_cubes[i] <= max_cube_counts[i] {
1020            return true;
1021        }
1022
1023        loop {
1024            num_cubes[i] = num_cubes[i].div_ceil(base);
1025            num_cubes[i + 1] *= base;
1026
1027            if num_cubes[i] <= max_cube_counts[i] {
1028                return false;
1029            }
1030        }
1031    };
1032
1033    for i in 0..2 {
1034        if reduce_count(i) {
1035            break;
1036        }
1037    }
1038
1039    num_cubes
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045
1046    #[test_log::test]
1047    fn safe_num_cubes_even() {
1048        let max = (32, 32, 32);
1049        let required = 2048;
1050
1051        let actual = cube_count_spread(&max, required);
1052        let expected = [32, 32, 2];
1053        assert_eq!(actual, expected);
1054    }
1055
1056    #[test_log::test]
1057    fn safe_num_cubes_odd() {
1058        let max = (48, 32, 16);
1059        let required = 3177;
1060
1061        let actual = cube_count_spread(&max, required);
1062        let expected = [25, 32, 4];
1063        assert_eq!(actual, expected);
1064    }
1065}