use std::sync::OnceLock;
const BLOCK_GROUP_SIZE: usize = 128;
static GLOBAL_PROFILE: OnceLock<PlatformProfile> = OnceLock::new();
static GLOBAL_THRESHOLDS: OnceLock<TunedThresholds> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct PlatformProfile {
pub logical_cores: usize,
pub physical_cores: usize,
pub cache_line_bytes: usize,
pub l1_cache_bytes: usize,
pub l2_cache_bytes: usize,
pub has_avx2: bool,
pub has_avx512: bool,
pub has_neon: bool,
}
#[derive(Debug, Clone)]
pub struct TunedThresholds {
pub par_gemv_min_rows: usize,
pub par_gemm_min_batch: usize,
pub tiled_gemm_block_m: usize,
pub tiled_gemm_block_n: usize,
pub tiled_gemm_block_k: usize,
}
impl PlatformProfile {
pub fn detect() -> Self {
let logical_cores = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1);
let physical_cores = Self::estimate_physical_cores(logical_cores);
let cache_line_bytes = 64;
let (l1_cache_bytes, l2_cache_bytes) = Self::estimate_cache_sizes();
let has_avx2 = Self::detect_avx2();
let has_avx512 = Self::detect_avx512();
let has_neon = Self::detect_neon();
Self {
logical_cores,
physical_cores,
cache_line_bytes,
l1_cache_bytes,
l2_cache_bytes,
has_avx2,
has_avx512,
has_neon,
}
}
pub fn global() -> &'static PlatformProfile {
GLOBAL_PROFILE.get_or_init(Self::detect)
}
pub fn compute_thresholds(&self) -> TunedThresholds {
let par_gemv_min_rows = self.compute_gemv_threshold();
let par_gemm_min_batch = self.compute_gemm_threshold();
let (block_m, block_n, block_k) = self.compute_tile_sizes();
TunedThresholds {
par_gemv_min_rows,
par_gemm_min_batch,
tiled_gemm_block_m: block_m,
tiled_gemm_block_n: block_n,
tiled_gemm_block_k: block_k,
}
}
pub fn global_thresholds() -> &'static TunedThresholds {
GLOBAL_THRESHOLDS.get_or_init(|| Self::global().compute_thresholds())
}
pub fn with_cores(logical: usize, physical: usize) -> Self {
Self {
logical_cores: logical.max(1),
physical_cores: physical.max(1),
cache_line_bytes: 64,
l1_cache_bytes: 32 * 1024,
l2_cache_bytes: 256 * 1024,
has_avx2: false,
has_avx512: false,
has_neon: false,
}
}
pub fn with_cache(l1_bytes: usize, l2_bytes: usize) -> Self {
Self {
logical_cores: 4,
physical_cores: 4,
cache_line_bytes: 64,
l1_cache_bytes: l1_bytes,
l2_cache_bytes: l2_bytes,
has_avx2: false,
has_avx512: false,
has_neon: false,
}
}
fn estimate_physical_cores(logical: usize) -> usize {
#[cfg(target_arch = "x86_64")]
{
(logical / 2).max(1)
}
#[cfg(target_arch = "aarch64")]
{
logical
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
logical
}
}
fn estimate_cache_sizes() -> (usize, usize) {
#[cfg(target_arch = "aarch64")]
{
(64 * 1024, 512 * 1024)
}
#[cfg(not(target_arch = "aarch64"))]
{
(32 * 1024, 256 * 1024)
}
}
fn detect_avx2() -> bool {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return true;
}
cfg!(target_feature = "avx2")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
fn detect_avx512() -> bool {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return true;
}
cfg!(target_feature = "avx512f")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
fn detect_neon() -> bool {
#[cfg(target_arch = "aarch64")]
{
true
}
#[cfg(not(target_arch = "aarch64"))]
{
false
}
}
fn compute_gemv_threshold(&self) -> usize {
let base = match self.physical_cores {
0..=2 => 256,
3..=4 => 128,
5..=8 => 64,
9..=15 => 48,
_ => 32,
};
let simd_factor = if self.has_avx512 {
3
} else if self.has_avx2 || self.has_neon {
2
} else {
1
};
let adjusted = base + (base * simd_factor) / 4;
round_up_to(adjusted, 8)
}
fn compute_gemm_threshold(&self) -> usize {
match self.physical_cores {
0..=2 => 16,
3..=4 => 8,
5..=8 => 4,
9..=15 => 3,
_ => 2,
}
}
fn compute_tile_sizes(&self) -> (usize, usize, usize) {
let l1 = self.l1_cache_bytes;
let sizeof_f32 = std::mem::size_of::<f32>();
let usable_l1 = (l1 * 3) / 4;
let raw_block = isqrt(usable_l1 / (3 * sizeof_f32));
let block_mn = round_down_to(raw_block, 8).max(8);
let block_k = if block_mn >= BLOCK_GROUP_SIZE {
round_down_to(block_mn, BLOCK_GROUP_SIZE)
} else {
BLOCK_GROUP_SIZE
};
(block_mn, block_mn, block_k)
}
}
impl TunedThresholds {
#[inline]
pub fn should_parallelize_gemv(&self, n_rows: usize) -> bool {
n_rows >= self.par_gemv_min_rows
}
#[inline]
pub fn should_parallelize_gemm(&self, batch_size: usize) -> bool {
batch_size >= self.par_gemm_min_batch
}
}
#[derive(Debug, Clone)]
pub struct TuningSummary {
pub profile: PlatformProfile,
pub thresholds: TunedThresholds,
}
impl TuningSummary {
pub fn current() -> Self {
Self {
profile: PlatformProfile::global().clone(),
thresholds: PlatformProfile::global_thresholds().clone(),
}
}
}
impl std::fmt::Display for TuningSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Platform Tuning Summary")?;
writeln!(
f,
" Cores: {} logical, {} physical",
self.profile.logical_cores, self.profile.physical_cores
)?;
writeln!(
f,
" Cache: L1={} KB, L2={} KB, line={} B",
self.profile.l1_cache_bytes / 1024,
self.profile.l2_cache_bytes / 1024,
self.profile.cache_line_bytes
)?;
writeln!(
f,
" SIMD: AVX2={}, AVX-512={}, NEON={}",
self.profile.has_avx2, self.profile.has_avx512, self.profile.has_neon
)?;
writeln!(f, " Thresholds:")?;
writeln!(
f,
" par_gemv_min_rows: {}",
self.thresholds.par_gemv_min_rows
)?;
writeln!(
f,
" par_gemm_min_batch: {}",
self.thresholds.par_gemm_min_batch
)?;
writeln!(
f,
" tiled_gemm_block: {}x{}x{}",
self.thresholds.tiled_gemm_block_m,
self.thresholds.tiled_gemm_block_n,
self.thresholds.tiled_gemm_block_k
)?;
Ok(())
}
}
fn isqrt(n: usize) -> usize {
if n == 0 {
return 0;
}
let mut x = n;
let mut y = x.div_ceil(2);
while y < x {
x = y;
y = (x + n / x) / 2;
}
x
}
fn round_up_to(v: usize, align: usize) -> usize {
debug_assert!(align > 0);
let rem = v % align;
if rem == 0 {
v
} else {
v + (align - rem)
}
}
fn round_down_to(v: usize, align: usize) -> usize {
debug_assert!(align > 0);
v - (v % align)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_returns_valid_profile() {
let profile = PlatformProfile::detect();
assert!(
profile.logical_cores > 0,
"must have at least 1 logical core"
);
assert!(
profile.physical_cores > 0,
"must have at least 1 physical core"
);
assert!(profile.physical_cores <= profile.logical_cores);
assert_eq!(profile.cache_line_bytes, 64);
assert!(
profile.l1_cache_bytes >= 16 * 1024,
"L1 should be at least 16KB"
);
assert!(
profile.l2_cache_bytes >= 64 * 1024,
"L2 should be at least 64KB"
);
}
#[test]
fn thresholds_are_reasonable() {
let profile = PlatformProfile::detect();
let thresholds = profile.compute_thresholds();
assert!(thresholds.par_gemv_min_rows >= 8);
assert!(thresholds.par_gemv_min_rows <= 1024);
assert!(thresholds.par_gemm_min_batch >= 1);
assert!(thresholds.par_gemm_min_batch <= 64);
}
#[test]
fn tile_sizes_multiple_of_8() {
let profile = PlatformProfile::detect();
let t = profile.compute_thresholds();
assert_eq!(t.tiled_gemm_block_m % 8, 0, "block_m must be multiple of 8");
assert_eq!(t.tiled_gemm_block_n % 8, 0, "block_n must be multiple of 8");
assert_eq!(
t.tiled_gemm_block_k % BLOCK_GROUP_SIZE,
0,
"block_k must be multiple of group size (128)"
);
}
#[test]
fn more_cores_lower_gemv_threshold() {
let p2 = PlatformProfile::with_cores(2, 2);
let p16 = PlatformProfile::with_cores(16, 16);
let t2 = p2.compute_thresholds();
let t16 = p16.compute_thresholds();
assert!(
t16.par_gemv_min_rows < t2.par_gemv_min_rows,
"16 cores ({}) should have lower threshold than 2 cores ({})",
t16.par_gemv_min_rows,
t2.par_gemv_min_rows
);
}
#[test]
fn more_cores_lower_gemm_threshold() {
let p2 = PlatformProfile::with_cores(2, 2);
let p16 = PlatformProfile::with_cores(16, 16);
let t2 = p2.compute_thresholds();
let t16 = p16.compute_thresholds();
assert!(
t16.par_gemm_min_batch < t2.par_gemm_min_batch,
"16 cores ({}) should have lower threshold than 2 cores ({})",
t16.par_gemm_min_batch,
t2.par_gemm_min_batch
);
}
#[test]
fn isqrt_correctness() {
assert_eq!(isqrt(0), 0);
assert_eq!(isqrt(1), 1);
assert_eq!(isqrt(4), 2);
assert_eq!(isqrt(9), 3);
assert_eq!(isqrt(10), 3); assert_eq!(isqrt(100), 10);
assert_eq!(isqrt(8192), 90); }
#[test]
fn round_helpers() {
assert_eq!(round_up_to(0, 8), 0);
assert_eq!(round_up_to(1, 8), 8);
assert_eq!(round_up_to(8, 8), 8);
assert_eq!(round_up_to(9, 8), 16);
assert_eq!(round_down_to(0, 8), 0);
assert_eq!(round_down_to(7, 8), 0);
assert_eq!(round_down_to(8, 8), 8);
assert_eq!(round_down_to(15, 8), 8);
}
#[test]
fn global_profile_consistent() {
let p1 = PlatformProfile::global();
let p2 = PlatformProfile::global();
assert_eq!(p1.logical_cores, p2.logical_cores);
assert_eq!(p1.physical_cores, p2.physical_cores);
}
#[test]
fn tuning_summary_display() {
let summary = TuningSummary::current();
let text = format!("{summary}");
assert!(text.contains("Platform Tuning Summary"));
assert!(text.contains("Cores:"));
assert!(text.contains("Cache:"));
assert!(text.contains("SIMD:"));
}
#[test]
fn should_parallelize_decisions() {
let p = PlatformProfile::with_cores(8, 8);
let t = p.compute_thresholds();
assert!(!t.should_parallelize_gemv(1));
assert!(t.should_parallelize_gemv(t.par_gemv_min_rows));
assert!(t.should_parallelize_gemv(t.par_gemv_min_rows + 1));
}
#[test]
fn with_cache_custom_sizes() {
let p = PlatformProfile::with_cache(64 * 1024, 1024 * 1024);
assert_eq!(p.l1_cache_bytes, 64 * 1024);
assert_eq!(p.l2_cache_bytes, 1024 * 1024);
let t = p.compute_thresholds();
let p_small = PlatformProfile::with_cache(16 * 1024, 128 * 1024);
let t_small = p_small.compute_thresholds();
assert!(
t.tiled_gemm_block_m >= t_small.tiled_gemm_block_m,
"larger L1 ({}) should give >= block_m than smaller L1 ({})",
t.tiled_gemm_block_m,
t_small.tiled_gemm_block_m
);
}
}