cubecl_std/swizzle.rs
1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3
4/// Swizzling strategy for a buffer.
5/// See the following docs from cutlass:
6///
7/// 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
8/// ^--^ `MBase` is the number of least-sig bits to keep constant
9/// ^-^ ^-^ `BBits` is the number of bits in the mask
10/// ^---------^ `SShift` is the distance to shift the YYY mask
11/// (pos shifts YYY to the right, neg shifts YYY to the left)
12///
13/// # Example
14/// Given:
15/// 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
16/// the result is:
17/// 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
18///
19///
20/// Some newer features, as well as cutlass in places, use a different terminology of `span` and
21/// `atom`. For shared memory swizzle specifically, the parameters map as follows:
22/// * `bits` = `log2(span / atom)`, or the number of atoms within one span, converted to address bits
23/// * `base` = `log2(atom)`, the size of the atom, converted to address bits
24/// * `shift` = `log2(all_banks_bytes / atom)`, or the total number of atoms in all 32 shared memory banks, converted to address bits
25///
26/// For example:
27/// * 32-byte span with a 16-byte atom = `[1, 4, 3]`
28/// * 128-byte span with a 32-byte atom = `[3, 5, 2]`
29///
30#[derive(CubeType, CubeLaunch, Clone, Copy)]
31pub struct Swizzle {
32 #[cube(comptime)]
33 yyy_mask: u32,
34 #[cube(comptime)]
35 shift: u32,
36 #[cube(comptime)]
37 invert_shift: bool,
38 /// Precalculate repeat after so we don't need to keep all the parts
39 #[cube(comptime)]
40 repeats_after: u32,
41}
42
43#[cube]
44impl Swizzle {
45 /// Create a new swizzle with comptime parameters
46 pub fn new(#[comptime] bits: u32, #[comptime] base: u32, #[comptime] shift: i32) -> Self {
47 let invert_shift = shift < 0;
48 let mask = (1u32 << bits) - 1;
49 let yyy_mask = comptime![mask << (base + Ord::max(shift, 0) as u32)];
50 let repeats_after = comptime![if bits > 0 {
51 1u32 << (base + bits + Ord::max(shift, 0) as u32)
52 } else {
53 1u32 << base
54 }];
55 Swizzle {
56 yyy_mask,
57 shift: comptime![shift.unsigned_abs()],
58 invert_shift,
59 repeats_after,
60 }
61 }
62
63 /// Create a new noop swizzle object
64 pub fn none() -> Self {
65 Swizzle {
66 yyy_mask: 0u32,
67 shift: 0u32,
68 invert_shift: false,
69 repeats_after: 1u32,
70 }
71 }
72
73 /// Apply the swizzle to a coordinate with a given item size. This is the size of the full type,
74 /// including vectorization.
75 /// `offset` should be in terms of vectors from the start of the buffer, and the buffer should be
76 /// aligned to `repeats_after`. This is to work around the fact we don't currently support
77 /// retrieving the actual address of an offset.
78 /// If you're using absolute/unvectorized indices, pass `E::Scalar::type_size()` instead of the full
79 /// vector size.
80 pub fn apply(&self, offset: u32, #[comptime] type_size: usize) -> u32 {
81 // Special case here so we don't need to special case in kernels that can have no swizzle.
82 // If `yyy_mask == 0`, the whole thing is a noop.
83 if comptime![self.yyy_mask == 0] {
84 offset
85 } else {
86 let offset_bytes = offset * type_size as u32;
87 let offset_masked = offset_bytes & self.yyy_mask;
88 let offset_shifted =
89 shift_right(offset_masked, self.shift, comptime![self.invert_shift]);
90 let offset_bytes = offset_bytes ^ offset_shifted;
91 offset_bytes / type_size as u32
92 }
93 }
94
95 /// After how many elements this pattern repeats. Can be used to align the buffer (i.e. smem)
96 /// so offsets match addresses.
97 pub fn repeats_after(&self) -> comptime_type!(u32) {
98 self.repeats_after
99 }
100}
101
102#[cube]
103fn shift_right(value: u32, shift: u32, #[comptime] invert: bool) -> u32 {
104 if invert {
105 value << shift
106 } else {
107 value >> shift
108 }
109}