cubecl_linalg/matmul/components/batch/
base.rs

1use crate::matmul::components::{
2    Ident, MatmulLaunch, MatmulPrecision, Quantized, TilingDimensions,
3    config::MatmulConfig,
4    global::{
5        self, Quantization,
6        args::{self, MatmulArgs, TensorInput, TensorOutput},
7    },
8};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::*;
11use cubecl_std::{
12    CubeOption,
13    tensor::r#virtual::{ReadWrite, VirtualTensor},
14};
15
16/// A family of [matmuls](BatchMatmul) working with any [precision](MatmulPrecision).
17pub trait BatchMatmulFamily: 'static + Send + Sync + MatmulLaunch<Config: BatchConfig> {
18    type Matmul<MP: MatmulPrecision>: BatchMatmul<MP, Config = Self::Config>;
19}
20
21#[cube]
22/// Provides matrix multiplication operations at the batch level.
23///
24/// At the batch level,
25///  - Inputs are whole tensors in global memory.
26///  - All Cubes can collaborate to solve the problem
27///  - Dimensions M, N and K can be arbitrary large,
28///    as well as the number of batches.
29///
30/// # Assumptions
31/// - Line sizes of the inputs evenly divide the dimension they are aligned with.
32/// - Enough Cubes are launched to perform the whole computation.
33///
34/// # Safety
35///
36/// It is not assumed that the matmul's dimensions match its inputs dimensions perfectly.
37/// It is therefore important to use an underlying global matmul that performs check bounds,
38/// and to not launch more Cubes than necessary.
39pub trait BatchMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
40    type Config: BatchConfig;
41
42    /// Performs batchwise matrix multiplication over tensors.
43    fn execute(
44        lhs: VirtualTensor<MP::EI>,
45        rhs: VirtualTensor<MP::EI>,
46        out: VirtualTensor<MP::EO, ReadWrite>,
47        size_k: u32,
48        quantization: CubeOption<Quantization<MP>>,
49        #[comptime] config: Self::Config,
50    );
51}
52
53/// Configuration for the [batch matmul](BatchMatmul) level.
54pub trait BatchConfig: MatmulConfig {
55    /// Underlying Global matmul config
56    type GmmConfig: global::GlobalConfig;
57
58    /// Convert itself to the underlying global matmul config
59    fn to_gmm_config(&self) -> Self::GmmConfig;
60
61    /// Returns the [StageDim] for the given ident
62    fn tiling_dimensions(&self, ident: Ident) -> TilingDimensions;
63
64    /// Returns the largest m dimension supported with these configs
65    fn max_m(&self) -> u32;
66
67    /// Returns the largest n dimension supported with these configs
68    fn max_n(&self) -> u32;
69
70    /// Returns the largest number of batches supported with these configs
71    fn max_batches(&self) -> u32;
72
73    /// Returns true if the matmul is quantized.
74    fn quantized(&self) -> bool;
75}
76
77type Input<Args, EI> = <Args as MatmulArgs>::Input<EI>;
78type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
79
80#[cube(launch_unchecked)]
81pub(crate) fn matmul<
82    Args: MatmulArgs,
83    EI: Numeric,
84    ES: Numeric,
85    EA: Numeric,
86    EO: Numeric,
87    BMM: BatchMatmulFamily,
88>(
89    inputs: &Input<Args, EI>,
90    output: &mut Output<Args, EO>,
91    size_k: u32,
92    #[comptime] config: BMM::Config,
93) {
94    let mut state = Args::init_state(inputs, output);
95
96    let lhs = TensorInput::<EI, EO, Args>::new(&state, args::TensorInputIdent::Lhs);
97    let rhs = TensorInput::<EI, EO, Args>::new(&state, args::TensorInputIdent::Rhs);
98    let mut out = TensorOutput::<EI, EO, Args>::new(&mut state);
99
100    let lhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&lhs);
101    let rhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&rhs);
102    let out = VirtualTensor::<EO, ReadWrite>::new::<TensorOutput<EI, EO, Args>>(&mut out);
103
104    if config.quantized() {
105        let quantization = Args::quantization::<(EI, ES, EA, EO, Quantized)>(&state);
106        BMM::Matmul::<(EI, ES, EA, EO, Quantized)>::execute(
107            lhs,
108            rhs,
109            out,
110            size_k,
111            CubeOption::new_Some(quantization),
112            config,
113        );
114    } else {
115        BMM::Matmul::<(EI, ES, EA, EO)>::execute(
116            lhs,
117            rhs,
118            out,
119            size_k,
120            CubeOption::new_None(),
121            config,
122        );
123    };
124}