cubecl_matmul/components/
selection.rs

1use cubecl_core::{Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
2
3use crate::components::{
4    MatmulElems, TilingScheme,
5    batch::HypercubeSelection,
6    global::{LoadSpecializationConfig, read::ReaderMode},
7    stage::{PartitionBuffering, SwizzleMode},
8};
9
10#[derive(Debug, Clone)]
11pub struct MatmulSelection {
12    pub plane_dim: u32,
13    pub tiling_scheme: TilingScheme,
14    pub shared_swizzle: SwizzleConfig,
15    pub partition_buffering: PartitionBuffering,
16    pub loading_precompute_strategy: LoadingPrecomputeStrategy,
17    pub reader_mode: ReaderMode,
18    pub load_specialization_config: LoadSpecializationConfig,
19    pub hypercube_selection: HypercubeSelection,
20}
21
22/// Modifies the given matmul element types based on the kind of accelerator the kernel is run on.
23pub fn adjust_dtypes<R: Runtime>(
24    client: &ComputeClient<R>,
25    dtypes: &mut MatmulElems,
26    requires_accelerator: bool,
27) {
28    let f32_dtype = f32::as_type_native_unchecked();
29    let flex_dtype = flex32::as_type_native_unchecked();
30    let tf32_dtype = tf32::as_type_native_unchecked();
31    let f16_dtype = half::f16::as_type_native_unchecked();
32
33    if requires_accelerator {
34        if dtypes.lhs_global == f32_dtype
35            && dtypes.rhs_global == f32_dtype
36            && client.properties().supports_type(tf32_dtype)
37        {
38            dtypes.lhs_stage = tf32_dtype;
39            dtypes.rhs_stage = tf32_dtype;
40            dtypes.lhs_register = tf32_dtype;
41            dtypes.rhs_register = tf32_dtype;
42        } else if dtypes.lhs_global == flex_dtype
43            && dtypes.rhs_global == flex_dtype
44            && client.properties().supports_type(f16_dtype)
45        {
46            dtypes.lhs_stage = f16_dtype;
47            dtypes.rhs_stage = f16_dtype;
48            dtypes.lhs_register = f16_dtype;
49            dtypes.rhs_register = f16_dtype;
50        }
51    }
52}
53
54#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
55pub struct SwizzleConfig {
56    pub lhs: SwizzleMode,
57    pub rhs: SwizzleMode,
58    pub acc: SwizzleMode,
59    pub out: SwizzleMode,
60}
61
62impl MatmulSelection {
63    pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
64        let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
65        MatmulSelectionBuilder::new()
66            .tiling_scheme(tiling_scheme)
67            .hypercube_config(hypercube_config)
68            .plane_dim(plane_dim)
69    }
70}
71
72pub struct MatmulSelectionBuilder {
73    plane_dim: Option<u32>,
74    pub tiling_scheme: Option<TilingScheme>,
75    shared_swizzle: SwizzleConfig,
76    hypercube_selection: Option<HypercubeSelection>,
77    partition_buffering: PartitionBuffering,
78    loading_precompute_strategy: LoadingPrecomputeStrategy,
79    reader_mode: ReaderMode,
80    load_specialization_config: LoadSpecializationConfig,
81}
82
83impl MatmulSelectionBuilder {
84    fn new() -> Self {
85        Self {
86            plane_dim: None,
87            tiling_scheme: None,
88            shared_swizzle: Default::default(),
89            hypercube_selection: None,
90            partition_buffering: PartitionBuffering::default(),
91            loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
92            reader_mode: ReaderMode::default(),
93            load_specialization_config: LoadSpecializationConfig::default(),
94        }
95    }
96
97    pub fn plane_dim(mut self, plane_dim: u32) -> Self {
98        self.plane_dim = Some(plane_dim);
99        self
100    }
101
102    pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
103        self.tiling_scheme = Some(tiling_scheme);
104        self
105    }
106
107    pub fn shared_swizzle(mut self, swizzle: SwizzleConfig) -> Self {
108        self.shared_swizzle = swizzle;
109        self
110    }
111
112    pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
113        self.hypercube_selection = Some(hypercube_config);
114        self
115    }
116
117    pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
118        self.partition_buffering = partition_buffering;
119        self
120    }
121
122    pub fn loading_precompute_strategy(
123        mut self,
124        loading_precompute_strategy: LoadingPrecomputeStrategy,
125    ) -> Self {
126        self.loading_precompute_strategy = loading_precompute_strategy;
127        self
128    }
129
130    pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
131        self.reader_mode = reader_mode;
132        self
133    }
134
135    pub fn load_specialization_config(
136        mut self,
137        load_specialization_config: LoadSpecializationConfig,
138    ) -> Self {
139        self.load_specialization_config = load_specialization_config;
140        self
141    }
142
143    pub fn build(self) -> MatmulSelection {
144        MatmulSelection {
145            plane_dim: self.plane_dim.unwrap(),
146            tiling_scheme: self.tiling_scheme.unwrap(),
147            shared_swizzle: self.shared_swizzle,
148            hypercube_selection: self.hypercube_selection.unwrap(),
149            partition_buffering: self.partition_buffering,
150            loading_precompute_strategy: self.loading_precompute_strategy,
151            reader_mode: self.reader_mode,
152            load_specialization_config: self.load_specialization_config,
153        }
154    }
155}
156
157#[derive(Debug, Clone, Copy, Default)]
158pub enum MultiRowStrategy {
159    /// Always one row per plane
160    #[default]
161    Never,
162    /// Always multiple rows per plane
163    Always(u32),
164    /// Uses multiple rows if the `m` dimension of the matmul implies at least the minimum number of stages along `m`
165    Adaptive { minimum_stage_count: u32 },
166}
167
168#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
169pub enum LoadingPrecomputeStrategy {
170    /// Don't precompute anything in loading jobs
171    #[default]
172    Never,
173    /// Precompute values that are shared across tasks
174    Always,
175}
176
177impl From<LoadingPrecomputeStrategy> for bool {
178    fn from(strategy: LoadingPrecomputeStrategy) -> Self {
179        match strategy {
180            LoadingPrecomputeStrategy::Always => true,
181            LoadingPrecomputeStrategy::Never => false,
182        }
183    }
184}