1use crate::simd::{SimdLevel, detect_simd_level};
8use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9
10pub const DEFAULT_BLOCK_M: usize = 64;
12
13pub const DEFAULT_BLOCK_N: usize = 64;
15
16pub const DEFAULT_BLOCK_K: usize = 256;
18
19pub const DEFAULT_L2_BLOCK: usize = 256;
21
22pub const L1_CACHE_SIZE: usize = 32 * 1024;
24
25pub const L2_CACHE_SIZE: usize = 256 * 1024;
27
28pub const L3_CACHE_SIZE: usize = 8 * 1024 * 1024;
30
31#[derive(Debug, Clone, Copy)]
33pub struct TuningConfig {
34 pub block_m: usize,
36 pub block_n: usize,
38 pub block_k: usize,
40 pub l2_block: usize,
42 pub simd_level: SimdLevel,
44 pub parallel: bool,
46 pub par_threshold: usize,
48}
49
50impl Default for TuningConfig {
51 fn default() -> Self {
52 Self {
53 block_m: DEFAULT_BLOCK_M,
54 block_n: DEFAULT_BLOCK_N,
55 block_k: DEFAULT_BLOCK_K,
56 l2_block: DEFAULT_L2_BLOCK,
57 simd_level: detect_simd_level(),
58 parallel: false,
59 par_threshold: 64 * 64,
60 }
61 }
62}
63
64impl TuningConfig {
65 #[must_use]
67 pub fn new() -> Self {
68 Self::default()
69 }
70
71 #[must_use]
76 pub fn for_dimensions(m: usize, n: usize, _k: usize) -> Self {
77 let mut config = Self::new();
78 let simd_level = detect_simd_level();
79
80 match simd_level {
82 SimdLevel::Scalar => {
83 config.block_m = 32;
84 config.block_n = 32;
85 config.block_k = 128;
86 }
87 SimdLevel::Simd128 => {
88 config.block_m = 64;
89 config.block_n = 64;
90 config.block_k = 256;
91 }
92 SimdLevel::Simd256 => {
93 config.block_m = 96;
94 config.block_n = 96;
95 config.block_k = 384;
96 }
97 SimdLevel::Simd512 => {
98 config.block_m = 128;
99 config.block_n = 128;
100 config.block_k = 512;
101 }
102 }
103
104 if m < 128 || n < 128 {
106 config.block_m = config.block_m.min(m);
107 config.block_n = config.block_n.min(n);
108 }
109
110 let element_size = 8; let panel_size = config.block_m * config.block_k * element_size;
113 if panel_size > L2_CACHE_SIZE / 2 {
114 config.block_k = (L2_CACHE_SIZE / 2) / (config.block_m * element_size);
115 }
116
117 config.parallel = m * n >= config.par_threshold;
119
120 config
121 }
122
123 #[must_use]
125 pub fn gemv_block_size(&self) -> usize {
126 match self.simd_level {
127 SimdLevel::Scalar => 128,
128 SimdLevel::Simd128 => 256,
129 SimdLevel::Simd256 => 512,
130 SimdLevel::Simd512 => 1024,
131 }
132 }
133
134 #[must_use]
136 pub fn factorization_panel_width(&self) -> usize {
137 match self.simd_level {
138 SimdLevel::Scalar => 16,
139 SimdLevel::Simd128 => 32,
140 SimdLevel::Simd256 => 48,
141 SimdLevel::Simd512 => 64,
142 }
143 }
144}
145
146pub struct TuningCache {
148 initialized: AtomicBool,
149 block_m: AtomicUsize,
150 block_n: AtomicUsize,
151 block_k: AtomicUsize,
152}
153
154static TUNING_CACHE: TuningCache = TuningCache {
155 initialized: AtomicBool::new(false),
156 block_m: AtomicUsize::new(DEFAULT_BLOCK_M),
157 block_n: AtomicUsize::new(DEFAULT_BLOCK_N),
158 block_k: AtomicUsize::new(DEFAULT_BLOCK_K),
159};
160
161impl TuningCache {
162 pub fn get() -> TuningConfig {
164 if !TUNING_CACHE.initialized.load(Ordering::Relaxed) {
165 let config = TuningConfig::new();
166 TUNING_CACHE
167 .block_m
168 .store(config.block_m, Ordering::Relaxed);
169 TUNING_CACHE
170 .block_n
171 .store(config.block_n, Ordering::Relaxed);
172 TUNING_CACHE
173 .block_k
174 .store(config.block_k, Ordering::Relaxed);
175 TUNING_CACHE.initialized.store(true, Ordering::Relaxed);
176 }
177
178 TuningConfig {
179 block_m: TUNING_CACHE.block_m.load(Ordering::Relaxed),
180 block_n: TUNING_CACHE.block_n.load(Ordering::Relaxed),
181 block_k: TUNING_CACHE.block_k.load(Ordering::Relaxed),
182 ..TuningConfig::new()
183 }
184 }
185
186 pub fn set(config: &TuningConfig) {
188 TUNING_CACHE
189 .block_m
190 .store(config.block_m, Ordering::Relaxed);
191 TUNING_CACHE
192 .block_n
193 .store(config.block_n, Ordering::Relaxed);
194 TUNING_CACHE
195 .block_k
196 .store(config.block_k, Ordering::Relaxed);
197 TUNING_CACHE.initialized.store(true, Ordering::Relaxed);
198 }
199}
200
201pub struct AutoTuner {
206 config: TuningConfig,
207}
208
209impl Default for AutoTuner {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215impl AutoTuner {
216 #[must_use]
218 pub fn new() -> Self {
219 Self {
220 config: TuningCache::get(),
221 }
222 }
223
224 pub fn tune_gemm(&mut self, m: usize, n: usize, k: usize) -> &TuningConfig {
229 self.config = TuningConfig::for_dimensions(m, n, k);
230 TuningCache::set(&self.config);
231 &self.config
232 }
233
234 #[must_use]
236 pub const fn config(&self) -> &TuningConfig {
237 &self.config
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_default_config() {
247 let config = TuningConfig::default();
248 assert_eq!(config.block_m, DEFAULT_BLOCK_M);
249 assert_eq!(config.block_n, DEFAULT_BLOCK_N);
250 assert_eq!(config.block_k, DEFAULT_BLOCK_K);
251 }
252
253 #[test]
254 fn test_dimension_based_tuning() {
255 let config_small = TuningConfig::for_dimensions(32, 32, 32);
256 let config_large = TuningConfig::for_dimensions(1024, 1024, 1024);
257
258 assert!(config_small.block_m <= 32);
260 assert!(config_small.block_n <= 32);
261
262 assert!(config_large.parallel);
264 }
265
266 #[test]
267 fn test_tuning_cache() {
268 let config = TuningConfig {
269 block_m: 77,
270 block_n: 88,
271 block_k: 99,
272 ..TuningConfig::default()
273 };
274
275 TuningCache::set(&config);
276 let cached = TuningCache::get();
277
278 assert_eq!(cached.block_m, 77);
279 assert_eq!(cached.block_n, 88);
280 assert_eq!(cached.block_k, 99);
281 }
282
283 #[test]
284 fn test_auto_tuner() {
285 let mut tuner = AutoTuner::new();
286 let config = tuner.tune_gemm(512, 512, 512);
287
288 assert!(config.block_m > 0);
290 assert!(config.block_n > 0);
291 assert!(config.block_k > 0);
292 assert!(config.block_m <= 512);
293 assert!(config.block_n <= 512);
294 }
295
296 #[test]
297 fn test_gemv_block_size() {
298 let config = TuningConfig::default();
299 let block_size = config.gemv_block_size();
300
301 assert!(block_size >= 128);
303 assert!(block_size <= 1024);
304 }
305
306 #[test]
307 fn test_factorization_panel_width() {
308 let config = TuningConfig::default();
309 let panel_width = config.factorization_panel_width();
310
311 assert!(panel_width >= 16);
313 assert!(panel_width <= 128);
314 }
315}