cubecl_matmul/components/batch/partitioned_matmul/hypercube/
global_order.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::layout::Coords2d;
4
5use crate::components::batch::partitioned_matmul::hypercube::base::CubeSpan;
6
7#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
8#[allow(clippy::enum_variant_names)]
19pub enum GlobalOrder {
20 #[default]
21 RowMajor,
22 ColMajor,
23 SwizzleRowMajor(u32),
24 SwizzleColMajor(u32),
25}
26
27impl GlobalOrder {
28 pub fn canonicalize(self) -> Self {
32 match self {
33 GlobalOrder::SwizzleColMajor(1) => GlobalOrder::ColMajor,
34 GlobalOrder::SwizzleRowMajor(1) => GlobalOrder::RowMajor,
35 _ => self,
36 }
37 }
38}
39
40#[derive(Default)]
41#[allow(unused)]
43pub enum GlobalOrderSelection {
44 #[default]
46 Default,
47 Fixed(GlobalOrder),
49 SwizzleRow { m: u32, w: u32 },
53 SwizzleCol { n: u32, w: u32 },
57}
58
59impl GlobalOrderSelection {
60 pub fn into_order(self, span: &CubeSpan) -> GlobalOrder {
61 match self {
62 GlobalOrderSelection::Default => GlobalOrder::default(),
63 GlobalOrderSelection::Fixed(order) => order,
64 GlobalOrderSelection::SwizzleRow { m, w } => {
65 let m_cubes = m.div_ceil(span.m);
66 if m_cubes % w != 0 {
67 GlobalOrder::RowMajor
68 } else {
69 GlobalOrder::SwizzleRowMajor(w)
70 }
71 }
72 GlobalOrderSelection::SwizzleCol { n, w } => {
73 let n_cubes = n.div_ceil(span.n);
74 if n_cubes % w != 0 {
75 GlobalOrder::RowMajor
76 } else {
77 GlobalOrder::SwizzleRowMajor(w)
78 }
79 }
80 }
81 .canonicalize()
82 }
83}
84
85#[cube]
86pub fn swizzle(index: u32, num_steps: u32, #[comptime] step_length: u32) -> Coords2d {
103 comptime!(assert!(step_length > 0));
104
105 let num_elements_per_strip = num_steps * step_length;
106 let strip_index = index / num_elements_per_strip;
107 let pos_in_strip = index % num_elements_per_strip;
108 let strip_offset = step_length * strip_index;
109
110 let abs_step_index = pos_in_strip / step_length;
112 let abs_pos_in_step = pos_in_strip % step_length;
113
114 let strip_direction = strip_index % 2;
116 let step_direction = abs_step_index % 2;
118
119 let step_index =
121 strip_direction * (num_steps - abs_step_index - 1) + (1 - strip_direction) * abs_step_index;
122
123 let pos_in_step = if comptime!(step_length & (step_length - 1) == 0) {
124 abs_pos_in_step ^ (step_direction * (step_length - 1))
125 } else {
126 step_direction * (step_length - abs_pos_in_step - 1)
127 + (1 - step_direction) * abs_pos_in_step
128 };
129
130 (step_index, pos_in_step + strip_offset)
131}