cubecl_linalg/matmul/components/batch/
cube_dispatch.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3use std::fmt::Debug;
4use std::hash::Hash;
5
6use crate::matmul::components::batch::shared::swizzle;
7
8#[cube]
9pub trait CubeDispatch: Clone + Copy + 'static + Send + Sync + Debug + Hash + Eq {
11 fn x_y_indices() -> (u32, u32);
12 fn batch_index() -> u32;
13 fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32);
14 fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32);
15 fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32);
16}
17
18pub trait CubeCountDispatch {
19 fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount;
20}
21
22#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
23pub struct NaturalDispatch;
26
27#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
28pub struct TransposedDispatch;
31
32#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
33pub struct SwizzleNaturalDispatch<const W: u32>;
38
39#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
40pub struct SwizzleTransposedDispatch<const W: u32>;
45
46#[cube]
47impl CubeDispatch for NaturalDispatch {
48 fn x_y_indices() -> (u32, u32) {
49 (CUBE_POS_X, CUBE_POS_Y)
50 }
51
52 fn batch_index() -> u32 {
53 CUBE_POS_Z
54 }
55
56 fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
57 cube_count.0
58 }
59
60 fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
61 cube_count.1
62 }
63
64 fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
65 cube_count.2
66 }
67}
68
69impl CubeCountDispatch for NaturalDispatch {
70 fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount {
71 CubeCount::Static(cubes_for_m, cubes_for_n, cubes_for_batches)
72 }
73}
74
75#[cube]
76impl CubeDispatch for TransposedDispatch {
77 fn x_y_indices() -> (u32, u32) {
78 (CUBE_POS_Y, CUBE_POS_X)
79 }
80
81 fn batch_index() -> u32 {
82 CUBE_POS_Z
83 }
84
85 fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
86 cube_count.1
87 }
88
89 fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
90 cube_count.0
91 }
92
93 fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
94 cube_count.2
95 }
96}
97
98impl CubeCountDispatch for TransposedDispatch {
99 fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount {
100 CubeCount::Static(cubes_for_n, cubes_for_m, cubes_for_batches)
101 }
102}
103
104#[cube]
105impl<const W: u32> CubeDispatch for SwizzleNaturalDispatch<W> {
106 fn x_y_indices() -> (u32, u32) {
107 let height = CUBE_COUNT_X;
108 let nth_cube = CUBE_POS_Y * height + CUBE_POS_X;
109 swizzle(nth_cube, height, W)
110 }
111
112 fn batch_index() -> u32 {
113 CUBE_POS_Z
114 }
115
116 fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
117 cube_count.0
118 }
119
120 fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
121 cube_count.1
122 }
123
124 fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
125 cube_count.2
126 }
127}
128
129impl<const W: u32> CubeCountDispatch for SwizzleNaturalDispatch<W> {
130 fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount {
131 CubeCount::Static(cubes_for_m, cubes_for_n, cubes_for_batches)
132 }
133}
134
135#[cube]
136impl<const W: u32> CubeDispatch for SwizzleTransposedDispatch<W> {
137 fn x_y_indices() -> (u32, u32) {
138 let height = CUBE_COUNT_Y;
139 let nth_cube = CUBE_POS_X * height + CUBE_POS_Y;
140 swizzle(nth_cube, height, W)
141 }
142
143 fn batch_index() -> u32 {
144 CUBE_POS_Z
145 }
146
147 fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
148 cube_count.1
149 }
150
151 fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
152 cube_count.0
153 }
154
155 fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> comptime_type!(u32) {
156 cube_count.2
157 }
158}
159
160impl<const W: u32> CubeCountDispatch for SwizzleTransposedDispatch<W> {
161 fn cube_count(cubes_for_m: u32, cubes_for_n: u32, cubes_for_batches: u32) -> CubeCount {
162 CubeCount::Static(cubes_for_n, cubes_for_m, cubes_for_batches)
163 }
164}