cubecl_runtime/tma.rs
1use alloc::vec::Vec;
2use cubecl_zspace::Shape;
3#[cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))]
4use serde::{Deserialize, Serialize};
5
6/// Args for tiled tensor maps
7#[cfg_attr(
8 any(target_os = "windows", target_os = "linux", target_os = "macos"),
9 derive(Serialize, Deserialize)
10)]
11#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
12pub struct TiledArgs {
13 /// Tile size that's loaded from memory in each copy operation. Must have `rank` elements.
14 /// In matmul, for example, this might be `batch x m x k`, or whatever the stage size is.
15 /// If a dimension isn't present in the tile, it should just be set to `1`.
16 ///
17 /// For CUDA, this must be a power of two and `<= 256` on each dimension.
18 pub tile_size: Shape,
19}
20
21/// Args for im2col tensor maps
22#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
23pub struct Im2colArgs {
24 /// Pixel box lower corner. This is the logical upper left corner in the input tensor,
25 /// when offset is 0. The length of this value should equal the *spatial* dimensions of
26 /// the input tensor (i.e. `h, w` for an NHWC tensor). Should normally be set to `-padding`.
27 pub pixel_box_lower_corner: Vec<i32>,
28 /// Pixel box top corner. This is the logical lower right corner in the input tensor,
29 /// when offset is 0. The length of this value should equal the *spatial* dimensions of
30 /// the input tensor (i.e. `h, w` for an NHWC tensor). Should normally be set to
31 /// `padding - kernel_size - 1` (where `kernel_size` accounts for dilation). This is not
32 /// equal to padding, it's equal to the bounding box for the *top left corner of the kernel*.
33 pub pixel_box_upper_corner: Vec<i32>,
34 /// Channels to load per pixel, should be a multiple or divisor of the matmul tile size.
35 /// This is not the total number of channels in the tensor, but only the number loaded in
36 /// each load. Must be <= 256 and aligned to 16 bytes.
37 pub channels_per_pixel: u32,
38 /// Pixels per column, equivalent to the `m`/`n` dimension of each tile in the matrix
39 /// multiplication. i.e. `NHW` for a 4D tensor.
40 /// Must be <= 256 and aligned to 16 bytes
41 pub pixels_per_column: u32,
42}
43
44/// Args for im2col wide tensor maps
45#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
46pub struct Im2colWideArgs {
47 /// Pixel box lower corner width. TODO: How does this work?
48 pub pixel_box_lower_corner_width: i32,
49 /// Pixel box upper corner width. TODO: How does this work?
50 pub pixel_box_upper_corner_width: i32,
51 /// Channels per pixel
52 pub channels_per_pixel: u32,
53 /// Pixels per column
54 pub pixels_per_column: u32,
55}
56
57/// Format of ``TensorMap``
58#[derive(Hash, PartialEq, Eq, Clone, Debug)]
59pub enum TensorMapFormat {
60 /// Simple tiling
61 Tiled(TiledArgs),
62 /// Im2col indexing.
63 Im2col(Im2colArgs),
64 /// Wide im2col
65 Im2colWide(Im2colWideArgs),
66}
67
68/// Interleave setting for ``TensorMap``
69#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
70pub enum TensorMapInterleave {
71 /// No interleaving
72 #[default]
73 None,
74 /// Interleaved with 16 bytes chunks in the last dim.
75 /// i.e. NC/8HWC8 with f16
76 B16,
77 /// Interleaved with 32 bytes chunks in the last dim.
78 /// i.e. NC/16HWC16 with f16
79 B32,
80}
81
82/// Data are organized in a specific order in global memory; however, this may not match the order
83/// in which the application accesses data in shared memory. This difference in data organization
84/// may cause bank conflicts when shared memory is accessed. In order to avoid this problem, data
85/// can be loaded to shared memory with shuffling across shared memory banks. When interleave is
86/// [`TensorMapInterleave::B32`], swizzle must be [`TensorMapSwizzle::B32`].
87/// Other interleave modes can have any swizzling pattern.
88#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
89pub enum TensorMapSwizzle {
90 /// No swizzling
91 #[default]
92 None,
93 /// Swizzle 16B chunks within 32B span
94 B32,
95 /// Swizzle 16B chunks within 64B span
96 B64,
97 /// Swizzle 16B chunks within 128B span
98 B128,
99 /// Swizzle 32B chunks within 128B span
100 B128Atom32B,
101 /// Swizzle 32B chunks within 128B span, additionally swap lower 8B with upper 8B within each 16B for every alternate row
102 B128Atom32BFlip8B,
103 /// Swizzle 64B chunks within 128B span
104 B128Atom64B,
105}
106
107/// Additional prefetching to perform during load
108/// Specifies L2 fetch size which indicates the byte granularity at which L2 requests are filled from DRAM
109#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
110pub enum TensorMapPrefetch {
111 /// No extra prefetch
112 #[default]
113 None,
114 /// Prefetch 64 bytes
115 B64,
116 /// Prefetch 128 bytes
117 B128,
118 /// Prefetch 256 bytes
119 B256,
120}
121
122/// What value to use when filling out of bounds values
123#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
124pub enum OobFill {
125 /// Fill zeroes
126 #[default]
127 Zero,
128 /// Fill NaN
129 NaN,
130}