cubecl_linalg/matmul/components/batch/
cube_dispatch.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3use std::fmt::Debug;
4use std::hash::Hash;
5
6use crate::matmul::components::batch::shared::swizzle;
7
8#[cube]
9/// Distributes cube instances across the tensor, assigning each to compute data in distinct regions.
10pub trait CubeDispatch: Clone + Copy + 'static + Send + Sync + Debug + Hash + Eq {
11    fn x_y_indices() -> (u32, u32);
12    fn batch_index() -> u32;
13    fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32);
14    fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32);
15    fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32);
16}
17
18pub trait CubeCountDispatch {
19    fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount;
20}
21
22#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
23/// Operates on data further along the m dimension as `cube_pos_x` increases,
24/// and further along the n dimension as `cube_pos_y` increases.
25pub struct NaturalDispatch;
26
27#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
28/// Operates on data further along the m dimension as `cube_pos_x` increases,
29/// and further along the n dimension as `cube_pos_y` increases.
30pub struct TransposedDispatch;
31
32#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
33/// Processes data in a swizzled pattern, prioritizing cubes along the x-axis first.
34///
35/// # Generics
36/// - W: Width of a swizzle column
37pub struct SwizzleNaturalDispatch<const W: u32>;
38
39#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
40/// Processes data in a swizzled pattern, prioritizing cubes along the y-axis first.
41///
42/// # Generics
43/// - W: Width of a swizzle column
44pub struct SwizzleTransposedDispatch<const W: u32>;
45
46#[cube]
47impl CubeDispatch for NaturalDispatch {
48    fn x_y_indices() -> (u32, u32) {
49        (CUBE_POS_X, CUBE_POS_Y)
50    }
51
52    fn batch_index() -> u32 {
53        CUBE_POS_Z
54    }
55
56    fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
57        cube_count.0
58    }
59
60    fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
61        cube_count.1
62    }
63
64    fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
65        cube_count.2
66    }
67}
68
69impl CubeCountDispatch for NaturalDispatch {
70    fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount {
71        CubeCount::Static(cubes_for_m, cubes_for_n, cubes_for_batches)
72    }
73}
74
75#[cube]
76impl CubeDispatch for TransposedDispatch {
77    fn x_y_indices() -> (u32, u32) {
78        (CUBE_POS_Y, CUBE_POS_X)
79    }
80
81    fn batch_index() -> u32 {
82        CUBE_POS_Z
83    }
84
85    fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
86        cube_count.1
87    }
88
89    fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
90        cube_count.0
91    }
92
93    fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
94        cube_count.2
95    }
96}
97
98impl CubeCountDispatch for TransposedDispatch {
99    fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount {
100        CubeCount::Static(cubes_for_n, cubes_for_m, cubes_for_batches)
101    }
102}
103
104#[cube]
105impl<const W: u32> CubeDispatch for SwizzleNaturalDispatch<W> {
106    fn x_y_indices() -> (u32, u32) {
107        let height = CUBE_COUNT_X;
108        let nth_cube = CUBE_POS_Y * height + CUBE_POS_X;
109        swizzle(nth_cube, height, W)
110    }
111
112    fn batch_index() -> u32 {
113        CUBE_POS_Z
114    }
115
116    fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
117        cube_count.0
118    }
119
120    fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
121        cube_count.1
122    }
123
124    fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
125        cube_count.2
126    }
127}
128
129impl<const W: u32> CubeCountDispatch for SwizzleNaturalDispatch<W> {
130    fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount {
131        CubeCount::Static(cubes_for_m, cubes_for_n, cubes_for_batches)
132    }
133}
134
135#[cube]
136impl<const W: u32> CubeDispatch for SwizzleTransposedDispatch<W> {
137    fn x_y_indices() -> (u32, u32) {
138        let height = CUBE_COUNT_Y;
139        let nth_cube = CUBE_POS_X * height + CUBE_POS_Y;
140        swizzle(nth_cube, height, W)
141    }
142
143    fn batch_index() -> u32 {
144        CUBE_POS_Z
145    }
146
147    fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
148        cube_count.1
149    }
150
151    fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
152        cube_count.0
153    }
154
155    fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
156        cube_count.2
157    }
158}
159
160impl<const W: u32> CubeCountDispatch for SwizzleTransposedDispatch<W> {
161    fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount {
162        CubeCount::Static(cubes_for_n, cubes_for_m, cubes_for_batches)
163    }
164}