Skip to main content

cubecl_runtime/server/
base.rs

1use super::Handle;
2use crate::{
3    client::ComputeClient,
4    compiler::CompilationError,
5    config::{GlobalConfig, compilation::BoundsCheckMode},
6    kernel::KernelMetadata,
7    logging::ServerLogger,
8    memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryUsage},
9    runtime::Runtime,
10    server::Binding,
11    storage::{ComputeStorage, ManagedResource},
12    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use alloc::boxed::Box;
15#[cfg(feature = "profile-tracy")]
16use alloc::format;
17use alloc::string::String;
18use alloc::sync::Arc;
19use alloc::vec::Vec;
20use core::fmt::Debug;
21use cubecl_common::{
22    backtrace::BackTrace,
23    bytes::Bytes,
24    device::{self, DeviceId},
25    future::DynFut,
26    profile::ProfileDuration,
27    stream_id::StreamId,
28};
29use cubecl_ir::{DeviceProperties, ElemType, StorageType};
30use cubecl_zspace::{Shape, Strides, metadata::Metadata};
31use thiserror::Error;
32
33#[derive(Error, Clone)]
34#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
35/// An error during profiling.
36pub enum ProfileError {
37    /// An unknown error happened during profiling
38    #[error(
39        "An unknown error happened during profiling\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
40    )]
41    Unknown {
42        /// The caused of the error
43        reason: String,
44        /// The captured backtrace.
45        #[cfg_attr(std_io, serde(skip))]
46        backtrace: BackTrace,
47    },
48
49    /// No profiling was registered
50    #[error("No profiling registered\nBacktrace:\n{backtrace}")]
51    NotRegistered {
52        /// The captured backtrace.
53        #[cfg_attr(std_io, serde(skip))]
54        backtrace: BackTrace,
55    },
56
57    /// A launch error happened during profiling
58    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
59    Launch(#[from] LaunchError),
60
61    /// An execution error happened during profiling
62    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
63    Server(#[from] Box<ServerError>),
64}
65
66impl core::fmt::Debug for ProfileError {
67    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
68        f.write_fmt(format_args!("{self}"))
69    }
70}
71
72/// Contains many different types that are useful for server implementations and compute clients.
73pub struct ServerUtilities<Server: ComputeServer> {
74    /// The time when `profile-tracy` is activated.
75    #[cfg(feature = "profile-tracy")]
76    pub epoch_time: web_time::Instant,
77    /// The GPU client when `profile-tracy` is activated.
78    #[cfg(feature = "profile-tracy")]
79    pub gpu_client: tracy_client::GpuContext,
80    /// Information shared between all servers.
81    pub properties: DeviceProperties,
82    /// Stable hash of the device properties
83    pub properties_hash: u64,
84    /// Information specific to the current server.
85    pub info: Server::Info,
86    /// The logger based on global cubecl configs.
87    pub logger: Arc<ServerLogger>,
88    /// How to create the allocation.
89    pub layout_policy: Server::MemoryLayoutPolicy,
90    /// How to enforce bounds checking on kernels.
91    pub check_mode: BoundsCheckMode,
92}
93
94/// Defines how the memory layout is determined.
95pub trait MemoryLayoutPolicy: Send + Sync + 'static {
96    /// Applies the memory layout policy to a list of descriptors.
97    ///
98    /// Returns a vector of `MemoryLayout`, one per descriptor, with layouts that share a
99    /// single `Binding`.
100    fn apply(
101        &self,
102        stream_id: StreamId,
103        descriptors: &[MemoryLayoutDescriptor],
104    ) -> (Handle, Vec<MemoryLayout>);
105}
106
107impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
108where
109    Server: ComputeServer,
110    Server::Info: core::fmt::Debug,
111{
112    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
113        f.debug_struct("ServerUtilities")
114            .field("properties", &self.properties)
115            .field("info", &self.info)
116            .field("logger", &self.logger)
117            .finish()
118    }
119}
120
121impl<S: ComputeServer> ServerUtilities<S> {
122    /// Creates a new server utilities.
123    pub fn new(
124        properties: DeviceProperties,
125        logger: Arc<ServerLogger>,
126        info: S::Info,
127        allocator: S::MemoryLayoutPolicy,
128    ) -> Self {
129        // Start a tracy client if needed.
130        #[cfg(feature = "profile-tracy")]
131        let client = tracy_client::Client::start();
132
133        Self {
134            properties_hash: properties.checksum(),
135            properties,
136            logger,
137            // Create the GPU client if needed.
138            #[cfg(feature = "profile-tracy")]
139            gpu_client: client
140                .clone()
141                .new_gpu_context(
142                    Some(&format!("{info:?}")),
143                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
144                    tracy_client::GpuContextType::Invalid,
145                    0,   // Timestamps are manually aligned to this epoch so start at 0.
146                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
147                )
148                .unwrap(),
149            #[cfg(feature = "profile-tracy")]
150            epoch_time: web_time::Instant::now(),
151            info,
152            layout_policy: allocator,
153            check_mode: GlobalConfig::get().compilation.check_mode,
154        }
155    }
156}
157
158/// Kernel Launch Errors.
159#[derive(Error, Clone)]
160#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
161pub enum LaunchError {
162    /// The given kernel can't be compiled.
163    #[error("A compilation error happened during launch\nCaused by:\n  {0}")]
164    CompilationError(#[from] CompilationError),
165
166    /// The server is out of memory.
167    #[error(
168        "An out-of-memory error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
169    )]
170    OutOfMemory {
171        /// The caused of the memory error.
172        reason: String,
173        /// The backtrace for this error.
174        #[cfg_attr(std_io, serde(skip))]
175        backtrace: BackTrace,
176    },
177
178    /// Too many resources were requested
179    #[error("Too many resources were requested during launch\n{0}")]
180    TooManyResources(#[from] ResourceLimitError),
181
182    /// Unknown launch error.
183    #[error(
184        "An unknown error happened during launch\nCaused by:\n  {reason}\nBacktrace\n{backtrace}"
185    )]
186    Unknown {
187        /// The caused of the unknown error.
188        reason: String,
189        /// The backtrace for this error.
190        #[cfg_attr(std_io, serde(skip))]
191        backtrace: BackTrace,
192    },
193
194    /// Can't launch because of an IO Error.
195    #[error("An io error happened during launch\nCaused by:\n  {0}")]
196    IoError(#[from] IoError),
197}
198
199/// Resource limit errors.
200#[derive(Error, Clone)]
201#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
202pub enum ResourceLimitError {
203    /// Shared memory exceeds maximum
204    #[error(
205        "Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
206    )]
207    SharedMemory {
208        /// Value requested
209        requested: usize,
210        /// Maximum value
211        max: usize,
212        /// The backtrace for this error.
213        #[cfg_attr(std_io, serde(skip))]
214        backtrace: BackTrace,
215    },
216    /// Total units exceeds maximum
217    #[error(
218        "Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
219    )]
220    Units {
221        /// Requested value
222        requested: u32,
223        /// Maximum value
224        max: u32,
225        /// The backtrace for this error.
226        #[cfg_attr(std_io, serde(skip))]
227        backtrace: BackTrace,
228    },
229    /// `CubeDim` exceeds maximum
230    #[error(
231        "Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
232    )]
233    CubeDim {
234        /// Requested value
235        requested: (u32, u32, u32),
236        /// Maximum value
237        max: (u32, u32, u32),
238        /// The backtrace for this error.
239        #[cfg_attr(std_io, serde(skip))]
240        backtrace: BackTrace,
241    },
242}
243
244impl core::fmt::Debug for LaunchError {
245    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
246        f.write_fmt(format_args!("{self}"))
247    }
248}
249
250impl core::fmt::Debug for ResourceLimitError {
251    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
252        f.write_fmt(format_args!("{self}"))
253    }
254}
255
256/// Error that can happen asynchronously while executing registered kernels.
257#[derive(Error, Debug, Clone)]
258#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
259pub enum ServerError {
260    /// A generic runtime error.
261    #[error("An error happened during execution\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}")]
262    Generic {
263        /// The details of the generic error.
264        reason: String,
265        /// The backtrace for this error.
266        #[cfg_attr(std_io, serde(skip))]
267        backtrace: BackTrace,
268    },
269
270    /// A launch error happened during profiling
271    #[error("A launch error happened during profiling\nCaused by:\n  {0}")]
272    Launch(#[from] LaunchError),
273
274    /// An execution error happened during profiling
275    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
276    Profile(#[from] ProfileError),
277
278    /// An execution error happened during profiling
279    #[error("An execution error happened during profiling\nCaused by:\n  {0}")]
280    Io(#[from] IoError),
281
282    /// The server is an invalid state.
283    #[error("The server is in an invalid state\nCaused by:\n  {errors:?}")]
284    ServerUnhealthy {
285        /// The details of the generic error.
286        errors: Vec<Self>,
287        /// The backtrace for this error.
288        #[cfg_attr(std_io, serde(skip))]
289        backtrace: BackTrace,
290    },
291}
292
293/// How errors are handled in a stream when executing a task.
294#[derive(Clone, Copy)]
295pub struct StreamErrorMode {
296    /// Whether the task still executes even if the stream is in error.
297    pub ignore: bool,
298    /// Whether the errors are flushed by the current task.
299    pub flush: bool,
300}
301
302/// The compute server is responsible for handling resources and computations over resources.
303///
304/// Everything in the server is mutable, therefore it should be solely accessed through the
305/// [`ComputeClient`] for thread safety.
306pub trait ComputeServer:
307    Send + core::fmt::Debug + ServerCommunication + device::DeviceService + 'static
308where
309    Self: Sized,
310{
311    /// The kernel type defines the computation algorithms.
312    type Kernel: KernelMetadata;
313    /// Information that can be retrieved for the runtime.
314    type Info: Debug + Send + Sync;
315    /// Manages how allocations are performed for a server.
316    type MemoryLayoutPolicy: MemoryLayoutPolicy;
317    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
318    type Storage: ComputeStorage;
319
320    /// Initializes [memory](ManagedMemoryHandle) on the given [stream](StreamId) with the given size.
321    fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId);
322
323    /// Reserves N [Bytes] of the provided sizes to be used as staging to load data.
324    fn staging(
325        &mut self,
326        _sizes: &[usize],
327        _stream_id: StreamId,
328    ) -> Result<Vec<Bytes>, ServerError> {
329        Err(IoError::UnsupportedIoOperation {
330            backtrace: BackTrace::capture(),
331        }
332        .into())
333    }
334
335    /// Retrieve the server logger.
336    fn logger(&self) -> Arc<ServerLogger>;
337
338    /// Retrieve the server utilities.
339    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
340
341    /// Given bindings, returns the owned resources as bytes.
342    fn read(
343        &mut self,
344        descriptors: Vec<CopyDescriptor>,
345        stream_id: StreamId,
346    ) -> DynFut<Result<Vec<Bytes>, ServerError>>;
347
348    /// Writes the specified bytes into the buffers given
349    fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId);
350
351    /// Wait for the completion of every task in the server.
352    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>>;
353
354    /// Given a resource handle, returns the storage resource.
355    fn get_resource(
356        &mut self,
357        binding: Binding,
358        stream_id: StreamId,
359    ) -> Result<ManagedResource<<Self::Storage as ComputeStorage>::Resource>, ServerError>;
360
361    /// Executes the `kernel` over the given memory `handles`.
362    ///
363    /// Kernels have mutable access to every resource they are given
364    /// and are responsible of determining which should be read or written.
365    ///
366    /// # Safety
367    ///
368    /// When executing with mode [`ExecutionMode::Unchecked`], out-of-bound reads and writes can happen.
369    unsafe fn launch(
370        &mut self,
371        kernel: Self::Kernel,
372        count: CubeCount,
373        bindings: KernelArguments,
374        kind: ExecutionMode,
375        stream_id: StreamId,
376    );
377
378    /// Flush all outstanding tasks in the server.
379    fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError>;
380
381    /// The current memory usage of the server.
382    fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError>;
383
384    /// Ask the server to release memory that it can release.
385    fn memory_cleanup(&mut self, stream_id: StreamId);
386
387    /// Enable collecting timestamps.
388    fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError>;
389
390    /// Disable collecting timestamps.
391    fn end_profile(
392        &mut self,
393        stream_id: StreamId,
394        token: ProfilingToken,
395    ) -> Result<ProfileDuration, ProfileError>;
396
397    /// Update the memory mode of allocation in the server.
398    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
399}
400
401/// Different reduce operations.
402pub enum ReduceOperation {
403    /// Sum.
404    Sum,
405    /// Mean.
406    Mean,
407}
408
409/// Defines functions for optimized data transfer between servers, supporting custom communication
410/// mechanisms such as peer-to-peer communication or specialized implementations.
411pub trait ServerCommunication {
412    /// Indicates whether server-to-server communication is enabled for this implementation.
413    const SERVER_COMM_ENABLED: bool;
414
415    /// Ensure that all queued collective operations have been executed.
416    ///
417    /// # Arguments
418    ///
419    /// * `stream_id` - The [`StreamId`] of the stream waiting for the sync.
420    ///
421    /// # Returns
422    ///
423    /// Returns a `Result` containing an `ServerError` if the operation fails.
424    #[allow(unused_variables)]
425    fn sync_collective(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
426        todo!() // For backends other than cuda.
427    }
428
429    /// Performs an `all_reduce` operation on the input data and writes it to the output buffer.
430    /// see <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce>
431    ///
432    /// # Arguments
433    ///
434    /// * `src` - The data to be reduced.
435    /// * `dst` - Where to write the result.
436    /// * `stream_id` - The data's stream id.
437    /// * `op` - The reduce's aggregation operation e.g. mean, sum, etc.
438    /// * `device_ids` - The list of device ids from which to `all_reduce`.
439    ///
440    /// # Returns
441    ///
442    /// Returns a `Result` containing an `ServerError` if the operation fails.
443    #[allow(unused_variables)]
444    fn all_reduce(
445        &mut self,
446        src: Binding,
447        dst: Binding,
448        dtype: ElemType,
449        stream_id: StreamId,
450        op: ReduceOperation,
451        device_ids: Vec<DeviceId>,
452    ) -> Result<(), ServerError> {
453        unimplemented!()
454    }
455
456    /// Copies data from a source server to a destination server.
457    ///
458    /// # Arguments
459    ///
460    /// * `server_src` - A mutable reference to the source server from which data is copied.
461    /// * `server_dst` - A mutable reference to the destination server receiving the data.
462    /// * `src` - A descriptor specifying the data to be copied, including shape, strides, and binding.
463    /// * `stream_id_src` - The stream ID associated with the source server's operation.
464    /// * `stream_id_dst` - The stream ID associated with the destination server's operation.
465    ///
466    /// # Returns
467    ///
468    /// Returns a `Result` containing an `Allocation` on success, or an `IoError` if the operation fails.
469    ///
470    /// # Panics
471    ///
472    /// Panics if server communication is not enabled (`SERVER_COMM_ENABLED` is `false`) or if the
473    /// trait is incorrectly implemented by the server.
474    #[allow(unused_variables)]
475    fn copy(
476        handle_dst: Handle,
477        server_src: &mut Self,
478        server_dst: &mut Self,
479        src: CopyDescriptor,
480        stream_id_src: StreamId,
481        stream_id_dst: StreamId,
482    ) -> Result<(), ServerError> {
483        if !Self::SERVER_COMM_ENABLED {
484            panic!("Server-to-server communication is not supported by this server.");
485        } else {
486            panic!(
487                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
488            );
489        }
490    }
491}
492
493#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
494/// Profiling identification so that the server can support recursive and overlapping profilings.
495pub struct ProfilingToken {
496    /// The token value.
497    pub id: u64,
498}
499
500/// Type of allocation, either contiguous or optimized (row-aligned when possible)
501#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
502pub enum MemoryLayoutStrategy {
503    /// Contiguous layout, with no padding
504    Contiguous,
505    /// Optimized for access speed. In practice this means row-aligned with padding for runtimes
506    /// that support it.
507    Optimized,
508}
509
510/// Descriptor for a new tensor allocation
511#[derive(new, Debug, Clone)]
512pub struct MemoryLayoutDescriptor {
513    /// Strategy used to create the memory layout.
514    pub strategy: MemoryLayoutStrategy,
515    /// Shape of the tensor
516    pub shape: Shape,
517    /// Size of each element in the tensor (used for conversion of shape to bytes)
518    pub elem_size: usize,
519}
520
521impl MemoryLayoutDescriptor {
522    /// Create an optimized allocation descriptor
523    pub fn optimized(shape: Shape, elem_size: usize) -> Self {
524        MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size)
525    }
526
527    /// Create a contiguous allocation descriptor
528    pub fn contiguous(shape: Shape, elem_size: usize) -> Self {
529        MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, elem_size)
530    }
531}
532
533/// An allocation with associated strides. Strides depend on tensor layout.
534#[derive(Debug, Clone)]
535pub struct MemoryLayout {
536    /// The handle for the memory resource
537    pub memory: Handle,
538    /// TODO: `Strides` should become `Layout`.
539    ///
540    /// The strides of the tensor
541    pub strides: Strides,
542}
543
544impl MemoryLayout {
545    /// Create a new memory layout.
546    pub fn new(handle: Handle, strides: impl Into<Strides>) -> Self {
547        MemoryLayout {
548            memory: handle,
549            strides: strides.into(),
550        }
551    }
552}
553
554/// A reason for an error.
555#[derive(Default, Clone)]
556pub struct Reason {
557    inner: ReasonInner,
558}
559
560#[cfg(std_io)]
561mod _reason_serde {
562    use super::*;
563
564    use alloc::string::ToString;
565    use serde::{Deserialize, Deserializer, Serialize, Serializer};
566
567    impl Serialize for Reason {
568        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
569        where
570            S: Serializer,
571        {
572            // Use the Display implementation (via to_string) to flatten the enum
573            serializer.serialize_str(&self.to_string())
574        }
575    }
576
577    impl<'de> Deserialize<'de> for Reason {
578        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
579        where
580            D: Deserializer<'de>,
581        {
582            // Deserialize into a standard String first
583            let s = String::deserialize(deserializer)?;
584
585            // Wrap it in the Dynamic variant since we can't safely
586            // reconstruct a 'static str from a runtime string.
587            Ok(Reason {
588                inner: ReasonInner::Dynamic(Arc::new(s)),
589            })
590        }
591    }
592}
593
594#[derive(Default, Clone)]
595enum ReasonInner {
596    Static(&'static str),
597    Dynamic(Arc<String>),
598    #[default]
599    NotProvided,
600}
601
602impl core::fmt::Display for Reason {
603    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
604        match &self.inner {
605            ReasonInner::Static(content) => f.write_str(content),
606            ReasonInner::Dynamic(content) => f.write_str(content),
607            ReasonInner::NotProvided => f.write_str("No reason provided for the error"),
608        }
609    }
610}
611
612impl core::fmt::Debug for Reason {
613    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
614        core::fmt::Display::fmt(&self, f)
615    }
616}
617
618impl From<&'static str> for Reason {
619    fn from(value: &'static str) -> Self {
620        Self {
621            inner: ReasonInner::Static(value),
622        }
623    }
624}
625
626impl From<String> for Reason {
627    fn from(value: String) -> Self {
628        Self {
629            inner: ReasonInner::Dynamic(Arc::new(value)),
630        }
631    }
632}
633
634/// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors
635/// are able to be caught, so some IO errors will still panic.
636#[derive(Error, Clone)]
637#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
638pub enum IoError {
639    /// Buffer size exceeds the max available
640    #[error("can't allocate buffer of size: {size}\n{backtrace}")]
641    BufferTooBig {
642        /// The size of the buffer in bytes.
643        size: u64,
644        /// The captured backtrace.
645        #[cfg_attr(std_io, serde(skip))]
646        backtrace: BackTrace,
647    },
648
649    /// Strides aren't supported for this copy operation on this runtime
650    #[error("the provided strides are not supported for this operation\n{backtrace}")]
651    UnsupportedStrides {
652        /// The backtrace.
653        #[cfg_attr(std_io, serde(skip))]
654        backtrace: BackTrace,
655    },
656
657    /// Memory wasn't found in the memory pool
658    #[error("couldn't find resource for that handle: {reason}\n{backtrace}")]
659    NotFound {
660        /// The backtrace.
661        #[cfg_attr(std_io, serde(skip))]
662        backtrace: BackTrace,
663        /// The reason the handle is invalid.
664        reason: Reason,
665    },
666
667    /// Handle wasn't found in the memory pool
668    #[error("couldn't free the handle, since it is currently in used. \n{backtrace}")]
669    FreeError {
670        /// The backtrace.
671        #[cfg_attr(std_io, serde(skip))]
672        backtrace: BackTrace,
673    },
674
675    /// Unknown error happened during execution
676    #[error("Unknown error happened during execution\n{backtrace}")]
677    Unknown {
678        /// Details of the error
679        description: String,
680        /// The backtrace.
681        #[cfg_attr(std_io, serde(skip))]
682        backtrace: BackTrace,
683    },
684
685    /// The current IO operation is not supported
686    #[error("The current IO operation is not supported\n{backtrace}")]
687    UnsupportedIoOperation {
688        /// The backtrace.
689        #[cfg_attr(std_io, serde(skip))]
690        backtrace: BackTrace,
691    },
692
693    /// Can't perform the IO operation because of a runtime error.
694    #[error("Can't perform the IO operation because of a runtime error: {0}")]
695    Execution(#[from] Box<ServerError>),
696}
697
698impl core::fmt::Debug for IoError {
699    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
700        f.write_fmt(format_args!("{self}"))
701    }
702}
703
704/// Arguments to execute a kernel.
705#[derive(Debug, Default)]
706pub struct KernelArguments {
707    /// Buffer bindings
708    pub buffers: Vec<Binding>,
709    /// Packed scalars and metadata. First scalars sorted by type, then static metadata,
710    /// then dynamic metadata.
711    pub info: MetadataBindingInfo,
712    /// Tensor map bindings
713    pub tensor_maps: Vec<TensorMapBinding>,
714}
715
716impl core::fmt::Display for KernelArguments {
717    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
718        f.write_str("KernelArguments")?;
719        for b in self.buffers.iter() {
720            f.write_fmt(format_args!("\n - buffer: {b:?}\n"))?;
721        }
722
723        Ok(())
724    }
725}
726
727impl KernelArguments {
728    /// Create a new bindings struct
729    pub fn new() -> Self {
730        Self::default()
731    }
732
733    /// Add a buffer binding
734    pub fn with_buffer(mut self, binding: Binding) -> Self {
735        self.buffers.push(binding);
736        self
737    }
738
739    /// Extend the buffers with `bindings`
740    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
741        self.buffers.extend(bindings);
742        self
743    }
744
745    /// Set the info to `info`
746    pub fn with_info(mut self, info: MetadataBindingInfo) -> Self {
747        self.info = info;
748        self
749    }
750
751    /// Extend the tensor maps with `bindings`
752    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
753        self.tensor_maps.extend(bindings);
754        self
755    }
756}
757
758/// Binding of a set of scalars of the same type to execute a kernel.
759///
760/// The [`ComputeServer`] is responsible to convert those info into actual [`Binding`] when launching
761/// kernels.
762#[derive(new, Debug, Default)]
763pub struct MetadataBindingInfo {
764    /// Scalar and metadata values
765    pub data: Vec<u64>,
766    /// Start of the dynamically sized portion of the metadata, relative to the entire info buffer
767    pub dynamic_metadata_offset: usize,
768}
769
770impl MetadataBindingInfo {
771    /// Create a new binding info for custom data, for externally compiled kernels.
772    pub fn custom(data: Vec<u64>) -> Self {
773        Self::new(data, 0)
774    }
775}
776
777/// A binding with shape and stride info for non-contiguous reading
778#[derive(new, Debug)]
779pub struct CopyDescriptor {
780    /// Binding for the memory resource
781    pub handle: Binding,
782    /// Shape of the resource
783    pub shape: Shape,
784    /// Strides of the resource
785    pub strides: Strides,
786    /// Size of each element in the resource
787    pub elem_size: usize,
788}
789
790/// A tensor map used with TMA ops
791#[derive(new, Debug)]
792pub struct TensorMapBinding {
793    /// The binding for the backing tensor
794    pub binding: Binding,
795    /// The tensormap metadata
796    pub map: TensorMapMeta,
797}
798
799/// `TensorMap` metadata for the opaque proxy used in TMA copies
800#[derive(Debug, Clone)]
801pub struct TensorMapMeta {
802    /// Tensormap format (tiled or im2col)
803    pub format: TensorMapFormat,
804    /// Metadata of the backing tensor
805    pub metadata: Metadata,
806    /// Element stride, usually 1 but may be 2 for complex tensors
807    /// For im2col, this is equivalent to the kernel stride
808    pub elem_stride: Strides,
809    /// Interleave mode
810    pub interleave: TensorMapInterleave,
811    /// Swizzle mode
812    pub swizzle: TensorMapSwizzle,
813    /// Prefetch settings
814    pub prefetch: TensorMapPrefetch,
815    /// OOB fill value
816    pub oob_fill: OobFill,
817    /// Storage type
818    pub storage_ty: StorageType,
819}
820
821/// Specifieds the number of cubes to be dispatched for a kernel.
822///
823/// This translates to eg. a grid for CUDA, or to `num_workgroups` for wgsl.
824#[allow(clippy::large_enum_variant)]
825pub enum CubeCount {
826    /// Dispatch a known count of x, y, z cubes.
827    Static(u32, u32, u32),
828    /// Dispatch an amount based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
829    Dynamic(Binding),
830}
831
832/// Defines how to select cube count based on the number of cubes required.
833pub enum CubeCountSelection {
834    /// If the number of cubes is the same as required.
835    Exact(CubeCount),
836    /// If the number of cubes isn't the same as required.
837    ///
838    /// This can happen based on the hardware limit, requiring the kernel to perform OOB checks.
839    Approx(CubeCount, u32),
840}
841
842impl CubeCountSelection {
843    /// Creates a [`CubeCount`] while respecting the hardware limits.
844    pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
845        let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
846
847        let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
848        let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
849
850        match num_cubes_actual == num_cubes {
851            true => CubeCountSelection::Exact(cube_count),
852            false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
853        }
854    }
855
856    /// If some cubes will be idle.
857    pub fn has_idle(&self) -> bool {
858        matches!(self, Self::Approx(..))
859    }
860
861    /// Converts into [`CubeCount`].
862    pub fn cube_count(self) -> CubeCount {
863        match self {
864            CubeCountSelection::Exact(cube_count) => cube_count,
865            CubeCountSelection::Approx(cube_count, _) => cube_count,
866        }
867    }
868}
869
870impl From<CubeCountSelection> for CubeCount {
871    fn from(value: CubeCountSelection) -> Self {
872        value.cube_count()
873    }
874}
875
876impl CubeCount {
877    /// Create a new static cube count with the given x = y = z = 1.
878    pub fn new_single() -> Self {
879        CubeCount::Static(1, 1, 1)
880    }
881
882    /// Create a new static cube count with the given x, and y = z = 1.
883    pub fn new_1d(x: u32) -> Self {
884        CubeCount::Static(x, 1, 1)
885    }
886
887    /// Create a new static cube count with the given x and y, and z = 1.
888    pub fn new_2d(x: u32, y: u32) -> Self {
889        CubeCount::Static(x, y, 1)
890    }
891
892    /// Create a new static cube count with the given x, y and z.
893    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
894        CubeCount::Static(x, y, z)
895    }
896
897    /// Checks whether the cube count is definitely empty, i.e. has 0 dispatches.
898    pub fn is_empty(&self) -> bool {
899        match self {
900            Self::Static(x, y, z) => *x == 0 || *y == 0 || *z == 0,
901            Self::Dynamic(_) => false,
902        }
903    }
904}
905
906impl Debug for CubeCount {
907    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
908        match self {
909            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
910            CubeCount::Dynamic(_) => f.write_str("binding"),
911        }
912    }
913}
914
915impl Clone for CubeCount {
916    fn clone(&self) -> Self {
917        match self {
918            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
919            Self::Dynamic(binding) => Self::Dynamic(binding.clone()),
920        }
921    }
922}
923
924#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
925#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
926#[allow(missing_docs)]
927/// The number of units across all 3 axis totalling to the number of working units in a cube.
928pub struct CubeDim {
929    /// The number of units in the x axis.
930    pub x: u32,
931    /// The number of units in the y axis.
932    pub y: u32,
933    /// The number of units in the z axis.
934    pub z: u32,
935}
936
937impl CubeDim {
938    /// Creates a new [`CubeDim`] based on the maximum number of tasks that can be parellalized by units, in other words,
939    /// by the maximum number of working units.
940    ///
941    /// # Notes
942    ///
943    /// For complex problems, you probably want to have your own logic function to create the
944    /// [`CubeDim`], but for simpler problems such as elemwise-operation, this is a great default.
945    pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
946        let properties = client.properties();
947        let plane_size = properties.hardware.plane_size_max;
948        let plane_count = Self::calculate_plane_count_per_cube(
949            working_units as u32,
950            plane_size,
951            properties.hardware.num_cpu_cores,
952        );
953
954        // Make sure it respects the max units per cube (especially on wasm)
955        let limit = properties.hardware.max_units_per_cube / plane_size;
956
957        // Ensure at least 1 plane so CubeDim is always valid (num_elems() > 0).
958        Self::new_2d(plane_size, u32::min(limit, plane_count).max(1))
959    }
960
961    fn calculate_plane_count_per_cube(
962        working_units: u32,
963        plane_dim: u32,
964        num_cpu_cores: Option<u32>,
965    ) -> u32 {
966        match num_cpu_cores {
967            Some(num_cores) => core::cmp::min(num_cores, working_units),
968            None => {
969                let plane_count_max = core::cmp::max(1, working_units / plane_dim);
970
971                // Ensures `plane_count` is a power of 2.
972                const NUM_PLANE_MAX: u32 = 8u32;
973                const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
974                let plane_count_max_log2 =
975                    core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
976                2u32.pow(plane_count_max_log2)
977            }
978        }
979    }
980
981    /// Create a new cube dim with x = y = z = 1.
982    pub const fn new_single() -> Self {
983        Self { x: 1, y: 1, z: 1 }
984    }
985
986    /// Create a new cube dim with the given x, and y = z = 1.
987    pub const fn new_1d(x: u32) -> Self {
988        Self { x, y: 1, z: 1 }
989    }
990
991    /// Create a new cube dim with the given x and y, and z = 1.
992    pub const fn new_2d(x: u32, y: u32) -> Self {
993        Self { x, y, z: 1 }
994    }
995
996    /// Create a new cube dim with the given x, y and z.
997    /// This is equivalent to the [new](CubeDim::new) function.
998    pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
999        Self { x, y, z }
1000    }
1001
1002    /// Total numbers of units per cube
1003    pub const fn num_elems(&self) -> u32 {
1004        self.x * self.y * self.z
1005    }
1006
1007    /// Whether this `CubeDim` can fully contain `other`
1008    pub const fn can_contain(&self, other: CubeDim) -> bool {
1009        self.x >= other.x && self.y >= other.y && self.z >= other.z
1010    }
1011}
1012
1013impl From<(u32, u32, u32)> for CubeDim {
1014    fn from(value: (u32, u32, u32)) -> Self {
1015        CubeDim::new_3d(value.0, value.1, value.2)
1016    }
1017}
1018
1019impl From<CubeDim> for (u32, u32, u32) {
1020    fn from(val: CubeDim) -> Self {
1021        (val.x, val.y, val.z)
1022    }
1023}
1024
1025/// The kind of execution to be performed.
1026#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
1027#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
1028pub enum ExecutionMode {
1029    /// Checked kernels are safe.
1030    #[default]
1031    Checked,
1032    /// Validate OOB and alert if OOB access occurs
1033    Validate,
1034    /// Unchecked kernels are unsafe.
1035    Unchecked,
1036}
1037
1038fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
1039    let max_cube_counts = [max.0, max.1, max.2];
1040    let mut num_cubes = [num_cubes, 1, 1];
1041    let base = 2;
1042
1043    let mut reduce_count = |i: usize| {
1044        if num_cubes[i] <= max_cube_counts[i] {
1045            return true;
1046        }
1047
1048        loop {
1049            num_cubes[i] = num_cubes[i].div_ceil(base);
1050            num_cubes[i + 1] *= base;
1051
1052            if num_cubes[i] <= max_cube_counts[i] {
1053                return false;
1054            }
1055        }
1056    };
1057
1058    for i in 0..2 {
1059        if reduce_count(i) {
1060            break;
1061        }
1062    }
1063
1064    num_cubes
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070
1071    #[test_log::test]
1072    fn safe_num_cubes_even() {
1073        let max = (32, 32, 32);
1074        let required = 2048;
1075
1076        let actual = cube_count_spread(&max, required);
1077        let expected = [32, 32, 2];
1078        assert_eq!(actual, expected);
1079    }
1080
1081    #[test_log::test]
1082    fn safe_num_cubes_odd() {
1083        let max = (48, 32, 16);
1084        let required = 3177;
1085
1086        let actual = cube_count_spread(&max, required);
1087        let expected = [25, 32, 4];
1088        assert_eq!(actual, expected);
1089    }
1090}