cubecl_matmul/components/batch/partitioned_matmul/
setup.rs

1use std::marker::PhantomData;
2
3use crate::components::batch::partitioned_matmul::matmul::PartitionedBatchMatmul;
4use crate::components::batch::partitioned_matmul::partition::GlobalPartitionMatmul;
5use crate::components::batch::{BatchMatmulFamily, CubeCountInputArgs};
6use crate::components::global::GlobalMatmulFamily;
7use crate::components::{AccG, batch::entry_point::matmul};
8use crate::components::{AccS, batch::partitioned_matmul::config::PartitionedBatchConfig};
9use crate::components::{
10    Args, InputRuntimeArg, LhsG, LhsS, MatmulPrecision, MatmulProblem, MatmulSelection, MatmulSpec,
11    OutputRuntimeArg, RhsG, RhsS,
12};
13use crate::components::{MatmulLineSizes, MatmulSetupError};
14use cubecl_core::prelude::*;
15
16/// Simple partitioned batch matmul family for any precision
17pub struct PartitionedBatchMatmulFamily<GMM: GlobalMatmulFamily, S: GlobalPartitionMatmul> {
18    _gmm: PhantomData<GMM>,
19    _s: PhantomData<S>,
20}
21
22impl<GMM: GlobalMatmulFamily, S: GlobalPartitionMatmul> BatchMatmulFamily
23    for PartitionedBatchMatmulFamily<GMM, S>
24{
25    type Matmul<MP: MatmulPrecision> = PartitionedBatchMatmul<MP, GMM::Matmul<MP>, S>;
26    type Config = PartitionedBatchConfig<GMM::Config>;
27
28    fn setup<MP: MatmulPrecision, R: Runtime>(
29        client: &ComputeClient<R::Server>,
30        problem: &MatmulProblem,
31        selection: &MatmulSelection,
32        line_sizes: &MatmulLineSizes,
33    ) -> Result<Self::Config, MatmulSetupError> {
34        let global_config = GMM::setup::<MP, R>(client, problem, selection, line_sizes)?;
35
36        PartitionedBatchConfig::new(
37            global_config,
38            selection
39                .hypercube_selection
40                .to_hypercube_config(problem, client.properties().hardware.max_cube_count.clone()),
41        )
42        .validate(problem)
43    }
44
45    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
46        client: &ComputeClient<<R as Runtime>::Server>,
47        cube_dim: CubeDim,
48        cube_count: CubeCount,
49        input: InputRuntimeArg<'a, MS, R>,
50        output: OutputRuntimeArg<'a, MS, R>,
51        cube_count_input: CubeCountInputArgs<'a, R>,
52        config: Self::Config,
53    ) {
54        unsafe {
55            matmul::launch_unchecked::<
56                Args<MS>,
57                LhsG<MS>,
58                RhsG<MS>,
59                AccG<MS>,
60                LhsS<MS>,
61                RhsS<MS>,
62                AccS<MS>,
63                Self,
64                R,
65            >(
66                client,
67                cube_count,
68                cube_dim,
69                input,
70                output,
71                cube_count_input,
72                config,
73            );
74        }
75    }
76}