cubecl_matmul/components/
selection.rs1use cubecl_core::{Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
2
3use crate::components::{
4 MatmulElems, TilingScheme,
5 batch::HypercubeSelection,
6 global::{LoadSpecializationConfig, read::ReaderMode},
7 stage::{PartitionBuffering, SwizzleMode},
8};
9
10#[derive(Debug, Clone)]
11pub struct MatmulSelection {
12 pub plane_dim: u32,
13 pub tiling_scheme: TilingScheme,
14 pub shared_swizzle: SwizzleConfig,
15 pub partition_buffering: PartitionBuffering,
16 pub loading_precompute_strategy: LoadingPrecomputeStrategy,
17 pub reader_mode: ReaderMode,
18 pub load_specialization_config: LoadSpecializationConfig,
19 pub hypercube_selection: HypercubeSelection,
20}
21
22pub fn adjust_dtypes<R: Runtime>(
24 client: &ComputeClient<R>,
25 dtypes: &mut MatmulElems,
26 requires_accelerator: bool,
27) {
28 let f32_dtype = f32::as_type_native_unchecked();
29 let flex_dtype = flex32::as_type_native_unchecked();
30 let tf32_dtype = tf32::as_type_native_unchecked();
31 let f16_dtype = half::f16::as_type_native_unchecked();
32
33 if requires_accelerator {
34 if dtypes.lhs_global == f32_dtype
35 && dtypes.rhs_global == f32_dtype
36 && client.properties().supports_type(tf32_dtype)
37 {
38 dtypes.lhs_stage = tf32_dtype;
39 dtypes.rhs_stage = tf32_dtype;
40 dtypes.lhs_register = tf32_dtype;
41 dtypes.rhs_register = tf32_dtype;
42 } else if dtypes.lhs_global == flex_dtype
43 && dtypes.rhs_global == flex_dtype
44 && client.properties().supports_type(f16_dtype)
45 {
46 dtypes.lhs_stage = f16_dtype;
47 dtypes.rhs_stage = f16_dtype;
48 dtypes.lhs_register = f16_dtype;
49 dtypes.rhs_register = f16_dtype;
50 }
51 }
52}
53
54#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
55pub struct SwizzleConfig {
56 pub lhs: SwizzleMode,
57 pub rhs: SwizzleMode,
58 pub acc: SwizzleMode,
59 pub out: SwizzleMode,
60}
61
62impl MatmulSelection {
63 pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
64 let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
65 MatmulSelectionBuilder::new()
66 .tiling_scheme(tiling_scheme)
67 .hypercube_config(hypercube_config)
68 .plane_dim(plane_dim)
69 }
70}
71
72pub struct MatmulSelectionBuilder {
73 plane_dim: Option<u32>,
74 pub tiling_scheme: Option<TilingScheme>,
75 shared_swizzle: SwizzleConfig,
76 hypercube_selection: Option<HypercubeSelection>,
77 partition_buffering: PartitionBuffering,
78 loading_precompute_strategy: LoadingPrecomputeStrategy,
79 reader_mode: ReaderMode,
80 load_specialization_config: LoadSpecializationConfig,
81}
82
83impl MatmulSelectionBuilder {
84 fn new() -> Self {
85 Self {
86 plane_dim: None,
87 tiling_scheme: None,
88 shared_swizzle: Default::default(),
89 hypercube_selection: None,
90 partition_buffering: PartitionBuffering::default(),
91 loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
92 reader_mode: ReaderMode::default(),
93 load_specialization_config: LoadSpecializationConfig::default(),
94 }
95 }
96
97 pub fn plane_dim(mut self, plane_dim: u32) -> Self {
98 self.plane_dim = Some(plane_dim);
99 self
100 }
101
102 pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
103 self.tiling_scheme = Some(tiling_scheme);
104 self
105 }
106
107 pub fn shared_swizzle(mut self, swizzle: SwizzleConfig) -> Self {
108 self.shared_swizzle = swizzle;
109 self
110 }
111
112 pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
113 self.hypercube_selection = Some(hypercube_config);
114 self
115 }
116
117 pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
118 self.partition_buffering = partition_buffering;
119 self
120 }
121
122 pub fn loading_precompute_strategy(
123 mut self,
124 loading_precompute_strategy: LoadingPrecomputeStrategy,
125 ) -> Self {
126 self.loading_precompute_strategy = loading_precompute_strategy;
127 self
128 }
129
130 pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
131 self.reader_mode = reader_mode;
132 self
133 }
134
135 pub fn load_specialization_config(
136 mut self,
137 load_specialization_config: LoadSpecializationConfig,
138 ) -> Self {
139 self.load_specialization_config = load_specialization_config;
140 self
141 }
142
143 pub fn build(self) -> MatmulSelection {
144 MatmulSelection {
145 plane_dim: self.plane_dim.unwrap(),
146 tiling_scheme: self.tiling_scheme.unwrap(),
147 shared_swizzle: self.shared_swizzle,
148 hypercube_selection: self.hypercube_selection.unwrap(),
149 partition_buffering: self.partition_buffering,
150 loading_precompute_strategy: self.loading_precompute_strategy,
151 reader_mode: self.reader_mode,
152 load_specialization_config: self.load_specialization_config,
153 }
154 }
155}
156
157#[derive(Debug, Clone, Copy, Default)]
158pub enum MultiRowStrategy {
159 #[default]
161 Never,
162 Always(u32),
164 Adaptive { minimum_stage_count: u32 },
166}
167
168#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
169pub enum LoadingPrecomputeStrategy {
170 #[default]
172 Never,
173 Always,
175}
176
177impl From<LoadingPrecomputeStrategy> for bool {
178 fn from(strategy: LoadingPrecomputeStrategy) -> Self {
179 match strategy {
180 LoadingPrecomputeStrategy::Always => true,
181 LoadingPrecomputeStrategy::Never => false,
182 }
183 }
184}