cubecl_linalg/matmul/components/batch/
one_to_one.rs

1use std::marker::PhantomData;
2
3use crate::matmul::components::{
4    Args, EA, EI, EO, ES, Ident, InputRuntimeArg, InvalidConfigError, MatmulConfigFactory,
5    MatmulLaunch, MatmulPrecision, MatmulProblem, MatmulSpec, OutputRuntimeArg, TilingDimensions,
6    batch::{self, shared::gmm_execute},
7    config::MatmulConfig,
8    global::{self, GlobalMatmul, GlobalMatmulFamily, Quantization},
9};
10use crate::matmul::kernels::MatmulAvailabilityError;
11use batch::{BatchMatmul, BatchMatmulFamily};
12use cubecl_core as cubecl;
13use cubecl_core::prelude::*;
14use cubecl_std::CubeOption;
15use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
16
17use super::{BatchConfig as _, CubeDispatch};
18
19pub struct OneToOneMatmulFamily<GMM: GlobalMatmulFamily, C: CubeDispatch> {
20    _gmm: PhantomData<GMM>,
21    _c: PhantomData<C>,
22}
23
24impl<GMM: GlobalMatmulFamily, C: CubeDispatch> BatchMatmulFamily for OneToOneMatmulFamily<GMM, C> {
25    type Matmul<MP: MatmulPrecision> = OneToOneMatmul<MP, GMM::Matmul<MP>, C>;
26}
27
28impl<GMM: GlobalMatmulFamily, C: CubeDispatch> MatmulConfigFactory
29    for OneToOneMatmulFamily<GMM, C>
30{
31    type Input = GMM::Input;
32    type Config = Config<GMM::Config, C>;
33
34    fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
35        GMM::check_config(&config.to_gmm_config())
36    }
37
38    fn check_availability<R: Runtime, MP: MatmulPrecision>(
39        client: &ComputeClient<R::Server, R::Channel>,
40        config: &Self::Config,
41    ) -> Result<(), MatmulAvailabilityError> {
42        GMM::check_availability::<R, MP>(client, &config.gmm_config)
43    }
44
45    fn make_config(
46        input: Self::Input,
47        problem: &MatmulProblem,
48        cube_dim: &CubeDim,
49        cube_count: &CubeCount,
50        quantized: bool,
51    ) -> Self::Config {
52        let gmm_config = GMM::make_config(input, problem, cube_dim, cube_count, quantized);
53        let cube_count = if let CubeCount::Static(x, y, z) = cube_count {
54            (*x, *y, *z)
55        } else {
56            panic!("Dynamic cube count unsupported")
57        };
58
59        Config::<GMM::Config, C>::new(gmm_config, cube_count, quantized)
60    }
61}
62
63impl<GMM: GlobalMatmulFamily, C: CubeDispatch> MatmulLaunch for OneToOneMatmulFamily<GMM, C> {
64    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
65        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
66        cube_dim: CubeDim,
67        cube_count: CubeCount,
68        input: InputRuntimeArg<'a, MS, R>,
69        output: OutputRuntimeArg<'a, MS, R>,
70        size_k: ScalarArg<u32>,
71        config: Self::Config,
72    ) {
73        unsafe {
74            super::matmul::launch_unchecked::<Args<MS>, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
75                client, cube_count, cube_dim, input, output, size_k, config,
76            );
77        }
78    }
79}
80
81/// Executes matrix multiplication at the batch level,
82/// assigning each cube to a single global matmul.
83///
84/// Note: This algorithm requires one cube per global matmul;
85/// insufficient cubes will result in incomplete computations.
86pub struct OneToOneMatmul<MP: MatmulPrecision, GMM: GlobalMatmul<MP>, C: CubeDispatch> {
87    _mp: PhantomData<MP>,
88    _gmm: PhantomData<GMM>,
89    _c: PhantomData<C>,
90}
91
92#[cube]
93impl<MP: MatmulPrecision, GMM: GlobalMatmul<MP>, C: CubeDispatch> BatchMatmul<MP>
94    for OneToOneMatmul<MP, GMM, C>
95{
96    type Config = Config<GMM::Config, C>;
97
98    fn execute(
99        lhs: VirtualTensor<MP::EI>,
100        rhs: VirtualTensor<MP::EI>,
101        out: VirtualTensor<MP::EO, ReadWrite>,
102        size_k: u32,
103        quantization: CubeOption<Quantization<MP>>,
104        #[comptime] config: Self::Config,
105    ) {
106        let (x_index, y_index) = C::x_y_indices();
107        let x_offset = x_index * config.tiling_dimensions(Ident::Lhs).total_row();
108        let y_offset = y_index * config.tiling_dimensions(Ident::Rhs).total_col();
109        let nth_batch = C::batch_index();
110        let k_range = (0, size_k);
111
112        let gmm_config = config.to_gmm_config();
113
114        gmm_execute::<MP, GMM>(
115            lhs,
116            rhs,
117            out,
118            x_offset,
119            y_offset,
120            nth_batch,
121            &mut GMM::init_accumulator(gmm_config),
122            k_range,
123            quantization,
124            gmm_config,
125        );
126    }
127}
128
129#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
130/// Configuration for the OneToOneBatchMatmul
131pub struct Config<G: global::GlobalConfig, C: CubeDispatch> {
132    gmm_config: G,
133    cube_count: (u32, u32, u32),
134    quantized: bool,
135    _c: PhantomData<C>,
136}
137
138impl<G: global::GlobalConfig, C: CubeDispatch> batch::BatchConfig for Config<G, C> {
139    type GmmConfig = G;
140
141    fn to_gmm_config(&self) -> Self::GmmConfig {
142        self.gmm_config
143    }
144
145    fn tiling_dimensions(&self, ident: Ident) -> TilingDimensions {
146        self.gmm_config.tiling_dimensions(ident)
147    }
148
149    fn max_m(&self) -> u32 {
150        C::max_x(self.cube_count) * self.tiling_dimensions(Ident::Out).total_row()
151    }
152
153    fn max_n(&self) -> u32 {
154        C::max_y(self.cube_count) * self.tiling_dimensions(Ident::Out).total_col()
155    }
156
157    fn max_batches(&self) -> u32 {
158        C::max_batches(self.cube_count)
159    }
160
161    fn quantized(&self) -> bool {
162        self.quantized
163    }
164}
165
166impl<G: global::GlobalConfig, C: CubeDispatch> MatmulConfig for Config<G, C> {}
167
168impl<G: global::GlobalConfig, C: CubeDispatch> Config<G, C> {
169    pub fn new(gmm_config: G, cube_count: (u32, u32, u32), quantized: bool) -> Self {
170        Self {
171            gmm_config,
172            cube_count,
173            quantized,
174            _c: PhantomData,
175        }
176    }
177}