cubecl_matmul/components/global/
shared.rs

1use crate::components::{
2    InputIdent, MatmulLineSizes, TilingScheme,
3    error::MatmulSetupError,
4    global::{GlobalConfig, multi_stage::LoadMaxRoundPlaneCount},
5};
6
7pub(crate) fn shared_global_config_validation<G: GlobalConfig>(
8    config: G,
9) -> Result<G, MatmulSetupError> {
10    #[cfg(target_os = "macos")]
11    {
12        let cube_dim = config.cube_dim();
13        if cube_dim.num_elems() >= 512 {
14            use crate::components::error::MatmulAvailabilityError;
15
16            return Err(MatmulSetupError::Unavailable(
17                MatmulAvailabilityError::CubeDimTooBig(cube_dim),
18            ));
19        }
20    }
21
22    Ok(config)
23}
24
25/// Maximal number of planes each loader can handle to divide its workload evenly
26pub struct MaxLoaderPlanes {
27    pub lhs: u32,
28    pub rhs: u32,
29}
30
31impl MaxLoaderPlanes {
32    /// Create a MaxLoaderPlanes
33    pub fn new<LL: LoadMaxRoundPlaneCount, RL: LoadMaxRoundPlaneCount>(
34        tiling_scheme: &TilingScheme,
35        line_sizes: &MatmulLineSizes,
36        plane_dim: u32,
37    ) -> Self {
38        MaxLoaderPlanes {
39            lhs: LL::max_round_plane_count(
40                tiling_scheme,
41                InputIdent::Lhs,
42                line_sizes.lhs,
43                plane_dim,
44            ),
45            rhs: RL::max_round_plane_count(
46                tiling_scheme,
47                InputIdent::Rhs,
48                line_sizes.rhs,
49                plane_dim,
50            ),
51        }
52    }
53}