cubecl_linalg/matmul/components/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{InputRuntimeArg, MatmulConfigFactory, MatmulSpec, OutputRuntimeArg};
5
6#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
7pub struct MatmulSize {
8    pub m: u32,
9    pub n: u32,
10    pub k: u32,
11}
12
13#[derive(Debug)]
14pub struct MatmulSelection {
15    pub tile_shape: MatmulSize,
16    pub tile_count: MatmulSize,
17    pub plane_dim: u32,
18    pub rows_per_plane: u32,
19}
20
21/// Provides launch entry point to solve a matmul
22pub trait MatmulLaunch: MatmulConfigFactory {
23    /// Entry point
24    ///
25    /// # Safety
26    ///
27    /// Out-of-bounds can happen
28    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
29        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
30        cube_dim: CubeDim,
31        cube_count: CubeCount,
32        input: InputRuntimeArg<'a, MS, R>,
33        output: OutputRuntimeArg<'a, MS, R>,
34        size_k: ScalarArg<u32>,
35        config: <Self as MatmulConfigFactory>::Config,
36    );
37}