use crate::simd::{SimdLevel, detect_simd_level};
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
pub const DEFAULT_BLOCK_M: usize = 64;
pub const DEFAULT_BLOCK_N: usize = 64;
pub const DEFAULT_BLOCK_K: usize = 256;
pub const DEFAULT_L2_BLOCK: usize = 256;
pub const L1_CACHE_SIZE: usize = 32 * 1024;
pub const L2_CACHE_SIZE: usize = 256 * 1024;
pub const L3_CACHE_SIZE: usize = 8 * 1024 * 1024;
#[derive(Debug, Clone, Copy)]
pub struct TuningConfig {
pub block_m: usize,
pub block_n: usize,
pub block_k: usize,
pub l2_block: usize,
pub simd_level: SimdLevel,
pub parallel: bool,
pub par_threshold: usize,
}
impl Default for TuningConfig {
fn default() -> Self {
Self {
block_m: DEFAULT_BLOCK_M,
block_n: DEFAULT_BLOCK_N,
block_k: DEFAULT_BLOCK_K,
l2_block: DEFAULT_L2_BLOCK,
simd_level: detect_simd_level(),
parallel: false,
par_threshold: 64 * 64,
}
}
}
impl TuningConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn for_dimensions(m: usize, n: usize, _k: usize) -> Self {
let mut config = Self::new();
let simd_level = detect_simd_level();
match simd_level {
SimdLevel::Scalar => {
config.block_m = 32;
config.block_n = 32;
config.block_k = 128;
}
SimdLevel::Simd128 => {
config.block_m = 64;
config.block_n = 64;
config.block_k = 256;
}
SimdLevel::Simd256 => {
config.block_m = 96;
config.block_n = 96;
config.block_k = 384;
}
SimdLevel::Simd512 => {
config.block_m = 128;
config.block_n = 128;
config.block_k = 512;
}
}
if m < 128 || n < 128 {
config.block_m = config.block_m.min(m);
config.block_n = config.block_n.min(n);
}
let element_size = 8; let panel_size = config.block_m * config.block_k * element_size;
if panel_size > L2_CACHE_SIZE / 2 {
config.block_k = (L2_CACHE_SIZE / 2) / (config.block_m * element_size);
}
config.parallel = m * n >= config.par_threshold;
config
}
#[must_use]
pub fn gemv_block_size(&self) -> usize {
match self.simd_level {
SimdLevel::Scalar => 128,
SimdLevel::Simd128 => 256,
SimdLevel::Simd256 => 512,
SimdLevel::Simd512 => 1024,
}
}
#[must_use]
pub fn factorization_panel_width(&self) -> usize {
match self.simd_level {
SimdLevel::Scalar => 16,
SimdLevel::Simd128 => 32,
SimdLevel::Simd256 => 48,
SimdLevel::Simd512 => 64,
}
}
}
pub struct TuningCache {
initialized: AtomicBool,
block_m: AtomicUsize,
block_n: AtomicUsize,
block_k: AtomicUsize,
}
static TUNING_CACHE: TuningCache = TuningCache {
initialized: AtomicBool::new(false),
block_m: AtomicUsize::new(DEFAULT_BLOCK_M),
block_n: AtomicUsize::new(DEFAULT_BLOCK_N),
block_k: AtomicUsize::new(DEFAULT_BLOCK_K),
};
impl TuningCache {
pub fn get() -> TuningConfig {
if !TUNING_CACHE.initialized.load(Ordering::Relaxed) {
let config = TuningConfig::new();
TUNING_CACHE
.block_m
.store(config.block_m, Ordering::Relaxed);
TUNING_CACHE
.block_n
.store(config.block_n, Ordering::Relaxed);
TUNING_CACHE
.block_k
.store(config.block_k, Ordering::Relaxed);
TUNING_CACHE.initialized.store(true, Ordering::Relaxed);
}
TuningConfig {
block_m: TUNING_CACHE.block_m.load(Ordering::Relaxed),
block_n: TUNING_CACHE.block_n.load(Ordering::Relaxed),
block_k: TUNING_CACHE.block_k.load(Ordering::Relaxed),
..TuningConfig::new()
}
}
pub fn set(config: &TuningConfig) {
TUNING_CACHE
.block_m
.store(config.block_m, Ordering::Relaxed);
TUNING_CACHE
.block_n
.store(config.block_n, Ordering::Relaxed);
TUNING_CACHE
.block_k
.store(config.block_k, Ordering::Relaxed);
TUNING_CACHE.initialized.store(true, Ordering::Relaxed);
}
}
pub struct AutoTuner {
config: TuningConfig,
}
impl Default for AutoTuner {
fn default() -> Self {
Self::new()
}
}
impl AutoTuner {
#[must_use]
pub fn new() -> Self {
Self {
config: TuningCache::get(),
}
}
pub fn tune_gemm(&mut self, m: usize, n: usize, k: usize) -> &TuningConfig {
self.config = TuningConfig::for_dimensions(m, n, k);
TuningCache::set(&self.config);
&self.config
}
#[must_use]
pub const fn config(&self) -> &TuningConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = TuningConfig::default();
assert_eq!(config.block_m, DEFAULT_BLOCK_M);
assert_eq!(config.block_n, DEFAULT_BLOCK_N);
assert_eq!(config.block_k, DEFAULT_BLOCK_K);
}
#[test]
fn test_dimension_based_tuning() {
let config_small = TuningConfig::for_dimensions(32, 32, 32);
let config_large = TuningConfig::for_dimensions(1024, 1024, 1024);
assert!(config_small.block_m <= 32);
assert!(config_small.block_n <= 32);
assert!(config_large.parallel);
}
#[test]
fn test_tuning_cache() {
let config = TuningConfig {
block_m: 77,
block_n: 88,
block_k: 99,
..TuningConfig::default()
};
TuningCache::set(&config);
let cached = TuningCache::get();
assert_eq!(cached.block_m, 77);
assert_eq!(cached.block_n, 88);
assert_eq!(cached.block_k, 99);
}
#[test]
fn test_auto_tuner() {
let mut tuner = AutoTuner::new();
let config = tuner.tune_gemm(512, 512, 512);
assert!(config.block_m > 0);
assert!(config.block_n > 0);
assert!(config.block_k > 0);
assert!(config.block_m <= 512);
assert!(config.block_n <= 512);
}
#[test]
fn test_gemv_block_size() {
let config = TuningConfig::default();
let block_size = config.gemv_block_size();
assert!(block_size >= 128);
assert!(block_size <= 1024);
}
#[test]
fn test_factorization_panel_width() {
let config = TuningConfig::default();
let panel_width = config.factorization_panel_width();
assert!(panel_width >= 16);
assert!(panel_width <= 128);
}
}