#[derive(Debug, Clone, Copy)]
pub struct MatMulConfig {
pub mc: usize,
pub kc: usize,
pub nc: usize,
pub mr: usize,
pub nr: usize,
}
impl MatMulConfig {
pub const fn for_f32() -> Self {
#[cfg(target_arch = "x86_64")]
{
if cfg!(target_feature = "avx512f") {
Self {
mc: 384, kc: 256, nc: 4096, mr: 8, nr: 16, }
} else if cfg!(target_feature = "avx2") {
Self {
mc: 384,
kc: 256,
nc: 4096,
mr: 8, nr: 8, }
} else {
Self {
mc: 256,
kc: 128,
nc: 2048,
mr: 4,
nr: 4,
}
}
}
#[cfg(target_arch = "aarch64")]
{
Self {
mc: 384,
kc: 256,
nc: 4096,
mr: 8, nr: 4, }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
Self {
mc: 128,
kc: 128,
nc: 1024,
mr: 4,
nr: 4,
}
}
}
pub const fn for_f64() -> Self {
#[cfg(target_arch = "x86_64")]
{
if cfg!(target_feature = "avx512f") {
Self {
mc: 192, kc: 128, nc: 2048, mr: 8,
nr: 8,
}
} else if cfg!(target_feature = "avx2") {
Self {
mc: 192,
kc: 128,
nc: 2048,
mr: 8,
nr: 4,
}
} else {
Self {
mc: 128,
kc: 64,
nc: 1024,
mr: 4,
nr: 2,
}
}
}
#[cfg(target_arch = "aarch64")]
{
Self {
mc: 192,
kc: 128,
nc: 2048,
mr: 8,
nr: 2,
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
Self {
mc: 64,
kc: 64,
nc: 512,
mr: 4,
nr: 2,
}
}
}
#[inline]
pub fn auto<T>() -> Self
where
T: 'static,
{
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f32>() {
Self::for_f32()
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
Self::for_f64()
} else {
Self::for_f64()
}
}
}
impl Default for MatMulConfig {
fn default() -> Self {
Self::for_f32()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_f32_config() {
let config = MatMulConfig::for_f32();
assert!(config.mc >= 64 && config.mc <= 512);
assert!(config.kc >= 64 && config.kc <= 512);
assert!(config.nc >= 512 && config.nc <= 8192);
assert!(config.mr >= 4 && config.mr <= 16);
assert!(config.nr >= 4 && config.nr <= 32);
assert_eq!(config.mr % 4, 0);
assert_eq!(config.nr % 4, 0);
}
#[test]
fn test_f64_config() {
let config = MatMulConfig::for_f64();
assert!(config.mc >= 64 && config.mc <= 512);
assert!(config.kc >= 64 && config.kc <= 256);
assert!(config.nc >= 512 && config.nc <= 4096);
assert!(config.mr >= 4 && config.mr <= 16);
assert!(config.nr >= 2 && config.nr <= 16);
}
#[test]
fn test_auto_config() {
let config_f32 = MatMulConfig::auto::<f32>();
let expected_f32 = MatMulConfig::for_f32();
assert_eq!(config_f32.mc, expected_f32.mc);
assert_eq!(config_f32.kc, expected_f32.kc);
assert_eq!(config_f32.nc, expected_f32.nc);
assert_eq!(config_f32.mr, expected_f32.mr);
assert_eq!(config_f32.nr, expected_f32.nr);
let config_f64 = MatMulConfig::auto::<f64>();
let expected_f64 = MatMulConfig::for_f64();
assert_eq!(config_f64.mc, expected_f64.mc);
assert_eq!(config_f64.kc, expected_f64.kc);
assert_eq!(config_f64.nc, expected_f64.nc);
assert_eq!(config_f64.mr, expected_f64.mr);
assert_eq!(config_f64.nr, expected_f64.nr);
}
}