cubecl_linalg/matmul/components/batch/
one_to_one.rs1use std::marker::PhantomData;
2
3use crate::matmul::components::{
4 Args, EA, EI, EO, ES, Ident, InputRuntimeArg, InvalidConfigError, MatmulConfigFactory,
5 MatmulLaunch, MatmulPrecision, MatmulProblem, MatmulSpec, OutputRuntimeArg, TilingDimensions,
6 batch::{self, shared::gmm_execute},
7 config::MatmulConfig,
8 global::{self, GlobalMatmul, GlobalMatmulFamily, Quantization},
9};
10use crate::matmul::kernels::MatmulAvailabilityError;
11use batch::{BatchMatmul, BatchMatmulFamily};
12use cubecl_core as cubecl;
13use cubecl_core::prelude::*;
14use cubecl_std::CubeOption;
15use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
16
17use super::{BatchConfig as _, CubeDispatch};
18
19pub struct OneToOneMatmulFamily<GMM: GlobalMatmulFamily, C: CubeDispatch> {
20 _gmm: PhantomData<GMM>,
21 _c: PhantomData<C>,
22}
23
24impl<GMM: GlobalMatmulFamily, C: CubeDispatch> BatchMatmulFamily for OneToOneMatmulFamily<GMM, C> {
25 type Matmul<MP: MatmulPrecision> = OneToOneMatmul<MP, GMM::Matmul<MP>, C>;
26}
27
28impl<GMM: GlobalMatmulFamily, C: CubeDispatch> MatmulConfigFactory
29 for OneToOneMatmulFamily<GMM, C>
30{
31 type Input = GMM::Input;
32 type Config = Config<GMM::Config, C>;
33
34 fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
35 GMM::check_config(&config.to_gmm_config())
36 }
37
38 fn check_availability<R: Runtime, MP: MatmulPrecision>(
39 client: &ComputeClient<R::Server, R::Channel>,
40 config: &Self::Config,
41 ) -> Result<(), MatmulAvailabilityError> {
42 GMM::check_availability::<R, MP>(client, &config.gmm_config)
43 }
44
45 fn make_config(
46 input: Self::Input,
47 problem: &MatmulProblem,
48 cube_dim: &CubeDim,
49 cube_count: &CubeCount,
50 quantized: bool,
51 ) -> Self::Config {
52 let gmm_config = GMM::make_config(input, problem, cube_dim, cube_count, quantized);
53 let cube_count = if let CubeCount::Static(x, y, z) = cube_count {
54 (*x, *y, *z)
55 } else {
56 panic!("Dynamic cube count unsupported")
57 };
58
59 Config::<GMM::Config, C>::new(gmm_config, cube_count, quantized)
60 }
61}
62
63impl<GMM: GlobalMatmulFamily, C: CubeDispatch> MatmulLaunch for OneToOneMatmulFamily<GMM, C> {
64 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
65 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
66 cube_dim: CubeDim,
67 cube_count: CubeCount,
68 input: InputRuntimeArg<'a, MS, R>,
69 output: OutputRuntimeArg<'a, MS, R>,
70 size_k: ScalarArg<u32>,
71 config: Self::Config,
72 ) {
73 unsafe {
74 super::matmul::launch_unchecked::<Args<MS>, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
75 client, cube_count, cube_dim, input, output, size_k, config,
76 );
77 }
78 }
79}
80
81pub struct OneToOneMatmul<MP: MatmulPrecision, GMM: GlobalMatmul<MP>, C: CubeDispatch> {
87 _mp: PhantomData<MP>,
88 _gmm: PhantomData<GMM>,
89 _c: PhantomData<C>,
90}
91
92#[cube]
93impl<MP: MatmulPrecision, GMM: GlobalMatmul<MP>, C: CubeDispatch> BatchMatmul<MP>
94 for OneToOneMatmul<MP, GMM, C>
95{
96 type Config = Config<GMM::Config, C>;
97
98 fn execute(
99 lhs: VirtualTensor<MP::EI>,
100 rhs: VirtualTensor<MP::EI>,
101 out: VirtualTensor<MP::EO, ReadWrite>,
102 size_k: u32,
103 quantization: CubeOption<Quantization<MP>>,
104 #[comptime] config: Self::Config,
105 ) {
106 let (x_index, y_index) = C::x_y_indices();
107 let x_offset = x_index * config.tiling_dimensions(Ident::Lhs).total_row();
108 let y_offset = y_index * config.tiling_dimensions(Ident::Rhs).total_col();
109 let nth_batch = C::batch_index();
110 let k_range = (0, size_k);
111
112 let gmm_config = config.to_gmm_config();
113
114 gmm_execute::<MP, GMM>(
115 lhs,
116 rhs,
117 out,
118 x_offset,
119 y_offset,
120 nth_batch,
121 &mut GMM::init_accumulator(gmm_config),
122 k_range,
123 quantization,
124 gmm_config,
125 );
126 }
127}
128
129#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
130pub struct Config<G: global::GlobalConfig, C: CubeDispatch> {
132 gmm_config: G,
133 cube_count: (u32, u32, u32),
134 quantized: bool,
135 _c: PhantomData<C>,
136}
137
138impl<G: global::GlobalConfig, C: CubeDispatch> batch::BatchConfig for Config<G, C> {
139 type GmmConfig = G;
140
141 fn to_gmm_config(&self) -> Self::GmmConfig {
142 self.gmm_config
143 }
144
145 fn tiling_dimensions(&self, ident: Ident) -> TilingDimensions {
146 self.gmm_config.tiling_dimensions(ident)
147 }
148
149 fn max_m(&self) -> u32 {
150 C::max_x(self.cube_count) * self.tiling_dimensions(Ident::Out).total_row()
151 }
152
153 fn max_n(&self) -> u32 {
154 C::max_y(self.cube_count) * self.tiling_dimensions(Ident::Out).total_col()
155 }
156
157 fn max_batches(&self) -> u32 {
158 C::max_batches(self.cube_count)
159 }
160
161 fn quantized(&self) -> bool {
162 self.quantized
163 }
164}
165
166impl<G: global::GlobalConfig, C: CubeDispatch> MatmulConfig for Config<G, C> {}
167
168impl<G: global::GlobalConfig, C: CubeDispatch> Config<G, C> {
169 pub fn new(gmm_config: G, cube_count: (u32, u32, u32), quantized: bool) -> Self {
170 Self {
171 gmm_config,
172 cube_count,
173 quantized,
174 _c: PhantomData,
175 }
176 }
177}