cubecl_linalg/matmul/components/
base.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{InputRuntimeArg, MatmulConfigFactory, MatmulSpec, OutputRuntimeArg};
5
6#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
7pub struct MatmulSize {
8 pub m: u32,
9 pub n: u32,
10 pub k: u32,
11}
12
13#[derive(Debug)]
14pub struct MatmulSelection {
15 pub tile_shape: MatmulSize,
16 pub tile_count: MatmulSize,
17 pub plane_dim: u32,
18 pub rows_per_plane: u32,
19}
20
21pub trait MatmulLaunch: MatmulConfigFactory {
23 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
29 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
30 cube_dim: CubeDim,
31 cube_count: CubeCount,
32 input: InputRuntimeArg<'a, MS, R>,
33 output: OutputRuntimeArg<'a, MS, R>,
34 size_k: ScalarArg<u32>,
35 config: <Self as MatmulConfigFactory>::Config,
36 );
37}