trueno/blis/mod.rs
1//! BLIS-Style Matrix Multiplication
2//!
3//! High-performance GEMM implementation based on the BLIS framework.
4//!
5//! # References
6//!
7//! - Goto, K., & Van de Geijn, R. A. (2008). Anatomy of High-Performance Matrix Multiplication.
8//! ACM TOMS, 34(3). <https://doi.org/10.1145/1356052.1356053>
9//! - Van Zee, F. G., & Van de Geijn, R. A. (2015). BLIS: A Framework for Rapidly Instantiating
10//! BLAS Functionality. ACM TOMS, 41(3). <https://doi.org/10.1145/2764454>
11//! - Low, T. M., et al. (2016). Analytical Modeling Is Enough for High-Performance BLIS.
12//! ACM TOMS, 43(2). <https://doi.org/10.1145/2925987>
13//!
14//! # Toyota Production System Integration
15//!
16//! - **Jidoka**: Runtime guards that stop on numerical errors (see [`jidoka`] module)
17//! - **Poka-Yoke**: Compile-time type safety for panel dimensions
18//! - **Heijunka**: Load-balanced parallel execution
19//! - **Kaizen**: Performance tracking for continuous improvement (see [`profiler`] module)
20//!
21//! # Module Structure
22//!
23//! - [`jidoka`]: Runtime validation guards (stop-on-defect)
24//! - [`profiler`]: Performance tracking at all BLIS hierarchy levels
25//! - [`microkernels`]: High-performance SIMD compute kernels
26//! - [`backend_selection`]: Automatic CPU/GPU backend selection
27//! - [`reference`]: Scalar reference GEMM for validation
28//! - [`packing`]: Cache-optimized matrix packing routines
29//! - [`compute`]: Core BLIS blocked GEMM computation
30//! - [`parallel`]: Parallel GEMM with Heijunka scheduling
31//! - [`transpose`]: Matrix transpose operations
32
33pub mod attention;
34pub mod backend_selection;
35pub mod cache_topology;
36pub mod compute;
37pub mod elementwise;
38pub mod gemv;
39pub mod jidoka;
40pub mod microkernels;
41pub mod norms;
42pub mod packing;
43pub mod parallel;
44pub mod prepacked;
45pub mod profiler;
46pub mod reference;
47pub mod softmax;
48pub mod transpose;
49
50// Re-export jidoka types for backwards compatibility
51pub use jidoka::{JidokaError, JidokaGuard};
52
53// Re-export profiler types for backwards compatibility
54pub use profiler::{BlisLevelStats, BlisProfileLevel, BlisProfiler, KaizenMetrics};
55
56// Re-export microkernel functions
57#[cfg(target_arch = "aarch64")]
58pub use microkernels::microkernel_8x8_neon;
59pub use microkernels::microkernel_scalar;
60#[cfg(target_arch = "x86_64")]
61pub use microkernels::{microkernel_8x6_avx2, microkernel_8x6_avx2_asm, microkernel_8x6_true_asm};
62
63// Re-export backend selection types
64pub use backend_selection::{
65 gemm_auto, BackendCostModel, BrickLevel, ComputeBackend, PtxMicrokernelSpec, RooflineResult,
66 UnifiedBrickProfiler, WgslMicrokernelSpec,
67};
68
69// Re-export reference GEMM
70pub use reference::{gemm_reference, gemm_reference_with_jidoka};
71
72// Re-export packing functions
73pub use packing::{pack_a, pack_b, packed_a_size, packed_b_size};
74
75// Re-export compute
76#[cfg(target_arch = "x86_64")]
77pub use compute::gemm_blis_broadcast_b;
78pub use compute::{gemm_blis, gemm_blis_with_prepacked_b};
79
80// Re-export parallel
81#[cfg(feature = "parallel")]
82pub use parallel::gemm_blis_parallel_shared_b;
83pub use parallel::{gemm_blis_parallel, gemm_blis_parallel_with_prepacked_b, HeijunkaScheduler};
84
85// Re-export prepacked
86pub use prepacked::PrepackedB;
87
88// Re-export transpose
89pub use transpose::transpose;
90
91use crate::error::TruenoError;
92
93// ============================================================================
94// BLIS Configuration Constants
95// ============================================================================
96
97/// Microkernel row dimension (AVX2: 8 f32 per ymm register)
98pub const MR: usize = 8;
99
100/// Microkernel column dimension (6 columns fit in remaining registers)
101pub const NR: usize = 6;
102
103/// K-dimension blocking for L1 cache (256 elements = 1KB)
104pub const KC: usize = 256;
105
106/// M-dimension blocking for L2 cache.
107/// Must be a multiple of MR. 128 = 16×MR for AVX2 (vs old 72 = 9×MR).
108/// Larger MC reduces packing overhead per macroblock (fewer ic-loop iterations).
109/// Zen 4 L2 = 1MB per core; MC×KC×4B = 128×256×4 = 128KB << 1MB.
110pub const MC: usize = 128;
111
112/// N-dimension blocking for L3 cache
113pub const NC: usize = 4096;
114
115// ============================================================================
116// AVX-512 BLIS Configuration Constants
117// ============================================================================
118
119/// AVX-512 microkernel row dimension (16 f32 per zmm register)
120pub const MR_512: usize = 16;
121
122/// AVX-512 microkernel column dimension (8 columns in remaining zmm registers)
123pub const NR_512: usize = 8;
124
125/// AVX-512 K-dimension blocking (same as AVX2, L1 limited)
126pub const KC_512: usize = 256;
127
128/// AVX-512 M-dimension blocking for L2 cache.
129/// 128 = 8×MR_512. Zen 4 L2 = 1MB; 128×256×4 = 128KB.
130pub const MC_512: usize = 128;
131
132/// AVX-512 N-dimension blocking for L3 cache
133pub const NC_512: usize = 4096;
134
135// ============================================================================
136// AVX-512 32×6 Microkernel Constants (Phase 4, Appendix D optimization #1)
137// ============================================================================
138
139/// 32×6 microkernel: 2 zmm rows × 6 columns = 12 accumulators.
140/// 1.5× more FMAs per K step than 16×8 (12 vs 8).
141pub const MR_512V2: usize = 32;
142
143/// 6 columns: balances register pressure (12 acc + 2 A load = 14 zmm).
144pub const NR_512V2: usize = 6;
145
146/// Increased KC for 32×6: 32×256×4 = 32 KB fits L1 (32 KB on Zen 4).
147pub const KC_512V2: usize = 256;
148
149/// MC for 32×6: 192 = 6×MR_512V2. Packed A = 192×256×4 = 192 KB fits L2.
150pub const MC_512V2: usize = 192;
151
152/// NC for 32×6: same L3 blocking.
153pub const NC_512V2: usize = 4096;
154
155// ============================================================================
156// Public API
157// ============================================================================
158
159/// High-performance GEMM using BLIS algorithm
160///
161/// Computes C += A * B where:
162/// - A is M x K (row-major)
163/// - B is K x N (row-major)
164/// - C is M x N (row-major)
165///
166/// Automatically selects single-threaded or parallel execution based on matrix size.
167pub fn gemm(
168 m: usize,
169 n: usize,
170 k: usize,
171 a: &[f32],
172 b: &[f32],
173 c: &mut [f32],
174) -> Result<(), TruenoError> {
175 // Contract: matmul-kernel-v1.yaml precondition (pv codegen)
176 contract_pre_matmul!(a);
177
178 let result = {
179 #[cfg(feature = "parallel")]
180 {
181 gemm_blis_parallel(m, n, k, a, b, c)
182 }
183 #[cfg(not(feature = "parallel"))]
184 {
185 gemm_blis(m, n, k, a, b, c, None)
186 }
187 };
188 if result.is_ok() {
189 contract_post_matmul!(c);
190 }
191 result
192}
193
194/// GEMM with profiling enabled
195pub fn gemm_profiled(
196 m: usize,
197 n: usize,
198 k: usize,
199 a: &[f32],
200 b: &[f32],
201 c: &mut [f32],
202 profiler: &mut BlisProfiler,
203) -> Result<(), TruenoError> {
204 gemm_blis(m, n, k, a, b, c, Some(profiler))
205}
206
207/// Fused GEMM + bias + ReLU: C = max(0, A×B + bias)
208///
209/// Performs matmul then applies bias addition and ReLU activation in a single
210/// pass over C while the output tiles are still in L1/L2 cache. This avoids
211/// two extra full-matrix memory passes that separate add+relu would require.
212///
213/// For GEMM 64: saves ~2µs (bias+relu would cost 2×0.8µs on cold data).
214/// For GEMM 128: saves ~5µs.
215///
216/// # Arguments
217///
218/// * `bias` - Per-column bias vector of length `n` (broadcast across rows)
219///
220/// # Errors
221///
222/// Returns `Err` if dimensions don't match or bias length != n.
223pub fn gemm_bias_relu(
224 m: usize,
225 n: usize,
226 k: usize,
227 a: &[f32],
228 b: &[f32],
229 bias: &[f32],
230 c: &mut [f32],
231) -> Result<(), TruenoError> {
232 if bias.len() != n {
233 return Err(TruenoError::InvalidInput(format!(
234 "gemm_bias_relu: bias.len()={} != n={}",
235 bias.len(),
236 n
237 )));
238 }
239 // Step 1: GEMM (C = A×B)
240 gemm(m, n, k, a, b, c)?;
241
242 // Step 2: Fused bias + ReLU in-place on hot cache data.
243 // C is still in L1/L2 from the GEMM store — no DRAM reads needed.
244 for row in 0..m {
245 let row_start = row * n;
246 for col in 0..n {
247 let val = c[row_start + col] + bias[col];
248 c[row_start + col] = val.max(0.0);
249 }
250 }
251 Ok(())
252}
253
254#[cfg(test)]
255mod tests;