cubecl_matmul/components/batch/
base.rs1use crate::components::{
2 AvailableLineSizes, InputRuntimeArg, MatmulLineSizes, MatmulPrecision, MatmulProblem,
3 MatmulSelection, MatmulSpec, OutputRuntimeArg, TilingScheme,
4 batch::{CubeCountInput, CubeCountInputArgs, HypercubeConfig},
5 error::MatmulSetupError,
6 global::{self, GlobalConfig as _, Quantization},
7};
8use cubecl_core as cubecl;
9use cubecl_core::prelude::*;
10use cubecl_std::{
11 CubeOption,
12 tensor::r#virtual::{ReadWrite, VirtualTensor},
13};
14use std::{fmt::Debug, hash::Hash};
15
16pub trait BatchMatmulFamily: 'static + Send + Sync {
18 type Matmul<MP: MatmulPrecision>: BatchMatmul<MP, Config = Self::Config>;
20
21 type Config: BatchConfig;
23
24 fn setup<MP: MatmulPrecision, R: Runtime>(
28 client: &ComputeClient<R::Server, R::Channel>,
29 problem: &MatmulProblem,
30 selection: &MatmulSelection,
31 line_sizes: &MatmulLineSizes,
32 ) -> Result<Self::Config, MatmulSetupError>;
33
34 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
40 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
41 cube_dim: CubeDim,
42 cube_count: CubeCount,
43 input: InputRuntimeArg<'a, MS, R>,
44 output: OutputRuntimeArg<'a, MS, R>,
45 cube_count_input: CubeCountInputArgs<'a, R>,
46 config: Self::Config,
47 );
48
49 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
53 available_line_sizes
54 }
55}
56
57#[cube]
58pub trait BatchMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
76 type Config: BatchConfig;
77
78 fn execute(
80 lhs: VirtualTensor<MP::EI>,
81 rhs: VirtualTensor<MP::EI>,
82 out: VirtualTensor<MP::EO, ReadWrite>,
83 quantization: CubeOption<Quantization<MP>>,
84 cube_count_args: CubeCountInput,
85 #[comptime] config: Self::Config,
86 );
87}
88
89pub trait BatchConfig:
91 Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
92{
93 type GlobalConfig: global::GlobalConfig;
95
96 fn global_config(&self) -> Self::GlobalConfig;
98
99 fn quantized(&self) -> bool;
101
102 fn tiling_scheme(&self) -> TilingScheme {
104 self.global_config().tiling_scheme()
105 }
106
107 fn cube_dim(&self) -> CubeDim;
109
110 fn line_sizes(&self) -> MatmulLineSizes;
112
113 fn hypercube_config(&self) -> HypercubeConfig;
115
116 fn can_yield_extra_cubes(&self) -> bool;
118}