cubecl_std/swizzle.rs
1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4/// Swizzling strategy for a buffer.
5/// See the following docs from cutlass:
6///
7/// 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
8/// ^--^ MBase is the number of least-sig bits to keep constant
9/// ^-^ ^-^ BBits is the number of bits in the mask
10/// ^---------^ SShift is the distance to shift the YYY mask
11/// (pos shifts YYY to the right, neg shifts YYY to the left)
12///
13/// # Example
14/// Given:
15/// 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
16/// the result is:
17/// 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
18///
19///
20/// Some newer features, as well as cutlass in places, use a different terminology of `span` and
21/// `atom`. For shared memory swizzle specifically, the parameters map as follows:
22/// * `bits` = `log2(span / atom)`, or the number of atoms within one span, converted to address bits
23/// * `base` = `log2(atom)`, the size of the atom, converted to address bits
24/// * `shift` = `log2(all_banks_bytes / atom)`, or the total number of atoms in all 32 shared memory banks, converted to address bits
25///
26/// For example:
27/// * 32-byte span with a 16-byte atom = `[1, 4, 3]`
28/// * 128-byte span with a 32-byte atom = `[3, 5, 2]`
29///
30#[derive(CubeType, CubeLaunch, Clone, Copy)]
31pub struct Swizzle {
32 #[cube(comptime)]
33 yyy_mask: u32,
34 #[cube(comptime)]
35 shift: u32,
36 #[cube(comptime)]
37 invert_shift: bool,
38 /// Precalculate repeat after so we don't need to keep all the parts
39 #[cube(comptime)]
40 repeats_after: u32,
41}
42
43#[cube]
44impl Swizzle {
45 /// Create a new swizzle with comptime parameters
46 pub fn new(#[comptime] bits: u32, #[comptime] base: u32, #[comptime] shift: i32) -> Self {
47 let invert_shift = shift < 0;
48 let mask = (1u32 << bits) - 1;
49 let yyy_mask = comptime![mask << (base + Ord::max(shift, 0) as u32)];
50 let repeats_after = comptime![if bits > 0 {
51 1u32 << (base + bits + Ord::max(shift, 0) as u32)
52 } else {
53 1u32 << base
54 }];
55 Swizzle {
56 yyy_mask,
57 shift: comptime![shift.unsigned_abs()],
58 invert_shift,
59 repeats_after,
60 }
61 }
62
63 /// Create a new noop swizzle object
64 pub fn none() -> Self {
65 Swizzle {
66 yyy_mask: 0u32,
67 shift: 0u32,
68 invert_shift: false,
69 repeats_after: 1u32,
70 }
71 }
72
73 /// Apply the swizzle to a coordinate with a given item size. This is the size of the full type,
74 /// including line size. Use `type_size` helper for lines.
75 /// `offset` should be in terms of lines from the start of the buffer, and the buffer should be
76 /// aligned to `repeats_after`. This is to work around the fact we don't currently support
77 /// retrieving the actual address of an offset.
78 /// If you're using absolute/unlined indices, pass `E::type_size()` instead of the full line size.
79 pub fn apply(&self, offset: u32, #[comptime] type_size: u32) -> u32 {
80 // Special case here so we don't need to special case in kernels that can have no swizzle.
81 // If `yyy_mask == 0`, the whole thing is a noop.
82 if comptime![self.yyy_mask == 0] {
83 offset
84 } else {
85 let offset_bytes = offset * type_size;
86 let offset_masked = offset_bytes & self.yyy_mask;
87 let offset_shifted =
88 shift_right(offset_masked, self.shift, comptime![self.invert_shift]);
89 let offset_bytes = offset_bytes ^ offset_shifted;
90 offset_bytes / type_size
91 }
92 }
93
94 /// After how many elements this pattern repeats. Can be used to align the buffer (i.e. smem)
95 /// so offsets match addresses.
96 pub fn repeats_after(&self) -> comptime_type!(u32) {
97 self.repeats_after
98 }
99}
100
101/// Retrieve the type size of a lined buffer.
102#[cube]
103pub fn type_size<E: CubePrimitive>(#[comptime] line_size: u32) -> comptime_type!(u32) {
104 let storage_size = E::type_size();
105 comptime![storage_size * line_size]
106}
107
108#[cube]
109fn shift_right(value: u32, shift: u32, #[comptime] invert: bool) -> u32 {
110 if invert {
111 value << shift
112 } else {
113 value >> shift
114 }
115}