cubecl_matmul/components/batch/partitioned_matmul/
setup.rs1use std::marker::PhantomData;
2
3use crate::components::batch::entry_point::matmul;
4use crate::components::batch::partitioned_matmul::config::PartitionedBatchConfig;
5use crate::components::batch::partitioned_matmul::matmul::PartitionedBatchMatmul;
6use crate::components::batch::partitioned_matmul::partition::GlobalPartitionMatmul;
7use crate::components::batch::{BatchMatmulFamily, CubeCountInputArgs};
8use crate::components::global::GlobalMatmulFamily;
9use crate::components::{
10 Args, EA, EI, EO, ES, InputRuntimeArg, MatmulPrecision, MatmulProblem, MatmulSelection,
11 MatmulSpec, OutputRuntimeArg,
12};
13use crate::components::{MatmulLineSizes, MatmulSetupError};
14use cubecl_core::prelude::*;
15
16pub struct PartitionedBatchMatmulFamily<GMM: GlobalMatmulFamily, S: GlobalPartitionMatmul> {
18 _gmm: PhantomData<GMM>,
19 _s: PhantomData<S>,
20}
21
22impl<GMM: GlobalMatmulFamily, S: GlobalPartitionMatmul> BatchMatmulFamily
23 for PartitionedBatchMatmulFamily<GMM, S>
24{
25 type Matmul<MP: MatmulPrecision> = PartitionedBatchMatmul<MP, GMM::Matmul<MP>, S>;
26 type Config = PartitionedBatchConfig<GMM::Config>;
27
28 fn setup<MP: MatmulPrecision, R: Runtime>(
29 client: &ComputeClient<R::Server, R::Channel>,
30 problem: &MatmulProblem,
31 selection: &MatmulSelection,
32 line_sizes: &MatmulLineSizes,
33 ) -> Result<Self::Config, MatmulSetupError> {
34 let global_config = GMM::setup::<MP, R>(client, problem, selection, line_sizes)?;
35
36 PartitionedBatchConfig::new(
37 global_config,
38 selection
39 .hypercube_selection
40 .to_hypercube_config(problem, client.properties().hardware.max_cube_count.clone()),
41 )
42 .validate(problem)
43 }
44
45 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
46 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
47 cube_dim: CubeDim,
48 cube_count: CubeCount,
49 input: InputRuntimeArg<'a, MS, R>,
50 output: OutputRuntimeArg<'a, MS, R>,
51 cube_count_input: CubeCountInputArgs<'a, R>,
52 config: Self::Config,
53 ) {
54 unsafe {
55 matmul::launch_unchecked::<Args<MS>, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
56 client,
57 cube_count,
58 cube_dim,
59 input,
60 output,
61 cube_count_input,
62 config,
63 );
64 }
65 }
66}