cubecl_matmul/components/batch/partitioned_matmul/
setup.rs

1use std::marker::PhantomData;
2
3use crate::components::batch::entry_point::matmul;
4use crate::components::batch::partitioned_matmul::config::PartitionedBatchConfig;
5use crate::components::batch::partitioned_matmul::matmul::PartitionedBatchMatmul;
6use crate::components::batch::partitioned_matmul::partition::GlobalPartitionMatmul;
7use crate::components::batch::{BatchMatmulFamily, CubeCountInputArgs};
8use crate::components::global::GlobalMatmulFamily;
9use crate::components::{
10    Args, EA, EI, EO, ES, InputRuntimeArg, MatmulPrecision, MatmulProblem, MatmulSelection,
11    MatmulSpec, OutputRuntimeArg,
12};
13use crate::components::{MatmulLineSizes, MatmulSetupError};
14use cubecl_core::prelude::*;
15
16/// Simple partitioned batch matmul family for any precision
17pub struct PartitionedBatchMatmulFamily<GMM: GlobalMatmulFamily, S: GlobalPartitionMatmul> {
18    _gmm: PhantomData<GMM>,
19    _s: PhantomData<S>,
20}
21
22impl<GMM: GlobalMatmulFamily, S: GlobalPartitionMatmul> BatchMatmulFamily
23    for PartitionedBatchMatmulFamily<GMM, S>
24{
25    type Matmul<MP: MatmulPrecision> = PartitionedBatchMatmul<MP, GMM::Matmul<MP>, S>;
26    type Config = PartitionedBatchConfig<GMM::Config>;
27
28    fn setup<MP: MatmulPrecision, R: Runtime>(
29        client: &ComputeClient<R::Server, R::Channel>,
30        problem: &MatmulProblem,
31        selection: &MatmulSelection,
32        line_sizes: &MatmulLineSizes,
33    ) -> Result<Self::Config, MatmulSetupError> {
34        let global_config = GMM::setup::<MP, R>(client, problem, selection, line_sizes)?;
35
36        PartitionedBatchConfig::new(
37            global_config,
38            selection
39                .hypercube_selection
40                .to_hypercube_config(problem, client.properties().hardware.max_cube_count.clone()),
41        )
42        .validate(problem)
43    }
44
45    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
46        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
47        cube_dim: CubeDim,
48        cube_count: CubeCount,
49        input: InputRuntimeArg<'a, MS, R>,
50        output: OutputRuntimeArg<'a, MS, R>,
51        cube_count_input: CubeCountInputArgs<'a, R>,
52        config: Self::Config,
53    ) {
54        unsafe {
55            matmul::launch_unchecked::<Args<MS>, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
56                client,
57                cube_count,
58                cube_dim,
59                input,
60                output,
61                cube_count_input,
62                config,
63            );
64        }
65    }
66}