cubecl_runtime/
tma.rs

1use alloc::vec::Vec;
2#[cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))]
3use serde::{Deserialize, Serialize};
4
5/// Format of [`TensorMap`]
6#[cfg_attr(
7    any(target_os = "windows", target_os = "linux", target_os = "macos"),
8    derive(Serialize, Deserialize)
9)]
10#[derive(Hash, PartialEq, Eq, Clone, Debug)]
11pub enum TensorMapFormat {
12    /// Simple tiling
13    Tiled {
14        /// Tile size that's loaded from memory in each copy operation. Must have `rank` elements.
15        /// In matmul, for example, this might be `batch x m x k`, or whatever the stage size is.
16        /// If a dimension isn't present in the tile, it should just be set to `1`.
17        ///
18        /// For CUDA, this must be a power of two and `<= 256` on each dimension.
19        tile_size: Vec<u32>,
20    },
21    /// Im2col indexing. Loads a "column" (not the same column as im2col) of pixels into shared
22    /// memory, with a certain offset (kernel position). The corners are the bounds to load pixels
23    /// from *at offset 0*, so the top left corner of the kernel. The offset is added to the
24    /// corner offsets, so a `(-1, -1)` corner will stop the bounding box at `(1, 1)` for kernel
25    /// offset `(2, 2)`.
26    Im2col {
27        /// Pixel box lower corner. This is the logical upper left corner in the input tensor,
28        /// when offset is 0. The length of this value should equal the *spatial* dimensions of
29        /// the input tensor (i.e. `h, w` for an NHWC tensor). Should normally be set to `-padding`.
30        pixel_box_lower_corner: Vec<i32>,
31        /// Pixel box top corner. This is the logical lower right corner in the input tensor,
32        /// when offset is 0. The length of this value should equal the *spatial* dimensions of
33        /// the input tensor (i.e. `h, w` for an NHWC tensor). Should normally be set to
34        /// `padding - kernel_size - 1` (where `kernel_size` accounts for dilation). This is not
35        /// equal to padding, it's equal to the bounding box for the *top left corner of the kernel*.
36        pixel_box_upper_corner: Vec<i32>,
37        /// Channels to load per pixel, should be a multiple or divisor of the matmul tile size.
38        /// This is not the total number of channels in the tensor, but only the number loaded in
39        /// each load. Must be <= 256 and aligned to 16 bytes.
40        channels_per_pixel: u32,
41        /// Pixels per column, equivalent to the `m`/`n` dimension of each tile in the matrix
42        /// multiplication. i.e. `NHW` for a 4D tensor.
43        /// Must be <= 256 and aligned to 16 bytes
44        pixels_per_column: u32,
45    },
46    /// Wide im2col
47    Im2colWide {
48        /// Pixel box lower corner width. TODO: How does this work?
49        pixel_box_lower_corner_width: i32,
50        /// Pixel box upper corner width. TODO: How does this work?
51        pixel_box_upper_corner_width: i32,
52        /// Channels per pixel
53        channels_per_pixel: u32,
54        /// Pixels per column
55        pixels_per_column: u32,
56    },
57}
58
59/// Interleave setting for [`TensorMap`]
60#[cfg_attr(
61    any(target_os = "windows", target_os = "linux", target_os = "macos"),
62    derive(Serialize, Deserialize)
63)]
64#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
65pub enum TensorMapInterleave {
66    /// No interleaving
67    #[default]
68    None,
69    /// Interleaved with 16 bytes chunks in the last dim.
70    /// i.e. NC/8HWC8 with f16
71    B16,
72    /// Interleaved with 32 bytes chunks in the last dim.
73    /// i.e. NC/16HWC16 with f16
74    B32,
75}
76
77/// Data are organized in a specific order in global memory; however, this may not match the order
78/// in which the application accesses data in shared memory. This difference in data organization
79/// may cause bank conflicts when shared memory is accessed. In order to avoid this problem, data
80/// can be loaded to shared memory with shuffling across shared memory banks. When interleave is
81/// [`TensorMapInterleave::B32`], swizzle must be [`TensorMapSwizzle::B32`].
82/// Other interleave modes can have any swizzling pattern.
83#[cfg_attr(
84    any(target_os = "windows", target_os = "linux", target_os = "macos"),
85    derive(Serialize, Deserialize)
86)]
87#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
88pub enum TensorMapSwizzle {
89    /// No swizzling
90    #[default]
91    None,
92    /// Swizzle 16B chunks within 32B span
93    B32,
94    /// Swizzle 16B chunks within 64B span
95    B64,
96    /// Swizzle 16B chunks within 128B span
97    B128,
98}
99
100/// Additional prefetching to perform during load
101/// Specifies L2 fetch size which indicates the byte granularity at which L2 requests are filled from DRAM
102#[cfg_attr(
103    any(target_os = "windows", target_os = "linux", target_os = "macos"),
104    derive(Serialize, Deserialize)
105)]
106#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
107pub enum TensorMapPrefetch {
108    /// No extra prefetch
109    #[default]
110    None,
111    /// Prefetch 64 bytes
112    B64,
113    /// Prefetch 128 bytes
114    B128,
115    /// Prefetch 256 bytes
116    B256,
117}
118
119/// What value to use when filling out of bounds values
120#[cfg_attr(
121    any(target_os = "windows", target_os = "linux", target_os = "macos"),
122    derive(Serialize, Deserialize)
123)]
124#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
125pub enum OobFill {
126    /// Fill zeroes
127    #[default]
128    Zero,
129    /// Fill NaN
130    NaN,
131}