Skip to main content

trueno/blis/
cache_topology.rs

1//! Dynamic cache topology detection for BLIS blocking parameters.
2//!
3//! Contract: cgp-dynamic-cache-v1.yaml (C-CACHE-001 through C-CACHE-006)
4//!
5//! Reads `/sys/devices/system/cpu/cpu0/cache/` at runtime to determine
6//! L1D, L2, L3 sizes. Computes optimal MC/KC/NC for the current CPU.
7//! Falls back to hardcoded defaults if `/sys/` is not readable (C-CACHE-006).
8
9use std::sync::OnceLock;
10
11/// Detected cache sizes in bytes.
12#[derive(Debug, Clone, Copy)]
13pub struct CacheTopology {
14    /// L1 data cache size in bytes (typically 32-48 KB)
15    pub l1d_bytes: usize,
16    /// L2 unified cache size in bytes (typically 256 KB - 2 MB)
17    pub l2_bytes: usize,
18    /// L3 unified cache size in bytes (typically 4-64 MB)
19    pub l3_bytes: usize,
20}
21
22/// BLIS blocking parameters computed from cache topology.
23#[derive(Debug, Clone, Copy)]
24pub struct BlisBlocking {
25    pub mr: usize,
26    pub nr: usize,
27    pub mc: usize,
28    pub kc: usize,
29    pub nc: usize,
30    /// Whether these were computed from detected topology or are defaults
31    pub dynamic: bool,
32}
33
34/// Hardcoded fallback for 8×16 path (small N, C-CACHE-006).
35const DEFAULT_BLOCKING_8X16: BlisBlocking =
36    BlisBlocking { mr: 8, nr: 16, mc: 64, kc: 256, nc: 1024, dynamic: false };
37
38/// Cached topology — detected once, reused forever.
39static TOPOLOGY: OnceLock<CacheTopology> = OnceLock::new();
40static BLOCKING_8X32: OnceLock<BlisBlocking> = OnceLock::new();
41
42/// Detect cache topology from /sys/. Returns None if not readable.
43fn detect_from_sys() -> Option<CacheTopology> {
44    let mut l1d = 0usize;
45    let mut l2 = 0usize;
46    let mut l3 = 0usize;
47
48    for idx in 0..4 {
49        let base = format!("/sys/devices/system/cpu/cpu0/cache/index{idx}");
50        let size_str = std::fs::read_to_string(format!("{base}/size")).ok()?;
51        let type_str = std::fs::read_to_string(format!("{base}/type")).ok()?;
52
53        let size_str = size_str.trim();
54        let type_str = type_str.trim();
55
56        // Parse size: "32K" → 32768, "1024K" → 1048576
57        let size_bytes = parse_cache_size(size_str)?;
58
59        match (idx, type_str) {
60            (0, "Data") => l1d = size_bytes,
61            (1, _) => {} // L1 instruction, skip
62            (2, "Unified") => l2 = size_bytes,
63            (3, "Unified") => l3 = size_bytes,
64            _ => {}
65        }
66    }
67
68    if l1d > 0 && l2 > 0 {
69        Some(CacheTopology { l1d_bytes: l1d, l2_bytes: l2, l3_bytes: l3 })
70    } else {
71        None
72    }
73}
74
75/// Parse "32K" → 32768, "1024K" → 1048576, "32768K" → 33554432.
76fn parse_cache_size(s: &str) -> Option<usize> {
77    let s = s.trim();
78    if let Some(kb) = s.strip_suffix('K') {
79        kb.parse::<usize>().ok().map(|v| v * 1024)
80    } else if let Some(mb) = s.strip_suffix('M') {
81        mb.parse::<usize>().ok().map(|v| v * 1024 * 1024)
82    } else {
83        s.parse::<usize>().ok()
84    }
85}
86
87/// Get detected cache topology (cached after first call).
88pub fn topology() -> CacheTopology {
89    *TOPOLOGY.get_or_init(|| {
90        detect_from_sys().unwrap_or(CacheTopology {
91            l1d_bytes: 32768,   // 32K default
92            l2_bytes: 1048576,  // 1M default
93            l3_bytes: 33554432, // 32M default
94        })
95    })
96}
97
98/// Compute BLIS blocking for 8×32 microkernel from cache topology.
99///
100/// Invariants (from contract cgp-dynamic-cache-v1.yaml):
101/// - C-CACHE-001: mc * kc * 4 <= l2
102/// - C-CACHE-002: kc * nr * 4 <= l1d
103/// - C-CACHE-003: kc * nc * 4 <= l3 / 2
104/// - C-CACHE-004: mc % mr == 0
105/// - C-CACHE-005: nc % nr == 0
106fn compute_blocking_8x32(topo: &CacheTopology) -> BlisBlocking {
107    let mr = 8usize;
108    let nr = 32usize;
109
110    // C-CACHE-002: kc * nr * 4 <= l1d → kc <= l1d / (nr * 4)
111    let kc_max = topo.l1d_bytes / (nr * 4);
112    // Round down to power of 2, min 64
113    let kc = kc_max.next_power_of_two().min(kc_max).max(64);
114
115    // C-CACHE-001: mc * kc * 4 <= l2 → mc <= l2 / (kc * 4)
116    let mc_max = topo.l2_bytes / (kc * 4);
117    // Round down to multiple of MR.
118    // EMPIRICAL: MC=96 outperforms MC=192/256 on Zen 4 (tested 2026-04-05).
119    // Large MC increases A-packing overhead without proportional L2 benefit.
120    // Cap at 12*MR=96 (matches empirically-tuned value).
121    let mc = (mc_max / mr * mr).min(12 * mr).max(mr);
122
123    // C-CACHE-003: kc * nc * 4 <= l3/2 → nc <= l3 / (2 * kc * 4)
124    let nc_max = topo.l3_bytes / (2 * kc * 4);
125    // Round down to multiple of NR, cap at 4096
126    let nc = (nc_max / nr * nr).min(4096).max(nr);
127
128    BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
129}
130
131/// Get optimal BLIS blocking for 8×32 microkernel (cached).
132pub fn blocking_8x32() -> BlisBlocking {
133    *BLOCKING_8X32.get_or_init(|| {
134        let topo = topology();
135        compute_blocking_8x32(&topo)
136    })
137}
138
139/// Get optimal BLIS blocking for 8×48 codegen microkernel (cached).
140/// 8×48: 24 accumulators, 24 FMAs/K-step (3× the 8×16 kernel).
141/// KC is smaller (L1-limited at NR=48) but more FMAs per K-step may compensate.
142pub fn blocking_8x48() -> BlisBlocking {
143    static BLOCKING_8X48: OnceLock<BlisBlocking> = OnceLock::new();
144    *BLOCKING_8X48.get_or_init(|| {
145        let topo = topology();
146        let mr = 8usize;
147        let nr = 48usize;
148        // KC: l1d / (nr * 4) = 32768 / 192 = 170, round down to power-of-2 = 128
149        // Power-of-2 KC ensures aligned loop trips for vectorized packing.
150        let kc_max = topo.l1d_bytes / (nr * 4);
151        let kc = (kc_max.next_power_of_two() >> 1).max(64);
152        let mc_max = topo.l2_bytes / (kc * 4);
153        let mc = (mc_max / mr * mr).min(12 * mr).max(mr);
154        let nc_max = topo.l3_bytes / (2 * kc * 4);
155        let nc = (nc_max / nr * nr).min(4096).max(nr);
156        BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
157    })
158}
159
160/// Get optimal BLIS blocking for broadcast-B 64×6 microkernel (cached).
161/// 64×6: 24 FMA accumulators — matches faer's nano-gemm register utilization.
162/// NR=6 keeps B panel tiny → KC can be large (256-512).
163pub fn blocking_64x6_bcast_b() -> BlisBlocking {
164    static BLOCKING_64X6: OnceLock<BlisBlocking> = OnceLock::new();
165    *BLOCKING_64X6.get_or_init(|| {
166        let topo = topology();
167        let mr = 64usize;
168        let nr = 6usize;
169        // KC: B panel = NR × KC × 4 = 24 × KC bytes. At KC=512: 12KB (fits L1 easily).
170        // A panel = MR × KC × 4 = 256 × KC bytes. At KC=512: 128KB (fits L2).
171        // Limit KC so A panel fits in L2: KC ≤ L2 / (MR × 4)
172        let kc_max_l2 = topo.l2_bytes / (mr * 4);
173        let kc_max_l1 = topo.l1d_bytes * 3 / 4 / (nr * 4); // B uses 3/4 of L1
174        let kc = kc_max_l2.min(kc_max_l1).clamp(64, 512);
175        // MC: number of rows per L2 tile. Since MR=64 is large, MC should be
176        // a small multiple of MR.
177        let mc_max = topo.l2_bytes / (kc * 4);
178        let mc = (mc_max / mr * mr).min(4 * mr).max(mr);
179        // NC: columns per L3 tile
180        let nc_max = topo.l3_bytes / (2 * kc * 4);
181        let nc = (nc_max / nr * nr).min(4096).max(nr);
182        BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
183    })
184}
185
186/// Get default blocking for 8×16 microkernel (hardcoded, used for small N).
187pub fn blocking_8x16() -> BlisBlocking {
188    DEFAULT_BLOCKING_8X16
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    /// FALSIFY-CACHE-001: detect_cache_topology returns valid sizes on Linux.
196    #[test]
197    fn test_detect_topology() {
198        let topo = topology();
199        assert!(topo.l1d_bytes > 0, "L1D must be > 0");
200        assert!(topo.l2_bytes > 0, "L2 must be > 0");
201        // L3 may be 0 on some systems but should be > 0 on most
202        assert!(
203            topo.l1d_bytes <= topo.l2_bytes,
204            "L1D ({}) must be <= L2 ({})",
205            topo.l1d_bytes,
206            topo.l2_bytes
207        );
208    }
209
210    /// FALSIFY-CACHE-002: computed MC * KC * 4 fits in L2.
211    #[test]
212    fn test_mc_kc_fits_l2() {
213        let topo = topology();
214        let blk = compute_blocking_8x32(&topo);
215        let packed_a_bytes = blk.mc * blk.kc * 4;
216        assert!(
217            packed_a_bytes <= topo.l2_bytes,
218            "C-CACHE-001: packed A = {} bytes > L2 = {} bytes",
219            packed_a_bytes,
220            topo.l2_bytes
221        );
222    }
223
224    /// Contract C-CACHE-002: KC * NR * 4 fits in L1D.
225    #[test]
226    fn test_kc_nr_fits_l1() {
227        let topo = topology();
228        let blk = compute_blocking_8x32(&topo);
229        let b_panel_bytes = blk.kc * blk.nr * 4;
230        assert!(
231            b_panel_bytes <= topo.l1d_bytes,
232            "C-CACHE-002: B panel = {} bytes > L1D = {} bytes",
233            b_panel_bytes,
234            topo.l1d_bytes
235        );
236    }
237
238    /// Contract C-CACHE-004: MC is multiple of MR.
239    #[test]
240    fn test_mc_multiple_of_mr() {
241        let blk = blocking_8x32();
242        assert_eq!(blk.mc % blk.mr, 0, "C-CACHE-004: MC={} not multiple of MR={}", blk.mc, blk.mr);
243    }
244
245    /// Contract C-CACHE-005: NC is multiple of NR.
246    #[test]
247    fn test_nc_multiple_of_nr() {
248        let blk = blocking_8x32();
249        assert_eq!(blk.nc % blk.nr, 0, "C-CACHE-005: NC={} not multiple of NR={}", blk.nc, blk.nr);
250    }
251
252    /// Parse cache size strings correctly.
253    #[test]
254    fn test_parse_cache_size() {
255        assert_eq!(parse_cache_size("32K"), Some(32768));
256        assert_eq!(parse_cache_size("1024K"), Some(1048576));
257        assert_eq!(parse_cache_size("32768K"), Some(33554432));
258        assert_eq!(parse_cache_size("2M"), Some(2097152));
259    }
260
261    /// Computed blocking for Zen 4 (32K L1D, 1M L2, 32M L3).
262    #[test]
263    fn test_zen4_blocking() {
264        let topo = CacheTopology { l1d_bytes: 32768, l2_bytes: 1048576, l3_bytes: 33554432 };
265        let blk = compute_blocking_8x32(&topo);
266        // KC: 32768 / (32 * 4) = 256
267        assert_eq!(blk.kc, 256, "KC for Zen 4");
268        // MC: 1048576 / (256 * 4) = 1024, capped at 256, rounded to 8
269        assert!(blk.mc >= 64 && blk.mc <= 256, "MC={} for Zen 4", blk.mc);
270        assert_eq!(blk.mc % 8, 0, "MC must be multiple of MR=8");
271        // NC: 33554432 / (2 * 256 * 4) = 16384, capped at 4096
272        assert_eq!(blk.nc, 4096, "NC for Zen 4");
273        assert!(blk.dynamic);
274    }
275}