cubecl_matmul/components/
selection.rs1use cubecl_core::{Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
2
3use crate::components::{
4 MatmulElems, TilingScheme,
5 batch::HypercubeSelection,
6 global::{LoadSpecializationConfig, read::ReaderMode},
7 stage::{PartitionBuffering, SwizzleMode},
8};
9
10#[derive(Debug, Clone)]
11pub struct MatmulSelection {
12 pub plane_dim: u32,
13 pub tiling_scheme: TilingScheme,
14 pub shared_swizzle: SwizzleConfig,
15 pub partition_buffering: PartitionBuffering,
16 pub loading_precompute_strategy: LoadingPrecomputeStrategy,
17 pub reader_mode: ReaderMode,
18 pub load_specialization_config: LoadSpecializationConfig,
19 pub hypercube_selection: HypercubeSelection,
20}
21
22pub fn adjust_dtypes<R: Runtime>(
24 client: &ComputeClient<R>,
25 dtypes: &mut MatmulElems,
26 requires_accelerator: bool,
27) {
28 let f32_dtype = f32::as_type_native_unchecked();
29 let flex_dtype = flex32::as_type_native_unchecked();
30 let tf32_dtype = tf32::as_type_native_unchecked();
31 let f16_dtype = half::f16::as_type_native_unchecked();
32
33 if requires_accelerator {
34 if *dtypes.lhs_global == f32_dtype
35 && *dtypes.rhs_global == f32_dtype
36 && client.properties().supports_type(tf32_dtype)
37 {
38 dtypes.lhs_stage.dtype = tf32_dtype;
39 dtypes.rhs_stage.dtype = tf32_dtype;
40 dtypes.lhs_register.dtype = tf32_dtype;
41 dtypes.rhs_register.dtype = tf32_dtype;
42 } else if *dtypes.lhs_global == flex_dtype
43 && *dtypes.rhs_global == flex_dtype
44 && client.properties().supports_type(f16_dtype)
45 {
46 dtypes.lhs_stage.dtype = f16_dtype;
47 dtypes.rhs_stage.dtype = f16_dtype;
48 dtypes.lhs_register.dtype = f16_dtype;
49 dtypes.rhs_register.dtype = f16_dtype;
50 }
51 }
52}
53
54#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
55pub struct SwizzleConfig {
56 pub lhs: SwizzleMode,
57 pub rhs: SwizzleMode,
58 pub acc: SwizzleMode,
59 pub out: SwizzleMode,
60}
61
62impl SwizzleConfig {
63 pub fn has_swizzle(&self) -> bool {
64 self.lhs != SwizzleMode::None
65 || self.rhs != SwizzleMode::None
66 || self.acc != SwizzleMode::None
67 || self.out != SwizzleMode::None
68 }
69}
70
71impl MatmulSelection {
72 pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
73 let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
74 MatmulSelectionBuilder::new()
75 .tiling_scheme(tiling_scheme)
76 .hypercube_config(hypercube_config)
77 .plane_dim(plane_dim)
78 }
79}
80
81pub struct MatmulSelectionBuilder {
82 plane_dim: Option<u32>,
83 pub tiling_scheme: Option<TilingScheme>,
84 shared_swizzle: SwizzleConfig,
85 hypercube_selection: Option<HypercubeSelection>,
86 partition_buffering: PartitionBuffering,
87 loading_precompute_strategy: LoadingPrecomputeStrategy,
88 reader_mode: ReaderMode,
89 load_specialization_config: LoadSpecializationConfig,
90}
91
92impl MatmulSelectionBuilder {
93 fn new() -> Self {
94 Self {
95 plane_dim: None,
96 tiling_scheme: None,
97 shared_swizzle: Default::default(),
98 hypercube_selection: None,
99 partition_buffering: PartitionBuffering::default(),
100 loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
101 reader_mode: ReaderMode::default(),
102 load_specialization_config: LoadSpecializationConfig::default(),
103 }
104 }
105
106 pub fn plane_dim(mut self, plane_dim: u32) -> Self {
107 self.plane_dim = Some(plane_dim);
108 self
109 }
110
111 pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
112 self.tiling_scheme = Some(tiling_scheme);
113 self
114 }
115
116 pub fn shared_swizzle(mut self, swizzle: SwizzleConfig) -> Self {
117 self.shared_swizzle = swizzle;
118 self
119 }
120
121 pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
122 self.hypercube_selection = Some(hypercube_config);
123 self
124 }
125
126 pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
127 self.partition_buffering = partition_buffering;
128 self
129 }
130
131 pub fn loading_precompute_strategy(
132 mut self,
133 loading_precompute_strategy: LoadingPrecomputeStrategy,
134 ) -> Self {
135 self.loading_precompute_strategy = loading_precompute_strategy;
136 self
137 }
138
139 pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
140 self.reader_mode = reader_mode;
141 self
142 }
143
144 pub fn load_specialization_config(
145 mut self,
146 load_specialization_config: LoadSpecializationConfig,
147 ) -> Self {
148 self.load_specialization_config = load_specialization_config;
149 self
150 }
151
152 pub fn build(self) -> MatmulSelection {
153 MatmulSelection {
154 plane_dim: self.plane_dim.unwrap(),
155 tiling_scheme: self.tiling_scheme.unwrap(),
156 shared_swizzle: self.shared_swizzle,
157 hypercube_selection: self.hypercube_selection.unwrap(),
158 partition_buffering: self.partition_buffering,
159 loading_precompute_strategy: self.loading_precompute_strategy,
160 reader_mode: self.reader_mode,
161 load_specialization_config: self.load_specialization_config,
162 }
163 }
164}
165
166#[derive(Debug, Clone, Copy, Default)]
167pub enum MultiRowStrategy {
168 #[default]
170 Never,
171 Always(u32),
173 Adaptive { minimum_stage_count: u32 },
175}
176
177#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
178pub enum LoadingPrecomputeStrategy {
179 #[default]
181 Never,
182 Always,
184}
185
186impl From<LoadingPrecomputeStrategy> for bool {
187 fn from(strategy: LoadingPrecomputeStrategy) -> Self {
188 match strategy {
189 LoadingPrecomputeStrategy::Always => true,
190 LoadingPrecomputeStrategy::Never => false,
191 }
192 }
193}