cubecl_matmul/components/
selection.rs

1use cubecl_core::{Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
2
3use crate::components::{
4    MatmulElems, TilingScheme,
5    batch::HypercubeSelection,
6    global::{LoadSpecializationConfig, read::ReaderMode},
7    stage::PartitionBuffering,
8};
9
10#[derive(Debug, Clone)]
11pub struct MatmulSelection {
12    pub plane_dim: u32,
13    pub tiling_scheme: TilingScheme,
14    pub quantized: bool,
15    pub partition_buffering: PartitionBuffering,
16    pub loading_precompute_strategy: LoadingPrecomputeStrategy,
17    pub reader_mode: ReaderMode,
18    pub load_specialization_config: LoadSpecializationConfig,
19    pub hypercube_selection: HypercubeSelection,
20}
21
22/// Modifies the given matmul element types based on the kind of accelerator the kernel is run on.
23pub fn adjust_dtypes<R: Runtime>(
24    client: &ComputeClient<R::Server>,
25    dtypes: &mut MatmulElems,
26    requires_accelerator: bool,
27) {
28    let f32_dtype = f32::as_type_native_unchecked();
29    let flex_dtype = flex32::as_type_native_unchecked();
30    let tf32_dtype = tf32::as_type_native_unchecked();
31    let f16_dtype = half::f16::as_type_native_unchecked();
32
33    if requires_accelerator {
34        if dtypes.lhs_global == f32_dtype
35            && dtypes.rhs_global == f32_dtype
36            && client.properties().supports_type(tf32_dtype)
37        {
38            dtypes.lhs_stage = tf32_dtype;
39            dtypes.rhs_stage = tf32_dtype;
40            dtypes.lhs_register = tf32_dtype;
41            dtypes.rhs_register = tf32_dtype;
42        } else if dtypes.lhs_global == flex_dtype
43            && dtypes.rhs_global == flex_dtype
44            && client.properties().supports_type(f16_dtype)
45        {
46            dtypes.lhs_stage = f16_dtype;
47            dtypes.rhs_stage = f16_dtype;
48            dtypes.lhs_register = f16_dtype;
49            dtypes.rhs_register = f16_dtype;
50        }
51    }
52}
53
54impl MatmulSelection {
55    pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
56        let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
57        MatmulSelectionBuilder::new()
58            .tiling_scheme(tiling_scheme)
59            .hypercube_config(hypercube_config)
60            .plane_dim(plane_dim)
61    }
62}
63
64pub struct MatmulSelectionBuilder {
65    plane_dim: Option<u32>,
66    pub tiling_scheme: Option<TilingScheme>,
67    hypercube_selection: Option<HypercubeSelection>,
68    quantized: bool,
69    partition_buffering: PartitionBuffering,
70    loading_precompute_strategy: LoadingPrecomputeStrategy,
71    reader_mode: ReaderMode,
72    load_specialization_config: LoadSpecializationConfig,
73}
74
75impl MatmulSelectionBuilder {
76    fn new() -> Self {
77        Self {
78            plane_dim: None,
79            tiling_scheme: None,
80            hypercube_selection: None,
81            quantized: false,
82            partition_buffering: PartitionBuffering::default(),
83            loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
84            reader_mode: ReaderMode::default(),
85            load_specialization_config: LoadSpecializationConfig::default(),
86        }
87    }
88
89    pub fn plane_dim(mut self, plane_dim: u32) -> Self {
90        self.plane_dim = Some(plane_dim);
91        self
92    }
93
94    pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
95        self.tiling_scheme = Some(tiling_scheme);
96        self
97    }
98
99    pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
100        self.hypercube_selection = Some(hypercube_config);
101        self
102    }
103
104    pub fn quantized(mut self, quantized: bool) -> Self {
105        self.quantized = quantized;
106        self
107    }
108
109    pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
110        self.partition_buffering = partition_buffering;
111        self
112    }
113
114    pub fn loading_precompute_strategy(
115        mut self,
116        loading_precompute_strategy: LoadingPrecomputeStrategy,
117    ) -> Self {
118        self.loading_precompute_strategy = loading_precompute_strategy;
119        self
120    }
121
122    pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
123        self.reader_mode = reader_mode;
124        self
125    }
126
127    pub fn load_specialization_config(
128        mut self,
129        load_specialization_config: LoadSpecializationConfig,
130    ) -> Self {
131        self.load_specialization_config = load_specialization_config;
132        self
133    }
134
135    pub fn build(self) -> MatmulSelection {
136        MatmulSelection {
137            plane_dim: self.plane_dim.unwrap(),
138            tiling_scheme: self.tiling_scheme.unwrap(),
139            hypercube_selection: self.hypercube_selection.unwrap(),
140            quantized: self.quantized,
141            partition_buffering: self.partition_buffering,
142            loading_precompute_strategy: self.loading_precompute_strategy,
143            reader_mode: self.reader_mode,
144            load_specialization_config: self.load_specialization_config,
145        }
146    }
147}
148
149#[derive(Debug, Clone, Copy, Default)]
150pub enum MultiRowStrategy {
151    /// Always one row per plane
152    #[default]
153    Never,
154    /// Always multiple rows per plane
155    Always(u32),
156    /// Uses multiple rows if the `m` dimension of the matmul implies at least the minimum number of stages along `m`
157    Adaptive { minimum_stage_count: u32 },
158}
159
160#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
161pub enum LoadingPrecomputeStrategy {
162    /// Don't precompute anything in loading jobs
163    #[default]
164    Never,
165    /// Precompute values that are shared across tasks
166    Always,
167}
168
169impl From<LoadingPrecomputeStrategy> for bool {
170    fn from(strategy: LoadingPrecomputeStrategy) -> Self {
171        match strategy {
172            LoadingPrecomputeStrategy::Always => true,
173            LoadingPrecomputeStrategy::Never => false,
174        }
175    }
176}