cubecl_matmul/components/batch/
base.rs

1use crate::components::{
2    AccG, AvailableLineSizes, InputRuntimeArg, LhsG, MatmulLineSizes, MatmulPrecision,
3    MatmulProblem, MatmulSelection, MatmulSpec, OutputRuntimeArg, RhsG, TilingScheme,
4    batch::{CubeCountInput, CubeCountInputArgs, HypercubeConfig},
5    error::MatmulSetupError,
6    global::{self, GlobalConfig as _},
7};
8use cubecl_core as cubecl;
9use cubecl_core::prelude::*;
10use cubecl_std::{
11    CubeOption,
12    tensor::{View, layout::Coords3d},
13};
14use std::{fmt::Debug, hash::Hash};
15
16/// A family of [matmuls](BatchMatmul) working with any [precision](MatmulPrecision).
17pub trait BatchMatmulFamily: 'static + Send + Sync {
18    /// The specific [BatchMatmul] implementation associated with this family.
19    type Matmul<MP: MatmulPrecision>: BatchMatmul<MP, Config = Self::Config>;
20
21    /// The configuration type associated with this matmul family.
22    type Config: BatchConfig;
23
24    /// Constructs the configuration based on the matmul problem, selection, and line sizes.
25    ///
26    /// This function may return an error if the configuration cannot be supported on the current runtime.
27    fn setup<MP: MatmulPrecision, R: Runtime>(
28        client: &ComputeClient<R::Server>,
29        problem: &MatmulProblem,
30        selection: &MatmulSelection,
31        line_sizes: &MatmulLineSizes,
32    ) -> Result<Self::Config, MatmulSetupError>;
33
34    /// Entry point
35    ///
36    /// # Safety
37    ///
38    /// Out-of-bounds can happen
39    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
40        client: &ComputeClient<<R as Runtime>::Server>,
41        cube_dim: CubeDim,
42        cube_count: CubeCount,
43        input: InputRuntimeArg<'a, MS, R>,
44        output: OutputRuntimeArg<'a, MS, R>,
45        cube_count_input: CubeCountInputArgs<'a, R>,
46        config: Self::Config,
47    );
48
49    /// Filters out line sizes that are incompatible with this matmul family.
50    ///
51    /// By default, returns the input unchanged.
52    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
53        available_line_sizes
54    }
55}
56
57#[cube]
58/// Provides matrix multiplication operations at the batch level.
59///
60/// At the batch level,
61///  - Inputs are whole tensors in global memory.
62///  - All Cubes are used to solve the problem
63///  - Dimensions M, N and K can be arbitrary large,
64///    as well as the number of batches.
65///
66/// # Assumptions
67/// - Line sizes of the inputs evenly divide the dimension they are aligned with.
68///
69/// # Safety
70///
71/// - It is not assumed that the matmul's dimensions match its inputs dimensions perfectly.
72///   It is therefore important to use an underlying global matmul that performs check bounds,
73/// - It is accepted to launch more Cube than necessary, providing a CubeCountInput that states
74///   the max cube position
75pub trait BatchMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
76    type Config: BatchConfig;
77
78    /// Performs batchwise matrix multiplication over tensors.
79    fn execute(
80        a: View<Line<LhsG<MP>>, Coords3d>,
81        b: View<Line<RhsG<MP>>, Coords3d>,
82        c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
83        out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
84        cube_count_args: CubeCountInput,
85        #[comptime] config: Self::Config,
86    );
87}
88
89/// Configuration for the [batch matmul](BatchMatmul) level.
90pub trait BatchConfig:
91    Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
92{
93    /// Underlying Global matmul config
94    type GlobalConfig: global::GlobalConfig;
95
96    /// Convert itself to the underlying global matmul config
97    fn global_config(&self) -> Self::GlobalConfig;
98
99    /// Returns the [TilingScheme]
100    fn tiling_scheme(&self) -> TilingScheme {
101        self.global_config().tiling_scheme()
102    }
103
104    /// Returns the [CubeDim]
105    fn cube_dim(&self) -> CubeDim;
106
107    /// Returns the line sizes for Lhs, Rhs and output
108    fn line_sizes(&self) -> MatmulLineSizes;
109
110    /// Returns the [HypercubeConfig]
111    fn hypercube_config(&self) -> HypercubeConfig;
112
113    /// Whether it may launch more cubes than the minimum required
114    fn can_yield_extra_cubes(&self) -> bool;
115}