Skip to main content

cubecl_runtime/server/
base.rs

1use super::Handle;
2use crate::{
3    client::ComputeClient,
4    compiler::CompilationError,
5    config::{CubeClRuntimeConfig, RuntimeConfig, compilation::BoundsCheckMode},
6    kernel::KernelMetadata,
7    logging::ServerLogger,
8    memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryUsage},
9    runtime::Runtime,
10    server::Binding,
11    storage::{ComputeStorage, ManagedResource},
12    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use ahash::AHasher;
15use alloc::boxed::Box;
16#[cfg(feature = "profile-tracy")]
17use alloc::format;
18use alloc::string::String;
19use alloc::sync::Arc;
20use alloc::vec::Vec;
21use core::{
22    fmt::Debug,
23    hash::{Hash, Hasher},
24};
25use cubecl_common::{
26    backtrace::BackTrace,
27    bytes::Bytes,
28    device::{self, DeviceId},
29    future::DynFut,
30    profile::ProfileDuration,
31    stream_id::StreamId,
32    stub::RwLock,
33};
34use cubecl_ir::{DeviceProperties, ElemType, StorageType};
35use cubecl_zspace::{Shape, Strides, metadata::Metadata};
36use hashbrown::HashSet;
37use thiserror::Error;
38
39#[derive(Error, Clone)]
40#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
41/// An error during profiling.
42pub enum ProfileError {
43    /// An unknown error happened during profiling
44    #[error(
45        "An unknown error happened during profiling\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
46    )]
47    Unknown {
48        /// The caused of the error
49        reason: String,
50        /// The captured backtrace.
51        #[cfg_attr(std_io, serde(skip))]
52        backtrace: BackTrace,
53    },
54
55    /// No profiling was registered
56    #[error("No profiling registered\nBacktrace:\n{backtrace}")]
57    NotRegistered {
58        /// The captured backtrace.
59        #[cfg_attr(std_io, serde(skip))]
60        backtrace: BackTrace,
61    },
62
63    /// A launch error happened during profiling
64    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
65    Launch(#[from] LaunchError),
66
67    /// An execution error happened during profiling
68    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
69    Server(#[from] Box<ServerError>),
70}
71
72impl core::fmt::Debug for ProfileError {
73    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74        f.write_fmt(format_args!("{self}"))
75    }
76}
77
78/// Contains many different types that are useful for server implementations and compute clients.
79pub struct ServerUtilities<Server: ComputeServer> {
80    /// The time when `profile-tracy` is activated.
81    #[cfg(feature = "profile-tracy")]
82    pub epoch_time: web_time::Instant,
83    /// The GPU client when `profile-tracy` is activated.
84    #[cfg(feature = "profile-tracy")]
85    pub gpu_client: tracy_client::GpuContext,
86    /// Information shared between all servers.
87    pub properties: DeviceProperties,
88    /// Stable hash of the device properties
89    pub properties_hash: u64,
90    /// Information specific to the current server.
91    pub info: Server::Info,
92    /// The logger based on global cubecl configs.
93    pub logger: Arc<ServerLogger>,
94    /// How to create the allocation.
95    pub layout_policy: Server::MemoryLayoutPolicy,
96    /// How to enforce bounds checking on kernels.
97    pub check_mode: BoundsCheckMode,
98    /// A set containing the ids for which the inter-device communication has already been initialized.
99    pub initialized_comms: RwLock<HashSet<CommunicationId>>,
100}
101
102/// Defines how the memory layout is determined.
103pub trait MemoryLayoutPolicy: Send + Sync + 'static {
104    /// Applies the memory layout policy to a list of descriptors.
105    ///
106    /// Returns a vector of `MemoryLayout`, one per descriptor, with layouts that share a
107    /// single `Binding`.
108    fn apply(
109        &self,
110        stream_id: StreamId,
111        descriptors: &[MemoryLayoutDescriptor],
112    ) -> (Handle, Vec<MemoryLayout>);
113}
114
115impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
116where
117    Server: ComputeServer,
118    Server::Info: core::fmt::Debug,
119{
120    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
121        f.debug_struct("ServerUtilities")
122            .field("properties", &self.properties)
123            .field("info", &self.info)
124            .field("logger", &self.logger)
125            .finish()
126    }
127}
128
129impl<S: ComputeServer> ServerUtilities<S> {
130    /// Creates a new server utilities.
131    pub fn new(
132        properties: DeviceProperties,
133        logger: Arc<ServerLogger>,
134        info: S::Info,
135        allocator: S::MemoryLayoutPolicy,
136    ) -> Self {
137        // Start a tracy client if needed.
138        #[cfg(feature = "profile-tracy")]
139        let client = tracy_client::Client::start();
140
141        Self {
142            properties_hash: properties.checksum(),
143            properties,
144            logger,
145            // Create the GPU client if needed.
146            #[cfg(feature = "profile-tracy")]
147            gpu_client: client
148                .clone()
149                .new_gpu_context(
150                    Some(&format!("{info:?}")),
151                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
152                    tracy_client::GpuContextType::Invalid,
153                    0,   // Timestamps are manually aligned to this epoch so start at 0.
154                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
155                )
156                .unwrap(),
157            #[cfg(feature = "profile-tracy")]
158            epoch_time: web_time::Instant::now(),
159            info,
160            layout_policy: allocator,
161            check_mode: CubeClRuntimeConfig::get().compilation.check_mode,
162            initialized_comms: RwLock::new(HashSet::default()),
163        }
164    }
165}
166
167/// Kernel Launch Errors.
168#[derive(Error, Clone)]
169#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
170pub enum LaunchError {
171    /// The given kernel can't be compiled.
172    #[error("A compilation error happened during launch\nCaused by:\n  {0}")]
173    CompilationError(#[from] CompilationError),
174
175    /// The server is out of memory.
176    #[error(
177        "An out-of-memory error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
178    )]
179    OutOfMemory {
180        /// The caused of the memory error.
181        reason: String,
182        /// The backtrace for this error.
183        #[cfg_attr(std_io, serde(skip))]
184        backtrace: BackTrace,
185    },
186
187    /// Too many resources were requested
188    #[error("Too many resources were requested during launch\n{0}")]
189    TooManyResources(#[from] ResourceLimitError),
190
191    /// Unknown launch error.
192    #[error(
193        "An unknown error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
194    )]
195    Unknown {
196        /// The caused of the unknown error.
197        reason: String,
198        /// The backtrace for this error.
199        #[cfg_attr(std_io, serde(skip))]
200        backtrace: BackTrace,
201    },
202
203    /// Can't launch because of an IO Error.
204    #[error("An io error happened during launch\nCaused by:\n  {0}")]
205    IoError(#[from] IoError),
206}
207
208/// Resource limit errors.
209#[derive(Error, Clone)]
210#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
211pub enum ResourceLimitError {
212    /// Shared memory exceeds maximum
213    #[error(
214        "Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
215    )]
216    SharedMemory {
217        /// Value requested
218        requested: usize,
219        /// Maximum value
220        max: usize,
221        /// The backtrace for this error.
222        #[cfg_attr(std_io, serde(skip))]
223        backtrace: BackTrace,
224    },
225    /// Total units exceeds maximum
226    #[error(
227        "Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
228    )]
229    Units {
230        /// Requested value
231        requested: u32,
232        /// Maximum value
233        max: u32,
234        /// The backtrace for this error.
235        #[cfg_attr(std_io, serde(skip))]
236        backtrace: BackTrace,
237    },
238    /// `CubeDim` exceeds maximum
239    #[error(
240        "Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
241    )]
242    CubeDim {
243        /// Requested value
244        requested: (u32, u32, u32),
245        /// Maximum value
246        max: (u32, u32, u32),
247        /// The backtrace for this error.
248        #[cfg_attr(std_io, serde(skip))]
249        backtrace: BackTrace,
250    },
251}
252
253impl core::fmt::Debug for LaunchError {
254    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
255        f.write_fmt(format_args!("{self}"))
256    }
257}
258
259impl core::fmt::Debug for ResourceLimitError {
260    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
261        f.write_fmt(format_args!("{self}"))
262    }
263}
264
265/// Error that can happen asynchronously while executing registered kernels.
266#[derive(Error, Debug, Clone)]
267#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
268pub enum ServerError {
269    /// A generic runtime error.
270    #[error("An error happened during execution\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}")]
271    Generic {
272        /// The details of the generic error.
273        reason: String,
274        /// The backtrace for this error.
275        #[cfg_attr(std_io, serde(skip))]
276        backtrace: BackTrace,
277    },
278
279    /// A launch error happened during profiling
280    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
281    Launch(#[from] LaunchError),
282
283    /// An execution error happened during profiling
284    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
285    Profile(#[from] ProfileError),
286
287    /// An execution error happened during profiling
288    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
289    Io(#[from] IoError),
290
291    /// The server is an invalid state.
292    #[error("The server is in an invalid state\nCaused by:\n  {errors:?}")]
293    ServerUnhealthy {
294        /// The details of the generic error.
295        errors: Vec<Self>,
296        /// The backtrace for this error.
297        #[cfg_attr(std_io, serde(skip))]
298        backtrace: BackTrace,
299    },
300}
301
302/// How errors are handled in a stream when executing a task.
303#[derive(Clone, Copy)]
304pub struct StreamErrorMode {
305    /// Whether the task still executes even if the stream is in error.
306    pub ignore: bool,
307    /// Whether the errors are flushed by the current task.
308    pub flush: bool,
309}
310
311/// The compute server is responsible for handling resources and computations over resources.
312///
313/// Everything in the server is mutable, therefore it should be solely accessed through the
314/// [`ComputeClient`] for thread safety.
315pub trait ComputeServer:
316    Send + core::fmt::Debug + ServerCommunication + device::DeviceService + 'static
317where
318    Self: Sized,
319{
320    /// The kernel type defines the computation algorithms.
321    type Kernel: KernelMetadata;
322    /// Information that can be retrieved for the runtime.
323    type Info: Debug + Send + Sync;
324    /// Manages how allocations are performed for a server.
325    type MemoryLayoutPolicy: MemoryLayoutPolicy;
326    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
327    type Storage: ComputeStorage;
328
329    /// Initializes [memory](ManagedMemoryHandle) on the given [stream](StreamId) with the given size.
330    fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId);
331
332    /// Reserves N [Bytes] of the provided sizes to be used as staging to load data.
333    fn staging(
334        &mut self,
335        _sizes: &[usize],
336        _stream_id: StreamId,
337    ) -> Result<Vec<Bytes>, ServerError> {
338        Err(IoError::UnsupportedIoOperation {
339            backtrace: BackTrace::capture(),
340        }
341        .into())
342    }
343
344    /// Retrieve the server logger.
345    fn logger(&self) -> Arc<ServerLogger>;
346
347    /// Retrieve the server utilities.
348    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
349
350    /// Given bindings, returns the owned resources as bytes.
351    fn read(
352        &mut self,
353        descriptors: Vec<CopyDescriptor>,
354        stream_id: StreamId,
355    ) -> DynFut<Result<Vec<Bytes>, ServerError>>;
356
357    /// Writes the specified bytes into the buffers given
358    fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId);
359
360    /// Wait for the completion of every task in the server.
361    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>>;
362
363    /// Given a resource handle, returns the storage resource.
364    fn get_resource(
365        &mut self,
366        binding: Binding,
367        stream_id: StreamId,
368    ) -> Result<ManagedResource<<Self::Storage as ComputeStorage>::Resource>, ServerError>;
369
370    /// Executes the `kernel` over the given memory `handles`.
371    ///
372    /// Kernels have mutable access to every resource they are given
373    /// and are responsible of determining which should be read or written.
374    ///
375    /// # Safety
376    ///
377    /// When executing with mode [`ExecutionMode::Unchecked`], out-of-bound reads and writes can happen.
378    unsafe fn launch(
379        &mut self,
380        kernel: Self::Kernel,
381        count: CubeCount,
382        bindings: KernelArguments,
383        kind: ExecutionMode,
384        stream_id: StreamId,
385    );
386
387    /// Flush all outstanding tasks in the server.
388    fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError>;
389
390    /// The current memory usage of the server.
391    fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError>;
392
393    /// Ask the server to release memory that it can release.
394    fn memory_cleanup(&mut self, stream_id: StreamId);
395
396    /// Enable collecting timestamps.
397    fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError>;
398
399    /// Disable collecting timestamps.
400    fn end_profile(
401        &mut self,
402        stream_id: StreamId,
403        token: ProfilingToken,
404    ) -> Result<ProfileDuration, ProfileError>;
405
406    /// Update the memory mode of allocation in the server.
407    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
408}
409
410/// An ID unique to any unordered combination of devices.
411#[derive(Clone, Debug, Hash, Eq, PartialEq)]
412pub struct CommunicationId {
413    /// The ID as a `String`.
414    pub id: u64,
415}
416
417impl From<Vec<DeviceId>> for CommunicationId {
418    fn from(mut value: Vec<DeviceId>) -> Self {
419        // Make sure that device ids are sorted so that any combination of the same devices uses the same communicator.
420        value.sort();
421        let mut hasher = AHasher::default();
422        value.hash(&mut hasher);
423        CommunicationId {
424            id: hasher.finish(),
425        }
426    }
427}
428
429/// Different reduce operations.
430pub enum ReduceOperation {
431    /// Sum.
432    Sum,
433    /// Mean.
434    Mean,
435}
436
437/// Defines functions for optimized data transfer between servers, supporting custom communication
438/// mechanisms such as peer-to-peer communication or specialized implementations.
439pub trait ServerCommunication {
440    /// Indicates whether server-to-server communication is enabled for this implementation.
441    const SERVER_COMM_ENABLED: bool;
442
443    /// Ensure that all queued collective operations have been executed.
444    ///
445    /// # Arguments
446    ///
447    /// * `stream_id` - The [`StreamId`] of the stream waiting for the sync.
448    ///
449    /// # Returns
450    ///
451    /// Returns a `Result` containing an `ServerError` if the operation fails.
452    #[allow(unused_variables)]
453    fn sync_collective(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
454        todo!() // For backends other than cuda.
455    }
456
457    /// Initialize the communication between the devices in `device_ids`.
458    ///
459    /// # Arguments
460    ///
461    /// * `device_ids` - The IDs of the devices that need communication.
462    ///
463    /// # Returns
464    ///
465    /// Returns a `Result` containing an `ServerError` if the operation fails.
466    #[allow(unused_variables)]
467    fn comm_init(&mut self, device_ids: Vec<DeviceId>) -> Result<(), ServerError> {
468        unimplemented!()
469    }
470
471    /// Performs an `all_reduce` operation on the input data and writes it to the output buffer.
472    /// see <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce>
473    ///
474    /// # Arguments
475    ///
476    /// * `src` - The data to be reduced.
477    /// * `dst` - Where to write the result.
478    /// * `dtype` - The element type of the data being reduced
479    /// * `stream_id` - The data's stream id.
480    /// * `op` - The reduce's aggregation operation e.g. mean, sum, etc.
481    /// * `device_ids` - The list of device ids from which to `all_reduce`.
482    ///
483    /// # Returns
484    ///
485    /// Returns a `Result` containing an `ServerError` if the operation fails.
486    #[allow(unused_variables)]
487    fn all_reduce(
488        &mut self,
489        src: Binding,
490        dst: Binding,
491        dtype: ElemType,
492        stream_id: StreamId,
493        op: ReduceOperation,
494        device_ids: Vec<DeviceId>,
495    ) -> Result<(), ServerError> {
496        unimplemented!()
497    }
498
499    /// Sends data from this server to a destination server.
500    ///
501    /// # Arguments
502    ///
503    /// * `desc` - A descriptor specifying the data to be sent, including shape, strides, and binding.
504    /// * `dtype` - The element type of the data being sent.
505    /// * `stream_id` - The stream ID associated with the server's operation.
506    /// * `device_id_dst` - ID of the device receiving the data.
507    ///
508    /// # Returns
509    ///
510    /// Returns a `Result` containing an `ServerError` if the operation fails.
511    #[allow(unused_variables)]
512    fn send(
513        &mut self,
514        desc: CopyDescriptor,
515        dtype: ElemType,
516        stream_id: StreamId,
517        device_id_dst: DeviceId,
518    ) -> Result<(), ServerError> {
519        unimplemented!()
520    }
521
522    /// Receive data from another server.
523    ///
524    /// # Arguments
525    ///
526    /// * `handle` - The handle in which the received data is written.
527    /// * `dtype` - The element type of the data being sent.
528    /// * `stream_id` - The stream ID associated with the server's operation.
529    /// * `device_id_src` - ID of the device sending the data.
530    ///
531    /// # Returns
532    ///
533    /// Returns a `Result` containing an `ServerError` if the operation fails.
534    #[allow(unused_variables)]
535    fn recv(
536        &mut self,
537        handle: Handle,
538        dtype: ElemType,
539        stream_id: StreamId,
540        device_id_src: DeviceId,
541    ) -> Result<(), ServerError> {
542        unimplemented!()
543    }
544}
545
546#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
547/// Profiling identification so that the server can support recursive and overlapping profilings.
548pub struct ProfilingToken {
549    /// The token value.
550    pub id: u64,
551}
552
553/// Type of allocation, either contiguous or optimized (row-aligned when possible)
554#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
555pub enum MemoryLayoutStrategy {
556    /// Contiguous layout, with no padding
557    Contiguous,
558    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
559    /// that support it.
560    Optimized,
561}
562
563/// Descriptor for a new tensor allocation
564#[derive(new, Debug, Clone)]
565pub struct MemoryLayoutDescriptor {
566    /// Strategy used to create the memory layout.
567    pub strategy: MemoryLayoutStrategy,
568    /// Shape of the tensor
569    pub shape: Shape,
570    /// Size of each element in the tensor (used for conversion of shape to bytes)
571    pub elem_size: usize,
572}
573
574impl MemoryLayoutDescriptor {
575    /// Create an optimized allocation descriptor
576    pub fn optimized(shape: Shape, elem_size: usize) -> Self {
577        MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size)
578    }
579
580    /// Create a contiguous allocation descriptor
581    pub fn contiguous(shape: Shape, elem_size: usize) -> Self {
582        MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, elem_size)
583    }
584}
585
586/// An allocation with associated strides. Strides depend on tensor layout.
587#[derive(Debug, Clone)]
588pub struct MemoryLayout {
589    /// The handle for the memory resource
590    pub memory: Handle,
591    /// TODO: `Strides` should become `Layout`.
592    ///
593    /// The strides of the tensor
594    pub strides: Strides,
595}
596
597impl MemoryLayout {
598    /// Create a new memory layout.
599    pub fn new(handle: Handle, strides: impl Into<Strides>) -> Self {
600        MemoryLayout {
601            memory: handle,
602            strides: strides.into(),
603        }
604    }
605}
606
607/// A reason for an error.
608#[derive(Default, Clone)]
609pub struct Reason {
610    inner: ReasonInner,
611}
612
613#[cfg(std_io)]
614mod _reason_serde {
615    use super::*;
616
617    use alloc::string::ToString;
618    use serde::{Deserialize, Deserializer, Serialize, Serializer};
619
620    impl Serialize for Reason {
621        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
622        where
623            S: Serializer,
624        {
625            // Use the Display implementation (via to_string) to flatten the enum
626            serializer.serialize_str(&self.to_string())
627        }
628    }
629
630    impl<'de> Deserialize<'de> for Reason {
631        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
632        where
633            D: Deserializer<'de>,
634        {
635            // Deserialize into a standard String first
636            let s = String::deserialize(deserializer)?;
637
638            // Wrap it in the Dynamic variant since we can't safely
639            // reconstruct a 'static str from a runtime string.
640            Ok(Reason {
641                inner: ReasonInner::Dynamic(Arc::new(s)),
642            })
643        }
644    }
645}
646
647#[derive(Default, Clone)]
648enum ReasonInner {
649    Static(&'static str),
650    Dynamic(Arc<String>),
651    #[default]
652    NotProvided,
653}
654
655impl core::fmt::Display for Reason {
656    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
657        match &self.inner {
658            ReasonInner::Static(content) => f.write_str(content),
659            ReasonInner::Dynamic(content) => f.write_str(content),
660            ReasonInner::NotProvided => f.write_str("No reason provided for the error"),
661        }
662    }
663}
664
665impl core::fmt::Debug for Reason {
666    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
667        core::fmt::Display::fmt(&self, f)
668    }
669}
670
671impl From<&'static str> for Reason {
672    fn from(value: &'static str) -> Self {
673        Self {
674            inner: ReasonInner::Static(value),
675        }
676    }
677}
678
679impl From<String> for Reason {
680    fn from(value: String) -> Self {
681        Self {
682            inner: ReasonInner::Dynamic(Arc::new(value)),
683        }
684    }
685}
686
687/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
688/// are able to be caught, so some IO errors will still panic.
689#[derive(Error, Clone)]
690#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
691pub enum IoError {
692    /// Buffer size exceeds the max available
693    #[error("can't allocate buffer of size: {size}\n{backtrace}")]
694    BufferTooBig {
695        /// The size of the buffer in bytes.
696        size: u64,
697        /// The captured backtrace.
698        #[cfg_attr(std_io, serde(skip))]
699        backtrace: BackTrace,
700    },
701
702    /// Strides aren't supported for this copy operation on this runtime
703    #[error("the provided strides are not supported for this operation\n{backtrace}")]
704    UnsupportedStrides {
705        /// The backtrace.
706        #[cfg_attr(std_io, serde(skip))]
707        backtrace: BackTrace,
708    },
709
710    /// Memory wasn't found in the memory pool
711    #[error("couldn't find resource for that handle: {reason}\n{backtrace}")]
712    NotFound {
713        /// The backtrace.
714        #[cfg_attr(std_io, serde(skip))]
715        backtrace: BackTrace,
716        /// The reason the handle is invalid.
717        reason: Reason,
718    },
719
720    /// Handle wasn't found in the memory pool
721    #[error("couldn't free the handle, since it is currently in used. \n{backtrace}")]
722    FreeError {
723        /// The backtrace.
724        #[cfg_attr(std_io, serde(skip))]
725        backtrace: BackTrace,
726    },
727
728    /// Unknown error happened during execution
729    #[error("Unknown error happened during execution\n{backtrace}")]
730    Unknown {
731        /// Details of the error
732        description: String,
733        /// The backtrace.
734        #[cfg_attr(std_io, serde(skip))]
735        backtrace: BackTrace,
736    },
737
738    /// The current IO operation is not supported
739    #[error("The current IO operation is not supported\n{backtrace}")]
740    UnsupportedIoOperation {
741        /// The backtrace.
742        #[cfg_attr(std_io, serde(skip))]
743        backtrace: BackTrace,
744    },
745
746    /// Can't perform the IO operation because of a runtime error.
747    #[error("Can't perform the IO operation because of a runtime error: {0}")]
748    Execution(#[from] Box<ServerError>),
749}
750
751impl core::fmt::Debug for IoError {
752    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
753        f.write_fmt(format_args!("{self}"))
754    }
755}
756
757/// Arguments to execute a kernel.
758#[derive(Debug, Default)]
759pub struct KernelArguments {
760    /// Buffer bindings
761    pub buffers: Vec<Binding>,
762    /// Packed scalars and metadata. First scalars sorted by type, then static metadata,
763    /// then dynamic metadata.
764    pub info: MetadataBindingInfo,
765    /// Tensor map bindings
766    pub tensor_maps: Vec<TensorMapBinding>,
767}
768
769impl core::fmt::Display for KernelArguments {
770    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
771        f.write_str("KernelArguments")?;
772        for b in self.buffers.iter() {
773            f.write_fmt(format_args!("\n - buffer: {b:?}\n"))?;
774        }
775
776        Ok(())
777    }
778}
779
780impl KernelArguments {
781    /// Create a new bindings struct
782    pub fn new() -> Self {
783        Self::default()
784    }
785
786    /// Add a buffer binding
787    pub fn with_buffer(mut self, binding: Binding) -> Self {
788        self.buffers.push(binding);
789        self
790    }
791
792    /// Extend the buffers with `bindings`
793    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
794        self.buffers.extend(bindings);
795        self
796    }
797
798    /// Set the info to `info`
799    pub fn with_info(mut self, info: MetadataBindingInfo) -> Self {
800        self.info = info;
801        self
802    }
803
804    /// Extend the tensor maps with `bindings`
805    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
806        self.tensor_maps.extend(bindings);
807        self
808    }
809}
810
811/// Binding of a set of scalars of the same type to execute a kernel.
812///
813/// The [`ComputeServer`] is responsible to convert those info into actual [`Binding`] when launching
814/// kernels.
815#[derive(new, Debug, Default)]
816pub struct MetadataBindingInfo {
817    /// Scalar and metadata values
818    pub data: Vec<u64>,
819    /// Start of the dynamically sized portion of the metadata, relative to the entire info buffer
820    pub dynamic_metadata_offset: usize,
821}
822
823impl MetadataBindingInfo {
824    /// Create a new binding info for custom data, for externally compiled kernels.
825    pub fn custom(data: Vec<u64>) -> Self {
826        Self::new(data, 0)
827    }
828}
829
830/// A binding with shape and stride info for non-contiguous reading
831#[derive(new, Debug)]
832pub struct CopyDescriptor {
833    /// Binding for the memory resource
834    pub handle: Binding,
835    /// Shape of the resource
836    pub shape: Shape,
837    /// Strides of the resource
838    pub strides: Strides,
839    /// Size of each element in the resource
840    pub elem_size: usize,
841}
842
843/// A tensor map used with TMA ops
844#[derive(new, Debug)]
845pub struct TensorMapBinding {
846    /// The binding for the backing tensor
847    pub binding: Binding,
848    /// The tensormap metadata
849    pub map: TensorMapMeta,
850}
851
852/// `TensorMap` metadata for the opaque proxy used in TMA copies
853#[derive(Debug, Clone)]
854pub struct TensorMapMeta {
855    /// Tensormap format (tiled or im2col)
856    pub format: TensorMapFormat,
857    /// Metadata of the backing tensor
858    pub metadata: Metadata,
859    /// Element stride, usually 1 but may be 2 for complex tensors
860    /// For im2col, this is equivalent to the kernel stride
861    pub elem_stride: Strides,
862    /// Interleave mode
863    pub interleave: TensorMapInterleave,
864    /// Swizzle mode
865    pub swizzle: TensorMapSwizzle,
866    /// Prefetch settings
867    pub prefetch: TensorMapPrefetch,
868    /// OOB fill value
869    pub oob_fill: OobFill,
870    /// Storage type
871    pub storage_ty: StorageType,
872}
873
874/// Specifieds the number of cubes to be dispatched for a kernel.
875///
876/// This translates to eg. a grid for CUDA, or to `num_workgroups` for wgsl.
877#[allow(clippy::large_enum_variant)]
878pub enum CubeCount {
879    /// Dispatch a known count of x, y, z cubes.
880    Static(u32, u32, u32),
881    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
882    Dynamic(Binding),
883}
884
885/// Defines how to select cube count based on the number of cubes required.
886pub enum CubeCountSelection {
887    /// If the number of cubes is the same as required.
888    Exact(CubeCount),
889    /// If the number of cubes isn't the same as required.
890    ///
891    /// This can happen based on the hardware limit, requiring the kernel to perform OOB checks.
892    Approx(CubeCount, u32),
893}
894
895impl CubeCountSelection {
896    /// Creates a [`CubeCount`] while respecting the hardware limits.
897    pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
898        let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
899
900        let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
901        let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
902
903        match num_cubes_actual == num_cubes {
904            true => CubeCountSelection::Exact(cube_count),
905            false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
906        }
907    }
908
909    /// If some cubes will be idle.
910    pub fn has_idle(&self) -> bool {
911        matches!(self, Self::Approx(..))
912    }
913
914    /// Converts into [`CubeCount`].
915    pub fn cube_count(self) -> CubeCount {
916        match self {
917            CubeCountSelection::Exact(cube_count) => cube_count,
918            CubeCountSelection::Approx(cube_count, _) => cube_count,
919        }
920    }
921}
922
923impl From<CubeCountSelection> for CubeCount {
924    fn from(value: CubeCountSelection) -> Self {
925        value.cube_count()
926    }
927}
928
929impl CubeCount {
930    /// Create a new static cube count with the given x = y = z = 1.
931    pub fn new_single() -> Self {
932        CubeCount::Static(1, 1, 1)
933    }
934
935    /// Create a new static cube count with the given x, and y = z = 1.
936    pub fn new_1d(x: u32) -> Self {
937        CubeCount::Static(x, 1, 1)
938    }
939
940    /// Create a new static cube count with the given x and y, and z = 1.
941    pub fn new_2d(x: u32, y: u32) -> Self {
942        CubeCount::Static(x, y, 1)
943    }
944
945    /// Create a new static cube count with the given x, y and z.
946    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
947        CubeCount::Static(x, y, z)
948    }
949
950    /// Checks whether the cube count is definitely empty, i.e. has 0 dispatches.
951    pub fn is_empty(&self) -> bool {
952        match self {
953            Self::Static(x, y, z) => *x == 0 || *y == 0 || *z == 0,
954            Self::Dynamic(_) => false,
955        }
956    }
957}
958
959impl Debug for CubeCount {
960    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
961        match self {
962            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
963            CubeCount::Dynamic(_) => f.write_str("binding"),
964        }
965    }
966}
967
968impl Clone for CubeCount {
969    fn clone(&self) -> Self {
970        match self {
971            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
972            Self::Dynamic(binding) => Self::Dynamic(binding.clone()),
973        }
974    }
975}
976
977#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
978#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
979#[allow(missing_docs)]
980/// The number of units across all 3 axis totalling to the number of working units in a cube.
981pub struct CubeDim {
982    /// The number of units in the x axis.
983    pub x: u32,
984    /// The number of units in the y axis.
985    pub y: u32,
986    /// The number of units in the z axis.
987    pub z: u32,
988}
989
990impl CubeDim {
991    /// Creates a new [`CubeDim`] based on the maximum number of tasks that can be parellalized by units, in other words,
992    /// by the maximum number of working units.
993    ///
994    /// # Notes
995    ///
996    /// For complex problems, you probably want to have your own logic function to create the
997    /// [`CubeDim`], but for simpler problems such as elemwise-operation, this is a great default.
998    pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
999        let properties = client.properties();
1000        let plane_size = properties.hardware.plane_size_max;
1001        let plane_count = Self::calculate_plane_count_per_cube(
1002            working_units as u32,
1003            plane_size,
1004            properties.hardware.num_cpu_cores,
1005        );
1006
1007        // Make sure it respects the max units per cube (especially on wasm)
1008        let limit = properties.hardware.max_units_per_cube / plane_size;
1009
1010        // Ensure at least 1 plane so CubeDim is always valid (num_elems() > 0).
1011        Self::new_2d(plane_size, u32::min(limit, plane_count).max(1))
1012    }
1013
1014    fn calculate_plane_count_per_cube(
1015        working_units: u32,
1016        plane_dim: u32,
1017        num_cpu_cores: Option<u32>,
1018    ) -> u32 {
1019        match num_cpu_cores {
1020            Some(num_cores) => core::cmp::min(num_cores, working_units),
1021            None => {
1022                let plane_count_max = core::cmp::max(1, working_units / plane_dim);
1023
1024                // Ensures `plane_count` is a power of 2.
1025                const NUM_PLANE_MAX: u32 = 8u32;
1026                const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
1027                let plane_count_max_log2 =
1028                    core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
1029                2u32.pow(plane_count_max_log2)
1030            }
1031        }
1032    }
1033
1034    /// Create a new cube dim with x = y = z = 1.
1035    pub const fn new_single() -> Self {
1036        Self { x: 1, y: 1, z: 1 }
1037    }
1038
1039    /// Create a new cube dim with the given x, and y = z = 1.
1040    pub const fn new_1d(x: u32) -> Self {
1041        Self { x, y: 1, z: 1 }
1042    }
1043
1044    /// Create a new cube dim with the given x and y, and z = 1.
1045    pub const fn new_2d(x: u32, y: u32) -> Self {
1046        Self { x, y, z: 1 }
1047    }
1048
1049    /// Create a new cube dim with the given x, y and z.
1050    /// This is equivalent to the [new](CubeDim::new) function.
1051    pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
1052        Self { x, y, z }
1053    }
1054
1055    /// Total numbers of units per cube
1056    pub const fn num_elems(&self) -> u32 {
1057        self.x * self.y * self.z
1058    }
1059
1060    /// Whether this `CubeDim` can fully contain `other`
1061    pub const fn can_contain(&self, other: CubeDim) -> bool {
1062        self.x >= other.x && self.y >= other.y && self.z >= other.z
1063    }
1064}
1065
1066impl From<(u32, u32, u32)> for CubeDim {
1067    fn from(value: (u32, u32, u32)) -> Self {
1068        CubeDim::new_3d(value.0, value.1, value.2)
1069    }
1070}
1071
1072impl From<CubeDim> for (u32, u32, u32) {
1073    fn from(val: CubeDim) -> Self {
1074        (val.x, val.y, val.z)
1075    }
1076}
1077
1078/// The kind of execution to be performed.
1079#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
1080#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
1081pub enum ExecutionMode {
1082    /// Checked kernels are safe.
1083    #[default]
1084    Checked,
1085    /// Validate OOB and alert if OOB access occurs
1086    Validate,
1087    /// Unchecked kernels are unsafe.
1088    Unchecked,
1089}
1090
1091fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
1092    let max_cube_counts = [max.0, max.1, max.2];
1093    let mut num_cubes = [num_cubes, 1, 1];
1094    let base = 2;
1095
1096    let mut reduce_count = |i: usize| {
1097        if num_cubes[i] <= max_cube_counts[i] {
1098            return true;
1099        }
1100
1101        loop {
1102            num_cubes[i] = num_cubes[i].div_ceil(base);
1103            num_cubes[i + 1] *= base;
1104
1105            if num_cubes[i] <= max_cube_counts[i] {
1106                return false;
1107            }
1108        }
1109    };
1110
1111    for i in 0..2 {
1112        if reduce_count(i) {
1113            break;
1114        }
1115    }
1116
1117    num_cubes
1118}
1119
1120#[cfg(test)]
1121mod tests {
1122    use super::*;
1123
1124    #[test_log::test]
1125    fn safe_num_cubes_even() {
1126        let max = (32, 32, 32);
1127        let required = 2048;
1128
1129        let actual = cube_count_spread(&max, required);
1130        let expected = [32, 32, 2];
1131        assert_eq!(actual, expected);
1132    }
1133
1134    #[test_log::test]
1135    fn safe_num_cubes_odd() {
1136        let max = (48, 32, 16);
1137        let required = 3177;
1138
1139        let actual = cube_count_spread(&max, required);
1140        let expected = [25, 32, 4];
1141        assert_eq!(actual, expected);
1142    }
1143}