cubecl_runtime/tma.rs
1use alloc::vec::Vec;
2#[cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))]
3use serde::{Deserialize, Serialize};
4
5/// Args for tiled tensor maps
6#[cfg_attr(
7 any(target_os = "windows", target_os = "linux", target_os = "macos"),
8 derive(Serialize, Deserialize)
9)]
10#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
11pub struct TiledArgs {
12 /// Tile size that's loaded from memory in each copy operation. Must have `rank` elements.
13 /// In matmul, for example, this might be `batch x m x k`, or whatever the stage size is.
14 /// If a dimension isn't present in the tile, it should just be set to `1`.
15 ///
16 /// For CUDA, this must be a power of two and `<= 256` on each dimension.
17 pub tile_size: Vec<u32>,
18}
19
20/// Args for im2col tensor maps
21#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
22pub struct Im2colArgs {
23 /// Pixel box lower corner. This is the logical upper left corner in the input tensor,
24 /// when offset is 0. The length of this value should equal the *spatial* dimensions of
25 /// the input tensor (i.e. `h, w` for an NHWC tensor). Should normally be set to `-padding`.
26 pub pixel_box_lower_corner: Vec<i32>,
27 /// Pixel box top corner. This is the logical lower right corner in the input tensor,
28 /// when offset is 0. The length of this value should equal the *spatial* dimensions of
29 /// the input tensor (i.e. `h, w` for an NHWC tensor). Should normally be set to
30 /// `padding - kernel_size - 1` (where `kernel_size` accounts for dilation). This is not
31 /// equal to padding, it's equal to the bounding box for the *top left corner of the kernel*.
32 pub pixel_box_upper_corner: Vec<i32>,
33 /// Channels to load per pixel, should be a multiple or divisor of the matmul tile size.
34 /// This is not the total number of channels in the tensor, but only the number loaded in
35 /// each load. Must be <= 256 and aligned to 16 bytes.
36 pub channels_per_pixel: u32,
37 /// Pixels per column, equivalent to the `m`/`n` dimension of each tile in the matrix
38 /// multiplication. i.e. `NHW` for a 4D tensor.
39 /// Must be <= 256 and aligned to 16 bytes
40 pub pixels_per_column: u32,
41}
42
43/// Args for im2col wide tensor maps
44#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
45pub struct Im2colWideArgs {
46 /// Pixel box lower corner width. TODO: How does this work?
47 pub pixel_box_lower_corner_width: i32,
48 /// Pixel box upper corner width. TODO: How does this work?
49 pub pixel_box_upper_corner_width: i32,
50 /// Channels per pixel
51 pub channels_per_pixel: u32,
52 /// Pixels per column
53 pub pixels_per_column: u32,
54}
55
56/// Format of [`TensorMap`]
57#[derive(Hash, PartialEq, Eq, Clone, Debug)]
58pub enum TensorMapFormat {
59 /// Simple tiling
60 Tiled(TiledArgs),
61 /// Im2col indexing.
62 Im2col(Im2colArgs),
63 /// Wide im2col
64 Im2colWide(Im2colWideArgs),
65}
66
67/// Interleave setting for [`TensorMap`]
68#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
69pub enum TensorMapInterleave {
70 /// No interleaving
71 #[default]
72 None,
73 /// Interleaved with 16 bytes chunks in the last dim.
74 /// i.e. NC/8HWC8 with f16
75 B16,
76 /// Interleaved with 32 bytes chunks in the last dim.
77 /// i.e. NC/16HWC16 with f16
78 B32,
79}
80
81/// Data are organized in a specific order in global memory; however, this may not match the order
82/// in which the application accesses data in shared memory. This difference in data organization
83/// may cause bank conflicts when shared memory is accessed. In order to avoid this problem, data
84/// can be loaded to shared memory with shuffling across shared memory banks. When interleave is
85/// [`TensorMapInterleave::B32`], swizzle must be [`TensorMapSwizzle::B32`].
86/// Other interleave modes can have any swizzling pattern.
87#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
88pub enum TensorMapSwizzle {
89 /// No swizzling
90 #[default]
91 None,
92 /// Swizzle 16B chunks within 32B span
93 B32,
94 /// Swizzle 16B chunks within 64B span
95 B64,
96 /// Swizzle 16B chunks within 128B span
97 B128,
98 /// Swizzle 32B chunks within 128B span
99 B128Atom32B,
100 /// Swizzle 32B chunks within 128B span, additionally swap lower 8B with upper 8B within each 16B for every alternate row
101 B128Atom32BFlip8B,
102 /// Swizzle 64B chunks within 128B span
103 B128Atom64B,
104}
105
106/// Additional prefetching to perform during load
107/// Specifies L2 fetch size which indicates the byte granularity at which L2 requests are filled from DRAM
108#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
109pub enum TensorMapPrefetch {
110 /// No extra prefetch
111 #[default]
112 None,
113 /// Prefetch 64 bytes
114 B64,
115 /// Prefetch 128 bytes
116 B128,
117 /// Prefetch 256 bytes
118 B256,
119}
120
121/// What value to use when filling out of bounds values
122#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
123pub enum OobFill {
124 /// Fill zeroes
125 #[default]
126 Zero,
127 /// Fill NaN
128 NaN,
129}