Skip to main content

trueno/blis/
prepacked.rs

1//! Pre-packed B matrix for BLIS GEMM.
2//!
3//! Eliminates redundant B packing in parallel GEMM by pre-packing all
4//! (jc, pc) tiles once at weight-load time. The packed data can then be
5//! shared immutably across all threads without per-thread repacking.
6//!
7//! # Motivation (WAPR-KAIZEN Cycle 12)
8//!
9//! In `gemm_blis_parallel`, each thread independently calls `gemm_blis` which
10//! packs B internally via `pack_b_block()`. For encoder FFN with 16 threads,
11//! 2 GEMMs per block, and 4 layers, this results in 128 redundant B packings
12//! per encoder pass. Pre-packing eliminates this entirely.
13//!
14//! # References
15//!
16//! - Van Zee & Van de Geijn (2015): BLIS framework, Section 3.2 (packing)
17
18use super::packing::{pack_b_block, packed_b_size};
19use super::{KC, NC};
20
21/// Pre-packed B matrix in BLIS tile format.
22///
23/// Stores all (jc, pc) tiles of a B matrix (k × n, row-major) in the packed
24/// micro-panel layout expected by BLIS microkernels. Once constructed, this
25/// is immutable and can be shared across threads via `&PrepackedB`.
26#[derive(Debug, Clone)]
27pub struct PrepackedB {
28    /// Flat buffer of all pre-packed B tiles
29    data: Vec<f32>,
30    /// Original K dimension
31    pub k: usize,
32    /// Original N dimension
33    pub n: usize,
34    /// Offset of each (jc_idx, pc_idx) tile in `data`
35    tile_offsets: Vec<usize>,
36    /// Size of each (jc_idx, pc_idx) tile
37    tile_sizes: Vec<usize>,
38    /// Number of pc tiles (K-dimension)
39    num_pc_tiles: usize,
40}
41
42impl PrepackedB {
43    /// Pre-pack a B matrix (k × n, row-major) into BLIS tile format.
44    ///
45    /// This iterates the same (jc, pc) loop as `gemm_blis` and packs each
46    /// B tile into the NR-aligned micro-panel layout. The result can be
47    /// reused across many GEMM calls with different A matrices.
48    ///
49    /// # Panics
50    ///
51    /// Panics if `b.len() != k * n`.
52    pub fn pack(b: &[f32], k: usize, n: usize) -> Self {
53        assert_eq!(b.len(), k * n, "B size mismatch: expected {}, got {}", k * n, b.len());
54
55        if k == 0 || n == 0 {
56            return Self {
57                data: Vec::new(),
58                k,
59                n,
60                tile_offsets: Vec::new(),
61                tile_sizes: Vec::new(),
62                num_pc_tiles: 0,
63            };
64        }
65
66        let num_jc = (n + NC - 1) / NC;
67        let num_pc = (k + KC - 1) / KC;
68        let num_tiles = num_jc * num_pc;
69
70        // First pass: compute tile sizes and cumulative offsets
71        let mut tile_offsets = Vec::with_capacity(num_tiles);
72        let mut tile_sizes = Vec::with_capacity(num_tiles);
73        let mut total_size = 0;
74
75        for jc in (0..n).step_by(NC) {
76            let nc_block = NC.min(n - jc);
77            for pc in (0..k).step_by(KC) {
78                let kc_block = KC.min(k - pc);
79                let size = packed_b_size(kc_block, nc_block);
80                tile_offsets.push(total_size);
81                tile_sizes.push(size);
82                total_size += size;
83            }
84        }
85
86        // Second pass: pack all tiles
87        let mut data = vec![0.0_f32; total_size];
88        let mut tile_idx = 0;
89
90        for jc in (0..n).step_by(NC) {
91            let nc_block = NC.min(n - jc);
92            for pc in (0..k).step_by(KC) {
93                let kc_block = KC.min(k - pc);
94                let offset = tile_offsets[tile_idx];
95                let size = tile_sizes[tile_idx];
96                pack_b_block(b, n, pc, jc, kc_block, nc_block, &mut data[offset..offset + size]);
97                tile_idx += 1;
98            }
99        }
100
101        Self { data, k, n, tile_offsets, tile_sizes, num_pc_tiles: num_pc }
102    }
103
104    /// Get the pre-packed tile for the given (jc, pc) tile indices.
105    ///
106    /// `jc_idx` = jc / NC, `pc_idx` = pc / KC
107    #[inline]
108    pub fn tile(&self, jc_idx: usize, pc_idx: usize) -> &[f32] {
109        let idx = jc_idx * self.num_pc_tiles + pc_idx;
110        let offset = self.tile_offsets[idx];
111        let size = self.tile_sizes[idx];
112        &self.data[offset..offset + size]
113    }
114
115    /// Total memory used by packed data (bytes).
116    #[must_use]
117    pub fn memory_bytes(&self) -> usize {
118        self.data.len() * std::mem::size_of::<f32>()
119    }
120
121    /// Number of packed tiles.
122    #[must_use]
123    pub fn num_tiles(&self) -> usize {
124        self.tile_offsets.len()
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_prepack_empty() {
134        let pb = PrepackedB::pack(&[], 0, 0);
135        assert_eq!(pb.k, 0);
136        assert_eq!(pb.n, 0);
137        assert_eq!(pb.num_tiles(), 0);
138        assert_eq!(pb.memory_bytes(), 0);
139    }
140
141    #[test]
142    fn test_prepack_small() {
143        // 4x4 matrix — small enough for a single tile
144        let b: Vec<f32> = (0..16).map(|i| i as f32).collect();
145        let pb = PrepackedB::pack(&b, 4, 4);
146        assert_eq!(pb.k, 4);
147        assert_eq!(pb.n, 4);
148        assert!(pb.num_tiles() > 0);
149        assert!(pb.memory_bytes() > 0);
150    }
151
152    #[test]
153    fn test_prepack_dimensions() {
154        // Whisper-tiny fc1: B is 384×1536 (transposed weights)
155        let k = 384;
156        let n = 1536;
157        let b = vec![0.0_f32; k * n];
158        let pb = PrepackedB::pack(&b, k, n);
159        assert_eq!(pb.k, k);
160        assert_eq!(pb.n, n);
161
162        let num_jc = (n + NC - 1) / NC;
163        let num_pc = (k + KC - 1) / KC;
164        assert_eq!(pb.num_tiles(), num_jc * num_pc);
165    }
166
167    #[test]
168    fn test_prepack_tile_access() {
169        let k = 384;
170        let n = 384;
171        let b = vec![1.0_f32; k * n];
172        let pb = PrepackedB::pack(&b, k, n);
173
174        // Access first tile
175        let tile = pb.tile(0, 0);
176        assert!(!tile.is_empty());
177    }
178
179    #[test]
180    #[should_panic(expected = "B size mismatch")]
181    fn test_prepack_size_mismatch() {
182        PrepackedB::pack(&[1.0, 2.0], 4, 4);
183    }
184
185    /// Golden test: gemm_blis_with_prepacked_b must produce identical output to gemm_blis.
186    #[test]
187    fn test_prepacked_matches_gemm_blis() {
188        use crate::blis::compute::{gemm_blis, gemm_blis_with_prepacked_b};
189
190        let m = 128;
191        let k = 64;
192        let n = 96;
193
194        // Deterministic pseudo-random data
195        let a: Vec<f32> = (0..m * k).map(|i| ((i * 7 + 13) % 97) as f32 / 97.0).collect();
196        let b: Vec<f32> = (0..k * n).map(|i| ((i * 11 + 3) % 89) as f32 / 89.0).collect();
197
198        // Standard GEMM
199        let mut c_standard = vec![0.0_f32; m * n];
200        gemm_blis(m, n, k, &a, &b, &mut c_standard, None).unwrap();
201
202        // Pre-packed GEMM
203        let prepacked = PrepackedB::pack(&b, k, n);
204        let mut c_prepacked = vec![0.0_f32; m * n];
205        gemm_blis_with_prepacked_b(m, n, k, &a, &prepacked, &mut c_prepacked, None).unwrap();
206
207        // Must be bit-identical (same packing, same microkernel)
208        for i in 0..m * n {
209            assert!(
210                (c_standard[i] - c_prepacked[i]).abs() < 1e-5,
211                "Mismatch at index {i}: standard={}, prepacked={}",
212                c_standard[i],
213                c_prepacked[i]
214            );
215        }
216    }
217
218    /// Golden test for parallel pre-packed GEMM.
219    #[test]
220    fn test_prepacked_parallel_matches_standard() {
221        use crate::blis::parallel::{gemm_blis_parallel, gemm_blis_parallel_with_prepacked_b};
222
223        // Use dimensions large enough to trigger parallel path (m*n*k >= 1_000_000)
224        let m = 256;
225        let k = 128;
226        let n = 64;
227
228        let a: Vec<f32> = (0..m * k).map(|i| ((i * 7 + 13) % 97) as f32 / 97.0).collect();
229        let b: Vec<f32> = (0..k * n).map(|i| ((i * 11 + 3) % 89) as f32 / 89.0).collect();
230
231        // Standard parallel GEMM
232        let mut c_standard = vec![0.0_f32; m * n];
233        gemm_blis_parallel(m, n, k, &a, &b, &mut c_standard).unwrap();
234
235        // Pre-packed parallel GEMM
236        let prepacked = PrepackedB::pack(&b, k, n);
237        let mut c_prepacked = vec![0.0_f32; m * n];
238        gemm_blis_parallel_with_prepacked_b(m, n, k, &a, &prepacked, &mut c_prepacked).unwrap();
239
240        for i in 0..m * n {
241            assert!(
242                (c_standard[i] - c_prepacked[i]).abs() < 1e-5,
243                "Mismatch at index {i}: standard={}, prepacked={}",
244                c_standard[i],
245                c_prepacked[i]
246            );
247        }
248    }
249}