1#[derive(Default, Debug, Copy, Clone)]
2pub struct CacheInfo {
3 pub associativity: usize,
4 pub cache_bytes: usize,
5 pub cache_line_bytes: usize,
6}
7
8#[derive(Default, Debug, Copy, Clone)]
9pub struct KernelParams {
10 pub kc: usize,
11 pub mc: usize,
12 pub nc: usize,
13}
14
15pub trait DivCeil: Sized {
16 fn msrv_div_ceil(self, rhs: Self) -> Self;
17 fn msrv_next_multiple_of(self, rhs: Self) -> Self;
18 fn msrv_checked_next_multiple_of(self, rhs: Self) -> Option<Self>;
19}
20
21impl DivCeil for usize {
22 #[inline]
23 fn msrv_div_ceil(self, rhs: Self) -> Self {
24 let d = self / rhs;
25 let r = self % rhs;
26 if r > 0 {
27 d + 1
28 } else {
29 d
30 }
31 }
32
33 #[inline]
34 fn msrv_next_multiple_of(self, rhs: Self) -> Self {
35 match self % rhs {
36 0 => self,
37 r => self + (rhs - r),
38 }
39 }
40
41 #[inline]
42 fn msrv_checked_next_multiple_of(self, rhs: Self) -> Option<Self> {
43 {
44 match self.checked_rem(rhs)? {
45 0 => Some(self),
46 r => self.checked_add(rhs - r),
47 }
48 }
49 }
50}
51
52#[cfg(target_vendor = "apple")]
53fn has_amx_impl() -> bool {
54 if !cfg!(miri) {
55 #[cfg(feature = "std")]
56 {
57 use sysctl::Ctl;
58 use sysctl::Sysctl;
59
60 if let Ok(brand) =
61 Ctl::new("machdep.cpu.brand_string").and_then(|ctl| ctl.value_string())
62 {
63 let mut words = brand.split_whitespace();
64 let apple = words.next();
65 let mx = words.next();
66 return apple == Some("Apple") && matches!(mx, Some("M1" | "M2" | "M3"));
67 }
68 }
69 }
70 false
71}
72
73fn cache_info() -> Option<[CacheInfo; 3]> {
74 if !cfg!(miri) {
75 #[cfg(feature = "std")]
76 {
77 #[cfg(target_os = "linux")]
78 {
79 use std::fs;
80 fn try_cache_info_linux() -> Result<[CacheInfo; 3], std::io::Error> {
81 let mut all_info = [CacheInfo {
82 associativity: 8,
83 cache_bytes: 0,
84 cache_line_bytes: 64,
85 }; 3];
86
87 for cpu_x in fs::read_dir("/sys/devices/system/cpu")? {
88 let cpu_x = cpu_x?.path();
89 let Some(cpu_x_name) = cpu_x.file_name().and_then(|f| f.to_str()) else {
90 continue;
91 };
92 if !cpu_x_name.starts_with("cpu") {
93 continue;
94 }
95 let cache = cpu_x.join("cache");
96 if !cache.is_dir() {
97 continue;
98 }
99 'index: for index_y in fs::read_dir(cache)? {
100 let index_y = index_y?.path();
101 if !index_y.is_dir() {
102 continue;
103 }
104 let Some(index_y_name) = index_y.file_name().and_then(|f| f.to_str())
105 else {
106 continue;
107 };
108 if !index_y_name.starts_with("index") {
109 continue;
110 }
111
112 let mut cache_info = CacheInfo {
113 associativity: 8,
114 cache_bytes: 0,
115 cache_line_bytes: 64,
116 };
117 let mut level: usize = 0;
118 let mut shared_count: usize = 0;
119
120 for entry in fs::read_dir(index_y)? {
121 let entry = entry?.path();
122 if let Some(name) = entry.file_name() {
123 let contents = fs::read_to_string(&entry)?;
124 let contents = contents.trim();
125 if name == "type" && !matches!(contents, "Data" | "Unified") {
126 continue 'index;
127 }
128 if name == "shared_cpu_list" {
129 for item in contents.split(',') {
130 if item.contains('-') {
131 let mut item = item.split('-');
132 let Some(start) = item.next() else {
133 continue 'index;
134 };
135 let Some(end) = item.next() else {
136 continue 'index;
137 };
138
139 let Ok(start) = start.parse::<usize>() else {
140 continue 'index;
141 };
142 let Ok(end) = end.parse::<usize>() else {
143 continue 'index;
144 };
145
146 shared_count += end + 1 - start;
147 } else {
148 shared_count += 1;
149 }
150 }
151 }
152
153 if name == "level" {
154 let Ok(contents) = contents.parse::<usize>() else {
155 continue 'index;
156 };
157 level = contents;
158 }
159
160 if name == "coherency_line_size" {
161 let Ok(contents) = contents.parse::<usize>() else {
162 continue 'index;
163 };
164 cache_info.cache_line_bytes = contents;
165 }
166 if name == "ways_of_associativity" {
167 let Ok(contents) = contents.parse::<usize>() else {
168 continue 'index;
169 };
170 cache_info.associativity = contents;
171 }
172 if name == "size" {
173 if contents.ends_with("G") {
174 let Ok(contents) =
175 contents.trim_end_matches('G').parse::<usize>()
176 else {
177 continue 'index;
178 };
179 cache_info.cache_bytes = contents * 1024 * 1024 * 1024;
180 } else if contents.ends_with("M") {
181 let Ok(contents) =
182 contents.trim_end_matches('M').parse::<usize>()
183 else {
184 continue 'index;
185 };
186 cache_info.cache_bytes = contents * 1024 * 1024;
187 } else if contents.ends_with("K") {
188 let Ok(contents) =
189 contents.trim_end_matches('K').parse::<usize>()
190 else {
191 continue 'index;
192 };
193 cache_info.cache_bytes = contents * 1024;
194 } else {
195 let Ok(contents) = contents.parse::<usize>() else {
196 continue 'index;
197 };
198 cache_info.cache_bytes = contents;
199 }
200 }
201 }
202 }
203 if level == 3 {
204 shared_count = 1;
205 }
206 if level > 0 {
207 if cache_info.cache_line_bytes
208 >= all_info[level - 1].cache_line_bytes
209 {
210 all_info[level - 1].associativity = cache_info.associativity;
211 all_info[level - 1].cache_line_bytes =
212 cache_info.cache_line_bytes;
213 all_info[level - 1].cache_bytes =
214 cache_info.cache_bytes / shared_count;
215 }
216 }
217 }
218 }
219
220 for (info, default) in core::iter::zip(&mut all_info, CACHE_INFO_DEFAULT) {
221 if info.cache_bytes == 0 {
222 *info = default;
223 }
224 }
225
226 Ok(all_info)
227 }
228 if let Ok(info) = try_cache_info_linux() {
229 return Some(info);
230 }
231
232 if let Ok(lscpu) = std::process::Command::new("lscpu")
233 .arg("-B")
234 .arg("-C=type,level,ways,coherency-size,one-size")
235 .output()
236 {
237 if lscpu.status.success() {
238 if let Ok(lscpu) = String::from_utf8(lscpu.stdout).as_deref() {
239 let mut info = CACHE_INFO_DEFAULT;
240 for line in lscpu.lines().skip(1) {
241 let mut iter = line.split_whitespace();
242 if let [Some(cache_type), Some(level), Some(ways), Some(coherency_size), Some(one_size)] = [
243 iter.next(),
244 iter.next(),
245 iter.next(),
246 iter.next(),
247 iter.next(),
248 ] {
249 if let "Data" | "Unified" = cache_type {
250 let level = level.parse::<usize>().unwrap();
251 let ways = ways.parse::<usize>().unwrap();
252 let coherency_size =
253 coherency_size.parse::<usize>().unwrap();
254 let one_size = one_size.parse::<usize>().unwrap();
255
256 info[level - 1].associativity = ways;
257 info[level - 1].cache_line_bytes = coherency_size;
258 info[level - 1].cache_bytes = one_size;
259 }
260 }
261 }
262 return Some(info);
263 }
264 }
265 }
266 }
267 #[cfg(target_vendor = "apple")]
268 {
269 use sysctl::Ctl;
270 use sysctl::Sysctl;
271
272 let mut all_info = CACHE_INFO_DEFAULT;
273 if let Ok(l1) =
274 Ctl::new("hw.perflevel0.l1dcachesize").and_then(|ctl| ctl.value_string())
275 {
276 if let Ok(l1) = l1.parse::<usize>() {
277 all_info[0].cache_bytes = l1;
278 }
279 }
280 if let (Ok(physicalcpu), Ok(cpusperl2), Ok(l2)) = (
281 Ctl::new("hw.perflevel0.physicalcpu").and_then(|ctl| ctl.value_string()),
282 Ctl::new("hw.perflevel0.cpusperl2").and_then(|ctl| ctl.value_string()),
283 Ctl::new("hw.perflevel0.l2cachesize").and_then(|ctl| ctl.value_string()),
284 ) {
285 if let (Ok(_physicalcpu), Ok(cpusperl2), Ok(l2)) = (
286 physicalcpu.parse::<usize>(),
287 cpusperl2.parse::<usize>(),
288 l2.parse::<usize>(),
289 ) {
290 all_info[1].cache_bytes = l2 / cpusperl2;
291 }
292 }
293 all_info[2].cache_bytes = 0;
294 return Some(all_info);
295 }
296 }
297
298 #[cfg(any(
299 all(target_arch = "x86", not(target_env = "sgx"), target_feature = "sse"),
300 all(target_arch = "x86_64", not(target_env = "sgx"))
301 ))]
302 {
303 use raw_cpuid::CpuId;
304 let cpuid = CpuId::new();
305
306 if let Some(vf) = cpuid.get_vendor_info() {
307 let vf = vf.as_str();
308 if vf == "GenuineIntel" {
309 if let Some(cparams) = cpuid.get_cache_parameters() {
310 let mut info = [CacheInfo {
312 cache_bytes: 0,
313 associativity: 0,
314 cache_line_bytes: 64,
315 }; 3];
316
317 for cache in cparams {
318 use raw_cpuid::CacheType::*;
319 match cache.cache_type() {
320 Null | Instruction | Reserved => continue,
321 Data | Unified => {
322 let level = cache.level() as usize;
323 let associativity = cache.associativity();
324 let nsets = cache.sets();
325 let cache_line_bytes = cache.coherency_line_size();
326 if level > 0 && level < 4 {
327 let info = &mut info[level - 1];
328 info.cache_line_bytes = cache_line_bytes;
329 info.associativity = associativity;
330 info.cache_bytes = associativity * nsets * cache_line_bytes;
331 }
332 }
333 }
334 }
335 return Some(info);
336 }
337 }
338
339 if vf == "AuthenticAMD" {
340 if let Some(l1) = cpuid.get_l1_cache_and_tlb_info() {
341 if let Some(l23) = cpuid.get_l2_l3_cache_and_tlb_info() {
342 let compute_info = |associativity: raw_cpuid::Associativity,
343 cache_kb: usize,
344 cache_line_bytes: u8|
345 -> CacheInfo {
346 let cache_bytes = cache_kb as usize * 1024;
347 let cache_line_bytes = cache_line_bytes as usize;
348
349 use raw_cpuid::Associativity::*;
350 let associativity = match associativity {
351 Unknown | Disabled => {
352 return CacheInfo {
353 associativity: 0,
354 cache_bytes: 0,
355 cache_line_bytes: 64,
356 }
357 }
358 FullyAssociative => cache_bytes / cache_line_bytes,
359 DirectMapped => 1,
360 NWay(n) => n as usize,
361 };
362
363 CacheInfo {
364 associativity,
365 cache_bytes,
366 cache_line_bytes,
367 }
368 };
369 return Some([
370 compute_info(
371 l1.dcache_associativity(),
372 l1.dcache_size() as usize,
373 l1.dcache_line_size(),
374 ),
375 compute_info(
376 l23.l2cache_associativity(),
377 l23.l2cache_size() as usize,
378 l23.l2cache_line_size(),
379 ),
380 compute_info(
381 l23.l3cache_associativity(),
382 l23.l3cache_size() as usize * 512,
383 l23.l3cache_line_size(),
384 ),
385 ]);
386 }
387 }
388 }
389 }
390 }
391 }
392 None
393}
394
395#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
396static CACHE_INFO_DEFAULT: [CacheInfo; 3] = [
397 CacheInfo {
398 associativity: 8,
399 cache_bytes: 32 * 1024, cache_line_bytes: 64,
401 },
402 CacheInfo {
403 associativity: 8,
404 cache_bytes: 256 * 1024, cache_line_bytes: 64,
406 },
407 CacheInfo {
408 associativity: 8,
409 cache_bytes: 2 * 1024 * 1024, cache_line_bytes: 64,
411 },
412];
413
414#[cfg(any(target_arch = "powerpc", target_arch = "powerpc64"))]
415static CACHE_INFO_DEFAULT: [CacheInfo; 3] = [
416 CacheInfo {
417 associativity: 8,
418 cache_bytes: 64 * 1024, cache_line_bytes: 64,
420 },
421 CacheInfo {
422 associativity: 8,
423 cache_bytes: 512 * 1024, cache_line_bytes: 64,
425 },
426 CacheInfo {
427 associativity: 8,
428 cache_bytes: 4 * 1024 * 1024, cache_line_bytes: 64,
430 },
431];
432
433#[cfg(not(any(
434 target_arch = "powerpc",
435 target_arch = "powerpc64",
436 target_arch = "x86",
437 target_arch = "x86_64"
438)))]
439static CACHE_INFO_DEFAULT: [CacheInfo; 3] = [
440 CacheInfo {
441 associativity: 8,
442 cache_bytes: 16 * 1024, cache_line_bytes: 64,
444 },
445 CacheInfo {
446 associativity: 8,
447 cache_bytes: 512 * 1024, cache_line_bytes: 64,
449 },
450 CacheInfo {
451 associativity: 8,
452 cache_bytes: 1024 * 1024, cache_line_bytes: 64,
454 },
455];
456
457pub struct CacheInfoDeref;
458#[cfg(target_vendor = "apple")]
459pub struct HasAmx;
460
461impl core::ops::Deref for CacheInfoDeref {
462 type Target = [CacheInfo; 3];
463
464 #[inline]
465 fn deref(&self) -> &Self::Target {
466 #[cfg(not(feature = "std"))]
467 {
468 static CACHE_INFO: once_cell::race::OnceBox<[CacheInfo; 3]> =
469 once_cell::race::OnceBox::new();
470 CACHE_INFO
471 .get_or_init(|| alloc::boxed::Box::new(cache_info().unwrap_or(CACHE_INFO_DEFAULT)))
472 }
473 #[cfg(feature = "std")]
474 {
475 static CACHE_INFO: once_cell::sync::OnceCell<[CacheInfo; 3]> =
476 once_cell::sync::OnceCell::new();
477 CACHE_INFO.get_or_init(|| cache_info().unwrap_or(CACHE_INFO_DEFAULT))
478 }
479 }
480}
481
482#[cfg(target_vendor = "apple")]
483impl HasAmx {
484 #[inline]
485 pub fn get() -> bool {
486 static HAS_AMX: core::sync::atomic::AtomicU8 = core::sync::atomic::AtomicU8::new(u8::MAX);
487 let mut has_amx = HAS_AMX.load(::core::sync::atomic::Ordering::Relaxed);
488 if has_amx == u8::MAX {
489 let b = has_amx_impl() as u8;
490 HAS_AMX.store(b, core::sync::atomic::Ordering::Relaxed);
491 has_amx = b;
492 }
493 has_amx != 0
494 }
495}
496
497pub static CACHE_INFO: CacheInfoDeref = CacheInfoDeref;
498
499#[inline]
500fn gcd(mut a: usize, mut b: usize) -> usize {
501 while b != 0 {
502 let rem = a % b;
503 a = b;
504 b = rem;
505 }
506 a
507}
508
509#[inline]
510fn round_down(a: usize, b: usize) -> usize {
511 a / b * b
512}
513
514pub fn kernel_params(
515 m: usize,
516 n: usize,
517 k: usize,
518 mr: usize,
519 nr: usize,
520 sizeof: usize,
521) -> KernelParams {
522 if m == 0 || n == 0 || k == 0 {
523 return KernelParams {
524 kc: k,
525 mc: m,
526 nc: n,
527 };
528 }
529
530 let info = *CACHE_INFO;
531
532 let l1_cache_bytes = info[0].cache_bytes.max(32 * 1024);
533 let l2_cache_bytes = info[1].cache_bytes;
534 let l3_cache_bytes = info[2].cache_bytes;
535
536 let l1_line_bytes = info[0].cache_line_bytes.max(64);
537
538 let l1_assoc = info[0].associativity.max(2);
539 let l2_assoc = info[1].associativity.max(2);
540 let l3_assoc = info[2].associativity.max(2);
541
542 let l1_n_sets = l1_cache_bytes / (l1_line_bytes * l1_assoc);
543
544 let gcd = gcd(mr * sizeof, l1_line_bytes * l1_n_sets);
561 let kc_0 = (l1_line_bytes * l1_n_sets) / gcd;
562 let c_lhs = (mr * sizeof) / gcd;
563 let c_rhs = (nr * kc_0 * sizeof) / (l1_line_bytes * l1_n_sets);
564 let kc_multiplier = l1_assoc / (c_lhs + c_rhs);
565 let auto_kc = (kc_0 * kc_multiplier.next_power_of_two()).max(512).min(k);
567 let k_iter = k.msrv_div_ceil(auto_kc);
568 let auto_kc = k.msrv_div_ceil(k_iter);
569
570 let auto_mc = if l2_cache_bytes == 0 {
575 panic!();
576 } else {
577 let rhs_micropanel_bytes = nr * auto_kc * sizeof;
578 let rhs_l2_assoc = rhs_micropanel_bytes.msrv_div_ceil(l2_cache_bytes / l2_assoc);
579 let lhs_l2_assoc = (l2_assoc - 1 - rhs_l2_assoc).max(1);
580
581 let mc_from_lhs_l2_assoc = |lhs_l2_assoc: usize| -> usize {
582 (lhs_l2_assoc * l2_cache_bytes) / (l2_assoc * sizeof * auto_kc)
583 };
584
585 let auto_mc = round_down(mc_from_lhs_l2_assoc(lhs_l2_assoc), mr);
586 let m_iter = m.msrv_div_ceil(auto_mc);
587 m.msrv_div_ceil(m_iter * mr) * mr
588 };
589 let auto_mc = Ord::min(auto_mc, 8 * mr);
590
591 let auto_nc = if l3_cache_bytes == 0 {
595 0
596 } else {
597 let rhs_l3_assoc = l3_assoc - 1;
600 let rhs_macropanel_max_bytes = (rhs_l3_assoc * l3_cache_bytes) / l3_assoc;
601
602 let auto_nc = round_down(rhs_macropanel_max_bytes / (sizeof * auto_kc), nr);
603 let n_iter = n.msrv_div_ceil(auto_nc);
604 n.msrv_div_ceil(n_iter * nr) * nr
605 };
606
607 KernelParams {
608 kc: auto_kc,
609 mc: auto_mc,
610 nc: auto_nc,
611 }
612}