cubecl_std/
swizzle.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4/// Swizzling strategy for a buffer.
5/// See the following docs from cutlass:
6///
7/// 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
8///                                ^--^ MBase is the number of least-sig bits to keep constant
9///                   ^-^       ^-^     BBits is the number of bits in the mask
10///                     ^---------^     SShift is the distance to shift the YYY mask
11///                                        (pos shifts YYY to the right, neg shifts YYY to the left)
12///
13/// # Example
14/// Given:
15/// 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
16/// the result is:
17/// 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
18///
19///
20/// Some newer features, as well as cutlass in places, use a different terminology of `span` and
21/// `atom`. For shared memory swizzle specifically, the parameters map as follows:
22/// * `bits` = `log2(span / atom)`, or the number of atoms within one span, converted to address bits
23/// * `base` = `log2(atom)`, the size of the atom, converted to address bits
24/// * `shift` = `log2(all_banks_bytes / atom)`, or the total number of atoms in all 32 shared memory banks, converted to address bits
25///
26/// For example:
27/// * 32-byte span with a 16-byte atom = `[1, 4, 3]`
28/// * 128-byte span with a 32-byte atom = `[3, 5, 2]`
29///
30#[derive(CubeType, CubeLaunch, Clone, Copy)]
31pub struct Swizzle {
32    #[cube(comptime)]
33    yyy_mask: u32,
34    #[cube(comptime)]
35    shift: u32,
36    #[cube(comptime)]
37    invert_shift: bool,
38    /// Precalculate repeat after so we don't need to keep all the parts
39    #[cube(comptime)]
40    repeats_after: u32,
41}
42
43#[cube]
44impl Swizzle {
45    /// Create a new swizzle with comptime parameters
46    pub fn new(#[comptime] bits: u32, #[comptime] base: u32, #[comptime] shift: i32) -> Self {
47        let invert_shift = shift < 0;
48        let mask = (1u32 << bits) - 1;
49        let yyy_mask = comptime![mask << (base + Ord::max(shift, 0) as u32)];
50        let repeats_after = comptime![if bits > 0 {
51            1u32 << (base + bits + Ord::max(shift, 0) as u32)
52        } else {
53            1u32 << base
54        }];
55        Swizzle {
56            yyy_mask,
57            shift: comptime![shift.unsigned_abs()],
58            invert_shift,
59            repeats_after,
60        }
61    }
62
63    /// Create a new noop swizzle object
64    pub fn none() -> Self {
65        Swizzle {
66            yyy_mask: 0u32,
67            shift: 0u32,
68            invert_shift: false,
69            repeats_after: 1u32,
70        }
71    }
72
73    /// Apply the swizzle to a coordinate with a given item size. This is the size of the full type,
74    /// including line size. Use `type_size` helper for lines.
75    /// `offset` should be in terms of lines from the start of the buffer, and the buffer should be
76    /// aligned to `repeats_after`. This is to work around the fact we don't currently support
77    /// retrieving the actual address of an offset.
78    /// If you're using absolute/unlined indices, pass `E::type_size()` instead of the full line size.
79    pub fn apply(&self, offset: u32, #[comptime] type_size: u32) -> u32 {
80        // Special case here so we don't need to special case in kernels that can have no swizzle.
81        // If `yyy_mask == 0`, the whole thing is a noop.
82        if comptime![self.yyy_mask == 0] {
83            offset
84        } else {
85            let offset_bytes = offset * type_size;
86            let offset_masked = offset_bytes & self.yyy_mask;
87            let offset_shifted =
88                shift_right(offset_masked, self.shift, comptime![self.invert_shift]);
89            let offset_bytes = offset_bytes ^ offset_shifted;
90            offset_bytes / type_size
91        }
92    }
93
94    /// After how many elements this pattern repeats. Can be used to align the buffer (i.e. smem)
95    /// so offsets match addresses.
96    pub fn repeats_after(&self) -> comptime_type!(u32) {
97        self.repeats_after
98    }
99}
100
101/// Retrieve the type size of a lined buffer.
102#[cube]
103pub fn type_size<E: CubePrimitive>(#[comptime] line_size: u32) -> comptime_type!(u32) {
104    let storage_size = E::type_size();
105    comptime![storage_size * line_size]
106}
107
108#[cube]
109fn shift_right(value: u32, shift: u32, #[comptime] invert: bool) -> u32 {
110    if invert {
111        value << shift
112    } else {
113        value >> shift
114    }
115}