1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
//! BLIS-Style Matrix Multiplication
//!
//! High-performance GEMM implementation based on the BLIS framework.
//!
//! # References
//!
//! - Goto, K., & Van de Geijn, R. A. (2008). Anatomy of High-Performance Matrix Multiplication.
//! ACM TOMS, 34(3). <https://doi.org/10.1145/1356052.1356053>
//! - Van Zee, F. G., & Van de Geijn, R. A. (2015). BLIS: A Framework for Rapidly Instantiating
//! BLAS Functionality. ACM TOMS, 41(3). <https://doi.org/10.1145/2764454>
//! - Low, T. M., et al. (2016). Analytical Modeling Is Enough for High-Performance BLIS.
//! ACM TOMS, 43(2). <https://doi.org/10.1145/2925987>
//!
//! # Toyota Production System Integration
//!
//! - **Jidoka**: Runtime guards that stop on numerical errors (see [`jidoka`] module)
//! - **Poka-Yoke**: Compile-time type safety for panel dimensions
//! - **Heijunka**: Load-balanced parallel execution
//! - **Kaizen**: Performance tracking for continuous improvement (see [`profiler`] module)
//!
//! # Module Structure
//!
//! - [`jidoka`]: Runtime validation guards (stop-on-defect)
//! - [`profiler`]: Performance tracking at all BLIS hierarchy levels
//! - [`microkernels`]: High-performance SIMD compute kernels
//! - [`backend_selection`]: Automatic CPU/GPU backend selection
//! - [`reference`]: Scalar reference GEMM for validation
//! - [`packing`]: Cache-optimized matrix packing routines
//! - [`compute`]: Core BLIS blocked GEMM computation
//! - [`parallel`]: Parallel GEMM with Heijunka scheduling
//! - [`transpose`]: Matrix transpose operations
// Re-export jidoka types for backwards compatibility
pub use ;
// Re-export profiler types for backwards compatibility
pub use ;
// Re-export microkernel functions
pub use microkernel_8x8_neon;
pub use microkernel_scalar;
pub use ;
// Re-export backend selection types
pub use ;
// Re-export reference GEMM
pub use ;
// Re-export packing functions
pub use ;
// Re-export compute
pub use gemm_blis_broadcast_b;
pub use ;
// Re-export parallel
pub use gemm_blis_parallel_shared_b;
pub use ;
// Re-export prepacked
pub use PrepackedB;
// Re-export transpose
pub use transpose;
use crateTruenoError;
// ============================================================================
// BLIS Configuration Constants
// ============================================================================
/// Microkernel row dimension (AVX2: 8 f32 per ymm register)
pub const MR: usize = 8;
/// Microkernel column dimension (6 columns fit in remaining registers)
pub const NR: usize = 6;
/// K-dimension blocking for L1 cache (256 elements = 1KB)
pub const KC: usize = 256;
/// M-dimension blocking for L2 cache.
/// Must be a multiple of MR. 128 = 16×MR for AVX2 (vs old 72 = 9×MR).
/// Larger MC reduces packing overhead per macroblock (fewer ic-loop iterations).
/// Zen 4 L2 = 1MB per core; MC×KC×4B = 128×256×4 = 128KB << 1MB.
pub const MC: usize = 128;
/// N-dimension blocking for L3 cache
pub const NC: usize = 4096;
// ============================================================================
// AVX-512 BLIS Configuration Constants
// ============================================================================
/// AVX-512 microkernel row dimension (16 f32 per zmm register)
pub const MR_512: usize = 16;
/// AVX-512 microkernel column dimension (8 columns in remaining zmm registers)
pub const NR_512: usize = 8;
/// AVX-512 K-dimension blocking (same as AVX2, L1 limited)
pub const KC_512: usize = 256;
/// AVX-512 M-dimension blocking for L2 cache.
/// 128 = 8×MR_512. Zen 4 L2 = 1MB; 128×256×4 = 128KB.
pub const MC_512: usize = 128;
/// AVX-512 N-dimension blocking for L3 cache
pub const NC_512: usize = 4096;
// ============================================================================
// AVX-512 32×6 Microkernel Constants (Phase 4, Appendix D optimization #1)
// ============================================================================
/// 32×6 microkernel: 2 zmm rows × 6 columns = 12 accumulators.
/// 1.5× more FMAs per K step than 16×8 (12 vs 8).
pub const MR_512V2: usize = 32;
/// 6 columns: balances register pressure (12 acc + 2 A load = 14 zmm).
pub const NR_512V2: usize = 6;
/// Increased KC for 32×6: 32×256×4 = 32 KB fits L1 (32 KB on Zen 4).
pub const KC_512V2: usize = 256;
/// MC for 32×6: 192 = 6×MR_512V2. Packed A = 192×256×4 = 192 KB fits L2.
pub const MC_512V2: usize = 192;
/// NC for 32×6: same L3 blocking.
pub const NC_512V2: usize = 4096;
// ============================================================================
// Public API
// ============================================================================
/// High-performance GEMM using BLIS algorithm
///
/// Computes C += A * B where:
/// - A is M x K (row-major)
/// - B is K x N (row-major)
/// - C is M x N (row-major)
///
/// Automatically selects single-threaded or parallel execution based on matrix size.
/// GEMM with profiling enabled
/// Fused GEMM + bias + ReLU: C = max(0, A×B + bias)
///
/// Performs matmul then applies bias addition and ReLU activation in a single
/// pass over C while the output tiles are still in L1/L2 cache. This avoids
/// two extra full-matrix memory passes that separate add+relu would require.
///
/// For GEMM 64: saves ~2µs (bias+relu would cost 2×0.8µs on cold data).
/// For GEMM 128: saves ~5µs.
///
/// # Arguments
///
/// * `bias` - Per-column bias vector of length `n` (broadcast across rows)
///
/// # Errors
///
/// Returns `Err` if dimensions don't match or bias length != n.