cubecl_matmul/components/batch/partitioned_matmul/hypercube/
global_order.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::layout::Coords2d;
4
5use crate::components::batch::partitioned_matmul::hypercube::base::CubeSpan;
6
7#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
8/// Describes the global traversal order as flattened cube position increases.
9///
10/// - `RowMajor`: standard row-first traversal
11/// - `ColMajor`: standard column-first traversal
12/// - `SwizzleColMajor(w)`: zigzag pattern down columns, with `w`-wide steps
13/// - `SwizzleRowMajor(w)`: zigzag pattern across rows, with `w`-wide steps
14///
15/// Special cases:
16/// - `SwizzleColMajor(1)` is equivalent to `ColMajor`
17/// - `SwizzleRowMajor(1)` is equivalent to `RowMajor`
18#[allow(clippy::enum_variant_names)]
19pub enum GlobalOrder {
20    #[default]
21    RowMajor,
22    ColMajor,
23    SwizzleRowMajor(u32),
24    SwizzleColMajor(u32),
25}
26
27impl GlobalOrder {
28    /// Since they are equivalent but the latter form will skip some calculations,
29    /// - `SwizzleColMajor(1)` becomes `ColMajor`
30    /// - `SwizzleRowMajor(1)` becomes `RowMajor`
31    pub fn canonicalize(self) -> Self {
32        match self {
33            GlobalOrder::SwizzleColMajor(1) => GlobalOrder::ColMajor,
34            GlobalOrder::SwizzleRowMajor(1) => GlobalOrder::RowMajor,
35            _ => self,
36        }
37    }
38}
39
40#[derive(Default)]
41/// Used to create [GlobalOrder].
42#[allow(unused)]
43pub enum GlobalOrderSelection {
44    /// It creates the default global order.
45    #[default]
46    Default,
47    /// Set a global order.
48    Fixed(GlobalOrder),
49    /// Creates swizzle row global order if possible.
50    ///
51    /// Fallbacks to row global order otherwise.
52    SwizzleRow { m: u32, w: u32 },
53    /// Creates swizzle col global order if possible.
54    ///
55    /// Fallbacks to col global order otherwise.
56    SwizzleCol { n: u32, w: u32 },
57}
58
59impl GlobalOrderSelection {
60    pub fn into_order(self, span: &CubeSpan) -> GlobalOrder {
61        match self {
62            GlobalOrderSelection::Default => GlobalOrder::default(),
63            GlobalOrderSelection::Fixed(order) => order,
64            GlobalOrderSelection::SwizzleRow { m, w } => {
65                let m_cubes = m.div_ceil(span.m);
66                if m_cubes % w != 0 {
67                    GlobalOrder::RowMajor
68                } else {
69                    GlobalOrder::SwizzleRowMajor(w)
70                }
71            }
72            GlobalOrderSelection::SwizzleCol { n, w } => {
73                let n_cubes = n.div_ceil(span.n);
74                if n_cubes % w != 0 {
75                    GlobalOrder::RowMajor
76                } else {
77                    GlobalOrder::SwizzleRowMajor(w)
78                }
79            }
80        }
81        .canonicalize()
82    }
83}
84
85#[cube]
86/// Maps a linear `index` to 2D zigzag coordinates `(x, y)` within horizontal or vertical strips.
87///
88/// Each strip is made of `num_steps` steps, each of length `step_length`.
89/// Strips alternate direction: even strips go top-down, odd strips bottom-up.
90/// Steps alternate direction: even steps go left-to-right, odd steps right-to-left.
91///
92/// - Prefer **odd `num_steps`** for smoother transitions between strips.
93/// - Prefer **power-of-two `step_length`** for better performance.
94///
95/// # Parameters
96/// - `index`: linear input index
97/// - `num_steps`: number of snaking steps in a strip
98/// - `step_length`: number of elements in each step (must be > 0)
99///
100/// # Returns
101/// `(x, y)` coordinates after swizzling
102pub fn swizzle(index: u32, num_steps: u32, #[comptime] step_length: u32) -> Coords2d {
103    comptime!(assert!(step_length > 0));
104
105    let num_elements_per_strip = num_steps * step_length;
106    let strip_index = index / num_elements_per_strip;
107    let pos_in_strip = index % num_elements_per_strip;
108    let strip_offset = step_length * strip_index;
109
110    // Indices without regards to direction
111    let abs_step_index = pos_in_strip / step_length;
112    let abs_pos_in_step = pos_in_strip % step_length;
113
114    // Top-down (0) or Bottom-up (1)
115    let strip_direction = strip_index % 2;
116    // Left-right (0) or Right-left (1)
117    let step_direction = abs_step_index % 2;
118
119    // Update indices with direction
120    let step_index =
121        strip_direction * (num_steps - abs_step_index - 1) + (1 - strip_direction) * abs_step_index;
122
123    let pos_in_step = if comptime!(step_length & (step_length - 1) == 0) {
124        abs_pos_in_step ^ (step_direction * (step_length - 1))
125    } else {
126        step_direction * (step_length - abs_pos_in_step - 1)
127            + (1 - step_direction) * abs_pos_in_step
128    };
129
130    (step_index, pos_in_step + strip_offset)
131}