cubecl_matmul/components/
selection.rs1use cubecl_core::{Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
2
3use crate::components::{
4 MatmulElems, TilingScheme,
5 batch::HypercubeSelection,
6 global::{LoadSpecializationConfig, read::ReaderMode},
7 stage::PartitionBuffering,
8};
9
10#[derive(Debug, Clone)]
11pub struct MatmulSelection {
12 pub plane_dim: u32,
13 pub tiling_scheme: TilingScheme,
14 pub quantized: bool,
15 pub partition_buffering: PartitionBuffering,
16 pub loading_precompute_strategy: LoadingPrecomputeStrategy,
17 pub reader_mode: ReaderMode,
18 pub load_specialization_config: LoadSpecializationConfig,
19 pub hypercube_selection: HypercubeSelection,
20}
21
22pub fn adjust_dtypes<R: Runtime>(
24 client: &ComputeClient<R::Server>,
25 dtypes: &mut MatmulElems,
26 requires_accelerator: bool,
27) {
28 let f32_dtype = f32::as_type_native_unchecked();
29 let flex_dtype = flex32::as_type_native_unchecked();
30 let tf32_dtype = tf32::as_type_native_unchecked();
31 let f16_dtype = half::f16::as_type_native_unchecked();
32
33 if requires_accelerator {
34 if dtypes.lhs_global == f32_dtype
35 && dtypes.rhs_global == f32_dtype
36 && client.properties().supports_type(tf32_dtype)
37 {
38 dtypes.lhs_stage = tf32_dtype;
39 dtypes.rhs_stage = tf32_dtype;
40 dtypes.lhs_register = tf32_dtype;
41 dtypes.rhs_register = tf32_dtype;
42 } else if dtypes.lhs_global == flex_dtype
43 && dtypes.rhs_global == flex_dtype
44 && client.properties().supports_type(f16_dtype)
45 {
46 dtypes.lhs_stage = f16_dtype;
47 dtypes.rhs_stage = f16_dtype;
48 dtypes.lhs_register = f16_dtype;
49 dtypes.rhs_register = f16_dtype;
50 }
51 }
52}
53
54impl MatmulSelection {
55 pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
56 let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
57 MatmulSelectionBuilder::new()
58 .tiling_scheme(tiling_scheme)
59 .hypercube_config(hypercube_config)
60 .plane_dim(plane_dim)
61 }
62}
63
64pub struct MatmulSelectionBuilder {
65 plane_dim: Option<u32>,
66 pub tiling_scheme: Option<TilingScheme>,
67 hypercube_selection: Option<HypercubeSelection>,
68 quantized: bool,
69 partition_buffering: PartitionBuffering,
70 loading_precompute_strategy: LoadingPrecomputeStrategy,
71 reader_mode: ReaderMode,
72 load_specialization_config: LoadSpecializationConfig,
73}
74
75impl MatmulSelectionBuilder {
76 fn new() -> Self {
77 Self {
78 plane_dim: None,
79 tiling_scheme: None,
80 hypercube_selection: None,
81 quantized: false,
82 partition_buffering: PartitionBuffering::default(),
83 loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
84 reader_mode: ReaderMode::default(),
85 load_specialization_config: LoadSpecializationConfig::default(),
86 }
87 }
88
89 pub fn plane_dim(mut self, plane_dim: u32) -> Self {
90 self.plane_dim = Some(plane_dim);
91 self
92 }
93
94 pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
95 self.tiling_scheme = Some(tiling_scheme);
96 self
97 }
98
99 pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
100 self.hypercube_selection = Some(hypercube_config);
101 self
102 }
103
104 pub fn quantized(mut self, quantized: bool) -> Self {
105 self.quantized = quantized;
106 self
107 }
108
109 pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
110 self.partition_buffering = partition_buffering;
111 self
112 }
113
114 pub fn loading_precompute_strategy(
115 mut self,
116 loading_precompute_strategy: LoadingPrecomputeStrategy,
117 ) -> Self {
118 self.loading_precompute_strategy = loading_precompute_strategy;
119 self
120 }
121
122 pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
123 self.reader_mode = reader_mode;
124 self
125 }
126
127 pub fn load_specialization_config(
128 mut self,
129 load_specialization_config: LoadSpecializationConfig,
130 ) -> Self {
131 self.load_specialization_config = load_specialization_config;
132 self
133 }
134
135 pub fn build(self) -> MatmulSelection {
136 MatmulSelection {
137 plane_dim: self.plane_dim.unwrap(),
138 tiling_scheme: self.tiling_scheme.unwrap(),
139 hypercube_selection: self.hypercube_selection.unwrap(),
140 quantized: self.quantized,
141 partition_buffering: self.partition_buffering,
142 loading_precompute_strategy: self.loading_precompute_strategy,
143 reader_mode: self.reader_mode,
144 load_specialization_config: self.load_specialization_config,
145 }
146 }
147}
148
149#[derive(Debug, Clone, Copy, Default)]
150pub enum MultiRowStrategy {
151 #[default]
153 Never,
154 Always(u32),
156 Adaptive { minimum_stage_count: u32 },
158}
159
160#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
161pub enum LoadingPrecomputeStrategy {
162 #[default]
164 Never,
165 Always,
167}
168
169impl From<LoadingPrecomputeStrategy> for bool {
170 fn from(strategy: LoadingPrecomputeStrategy) -> Self {
171 match strategy {
172 LoadingPrecomputeStrategy::Always => true,
173 LoadingPrecomputeStrategy::Never => false,
174 }
175 }
176}