cubecl_matmul/components/batch/partitioned_matmul/hypercube/
global_order.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::batch::partitioned_matmul::hypercube::base::CubeSpan;
5
6#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
7/// Describes the global traversal order as flattened cube position increases.
8///
9/// - `RowMajor`: standard row-first traversal
10/// - `ColMajor`: standard column-first traversal
11/// - `SwizzleColMajor(w)`: zigzag pattern down columns, with `w`-wide steps
12/// - `SwizzleRowMajor(w)`: zigzag pattern across rows, with `w`-wide steps
13///
14/// Special cases:
15/// - `SwizzleColMajor(1)` is equivalent to `ColMajor`
16/// - `SwizzleRowMajor(1)` is equivalent to `RowMajor`
17#[allow(clippy::enum_variant_names)]
18pub enum GlobalOrder {
19    #[default]
20    RowMajor,
21    ColMajor,
22    SwizzleRowMajor(u32),
23    SwizzleColMajor(u32),
24}
25
26impl GlobalOrder {
27    /// Since they are equivalent but the latter form will skip some calculations,
28    /// - `SwizzleColMajor(1)` becomes `ColMajor`
29    /// - `SwizzleRowMajor(1)` becomes `RowMajor`
30    pub fn canonicalize(self) -> Self {
31        match self {
32            GlobalOrder::SwizzleColMajor(1) => GlobalOrder::ColMajor,
33            GlobalOrder::SwizzleRowMajor(1) => GlobalOrder::RowMajor,
34            _ => self,
35        }
36    }
37}
38
39#[derive(Default)]
40/// Used to create [GlobalOrder].
41#[allow(unused)]
42pub enum GlobalOrderSelection {
43    /// It creates the default global order.
44    #[default]
45    Default,
46    /// Set a global order.
47    Fixed(GlobalOrder),
48    /// Creates swizzle row global order if possible.
49    ///
50    /// Fallbacks to row global order otherwise.
51    SwizzleRow { m: u32, w: u32 },
52    /// Creates swizzle col global order if possible.
53    ///
54    /// Fallbacks to col global order otherwise.
55    SwizzleCol { n: u32, w: u32 },
56}
57
58impl GlobalOrderSelection {
59    pub fn into_order(self, span: &CubeSpan) -> GlobalOrder {
60        match self {
61            GlobalOrderSelection::Default => GlobalOrder::default(),
62            GlobalOrderSelection::Fixed(order) => order,
63            GlobalOrderSelection::SwizzleRow { m, w } => {
64                let m_cubes = m.div_ceil(span.m);
65                if m_cubes % w != 0 {
66                    GlobalOrder::RowMajor
67                } else {
68                    GlobalOrder::SwizzleRowMajor(w)
69                }
70            }
71            GlobalOrderSelection::SwizzleCol { n, w } => {
72                let n_cubes = n.div_ceil(span.n);
73                if n_cubes % w != 0 {
74                    GlobalOrder::RowMajor
75                } else {
76                    GlobalOrder::SwizzleRowMajor(w)
77                }
78            }
79        }
80        .canonicalize()
81    }
82}
83
84#[cube]
85/// Maps a linear `index` to 2D zigzag coordinates `(x, y)` within horizontal or vertical strips.
86///
87/// Each strip is made of `num_steps` steps, each of length `step_length`.
88/// Strips alternate direction: even strips go top-down, odd strips bottom-up.
89/// Steps alternate direction: even steps go left-to-right, odd steps right-to-left.
90///
91/// - Prefer **odd `num_steps`** for smoother transitions between strips.
92/// - Prefer **power-of-two `step_length`** for better performance.
93///
94/// # Parameters
95/// - `index`: linear input index
96/// - `num_steps`: number of snaking steps in a strip
97/// - `step_length`: number of elements in each step (must be > 0)
98///
99/// # Returns
100/// `(x, y)` coordinates after swizzling
101pub fn swizzle(index: u32, num_steps: u32, #[comptime] step_length: u32) -> (u32, u32) {
102    comptime!(assert!(step_length > 0));
103
104    let num_elements_per_strip = num_steps * step_length;
105    let strip_index = index / num_elements_per_strip;
106    let pos_in_strip = index % num_elements_per_strip;
107    let strip_offset = step_length * strip_index;
108
109    // Indices without regards to direction
110    let abs_step_index = pos_in_strip / step_length;
111    let abs_pos_in_step = pos_in_strip % step_length;
112
113    // Top-down (0) or Bottom-up (1)
114    let strip_direction = strip_index % 2;
115    // Left-right (0) or Right-left (1)
116    let step_direction = abs_step_index % 2;
117
118    // Update indices with direction
119    let step_index =
120        strip_direction * (num_steps - abs_step_index - 1) + (1 - strip_direction) * abs_step_index;
121
122    let pos_in_step = if comptime!(step_length & (step_length - 1) == 0) {
123        abs_pos_in_step ^ (step_direction * (step_length - 1))
124    } else {
125        step_direction * (step_length - abs_pos_in_step - 1)
126            + (1 - step_direction) * abs_pos_in_step
127    };
128
129    (step_index, pos_in_step + strip_offset)
130}