cubecl_matmul/components/
selection.rs

1use cubecl_core::{Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
2
3use crate::components::{
4    MatmulElems, TilingScheme,
5    batch::HypercubeSelection,
6    global::{LoadSpecializationConfig, read::ReaderMode},
7    stage::{PartitionBuffering, SwizzleMode},
8};
9
10#[derive(Debug, Clone)]
11pub struct MatmulSelection {
12    pub plane_dim: u32,
13    pub tiling_scheme: TilingScheme,
14    pub quantized: bool,
15    pub shared_swizzle: SwizzleConfig,
16    pub partition_buffering: PartitionBuffering,
17    pub loading_precompute_strategy: LoadingPrecomputeStrategy,
18    pub reader_mode: ReaderMode,
19    pub load_specialization_config: LoadSpecializationConfig,
20    pub hypercube_selection: HypercubeSelection,
21}
22
23/// Modifies the given matmul element types based on the kind of accelerator the kernel is run on.
24pub fn adjust_dtypes<R: Runtime>(
25    client: &ComputeClient<R::Server>,
26    dtypes: &mut MatmulElems,
27    requires_accelerator: bool,
28) {
29    let f32_dtype = f32::as_type_native_unchecked();
30    let flex_dtype = flex32::as_type_native_unchecked();
31    let tf32_dtype = tf32::as_type_native_unchecked();
32    let f16_dtype = half::f16::as_type_native_unchecked();
33
34    if requires_accelerator {
35        if dtypes.lhs_global == f32_dtype
36            && dtypes.rhs_global == f32_dtype
37            && client.properties().supports_type(tf32_dtype)
38        {
39            dtypes.lhs_stage = tf32_dtype;
40            dtypes.rhs_stage = tf32_dtype;
41            dtypes.lhs_register = tf32_dtype;
42            dtypes.rhs_register = tf32_dtype;
43        } else if dtypes.lhs_global == flex_dtype
44            && dtypes.rhs_global == flex_dtype
45            && client.properties().supports_type(f16_dtype)
46        {
47            dtypes.lhs_stage = f16_dtype;
48            dtypes.rhs_stage = f16_dtype;
49            dtypes.lhs_register = f16_dtype;
50            dtypes.rhs_register = f16_dtype;
51        }
52    }
53}
54
55#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
56pub struct SwizzleConfig {
57    pub lhs: SwizzleMode,
58    pub rhs: SwizzleMode,
59    pub acc: SwizzleMode,
60    pub out: SwizzleMode,
61}
62
63impl MatmulSelection {
64    pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
65        let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
66        MatmulSelectionBuilder::new()
67            .tiling_scheme(tiling_scheme)
68            .hypercube_config(hypercube_config)
69            .plane_dim(plane_dim)
70    }
71}
72
73pub struct MatmulSelectionBuilder {
74    plane_dim: Option<u32>,
75    pub tiling_scheme: Option<TilingScheme>,
76    shared_swizzle: SwizzleConfig,
77    hypercube_selection: Option<HypercubeSelection>,
78    quantized: bool,
79    partition_buffering: PartitionBuffering,
80    loading_precompute_strategy: LoadingPrecomputeStrategy,
81    reader_mode: ReaderMode,
82    load_specialization_config: LoadSpecializationConfig,
83}
84
85impl MatmulSelectionBuilder {
86    fn new() -> Self {
87        Self {
88            plane_dim: None,
89            tiling_scheme: None,
90            shared_swizzle: Default::default(),
91            hypercube_selection: None,
92            quantized: false,
93            partition_buffering: PartitionBuffering::default(),
94            loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
95            reader_mode: ReaderMode::default(),
96            load_specialization_config: LoadSpecializationConfig::default(),
97        }
98    }
99
100    pub fn plane_dim(mut self, plane_dim: u32) -> Self {
101        self.plane_dim = Some(plane_dim);
102        self
103    }
104
105    pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
106        self.tiling_scheme = Some(tiling_scheme);
107        self
108    }
109
110    pub fn shared_swizzle(mut self, swizzle: SwizzleConfig) -> Self {
111        self.shared_swizzle = swizzle;
112        self
113    }
114
115    pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
116        self.hypercube_selection = Some(hypercube_config);
117        self
118    }
119
120    pub fn quantized(mut self, quantized: bool) -> Self {
121        self.quantized = quantized;
122        self
123    }
124
125    pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
126        self.partition_buffering = partition_buffering;
127        self
128    }
129
130    pub fn loading_precompute_strategy(
131        mut self,
132        loading_precompute_strategy: LoadingPrecomputeStrategy,
133    ) -> Self {
134        self.loading_precompute_strategy = loading_precompute_strategy;
135        self
136    }
137
138    pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
139        self.reader_mode = reader_mode;
140        self
141    }
142
143    pub fn load_specialization_config(
144        mut self,
145        load_specialization_config: LoadSpecializationConfig,
146    ) -> Self {
147        self.load_specialization_config = load_specialization_config;
148        self
149    }
150
151    pub fn build(self) -> MatmulSelection {
152        MatmulSelection {
153            plane_dim: self.plane_dim.unwrap(),
154            tiling_scheme: self.tiling_scheme.unwrap(),
155            shared_swizzle: self.shared_swizzle,
156            hypercube_selection: self.hypercube_selection.unwrap(),
157            quantized: self.quantized,
158            partition_buffering: self.partition_buffering,
159            loading_precompute_strategy: self.loading_precompute_strategy,
160            reader_mode: self.reader_mode,
161            load_specialization_config: self.load_specialization_config,
162        }
163    }
164}
165
166#[derive(Debug, Clone, Copy, Default)]
167pub enum MultiRowStrategy {
168    /// Always one row per plane
169    #[default]
170    Never,
171    /// Always multiple rows per plane
172    Always(u32),
173    /// Uses multiple rows if the `m` dimension of the matmul implies at least the minimum number of stages along `m`
174    Adaptive { minimum_stage_count: u32 },
175}
176
177#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
178pub enum LoadingPrecomputeStrategy {
179    /// Don't precompute anything in loading jobs
180    #[default]
181    Never,
182    /// Precompute values that are shared across tasks
183    Always,
184}
185
186impl From<LoadingPrecomputeStrategy> for bool {
187    fn from(strategy: LoadingPrecomputeStrategy) -> Self {
188        match strategy {
189            LoadingPrecomputeStrategy::Always => true,
190            LoadingPrecomputeStrategy::Never => false,
191        }
192    }
193}