cubecl_matmul/components/
selection.rs1use cubecl_core::{Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
2
3use crate::components::{
4 MatmulElems, TilingScheme,
5 batch::HypercubeSelection,
6 global::{LoadSpecializationConfig, read::ReaderMode},
7 stage::{PartitionBuffering, SwizzleMode},
8};
9
10#[derive(Debug, Clone)]
11pub struct MatmulSelection {
12 pub plane_dim: u32,
13 pub tiling_scheme: TilingScheme,
14 pub quantized: bool,
15 pub shared_swizzle: SwizzleConfig,
16 pub partition_buffering: PartitionBuffering,
17 pub loading_precompute_strategy: LoadingPrecomputeStrategy,
18 pub reader_mode: ReaderMode,
19 pub load_specialization_config: LoadSpecializationConfig,
20 pub hypercube_selection: HypercubeSelection,
21}
22
23pub fn adjust_dtypes<R: Runtime>(
25 client: &ComputeClient<R::Server>,
26 dtypes: &mut MatmulElems,
27 requires_accelerator: bool,
28) {
29 let f32_dtype = f32::as_type_native_unchecked();
30 let flex_dtype = flex32::as_type_native_unchecked();
31 let tf32_dtype = tf32::as_type_native_unchecked();
32 let f16_dtype = half::f16::as_type_native_unchecked();
33
34 if requires_accelerator {
35 if dtypes.lhs_global == f32_dtype
36 && dtypes.rhs_global == f32_dtype
37 && client.properties().supports_type(tf32_dtype)
38 {
39 dtypes.lhs_stage = tf32_dtype;
40 dtypes.rhs_stage = tf32_dtype;
41 dtypes.lhs_register = tf32_dtype;
42 dtypes.rhs_register = tf32_dtype;
43 } else if dtypes.lhs_global == flex_dtype
44 && dtypes.rhs_global == flex_dtype
45 && client.properties().supports_type(f16_dtype)
46 {
47 dtypes.lhs_stage = f16_dtype;
48 dtypes.rhs_stage = f16_dtype;
49 dtypes.lhs_register = f16_dtype;
50 dtypes.rhs_register = f16_dtype;
51 }
52 }
53}
54
55#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
56pub struct SwizzleConfig {
57 pub lhs: SwizzleMode,
58 pub rhs: SwizzleMode,
59 pub acc: SwizzleMode,
60 pub out: SwizzleMode,
61}
62
63impl MatmulSelection {
64 pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
65 let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
66 MatmulSelectionBuilder::new()
67 .tiling_scheme(tiling_scheme)
68 .hypercube_config(hypercube_config)
69 .plane_dim(plane_dim)
70 }
71}
72
73pub struct MatmulSelectionBuilder {
74 plane_dim: Option<u32>,
75 pub tiling_scheme: Option<TilingScheme>,
76 shared_swizzle: SwizzleConfig,
77 hypercube_selection: Option<HypercubeSelection>,
78 quantized: bool,
79 partition_buffering: PartitionBuffering,
80 loading_precompute_strategy: LoadingPrecomputeStrategy,
81 reader_mode: ReaderMode,
82 load_specialization_config: LoadSpecializationConfig,
83}
84
85impl MatmulSelectionBuilder {
86 fn new() -> Self {
87 Self {
88 plane_dim: None,
89 tiling_scheme: None,
90 shared_swizzle: Default::default(),
91 hypercube_selection: None,
92 quantized: false,
93 partition_buffering: PartitionBuffering::default(),
94 loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
95 reader_mode: ReaderMode::default(),
96 load_specialization_config: LoadSpecializationConfig::default(),
97 }
98 }
99
100 pub fn plane_dim(mut self, plane_dim: u32) -> Self {
101 self.plane_dim = Some(plane_dim);
102 self
103 }
104
105 pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
106 self.tiling_scheme = Some(tiling_scheme);
107 self
108 }
109
110 pub fn shared_swizzle(mut self, swizzle: SwizzleConfig) -> Self {
111 self.shared_swizzle = swizzle;
112 self
113 }
114
115 pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
116 self.hypercube_selection = Some(hypercube_config);
117 self
118 }
119
120 pub fn quantized(mut self, quantized: bool) -> Self {
121 self.quantized = quantized;
122 self
123 }
124
125 pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
126 self.partition_buffering = partition_buffering;
127 self
128 }
129
130 pub fn loading_precompute_strategy(
131 mut self,
132 loading_precompute_strategy: LoadingPrecomputeStrategy,
133 ) -> Self {
134 self.loading_precompute_strategy = loading_precompute_strategy;
135 self
136 }
137
138 pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
139 self.reader_mode = reader_mode;
140 self
141 }
142
143 pub fn load_specialization_config(
144 mut self,
145 load_specialization_config: LoadSpecializationConfig,
146 ) -> Self {
147 self.load_specialization_config = load_specialization_config;
148 self
149 }
150
151 pub fn build(self) -> MatmulSelection {
152 MatmulSelection {
153 plane_dim: self.plane_dim.unwrap(),
154 tiling_scheme: self.tiling_scheme.unwrap(),
155 shared_swizzle: self.shared_swizzle,
156 hypercube_selection: self.hypercube_selection.unwrap(),
157 quantized: self.quantized,
158 partition_buffering: self.partition_buffering,
159 loading_precompute_strategy: self.loading_precompute_strategy,
160 reader_mode: self.reader_mode,
161 load_specialization_config: self.load_specialization_config,
162 }
163 }
164}
165
166#[derive(Debug, Clone, Copy, Default)]
167pub enum MultiRowStrategy {
168 #[default]
170 Never,
171 Always(u32),
173 Adaptive { minimum_stage_count: u32 },
175}
176
177#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
178pub enum LoadingPrecomputeStrategy {
179 #[default]
181 Never,
182 Always,
184}
185
186impl From<LoadingPrecomputeStrategy> for bool {
187 fn from(strategy: LoadingPrecomputeStrategy) -> Self {
188 match strategy {
189 LoadingPrecomputeStrategy::Always => true,
190 LoadingPrecomputeStrategy::Never => false,
191 }
192 }
193}