cubecl_matmul/components/batch/
base.rs

1use crate::components::{
2    AvailableLineSizes, InputRuntimeArg, MatmulLineSizes, MatmulPrecision, MatmulProblem,
3    MatmulSelection, MatmulSpec, OutputRuntimeArg, TilingScheme,
4    batch::{CubeCountInput, CubeCountInputArgs, HypercubeConfig},
5    error::MatmulSetupError,
6    global::{self, GlobalConfig as _, Quantization},
7};
8use cubecl_core as cubecl;
9use cubecl_core::prelude::*;
10use cubecl_std::{
11    CubeOption,
12    tensor::r#virtual::{ReadWrite, VirtualTensor},
13};
14use std::{fmt::Debug, hash::Hash};
15
16/// A family of [matmuls](BatchMatmul) working with any [precision](MatmulPrecision).
17pub trait BatchMatmulFamily: 'static + Send + Sync {
18    /// The specific [BatchMatmul] implementation associated with this family.
19    type Matmul<MP: MatmulPrecision>: BatchMatmul<MP, Config = Self::Config>;
20
21    /// The configuration type associated with this matmul family.
22    type Config: BatchConfig;
23
24    /// Constructs the configuration based on the matmul problem, selection, and line sizes.
25    ///
26    /// This function may return an error if the configuration cannot be supported on the current runtime.
27    fn setup<MP: MatmulPrecision, R: Runtime>(
28        client: &ComputeClient<R::Server, R::Channel>,
29        problem: &MatmulProblem,
30        selection: &MatmulSelection,
31        line_sizes: &MatmulLineSizes,
32    ) -> Result<Self::Config, MatmulSetupError>;
33
34    /// Entry point
35    ///
36    /// # Safety
37    ///
38    /// Out-of-bounds can happen
39    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
40        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
41        cube_dim: CubeDim,
42        cube_count: CubeCount,
43        input: InputRuntimeArg<'a, MS, R>,
44        output: OutputRuntimeArg<'a, MS, R>,
45        cube_count_input: CubeCountInputArgs<'a, R>,
46        config: Self::Config,
47    );
48
49    /// Filters out line sizes that are incompatible with this matmul family.
50    ///
51    /// By default, returns the input unchanged.
52    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
53        available_line_sizes
54    }
55}
56
57#[cube]
58/// Provides matrix multiplication operations at the batch level.
59///
60/// At the batch level,
61///  - Inputs are whole tensors in global memory.
62///  - All Cubes are used to solve the problem
63///  - Dimensions M, N and K can be arbitrary large,
64///    as well as the number of batches.
65///
66/// # Assumptions
67/// - Line sizes of the inputs evenly divide the dimension they are aligned with.
68///
69/// # Safety
70///
71/// - It is not assumed that the matmul's dimensions match its inputs dimensions perfectly.
72///   It is therefore important to use an underlying global matmul that performs check bounds,
73/// - It is accepted to launch more Cube than necessary, providing a CubeCountInput that states
74///   the max cube position
75pub trait BatchMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
76    type Config: BatchConfig;
77
78    /// Performs batchwise matrix multiplication over tensors.
79    fn execute(
80        lhs: VirtualTensor<MP::EI>,
81        rhs: VirtualTensor<MP::EI>,
82        out: VirtualTensor<MP::EO, ReadWrite>,
83        quantization: CubeOption<Quantization<MP>>,
84        cube_count_args: CubeCountInput,
85        #[comptime] config: Self::Config,
86    );
87}
88
89/// Configuration for the [batch matmul](BatchMatmul) level.
90pub trait BatchConfig:
91    Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
92{
93    /// Underlying Global matmul config
94    type GlobalConfig: global::GlobalConfig;
95
96    /// Convert itself to the underlying global matmul config
97    fn global_config(&self) -> Self::GlobalConfig;
98
99    /// Returns true if the matmul is quantized.
100    fn quantized(&self) -> bool;
101
102    /// Returns the [TilingScheme]
103    fn tiling_scheme(&self) -> TilingScheme {
104        self.global_config().tiling_scheme()
105    }
106
107    /// Returns the [CubeDim]
108    fn cube_dim(&self) -> CubeDim;
109
110    /// Returns the line sizes for Lhs, Rhs and output
111    fn line_sizes(&self) -> MatmulLineSizes;
112
113    /// Returns the [HypercubeConfig]
114    fn hypercube_config(&self) -> HypercubeConfig;
115
116    /// Whether it may launch more cubes than the minimum required
117    fn can_yield_extra_cubes(&self) -> bool;
118}