cubecl_linalg/matmul/components/batch/
one_to_many.rs

1use std::marker::PhantomData;
2
3use crate::matmul::components::batch::span::{Span, SpanDim, SpanMatmul};
4use crate::matmul::components::global::GlobalMatmulFamily;
5use crate::matmul::components::global::Quantization;
6use crate::matmul::components::{
7    Args, EA, EI, EO, ES, InputRuntimeArg, InvalidConfigError, MatmulPrecision, MatmulProblem,
8    MatmulSpec, OutputRuntimeArg,
9};
10use crate::matmul::components::{
11    Ident, MatmulConfigFactory, MatmulLaunch, TilingDimensions, batch, config::MatmulConfig, global,
12};
13use crate::matmul::kernels::MatmulAvailabilityError;
14use cubecl_core as cubecl;
15use cubecl_core::prelude::*;
16use cubecl_std::CubeOption;
17use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
18
19use super::{BatchConfig as _, BatchMatmulFamily, CubeDispatch};
20
21pub struct OneToManyMatmulFamily<GMM: GlobalMatmulFamily, S: SpanMatmul, C: CubeDispatch> {
22    _gmm: PhantomData<GMM>,
23    _s: PhantomData<S>,
24    _c: PhantomData<C>,
25}
26
27impl<GMM: GlobalMatmulFamily, S: SpanMatmul, C: CubeDispatch> BatchMatmulFamily
28    for OneToManyMatmulFamily<GMM, S, C>
29{
30    type Matmul<MP: MatmulPrecision> = OneToManyMatmul<MP, GMM::Matmul<MP>, S, C>;
31}
32
33impl<GMM: GlobalMatmulFamily, S: SpanMatmul, C: CubeDispatch> MatmulConfigFactory
34    for OneToManyMatmulFamily<GMM, S, C>
35{
36    type Config = Config<GMM::Config, C>;
37    type Input = GMM::Input;
38
39    fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
40        GMM::check_config(&config.to_gmm_config())
41    }
42
43    fn check_availability<R: Runtime, MP: MatmulPrecision>(
44        client: &ComputeClient<R::Server, R::Channel>,
45        config: &Self::Config,
46    ) -> Result<(), MatmulAvailabilityError> {
47        GMM::check_availability::<R, MP>(client, &config.gmm_config)
48    }
49
50    fn make_config(
51        input: Self::Input,
52        problem: &MatmulProblem,
53        cube_dim: &CubeDim,
54        cube_count: &CubeCount,
55        quantized: bool,
56    ) -> Self::Config {
57        let gmm_config = GMM::make_config(input, problem, cube_dim, cube_count, quantized);
58        let cube_count = if let CubeCount::Static(x, y, z) = cube_count {
59            (*x, *y, *z)
60        } else {
61            panic!("Dynamic cube count unsupported")
62        };
63
64        Config::new(gmm_config, cube_count, quantized)
65    }
66}
67
68impl<GMM: GlobalMatmulFamily, S: SpanMatmul, C: CubeDispatch> MatmulLaunch
69    for OneToManyMatmulFamily<GMM, S, C>
70{
71    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
72        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
73        cube_dim: CubeDim,
74        cube_count: CubeCount,
75        input: InputRuntimeArg<'a, MS, R>,
76        output: OutputRuntimeArg<'a, MS, R>,
77        size_k: ScalarArg<u32>,
78        config: Self::Config,
79    ) {
80        unsafe {
81            super::matmul::launch_unchecked::<Args<MS>, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
82                client, cube_count, cube_dim, input, output, size_k, config,
83            );
84        }
85    }
86}
87
88/// Executes matrix multiplication at the batch level,
89/// assigning each cube to handle multiple global matmuls.
90///
91/// The algorithm supports any number of cubes,
92/// looping as needed to process all data.
93pub struct OneToManyMatmul<
94    MP: MatmulPrecision,
95    GMM: global::GlobalMatmul<MP>,
96    S: SpanMatmul,
97    C: CubeDispatch,
98> {
99    _mp: PhantomData<MP>,
100    _gmm: PhantomData<GMM>,
101    _s: PhantomData<S>,
102    _c: PhantomData<C>,
103}
104
105#[cube]
106impl<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>, S: SpanMatmul, C: CubeDispatch>
107    batch::BatchMatmul<MP> for OneToManyMatmul<MP, GMM, S, C>
108{
109    type Config = Config<GMM::Config, C>;
110
111    fn execute(
112        lhs: VirtualTensor<MP::EI>,
113        rhs: VirtualTensor<MP::EI>,
114        out: VirtualTensor<MP::EO, ReadWrite>,
115        _size_k: u32,
116        quantization: CubeOption<Quantization<MP>>,
117        #[comptime] config: Self::Config,
118    ) {
119        let rank = out.rank();
120        let shape_x = out.shape(rank - 2);
121        let shape_y = out.shape(rank - 1);
122
123        let mut shape_z = 1;
124        for b in 0..rank - 2 {
125            shape_z *= out.shape(b);
126        }
127
128        let cubes_x = config.cube_count_x();
129        let cubes_y = config.cube_count_y();
130        let cubes_z = config.cube_count_batch();
131
132        let stage_x = config.tiling_dimensions(Ident::Out).total_row();
133        let stage_y = config.tiling_dimensions(Ident::Out).total_col();
134        let stage_z = 1;
135
136        let (x_index, y_index) = C::x_y_indices();
137        let batch_index = C::batch_index();
138
139        let span = Span::new(
140            SpanDim::new(shape_x, stage_x, x_index, cubes_x),
141            SpanDim::new(shape_y, stage_y, y_index, cubes_y),
142            SpanDim::new(shape_z, stage_z, batch_index, cubes_z),
143        );
144
145        let k_range = (0, lhs.shape(rank - 1));
146
147        let gmm_config = config.to_gmm_config();
148        let acc = GMM::init_accumulator(gmm_config);
149        S::execute::<MP, GMM>(lhs, rhs, out, span, acc, k_range, quantization, gmm_config);
150    }
151}
152
153#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
154/// Configuration for the OneToOneBatchMatmul
155pub struct Config<G: global::GlobalConfig, C: CubeDispatch> {
156    gmm_config: G,
157    cube_count: (u32, u32, u32),
158    quantized: bool,
159    _c: PhantomData<C>,
160}
161
162impl<G: global::GlobalConfig, C: CubeDispatch> batch::BatchConfig for Config<G, C> {
163    type GmmConfig = G;
164
165    fn to_gmm_config(&self) -> Self::GmmConfig {
166        self.gmm_config
167    }
168
169    fn tiling_dimensions(&self, ident: Ident) -> TilingDimensions {
170        self.gmm_config.tiling_dimensions(ident)
171    }
172
173    fn max_m(&self) -> u32 {
174        u32::maximum_value()
175    }
176
177    fn max_n(&self) -> u32 {
178        u32::maximum_value()
179    }
180
181    fn max_batches(&self) -> u32 {
182        u32::maximum_value()
183    }
184
185    fn quantized(&self) -> bool {
186        self.quantized
187    }
188}
189
190impl<G: global::GlobalConfig, C: CubeDispatch> MatmulConfig for Config<G, C> {}
191
192impl<G: global::GlobalConfig, C: CubeDispatch> Config<G, C> {
193    pub fn new(gmm_config: G, cube_count: (u32, u32, u32), quantized: bool) -> Self {
194        Self {
195            gmm_config,
196            cube_count,
197            quantized,
198            _c: PhantomData,
199        }
200    }
201
202    fn cube_count_x(&self) -> u32 {
203        C::max_x(self.cube_count)
204    }
205
206    fn cube_count_y(&self) -> u32 {
207        C::max_y(self.cube_count)
208    }
209
210    fn cube_count_batch(&self) -> u32 {
211        C::max_batches(self.cube_count)
212    }
213}