trueno/blis/microkernels/mod.rs
1//! BLIS Microkernels - High-Performance SIMD Compute Kernels
2//!
3//! This module contains the microkernel implementations for different architectures:
4//! - Scalar reference (correctness validation)
5//! - AVX2 intrinsics
6//! - AVX2 hand-tuned ASM with software pipelining
7//! - ARM NEON
8//!
9//! # Performance Targets
10//!
11//! - 70%+ FMA utilization on Haswell+ CPUs
12//! - 4-way K unrolling for software pipelining
13//! - 10-12 instruction latency hiding
14//!
15//! # References
16//!
17//! - Goto, K., & Van de Geijn, R. A. (2008). Anatomy of High-Performance Matrix Multiplication.
18//! - Agner Fog (2024). Optimizing subroutines in assembly language, Section 12.7.
19//! - Intel(R) 64 and IA-32 Architectures Optimization Reference Manual.
20
21#[cfg(target_arch = "x86_64")]
22mod avx2;
23#[cfg(target_arch = "x86_64")]
24mod avx512;
25#[cfg(target_arch = "x86_64")]
26pub mod codegen;
27mod neon;
28
29// Re-export all public microkernel functions
30#[cfg(target_arch = "x86_64")]
31pub use avx2::{
32 microkernel_8x6_avx2, microkernel_8x6_avx2_asm, microkernel_8x6_true_asm,
33 microkernel_8x8_avx2_fma,
34};
35#[cfg(target_arch = "x86_64")]
36pub use avx512::{microkernel_16x8_avx512, microkernel_32x6_avx512};
37#[cfg(target_arch = "aarch64")]
38pub use neon::microkernel_8x8_neon;
39
40use super::{MR, NR};
41
42/// Scalar microkernel for correctness validation
43///
44/// Computes C[MR x NR] += A[MR x K] * B[K x NR]
45/// where A is packed column-major and B is packed row-major.
46///
47/// This serves as the reference for validating SIMD microkernels.
48#[inline(never)]
49pub fn microkernel_scalar(
50 k: usize,
51 a: &[f32], // MR x K, column-major (MR stride)
52 b: &[f32], // K x NR, row-major (NR stride)
53 c: &mut [f32], // MR x NR, column-major
54 ldc: usize, // Leading dimension of C
55) {
56 // Accumulate MR x NR output tile
57 for p in 0..k {
58 for jr in 0..NR {
59 let b_val = b[p * NR + jr];
60 for ir in 0..MR {
61 let a_val = a[p * MR + ir];
62 c[jr * ldc + ir] += a_val * b_val;
63 }
64 }
65 }
66}