cubecl_common/
kernel.rs

1use serde::{Deserialize, Serialize};
2
3/// An approximation of the plane dimension.
4pub const PLANE_DIM_APPROX: usize = 16;
5
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[derive(new, Debug, PartialEq, Eq, Clone, Copy, Hash)]
8#[allow(missing_docs)]
9pub struct CubeDim {
10    pub x: u32,
11    pub y: u32,
12    pub z: u32,
13}
14
15impl CubeDim {
16    /// Create a new cube dim with x = y = z = 1.
17    pub const fn new_single() -> Self {
18        Self { x: 1, y: 1, z: 1 }
19    }
20
21    /// Create a new cube dim with the given x, and y = z = 1.
22    pub const fn new_1d(x: u32) -> Self {
23        Self { x, y: 1, z: 1 }
24    }
25
26    /// Create a new cube dim with the given x and y, and z = 1.
27    pub const fn new_2d(x: u32, y: u32) -> Self {
28        Self { x, y, z: 1 }
29    }
30
31    /// Create a new cube dim with the given x, y and z.
32    /// This is equivalent to the [new](CubeDim::new) function.
33    pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
34        Self { x, y, z }
35    }
36
37    /// Total numbers of units per cube
38    pub const fn num_elems(&self) -> u32 {
39        self.x * self.y * self.z
40    }
41}
42
43impl Default for CubeDim {
44    fn default() -> Self {
45        Self {
46            x: PLANE_DIM_APPROX as u32,
47            y: PLANE_DIM_APPROX as u32,
48            z: 1,
49        }
50    }
51}
52
53/// The kind of execution to be performed.
54#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
55pub enum ExecutionMode {
56    /// Checked kernels are safe.
57    #[default]
58    Checked,
59    /// Unchecked kernels are unsafe.
60    Unchecked,
61}