#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CpuArch {
AppleSilicon, AarchGeneric, X86_64,
Other,
}
impl CpuArch {
pub const fn current() -> Self {
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
{
Self::AppleSilicon
}
#[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
{
Self::AarchGeneric
}
#[cfg(target_arch = "x86_64")]
{
Self::X86_64
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
Self::Other
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OpClass {
Matmul,
Attention,
Elementwise,
Shape,
}
#[derive(Debug, Clone, Copy)]
pub struct KernelConfig {
pub neon_seq_threshold: usize,
pub par_grain: usize,
pub par_threshold: usize,
pub fuse_attn_threshold: usize,
}
const APPLE_SILICON: KernelConfig = KernelConfig {
neon_seq_threshold: 32,
par_grain: 64,
par_threshold: 30_000,
fuse_attn_threshold: 64,
};
const AARCH_GENERIC: KernelConfig = KernelConfig {
neon_seq_threshold: 24,
par_grain: 32,
par_threshold: 20_000,
fuse_attn_threshold: 48,
};
const X86_DEFAULT: KernelConfig = KernelConfig {
neon_seq_threshold: 16, par_grain: 32,
par_threshold: 20_000,
fuse_attn_threshold: 32,
};
const FALLBACK: KernelConfig = KernelConfig {
neon_seq_threshold: 16,
par_grain: 16,
par_threshold: 10_000,
fuse_attn_threshold: 16,
};
pub const fn kernel_config_for(arch: CpuArch, op: OpClass) -> KernelConfig {
let _ = op;
match arch {
CpuArch::AppleSilicon => APPLE_SILICON,
CpuArch::AarchGeneric => AARCH_GENERIC,
CpuArch::X86_64 => X86_DEFAULT,
CpuArch::Other => FALLBACK,
}
}
pub const fn current_config(op: OpClass) -> KernelConfig {
kernel_config_for(CpuArch::current(), op)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn current_resolves() {
let cfg = current_config(OpClass::Matmul);
assert!(cfg.neon_seq_threshold > 0);
assert!(cfg.par_threshold > 0);
}
#[test]
fn apple_silicon_picks_higher_thresholds() {
let m = kernel_config_for(CpuArch::AppleSilicon, OpClass::Matmul);
let f = kernel_config_for(CpuArch::Other, OpClass::Matmul);
assert!(m.neon_seq_threshold >= f.neon_seq_threshold);
assert!(m.par_threshold >= f.par_threshold);
}
}