1use std::marker::PhantomData;
2
3use crate::matmul::components::batch::span::{Span, SpanDim, SpanMatmul};
4use crate::matmul::components::global::GlobalMatmulFamily;
5use crate::matmul::components::global::Quantization;
6use crate::matmul::components::{
7 Args, EA, EI, EO, ES, InputRuntimeArg, InvalidConfigError, MatmulPrecision, MatmulProblem,
8 MatmulSpec, OutputRuntimeArg,
9};
10use crate::matmul::components::{
11 Ident, MatmulConfigFactory, MatmulLaunch, TilingDimensions, batch, config::MatmulConfig, global,
12};
13use crate::matmul::kernels::MatmulAvailabilityError;
14use cubecl_core as cubecl;
15use cubecl_core::prelude::*;
16use cubecl_std::CubeOption;
17use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
18
19use super::{BatchConfig as _, BatchMatmulFamily, CubeDispatch};
20
21pub struct OneToManyMatmulFamily<GMM: GlobalMatmulFamily, S: SpanMatmul, C: CubeDispatch> {
22 _gmm: PhantomData<GMM>,
23 _s: PhantomData<S>,
24 _c: PhantomData<C>,
25}
26
27impl<GMM: GlobalMatmulFamily, S: SpanMatmul, C: CubeDispatch> BatchMatmulFamily
28 for OneToManyMatmulFamily<GMM, S, C>
29{
30 type Matmul<MP: MatmulPrecision> = OneToManyMatmul<MP, GMM::Matmul<MP>, S, C>;
31}
32
33impl<GMM: GlobalMatmulFamily, S: SpanMatmul, C: CubeDispatch> MatmulConfigFactory
34 for OneToManyMatmulFamily<GMM, S, C>
35{
36 type Config = Config<GMM::Config, C>;
37 type Input = GMM::Input;
38
39 fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
40 GMM::check_config(&config.to_gmm_config())
41 }
42
43 fn check_availability<R: Runtime, MP: MatmulPrecision>(
44 client: &ComputeClient<R::Server, R::Channel>,
45 config: &Self::Config,
46 ) -> Result<(), MatmulAvailabilityError> {
47 GMM::check_availability::<R, MP>(client, &config.gmm_config)
48 }
49
50 fn make_config(
51 input: Self::Input,
52 problem: &MatmulProblem,
53 cube_dim: &CubeDim,
54 cube_count: &CubeCount,
55 quantized: bool,
56 ) -> Self::Config {
57 let gmm_config = GMM::make_config(input, problem, cube_dim, cube_count, quantized);
58 let cube_count = if let CubeCount::Static(x, y, z) = cube_count {
59 (*x, *y, *z)
60 } else {
61 panic!("Dynamic cube count unsupported")
62 };
63
64 Config::new(gmm_config, cube_count, quantized)
65 }
66}
67
68impl<GMM: GlobalMatmulFamily, S: SpanMatmul, C: CubeDispatch> MatmulLaunch
69 for OneToManyMatmulFamily<GMM, S, C>
70{
71 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
72 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
73 cube_dim: CubeDim,
74 cube_count: CubeCount,
75 input: InputRuntimeArg<'a, MS, R>,
76 output: OutputRuntimeArg<'a, MS, R>,
77 size_k: ScalarArg<u32>,
78 config: Self::Config,
79 ) {
80 unsafe {
81 super::matmul::launch_unchecked::<Args<MS>, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
82 client, cube_count, cube_dim, input, output, size_k, config,
83 );
84 }
85 }
86}
87
88pub struct OneToManyMatmul<
94 MP: MatmulPrecision,
95 GMM: global::GlobalMatmul<MP>,
96 S: SpanMatmul,
97 C: CubeDispatch,
98> {
99 _mp: PhantomData<MP>,
100 _gmm: PhantomData<GMM>,
101 _s: PhantomData<S>,
102 _c: PhantomData<C>,
103}
104
105#[cube]
106impl<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>, S: SpanMatmul, C: CubeDispatch>
107 batch::BatchMatmul<MP> for OneToManyMatmul<MP, GMM, S, C>
108{
109 type Config = Config<GMM::Config, C>;
110
111 fn execute(
112 lhs: VirtualTensor<MP::EI>,
113 rhs: VirtualTensor<MP::EI>,
114 out: VirtualTensor<MP::EO, ReadWrite>,
115 _size_k: u32,
116 quantization: CubeOption<Quantization<MP>>,
117 #[comptime] config: Self::Config,
118 ) {
119 let rank = out.rank();
120 let shape_x = out.shape(rank - 2);
121 let shape_y = out.shape(rank - 1);
122
123 let mut shape_z = 1;
124 for b in 0..rank - 2 {
125 shape_z *= out.shape(b);
126 }
127
128 let cubes_x = config.cube_count_x();
129 let cubes_y = config.cube_count_y();
130 let cubes_z = config.cube_count_batch();
131
132 let stage_x = config.tiling_dimensions(Ident::Out).total_row();
133 let stage_y = config.tiling_dimensions(Ident::Out).total_col();
134 let stage_z = 1;
135
136 let (x_index, y_index) = C::x_y_indices();
137 let batch_index = C::batch_index();
138
139 let span = Span::new(
140 SpanDim::new(shape_x, stage_x, x_index, cubes_x),
141 SpanDim::new(shape_y, stage_y, y_index, cubes_y),
142 SpanDim::new(shape_z, stage_z, batch_index, cubes_z),
143 );
144
145 let k_range = (0, lhs.shape(rank - 1));
146
147 let gmm_config = config.to_gmm_config();
148 let acc = GMM::init_accumulator(gmm_config);
149 S::execute::<MP, GMM>(lhs, rhs, out, span, acc, k_range, quantization, gmm_config);
150 }
151}
152
153#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
154pub struct Config<G: global::GlobalConfig, C: CubeDispatch> {
156 gmm_config: G,
157 cube_count: (u32, u32, u32),
158 quantized: bool,
159 _c: PhantomData<C>,
160}
161
162impl<G: global::GlobalConfig, C: CubeDispatch> batch::BatchConfig for Config<G, C> {
163 type GmmConfig = G;
164
165 fn to_gmm_config(&self) -> Self::GmmConfig {
166 self.gmm_config
167 }
168
169 fn tiling_dimensions(&self, ident: Ident) -> TilingDimensions {
170 self.gmm_config.tiling_dimensions(ident)
171 }
172
173 fn max_m(&self) -> u32 {
174 u32::maximum_value()
175 }
176
177 fn max_n(&self) -> u32 {
178 u32::maximum_value()
179 }
180
181 fn max_batches(&self) -> u32 {
182 u32::maximum_value()
183 }
184
185 fn quantized(&self) -> bool {
186 self.quantized
187 }
188}
189
190impl<G: global::GlobalConfig, C: CubeDispatch> MatmulConfig for Config<G, C> {}
191
192impl<G: global::GlobalConfig, C: CubeDispatch> Config<G, C> {
193 pub fn new(gmm_config: G, cube_count: (u32, u32, u32), quantized: bool) -> Self {
194 Self {
195 gmm_config,
196 cube_count,
197 quantized,
198 _c: PhantomData,
199 }
200 }
201
202 fn cube_count_x(&self) -> u32 {
203 C::max_x(self.cube_count)
204 }
205
206 fn cube_count_y(&self) -> u32 {
207 C::max_y(self.cube_count)
208 }
209
210 fn cube_count_batch(&self) -> u32 {
211 C::max_batches(self.cube_count)
212 }
213}