gemm_common/
cache.rs

1#[derive(Default, Debug, Copy, Clone)]
2pub struct CacheInfo {
3    pub associativity: usize,
4    pub cache_bytes: usize,
5    pub cache_line_bytes: usize,
6}
7
8#[derive(Default, Debug, Copy, Clone)]
9pub struct KernelParams {
10    pub kc: usize,
11    pub mc: usize,
12    pub nc: usize,
13}
14
15pub trait DivCeil: Sized {
16    fn msrv_div_ceil(self, rhs: Self) -> Self;
17    fn msrv_next_multiple_of(self, rhs: Self) -> Self;
18    fn msrv_checked_next_multiple_of(self, rhs: Self) -> Option<Self>;
19}
20
21impl DivCeil for usize {
22    #[inline]
23    fn msrv_div_ceil(self, rhs: Self) -> Self {
24        let d = self / rhs;
25        let r = self % rhs;
26        if r > 0 {
27            d + 1
28        } else {
29            d
30        }
31    }
32
33    #[inline]
34    fn msrv_next_multiple_of(self, rhs: Self) -> Self {
35        match self % rhs {
36            0 => self,
37            r => self + (rhs - r),
38        }
39    }
40
41    #[inline]
42    fn msrv_checked_next_multiple_of(self, rhs: Self) -> Option<Self> {
43        {
44            match self.checked_rem(rhs)? {
45                0 => Some(self),
46                r => self.checked_add(rhs - r),
47            }
48        }
49    }
50}
51
52#[cfg(target_vendor = "apple")]
53fn has_amx_impl() -> bool {
54    if !cfg!(miri) {
55        #[cfg(feature = "std")]
56        {
57            use sysctl::Ctl;
58            use sysctl::Sysctl;
59
60            if let Ok(brand) =
61                Ctl::new("machdep.cpu.brand_string").and_then(|ctl| ctl.value_string())
62            {
63                let mut words = brand.split_whitespace();
64                let apple = words.next();
65                let mx = words.next();
66                return apple == Some("Apple") && matches!(mx, Some("M1" | "M2" | "M3"));
67            }
68        }
69    }
70    false
71}
72
73fn cache_info() -> Option<[CacheInfo; 3]> {
74    if !cfg!(miri) {
75        #[cfg(feature = "std")]
76        {
77            #[cfg(target_os = "linux")]
78            {
79                use std::fs;
80                fn try_cache_info_linux() -> Result<[CacheInfo; 3], std::io::Error> {
81                    let mut all_info = [CacheInfo {
82                        associativity: 8,
83                        cache_bytes: 0,
84                        cache_line_bytes: 64,
85                    }; 3];
86
87                    for cpu_x in fs::read_dir("/sys/devices/system/cpu")? {
88                        let cpu_x = cpu_x?.path();
89                        let Some(cpu_x_name) = cpu_x.file_name().and_then(|f| f.to_str()) else {
90                            continue;
91                        };
92                        if !cpu_x_name.starts_with("cpu") {
93                            continue;
94                        }
95                        let cache = cpu_x.join("cache");
96                        if !cache.is_dir() {
97                            continue;
98                        }
99                        'index: for index_y in fs::read_dir(cache)? {
100                            let index_y = index_y?.path();
101                            if !index_y.is_dir() {
102                                continue;
103                            }
104                            let Some(index_y_name) = index_y.file_name().and_then(|f| f.to_str())
105                            else {
106                                continue;
107                            };
108                            if !index_y_name.starts_with("index") {
109                                continue;
110                            }
111
112                            let mut cache_info = CacheInfo {
113                                associativity: 8,
114                                cache_bytes: 0,
115                                cache_line_bytes: 64,
116                            };
117                            let mut level: usize = 0;
118                            let mut shared_count: usize = 0;
119
120                            for entry in fs::read_dir(index_y)? {
121                                let entry = entry?.path();
122                                if let Some(name) = entry.file_name() {
123                                    let contents = fs::read_to_string(&entry)?;
124                                    let contents = contents.trim();
125                                    if name == "type" && !matches!(contents, "Data" | "Unified") {
126                                        continue 'index;
127                                    }
128                                    if name == "shared_cpu_list" {
129                                        for item in contents.split(',') {
130                                            if item.contains('-') {
131                                                let mut item = item.split('-');
132                                                let Some(start) = item.next() else {
133                                                    continue 'index;
134                                                };
135                                                let Some(end) = item.next() else {
136                                                    continue 'index;
137                                                };
138
139                                                let Ok(start) = start.parse::<usize>() else {
140                                                    continue 'index;
141                                                };
142                                                let Ok(end) = end.parse::<usize>() else {
143                                                    continue 'index;
144                                                };
145
146                                                shared_count += end + 1 - start;
147                                            } else {
148                                                shared_count += 1;
149                                            }
150                                        }
151                                    }
152
153                                    if name == "level" {
154                                        let Ok(contents) = contents.parse::<usize>() else {
155                                            continue 'index;
156                                        };
157                                        level = contents;
158                                    }
159
160                                    if name == "coherency_line_size" {
161                                        let Ok(contents) = contents.parse::<usize>() else {
162                                            continue 'index;
163                                        };
164                                        cache_info.cache_line_bytes = contents;
165                                    }
166                                    if name == "ways_of_associativity" {
167                                        let Ok(contents) = contents.parse::<usize>() else {
168                                            continue 'index;
169                                        };
170                                        cache_info.associativity = contents;
171                                    }
172                                    if name == "size" {
173                                        if contents.ends_with("G") {
174                                            let Ok(contents) =
175                                                contents.trim_end_matches('G').parse::<usize>()
176                                            else {
177                                                continue 'index;
178                                            };
179                                            cache_info.cache_bytes = contents * 1024 * 1024 * 1024;
180                                        } else if contents.ends_with("M") {
181                                            let Ok(contents) =
182                                                contents.trim_end_matches('M').parse::<usize>()
183                                            else {
184                                                continue 'index;
185                                            };
186                                            cache_info.cache_bytes = contents * 1024 * 1024;
187                                        } else if contents.ends_with("K") {
188                                            let Ok(contents) =
189                                                contents.trim_end_matches('K').parse::<usize>()
190                                            else {
191                                                continue 'index;
192                                            };
193                                            cache_info.cache_bytes = contents * 1024;
194                                        } else {
195                                            let Ok(contents) = contents.parse::<usize>() else {
196                                                continue 'index;
197                                            };
198                                            cache_info.cache_bytes = contents;
199                                        }
200                                    }
201                                }
202                            }
203                            if level == 3 {
204                                shared_count = 1;
205                            }
206                            if level > 0 {
207                                if cache_info.cache_line_bytes
208                                    >= all_info[level - 1].cache_line_bytes
209                                {
210                                    all_info[level - 1].associativity = cache_info.associativity;
211                                    all_info[level - 1].cache_line_bytes =
212                                        cache_info.cache_line_bytes;
213                                    all_info[level - 1].cache_bytes =
214                                        cache_info.cache_bytes / shared_count;
215                                }
216                            }
217                        }
218                    }
219
220                    for (info, default) in core::iter::zip(&mut all_info, CACHE_INFO_DEFAULT) {
221                        if info.cache_bytes == 0 {
222                            *info = default;
223                        }
224                    }
225
226                    Ok(all_info)
227                }
228                if let Ok(info) = try_cache_info_linux() {
229                    return Some(info);
230                }
231
232                if let Ok(lscpu) = std::process::Command::new("lscpu")
233                    .arg("-B")
234                    .arg("-C=type,level,ways,coherency-size,one-size")
235                    .output()
236                {
237                    if lscpu.status.success() {
238                        if let Ok(lscpu) = String::from_utf8(lscpu.stdout).as_deref() {
239                            let mut info = CACHE_INFO_DEFAULT;
240                            for line in lscpu.lines().skip(1) {
241                                let mut iter = line.split_whitespace();
242                                if let [Some(cache_type), Some(level), Some(ways), Some(coherency_size), Some(one_size)] = [
243                                    iter.next(),
244                                    iter.next(),
245                                    iter.next(),
246                                    iter.next(),
247                                    iter.next(),
248                                ] {
249                                    if let "Data" | "Unified" = cache_type {
250                                        let level = level.parse::<usize>().unwrap();
251                                        let ways = ways.parse::<usize>().unwrap();
252                                        let coherency_size =
253                                            coherency_size.parse::<usize>().unwrap();
254                                        let one_size = one_size.parse::<usize>().unwrap();
255
256                                        info[level - 1].associativity = ways;
257                                        info[level - 1].cache_line_bytes = coherency_size;
258                                        info[level - 1].cache_bytes = one_size;
259                                    }
260                                }
261                            }
262                            return Some(info);
263                        }
264                    }
265                }
266            }
267            #[cfg(target_vendor = "apple")]
268            {
269                use sysctl::Ctl;
270                use sysctl::Sysctl;
271
272                let mut all_info = CACHE_INFO_DEFAULT;
273                if let Ok(l1) =
274                    Ctl::new("hw.perflevel0.l1dcachesize").and_then(|ctl| ctl.value_string())
275                {
276                    if let Ok(l1) = l1.parse::<usize>() {
277                        all_info[0].cache_bytes = l1;
278                    }
279                }
280                if let (Ok(physicalcpu), Ok(cpusperl2), Ok(l2)) = (
281                    Ctl::new("hw.perflevel0.physicalcpu").and_then(|ctl| ctl.value_string()),
282                    Ctl::new("hw.perflevel0.cpusperl2").and_then(|ctl| ctl.value_string()),
283                    Ctl::new("hw.perflevel0.l2cachesize").and_then(|ctl| ctl.value_string()),
284                ) {
285                    if let (Ok(_physicalcpu), Ok(cpusperl2), Ok(l2)) = (
286                        physicalcpu.parse::<usize>(),
287                        cpusperl2.parse::<usize>(),
288                        l2.parse::<usize>(),
289                    ) {
290                        all_info[1].cache_bytes = l2 / cpusperl2;
291                    }
292                }
293                all_info[2].cache_bytes = 0;
294                return Some(all_info);
295            }
296        }
297
298        #[cfg(any(
299            all(target_arch = "x86", not(target_env = "sgx"), target_feature = "sse"),
300            all(target_arch = "x86_64", not(target_env = "sgx"))
301        ))]
302        {
303            use raw_cpuid::CpuId;
304            let cpuid = CpuId::new();
305
306            if let Some(vf) = cpuid.get_vendor_info() {
307                let vf = vf.as_str();
308                if vf == "GenuineIntel" {
309                    if let Some(cparams) = cpuid.get_cache_parameters() {
310                        // not sure why, intel cpus seem to prefer smaller mc
311                        let mut info = [CacheInfo {
312                            cache_bytes: 0,
313                            associativity: 0,
314                            cache_line_bytes: 64,
315                        }; 3];
316
317                        for cache in cparams {
318                            use raw_cpuid::CacheType::*;
319                            match cache.cache_type() {
320                                Null | Instruction | Reserved => continue,
321                                Data | Unified => {
322                                    let level = cache.level() as usize;
323                                    let associativity = cache.associativity();
324                                    let nsets = cache.sets();
325                                    let cache_line_bytes = cache.coherency_line_size();
326                                    if level > 0 && level < 4 {
327                                        let info = &mut info[level - 1];
328                                        info.cache_line_bytes = cache_line_bytes;
329                                        info.associativity = associativity;
330                                        info.cache_bytes = associativity * nsets * cache_line_bytes;
331                                    }
332                                }
333                            }
334                        }
335                        return Some(info);
336                    }
337                }
338
339                if vf == "AuthenticAMD" {
340                    if let Some(l1) = cpuid.get_l1_cache_and_tlb_info() {
341                        if let Some(l23) = cpuid.get_l2_l3_cache_and_tlb_info() {
342                            let compute_info = |associativity: raw_cpuid::Associativity,
343                                                cache_kb: usize,
344                                                cache_line_bytes: u8|
345                             -> CacheInfo {
346                                let cache_bytes = cache_kb as usize * 1024;
347                                let cache_line_bytes = cache_line_bytes as usize;
348
349                                use raw_cpuid::Associativity::*;
350                                let associativity = match associativity {
351                                    Unknown | Disabled => {
352                                        return CacheInfo {
353                                            associativity: 0,
354                                            cache_bytes: 0,
355                                            cache_line_bytes: 64,
356                                        }
357                                    }
358                                    FullyAssociative => cache_bytes / cache_line_bytes,
359                                    DirectMapped => 1,
360                                    NWay(n) => n as usize,
361                                };
362
363                                CacheInfo {
364                                    associativity,
365                                    cache_bytes,
366                                    cache_line_bytes,
367                                }
368                            };
369                            return Some([
370                                compute_info(
371                                    l1.dcache_associativity(),
372                                    l1.dcache_size() as usize,
373                                    l1.dcache_line_size(),
374                                ),
375                                compute_info(
376                                    l23.l2cache_associativity(),
377                                    l23.l2cache_size() as usize,
378                                    l23.l2cache_line_size(),
379                                ),
380                                compute_info(
381                                    l23.l3cache_associativity(),
382                                    l23.l3cache_size() as usize * 512,
383                                    l23.l3cache_line_size(),
384                                ),
385                            ]);
386                        }
387                    }
388                }
389            }
390        }
391    }
392    None
393}
394
395#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
396static CACHE_INFO_DEFAULT: [CacheInfo; 3] = [
397    CacheInfo {
398        associativity: 8,
399        cache_bytes: 32 * 1024, // 32KiB
400        cache_line_bytes: 64,
401    },
402    CacheInfo {
403        associativity: 8,
404        cache_bytes: 256 * 1024, // 256KiB
405        cache_line_bytes: 64,
406    },
407    CacheInfo {
408        associativity: 8,
409        cache_bytes: 2 * 1024 * 1024, // 2MiB
410        cache_line_bytes: 64,
411    },
412];
413
414#[cfg(any(target_arch = "powerpc", target_arch = "powerpc64"))]
415static CACHE_INFO_DEFAULT: [CacheInfo; 3] = [
416    CacheInfo {
417        associativity: 8,
418        cache_bytes: 64 * 1024, // 64KiB
419        cache_line_bytes: 64,
420    },
421    CacheInfo {
422        associativity: 8,
423        cache_bytes: 512 * 1024, // 512KiB
424        cache_line_bytes: 64,
425    },
426    CacheInfo {
427        associativity: 8,
428        cache_bytes: 4 * 1024 * 1024, // 4MiB
429        cache_line_bytes: 64,
430    },
431];
432
433#[cfg(not(any(
434    target_arch = "powerpc",
435    target_arch = "powerpc64",
436    target_arch = "x86",
437    target_arch = "x86_64"
438)))]
439static CACHE_INFO_DEFAULT: [CacheInfo; 3] = [
440    CacheInfo {
441        associativity: 8,
442        cache_bytes: 16 * 1024, // 16KiB
443        cache_line_bytes: 64,
444    },
445    CacheInfo {
446        associativity: 8,
447        cache_bytes: 512 * 1024, // 512KiB
448        cache_line_bytes: 64,
449    },
450    CacheInfo {
451        associativity: 8,
452        cache_bytes: 1024 * 1024, // 1MiB
453        cache_line_bytes: 64,
454    },
455];
456
457pub struct CacheInfoDeref;
458#[cfg(target_vendor = "apple")]
459pub struct HasAmx;
460
461impl core::ops::Deref for CacheInfoDeref {
462    type Target = [CacheInfo; 3];
463
464    #[inline]
465    fn deref(&self) -> &Self::Target {
466        #[cfg(not(feature = "std"))]
467        {
468            static CACHE_INFO: once_cell::race::OnceBox<[CacheInfo; 3]> =
469                once_cell::race::OnceBox::new();
470            CACHE_INFO
471                .get_or_init(|| alloc::boxed::Box::new(cache_info().unwrap_or(CACHE_INFO_DEFAULT)))
472        }
473        #[cfg(feature = "std")]
474        {
475            static CACHE_INFO: once_cell::sync::OnceCell<[CacheInfo; 3]> =
476                once_cell::sync::OnceCell::new();
477            CACHE_INFO.get_or_init(|| cache_info().unwrap_or(CACHE_INFO_DEFAULT))
478        }
479    }
480}
481
482#[cfg(target_vendor = "apple")]
483impl HasAmx {
484    #[inline]
485    pub fn get() -> bool {
486        static HAS_AMX: core::sync::atomic::AtomicU8 = core::sync::atomic::AtomicU8::new(u8::MAX);
487        let mut has_amx = HAS_AMX.load(::core::sync::atomic::Ordering::Relaxed);
488        if has_amx == u8::MAX {
489            let b = has_amx_impl() as u8;
490            HAS_AMX.store(b, core::sync::atomic::Ordering::Relaxed);
491            has_amx = b;
492        }
493        has_amx != 0
494    }
495}
496
497pub static CACHE_INFO: CacheInfoDeref = CacheInfoDeref;
498
499#[inline]
500fn gcd(mut a: usize, mut b: usize) -> usize {
501    while b != 0 {
502        let rem = a % b;
503        a = b;
504        b = rem;
505    }
506    a
507}
508
509#[inline]
510fn round_down(a: usize, b: usize) -> usize {
511    a / b * b
512}
513
514pub fn kernel_params(
515    m: usize,
516    n: usize,
517    k: usize,
518    mr: usize,
519    nr: usize,
520    sizeof: usize,
521) -> KernelParams {
522    if m == 0 || n == 0 || k == 0 {
523        return KernelParams {
524            kc: k,
525            mc: m,
526            nc: n,
527        };
528    }
529
530    let info = *CACHE_INFO;
531
532    let l1_cache_bytes = info[0].cache_bytes.max(32 * 1024);
533    let l2_cache_bytes = info[1].cache_bytes;
534    let l3_cache_bytes = info[2].cache_bytes;
535
536    let l1_line_bytes = info[0].cache_line_bytes.max(64);
537
538    let l1_assoc = info[0].associativity.max(2);
539    let l2_assoc = info[1].associativity.max(2);
540    let l3_assoc = info[2].associativity.max(2);
541
542    let l1_n_sets = l1_cache_bytes / (l1_line_bytes * l1_assoc);
543
544    // requires
545    // A micropanels must occupy different cache sets
546    // so that loading a micropanel evicts the previous one
547    // => byte stride must be multiple of n_sets×line_bytes
548    //
549    // => mr×kc×scalar_bytes == C_A × l1_line_bytes × l1_n_sets
550    //
551    // l1 must be able to hold A micropanel, B micropanel
552    //
553    // => C_A + C_B <= l1_assoc
554
555    // a×n = b×m
556    // find lcm of a, b
557    // n = lcm / a = b/gcd(a,b)
558    // m = lcm / b = a/gcd(a,b)
559
560    let gcd = gcd(mr * sizeof, l1_line_bytes * l1_n_sets);
561    let kc_0 = (l1_line_bytes * l1_n_sets) / gcd;
562    let c_lhs = (mr * sizeof) / gcd;
563    let c_rhs = (nr * kc_0 * sizeof) / (l1_line_bytes * l1_n_sets);
564    let kc_multiplier = l1_assoc / (c_lhs + c_rhs);
565    // let auto_kc = kc_0 * kc_multiplier;
566    let auto_kc = (kc_0 * kc_multiplier.next_power_of_two()).max(512).min(k);
567    let k_iter = k.msrv_div_ceil(auto_kc);
568    let auto_kc = k.msrv_div_ceil(k_iter);
569
570    // l2 cache must hold
571    //  - B micropanel: nr×kc: assume 1 assoc degree
572    //  - A macropanel: mc×kc
573    // mc×kc×scalar_bytes
574    let auto_mc = if l2_cache_bytes == 0 {
575        panic!();
576    } else {
577        let rhs_micropanel_bytes = nr * auto_kc * sizeof;
578        let rhs_l2_assoc = rhs_micropanel_bytes.msrv_div_ceil(l2_cache_bytes / l2_assoc);
579        let lhs_l2_assoc = (l2_assoc - 1 - rhs_l2_assoc).max(1);
580
581        let mc_from_lhs_l2_assoc = |lhs_l2_assoc: usize| -> usize {
582            (lhs_l2_assoc * l2_cache_bytes) / (l2_assoc * sizeof * auto_kc)
583        };
584
585        let auto_mc = round_down(mc_from_lhs_l2_assoc(lhs_l2_assoc), mr);
586        let m_iter = m.msrv_div_ceil(auto_mc);
587        m.msrv_div_ceil(m_iter * mr) * mr
588    };
589    let auto_mc = Ord::min(auto_mc, 8 * mr);
590
591    // l3 cache must hold
592    //  - A macropanel: mc×kc: assume 1 assoc degree
593    //  - B macropanel: nc×kc
594    let auto_nc = if l3_cache_bytes == 0 {
595        0
596    } else {
597        // let lhs_macropanel_bytes = auto_mc * auto_kc * sizeof;
598        // let lhs_l3_assoc = msrv_div_ceil(lhs_macropanel_bytes, l3_cache_bytes / l3_assoc);
599        let rhs_l3_assoc = l3_assoc - 1;
600        let rhs_macropanel_max_bytes = (rhs_l3_assoc * l3_cache_bytes) / l3_assoc;
601
602        let auto_nc = round_down(rhs_macropanel_max_bytes / (sizeof * auto_kc), nr);
603        let n_iter = n.msrv_div_ceil(auto_nc);
604        n.msrv_div_ceil(n_iter * nr) * nr
605    };
606
607    KernelParams {
608        kc: auto_kc,
609        mc: auto_mc,
610        nc: auto_nc,
611    }
612}