cubecl_matmul/components/batch/partitioned_matmul/hypercube/
global_order.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::batch::partitioned_matmul::hypercube::base::CubeSpan;
5
6#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
7#[allow(clippy::enum_variant_names)]
18pub enum GlobalOrder {
19 #[default]
20 RowMajor,
21 ColMajor,
22 SwizzleRowMajor(u32),
23 SwizzleColMajor(u32),
24}
25
26impl GlobalOrder {
27 pub fn canonicalize(self) -> Self {
31 match self {
32 GlobalOrder::SwizzleColMajor(1) => GlobalOrder::ColMajor,
33 GlobalOrder::SwizzleRowMajor(1) => GlobalOrder::RowMajor,
34 _ => self,
35 }
36 }
37}
38
39#[derive(Default)]
40#[allow(unused)]
42pub enum GlobalOrderSelection {
43 #[default]
45 Default,
46 Fixed(GlobalOrder),
48 SwizzleRow { m: u32, w: u32 },
52 SwizzleCol { n: u32, w: u32 },
56}
57
58impl GlobalOrderSelection {
59 pub fn into_order(self, span: &CubeSpan) -> GlobalOrder {
60 match self {
61 GlobalOrderSelection::Default => GlobalOrder::default(),
62 GlobalOrderSelection::Fixed(order) => order,
63 GlobalOrderSelection::SwizzleRow { m, w } => {
64 let m_cubes = m.div_ceil(span.m);
65 if m_cubes % w != 0 {
66 GlobalOrder::RowMajor
67 } else {
68 GlobalOrder::SwizzleRowMajor(w)
69 }
70 }
71 GlobalOrderSelection::SwizzleCol { n, w } => {
72 let n_cubes = n.div_ceil(span.n);
73 if n_cubes % w != 0 {
74 GlobalOrder::RowMajor
75 } else {
76 GlobalOrder::SwizzleRowMajor(w)
77 }
78 }
79 }
80 .canonicalize()
81 }
82}
83
84#[cube]
85pub fn swizzle(index: u32, num_steps: u32, #[comptime] step_length: u32) -> (u32, u32) {
102 comptime!(assert!(step_length > 0));
103
104 let num_elements_per_strip = num_steps * step_length;
105 let strip_index = index / num_elements_per_strip;
106 let pos_in_strip = index % num_elements_per_strip;
107 let strip_offset = step_length * strip_index;
108
109 let abs_step_index = pos_in_strip / step_length;
111 let abs_pos_in_step = pos_in_strip % step_length;
112
113 let strip_direction = strip_index % 2;
115 let step_direction = abs_step_index % 2;
117
118 let step_index =
120 strip_direction * (num_steps - abs_step_index - 1) + (1 - strip_direction) * abs_step_index;
121
122 let pos_in_step = if comptime!(step_length & (step_length - 1) == 0) {
123 abs_pos_in_step ^ (step_direction * (step_length - 1))
124 } else {
125 step_direction * (step_length - abs_pos_in_step - 1)
126 + (1 - step_direction) * abs_pos_in_step
127 };
128
129 (step_index, pos_in_step + strip_offset)
130}