cubecl_linalg/matmul/components/batch/
base.rs1use crate::matmul::components::{
2 Ident, MatmulLaunch, MatmulPrecision, Quantized, TilingDimensions,
3 config::MatmulConfig,
4 global::{
5 self, Quantization,
6 args::{self, MatmulArgs, TensorInput, TensorOutput},
7 },
8};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::*;
11use cubecl_std::{
12 CubeOption,
13 tensor::r#virtual::{ReadWrite, VirtualTensor},
14};
15
16pub trait BatchMatmulFamily: 'static + Send + Sync + MatmulLaunch<Config: BatchConfig> {
18 type Matmul<MP: MatmulPrecision>: BatchMatmul<MP, Config = Self::Config>;
19}
20
21#[cube]
22pub trait BatchMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
40 type Config: BatchConfig;
41
42 fn execute(
44 lhs: VirtualTensor<MP::EI>,
45 rhs: VirtualTensor<MP::EI>,
46 out: VirtualTensor<MP::EO, ReadWrite>,
47 size_k: u32,
48 quantization: CubeOption<Quantization<MP>>,
49 #[comptime] config: Self::Config,
50 );
51}
52
53pub trait BatchConfig: MatmulConfig {
55 type GmmConfig: global::GlobalConfig;
57
58 fn to_gmm_config(&self) -> Self::GmmConfig;
60
61 fn tiling_dimensions(&self, ident: Ident) -> TilingDimensions;
63
64 fn max_m(&self) -> u32;
66
67 fn max_n(&self) -> u32;
69
70 fn max_batches(&self) -> u32;
72
73 fn quantized(&self) -> bool;
75}
76
77type Input<Args, EI> = <Args as MatmulArgs>::Input<EI>;
78type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
79
80#[cube(launch_unchecked)]
81pub(crate) fn matmul<
82 Args: MatmulArgs,
83 EI: Numeric,
84 ES: Numeric,
85 EA: Numeric,
86 EO: Numeric,
87 BMM: BatchMatmulFamily,
88>(
89 inputs: &Input<Args, EI>,
90 output: &mut Output<Args, EO>,
91 size_k: u32,
92 #[comptime] config: BMM::Config,
93) {
94 let mut state = Args::init_state(inputs, output);
95
96 let lhs = TensorInput::<EI, EO, Args>::new(&state, args::TensorInputIdent::Lhs);
97 let rhs = TensorInput::<EI, EO, Args>::new(&state, args::TensorInputIdent::Rhs);
98 let mut out = TensorOutput::<EI, EO, Args>::new(&mut state);
99
100 let lhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&lhs);
101 let rhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&rhs);
102 let out = VirtualTensor::<EO, ReadWrite>::new::<TensorOutput<EI, EO, Args>>(&mut out);
103
104 if config.quantized() {
105 let quantization = Args::quantization::<(EI, ES, EA, EO, Quantized)>(&state);
106 BMM::Matmul::<(EI, ES, EA, EO, Quantized)>::execute(
107 lhs,
108 rhs,
109 out,
110 size_k,
111 CubeOption::new_Some(quantization),
112 config,
113 );
114 } else {
115 BMM::Matmul::<(EI, ES, EA, EO)>::execute(
116 lhs,
117 rhs,
118 out,
119 size_k,
120 CubeOption::new_None(),
121 config,
122 );
123 };
124}