trueno/blis/
cache_topology.rs1use std::sync::OnceLock;
10
11#[derive(Debug, Clone, Copy)]
13pub struct CacheTopology {
14 pub l1d_bytes: usize,
16 pub l2_bytes: usize,
18 pub l3_bytes: usize,
20}
21
22#[derive(Debug, Clone, Copy)]
24pub struct BlisBlocking {
25 pub mr: usize,
26 pub nr: usize,
27 pub mc: usize,
28 pub kc: usize,
29 pub nc: usize,
30 pub dynamic: bool,
32}
33
34const DEFAULT_BLOCKING_8X16: BlisBlocking =
36 BlisBlocking { mr: 8, nr: 16, mc: 64, kc: 256, nc: 1024, dynamic: false };
37
38static TOPOLOGY: OnceLock<CacheTopology> = OnceLock::new();
40static BLOCKING_8X32: OnceLock<BlisBlocking> = OnceLock::new();
41
42fn detect_from_sys() -> Option<CacheTopology> {
44 let mut l1d = 0usize;
45 let mut l2 = 0usize;
46 let mut l3 = 0usize;
47
48 for idx in 0..4 {
49 let base = format!("/sys/devices/system/cpu/cpu0/cache/index{idx}");
50 let size_str = std::fs::read_to_string(format!("{base}/size")).ok()?;
51 let type_str = std::fs::read_to_string(format!("{base}/type")).ok()?;
52
53 let size_str = size_str.trim();
54 let type_str = type_str.trim();
55
56 let size_bytes = parse_cache_size(size_str)?;
58
59 match (idx, type_str) {
60 (0, "Data") => l1d = size_bytes,
61 (1, _) => {} (2, "Unified") => l2 = size_bytes,
63 (3, "Unified") => l3 = size_bytes,
64 _ => {}
65 }
66 }
67
68 if l1d > 0 && l2 > 0 {
69 Some(CacheTopology { l1d_bytes: l1d, l2_bytes: l2, l3_bytes: l3 })
70 } else {
71 None
72 }
73}
74
75fn parse_cache_size(s: &str) -> Option<usize> {
77 let s = s.trim();
78 if let Some(kb) = s.strip_suffix('K') {
79 kb.parse::<usize>().ok().map(|v| v * 1024)
80 } else if let Some(mb) = s.strip_suffix('M') {
81 mb.parse::<usize>().ok().map(|v| v * 1024 * 1024)
82 } else {
83 s.parse::<usize>().ok()
84 }
85}
86
87pub fn topology() -> CacheTopology {
89 *TOPOLOGY.get_or_init(|| {
90 detect_from_sys().unwrap_or(CacheTopology {
91 l1d_bytes: 32768, l2_bytes: 1048576, l3_bytes: 33554432, })
95 })
96}
97
98fn compute_blocking_8x32(topo: &CacheTopology) -> BlisBlocking {
107 let mr = 8usize;
108 let nr = 32usize;
109
110 let kc_max = topo.l1d_bytes / (nr * 4);
112 let kc = kc_max.next_power_of_two().min(kc_max).max(64);
114
115 let mc_max = topo.l2_bytes / (kc * 4);
117 let mc = (mc_max / mr * mr).min(12 * mr).max(mr);
122
123 let nc_max = topo.l3_bytes / (2 * kc * 4);
125 let nc = (nc_max / nr * nr).min(4096).max(nr);
127
128 BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
129}
130
131pub fn blocking_8x32() -> BlisBlocking {
133 *BLOCKING_8X32.get_or_init(|| {
134 let topo = topology();
135 compute_blocking_8x32(&topo)
136 })
137}
138
139pub fn blocking_8x48() -> BlisBlocking {
143 static BLOCKING_8X48: OnceLock<BlisBlocking> = OnceLock::new();
144 *BLOCKING_8X48.get_or_init(|| {
145 let topo = topology();
146 let mr = 8usize;
147 let nr = 48usize;
148 let kc_max = topo.l1d_bytes / (nr * 4);
151 let kc = (kc_max.next_power_of_two() >> 1).max(64);
152 let mc_max = topo.l2_bytes / (kc * 4);
153 let mc = (mc_max / mr * mr).min(12 * mr).max(mr);
154 let nc_max = topo.l3_bytes / (2 * kc * 4);
155 let nc = (nc_max / nr * nr).min(4096).max(nr);
156 BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
157 })
158}
159
160pub fn blocking_64x6_bcast_b() -> BlisBlocking {
164 static BLOCKING_64X6: OnceLock<BlisBlocking> = OnceLock::new();
165 *BLOCKING_64X6.get_or_init(|| {
166 let topo = topology();
167 let mr = 64usize;
168 let nr = 6usize;
169 let kc_max_l2 = topo.l2_bytes / (mr * 4);
173 let kc_max_l1 = topo.l1d_bytes * 3 / 4 / (nr * 4); let kc = kc_max_l2.min(kc_max_l1).clamp(64, 512);
175 let mc_max = topo.l2_bytes / (kc * 4);
178 let mc = (mc_max / mr * mr).min(4 * mr).max(mr);
179 let nc_max = topo.l3_bytes / (2 * kc * 4);
181 let nc = (nc_max / nr * nr).min(4096).max(nr);
182 BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
183 })
184}
185
186pub fn blocking_8x16() -> BlisBlocking {
188 DEFAULT_BLOCKING_8X16
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
197 fn test_detect_topology() {
198 let topo = topology();
199 assert!(topo.l1d_bytes > 0, "L1D must be > 0");
200 assert!(topo.l2_bytes > 0, "L2 must be > 0");
201 assert!(
203 topo.l1d_bytes <= topo.l2_bytes,
204 "L1D ({}) must be <= L2 ({})",
205 topo.l1d_bytes,
206 topo.l2_bytes
207 );
208 }
209
210 #[test]
212 fn test_mc_kc_fits_l2() {
213 let topo = topology();
214 let blk = compute_blocking_8x32(&topo);
215 let packed_a_bytes = blk.mc * blk.kc * 4;
216 assert!(
217 packed_a_bytes <= topo.l2_bytes,
218 "C-CACHE-001: packed A = {} bytes > L2 = {} bytes",
219 packed_a_bytes,
220 topo.l2_bytes
221 );
222 }
223
224 #[test]
226 fn test_kc_nr_fits_l1() {
227 let topo = topology();
228 let blk = compute_blocking_8x32(&topo);
229 let b_panel_bytes = blk.kc * blk.nr * 4;
230 assert!(
231 b_panel_bytes <= topo.l1d_bytes,
232 "C-CACHE-002: B panel = {} bytes > L1D = {} bytes",
233 b_panel_bytes,
234 topo.l1d_bytes
235 );
236 }
237
238 #[test]
240 fn test_mc_multiple_of_mr() {
241 let blk = blocking_8x32();
242 assert_eq!(blk.mc % blk.mr, 0, "C-CACHE-004: MC={} not multiple of MR={}", blk.mc, blk.mr);
243 }
244
245 #[test]
247 fn test_nc_multiple_of_nr() {
248 let blk = blocking_8x32();
249 assert_eq!(blk.nc % blk.nr, 0, "C-CACHE-005: NC={} not multiple of NR={}", blk.nc, blk.nr);
250 }
251
252 #[test]
254 fn test_parse_cache_size() {
255 assert_eq!(parse_cache_size("32K"), Some(32768));
256 assert_eq!(parse_cache_size("1024K"), Some(1048576));
257 assert_eq!(parse_cache_size("32768K"), Some(33554432));
258 assert_eq!(parse_cache_size("2M"), Some(2097152));
259 }
260
261 #[test]
263 fn test_zen4_blocking() {
264 let topo = CacheTopology { l1d_bytes: 32768, l2_bytes: 1048576, l3_bytes: 33554432 };
265 let blk = compute_blocking_8x32(&topo);
266 assert_eq!(blk.kc, 256, "KC for Zen 4");
268 assert!(blk.mc >= 64 && blk.mc <= 256, "MC={} for Zen 4", blk.mc);
270 assert_eq!(blk.mc % 8, 0, "MC must be multiple of MR=8");
271 assert_eq!(blk.nc, 4096, "NC for Zen 4");
273 assert!(blk.dynamic);
274 }
275}