use std::sync::OnceLock;
#[derive(Debug, Clone, Copy)]
pub struct CacheTopology {
pub l1d_bytes: usize,
pub l2_bytes: usize,
pub l3_bytes: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct BlisBlocking {
pub mr: usize,
pub nr: usize,
pub mc: usize,
pub kc: usize,
pub nc: usize,
pub dynamic: bool,
}
const DEFAULT_BLOCKING_8X16: BlisBlocking =
BlisBlocking { mr: 8, nr: 16, mc: 64, kc: 256, nc: 1024, dynamic: false };
static TOPOLOGY: OnceLock<CacheTopology> = OnceLock::new();
static BLOCKING_8X32: OnceLock<BlisBlocking> = OnceLock::new();
fn detect_from_sys() -> Option<CacheTopology> {
let mut l1d = 0usize;
let mut l2 = 0usize;
let mut l3 = 0usize;
for idx in 0..4 {
let base = format!("/sys/devices/system/cpu/cpu0/cache/index{idx}");
let size_str = std::fs::read_to_string(format!("{base}/size")).ok()?;
let type_str = std::fs::read_to_string(format!("{base}/type")).ok()?;
let size_str = size_str.trim();
let type_str = type_str.trim();
let size_bytes = parse_cache_size(size_str)?;
match (idx, type_str) {
(0, "Data") => l1d = size_bytes,
(1, _) => {} (2, "Unified") => l2 = size_bytes,
(3, "Unified") => l3 = size_bytes,
_ => {}
}
}
if l1d > 0 && l2 > 0 {
Some(CacheTopology { l1d_bytes: l1d, l2_bytes: l2, l3_bytes: l3 })
} else {
None
}
}
fn parse_cache_size(s: &str) -> Option<usize> {
let s = s.trim();
if let Some(kb) = s.strip_suffix('K') {
kb.parse::<usize>().ok().map(|v| v * 1024)
} else if let Some(mb) = s.strip_suffix('M') {
mb.parse::<usize>().ok().map(|v| v * 1024 * 1024)
} else {
s.parse::<usize>().ok()
}
}
pub fn topology() -> CacheTopology {
*TOPOLOGY.get_or_init(|| {
detect_from_sys().unwrap_or(CacheTopology {
l1d_bytes: 32768, l2_bytes: 1048576, l3_bytes: 33554432, })
})
}
fn compute_blocking_8x32(topo: &CacheTopology) -> BlisBlocking {
let mr = 8usize;
let nr = 32usize;
let kc_max = topo.l1d_bytes / (nr * 4);
let kc = kc_max.next_power_of_two().min(kc_max).max(64);
let mc_max = topo.l2_bytes / (kc * 4);
let mc = (mc_max / mr * mr).min(12 * mr).max(mr);
let nc_max = topo.l3_bytes / (2 * kc * 4);
let nc = (nc_max / nr * nr).min(4096).max(nr);
BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
}
pub fn blocking_8x32() -> BlisBlocking {
*BLOCKING_8X32.get_or_init(|| {
let topo = topology();
compute_blocking_8x32(&topo)
})
}
pub fn blocking_8x48() -> BlisBlocking {
static BLOCKING_8X48: OnceLock<BlisBlocking> = OnceLock::new();
*BLOCKING_8X48.get_or_init(|| {
let topo = topology();
let mr = 8usize;
let nr = 48usize;
let kc_max = topo.l1d_bytes / (nr * 4);
let kc = (kc_max.next_power_of_two() >> 1).max(64);
let mc_max = topo.l2_bytes / (kc * 4);
let mc = (mc_max / mr * mr).min(12 * mr).max(mr);
let nc_max = topo.l3_bytes / (2 * kc * 4);
let nc = (nc_max / nr * nr).min(4096).max(nr);
BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
})
}
pub fn blocking_64x6_bcast_b() -> BlisBlocking {
static BLOCKING_64X6: OnceLock<BlisBlocking> = OnceLock::new();
*BLOCKING_64X6.get_or_init(|| {
let topo = topology();
let mr = 64usize;
let nr = 6usize;
let kc_max_l2 = topo.l2_bytes / (mr * 4);
let kc_max_l1 = topo.l1d_bytes * 3 / 4 / (nr * 4); let kc = kc_max_l2.min(kc_max_l1).clamp(64, 512);
let mc_max = topo.l2_bytes / (kc * 4);
let mc = (mc_max / mr * mr).min(4 * mr).max(mr);
let nc_max = topo.l3_bytes / (2 * kc * 4);
let nc = (nc_max / nr * nr).min(4096).max(nr);
BlisBlocking { mr, nr, mc, kc, nc, dynamic: true }
})
}
pub fn blocking_8x16() -> BlisBlocking {
DEFAULT_BLOCKING_8X16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_topology() {
let topo = topology();
assert!(topo.l1d_bytes > 0, "L1D must be > 0");
assert!(topo.l2_bytes > 0, "L2 must be > 0");
assert!(
topo.l1d_bytes <= topo.l2_bytes,
"L1D ({}) must be <= L2 ({})",
topo.l1d_bytes,
topo.l2_bytes
);
}
#[test]
fn test_mc_kc_fits_l2() {
let topo = topology();
let blk = compute_blocking_8x32(&topo);
let packed_a_bytes = blk.mc * blk.kc * 4;
assert!(
packed_a_bytes <= topo.l2_bytes,
"C-CACHE-001: packed A = {} bytes > L2 = {} bytes",
packed_a_bytes,
topo.l2_bytes
);
}
#[test]
fn test_kc_nr_fits_l1() {
let topo = topology();
let blk = compute_blocking_8x32(&topo);
let b_panel_bytes = blk.kc * blk.nr * 4;
assert!(
b_panel_bytes <= topo.l1d_bytes,
"C-CACHE-002: B panel = {} bytes > L1D = {} bytes",
b_panel_bytes,
topo.l1d_bytes
);
}
#[test]
fn test_mc_multiple_of_mr() {
let blk = blocking_8x32();
assert_eq!(blk.mc % blk.mr, 0, "C-CACHE-004: MC={} not multiple of MR={}", blk.mc, blk.mr);
}
#[test]
fn test_nc_multiple_of_nr() {
let blk = blocking_8x32();
assert_eq!(blk.nc % blk.nr, 0, "C-CACHE-005: NC={} not multiple of NR={}", blk.nc, blk.nr);
}
#[test]
fn test_parse_cache_size() {
assert_eq!(parse_cache_size("32K"), Some(32768));
assert_eq!(parse_cache_size("1024K"), Some(1048576));
assert_eq!(parse_cache_size("32768K"), Some(33554432));
assert_eq!(parse_cache_size("2M"), Some(2097152));
}
#[test]
fn test_zen4_blocking() {
let topo = CacheTopology { l1d_bytes: 32768, l2_bytes: 1048576, l3_bytes: 33554432 };
let blk = compute_blocking_8x32(&topo);
assert_eq!(blk.kc, 256, "KC for Zen 4");
assert!(blk.mc >= 64 && blk.mc <= 256, "MC={} for Zen 4", blk.mc);
assert_eq!(blk.mc % 8, 0, "MC must be multiple of MR=8");
assert_eq!(blk.nc, 4096, "NC for Zen 4");
assert!(blk.dynamic);
}
}