1use super::packing::{pack_b_block, packed_b_size};
19use super::{KC, NC};
20
21#[derive(Debug, Clone)]
27pub struct PrepackedB {
28 data: Vec<f32>,
30 pub k: usize,
32 pub n: usize,
34 tile_offsets: Vec<usize>,
36 tile_sizes: Vec<usize>,
38 num_pc_tiles: usize,
40}
41
42impl PrepackedB {
43 pub fn pack(b: &[f32], k: usize, n: usize) -> Self {
53 assert_eq!(b.len(), k * n, "B size mismatch: expected {}, got {}", k * n, b.len());
54
55 if k == 0 || n == 0 {
56 return Self {
57 data: Vec::new(),
58 k,
59 n,
60 tile_offsets: Vec::new(),
61 tile_sizes: Vec::new(),
62 num_pc_tiles: 0,
63 };
64 }
65
66 let num_jc = (n + NC - 1) / NC;
67 let num_pc = (k + KC - 1) / KC;
68 let num_tiles = num_jc * num_pc;
69
70 let mut tile_offsets = Vec::with_capacity(num_tiles);
72 let mut tile_sizes = Vec::with_capacity(num_tiles);
73 let mut total_size = 0;
74
75 for jc in (0..n).step_by(NC) {
76 let nc_block = NC.min(n - jc);
77 for pc in (0..k).step_by(KC) {
78 let kc_block = KC.min(k - pc);
79 let size = packed_b_size(kc_block, nc_block);
80 tile_offsets.push(total_size);
81 tile_sizes.push(size);
82 total_size += size;
83 }
84 }
85
86 let mut data = vec![0.0_f32; total_size];
88 let mut tile_idx = 0;
89
90 for jc in (0..n).step_by(NC) {
91 let nc_block = NC.min(n - jc);
92 for pc in (0..k).step_by(KC) {
93 let kc_block = KC.min(k - pc);
94 let offset = tile_offsets[tile_idx];
95 let size = tile_sizes[tile_idx];
96 pack_b_block(b, n, pc, jc, kc_block, nc_block, &mut data[offset..offset + size]);
97 tile_idx += 1;
98 }
99 }
100
101 Self { data, k, n, tile_offsets, tile_sizes, num_pc_tiles: num_pc }
102 }
103
104 #[inline]
108 pub fn tile(&self, jc_idx: usize, pc_idx: usize) -> &[f32] {
109 let idx = jc_idx * self.num_pc_tiles + pc_idx;
110 let offset = self.tile_offsets[idx];
111 let size = self.tile_sizes[idx];
112 &self.data[offset..offset + size]
113 }
114
115 #[must_use]
117 pub fn memory_bytes(&self) -> usize {
118 self.data.len() * std::mem::size_of::<f32>()
119 }
120
121 #[must_use]
123 pub fn num_tiles(&self) -> usize {
124 self.tile_offsets.len()
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn test_prepack_empty() {
134 let pb = PrepackedB::pack(&[], 0, 0);
135 assert_eq!(pb.k, 0);
136 assert_eq!(pb.n, 0);
137 assert_eq!(pb.num_tiles(), 0);
138 assert_eq!(pb.memory_bytes(), 0);
139 }
140
141 #[test]
142 fn test_prepack_small() {
143 let b: Vec<f32> = (0..16).map(|i| i as f32).collect();
145 let pb = PrepackedB::pack(&b, 4, 4);
146 assert_eq!(pb.k, 4);
147 assert_eq!(pb.n, 4);
148 assert!(pb.num_tiles() > 0);
149 assert!(pb.memory_bytes() > 0);
150 }
151
152 #[test]
153 fn test_prepack_dimensions() {
154 let k = 384;
156 let n = 1536;
157 let b = vec![0.0_f32; k * n];
158 let pb = PrepackedB::pack(&b, k, n);
159 assert_eq!(pb.k, k);
160 assert_eq!(pb.n, n);
161
162 let num_jc = (n + NC - 1) / NC;
163 let num_pc = (k + KC - 1) / KC;
164 assert_eq!(pb.num_tiles(), num_jc * num_pc);
165 }
166
167 #[test]
168 fn test_prepack_tile_access() {
169 let k = 384;
170 let n = 384;
171 let b = vec![1.0_f32; k * n];
172 let pb = PrepackedB::pack(&b, k, n);
173
174 let tile = pb.tile(0, 0);
176 assert!(!tile.is_empty());
177 }
178
179 #[test]
180 #[should_panic(expected = "B size mismatch")]
181 fn test_prepack_size_mismatch() {
182 PrepackedB::pack(&[1.0, 2.0], 4, 4);
183 }
184
185 #[test]
187 fn test_prepacked_matches_gemm_blis() {
188 use crate::blis::compute::{gemm_blis, gemm_blis_with_prepacked_b};
189
190 let m = 128;
191 let k = 64;
192 let n = 96;
193
194 let a: Vec<f32> = (0..m * k).map(|i| ((i * 7 + 13) % 97) as f32 / 97.0).collect();
196 let b: Vec<f32> = (0..k * n).map(|i| ((i * 11 + 3) % 89) as f32 / 89.0).collect();
197
198 let mut c_standard = vec![0.0_f32; m * n];
200 gemm_blis(m, n, k, &a, &b, &mut c_standard, None).unwrap();
201
202 let prepacked = PrepackedB::pack(&b, k, n);
204 let mut c_prepacked = vec![0.0_f32; m * n];
205 gemm_blis_with_prepacked_b(m, n, k, &a, &prepacked, &mut c_prepacked, None).unwrap();
206
207 for i in 0..m * n {
209 assert!(
210 (c_standard[i] - c_prepacked[i]).abs() < 1e-5,
211 "Mismatch at index {i}: standard={}, prepacked={}",
212 c_standard[i],
213 c_prepacked[i]
214 );
215 }
216 }
217
218 #[test]
220 fn test_prepacked_parallel_matches_standard() {
221 use crate::blis::parallel::{gemm_blis_parallel, gemm_blis_parallel_with_prepacked_b};
222
223 let m = 256;
225 let k = 128;
226 let n = 64;
227
228 let a: Vec<f32> = (0..m * k).map(|i| ((i * 7 + 13) % 97) as f32 / 97.0).collect();
229 let b: Vec<f32> = (0..k * n).map(|i| ((i * 11 + 3) % 89) as f32 / 89.0).collect();
230
231 let mut c_standard = vec![0.0_f32; m * n];
233 gemm_blis_parallel(m, n, k, &a, &b, &mut c_standard).unwrap();
234
235 let prepacked = PrepackedB::pack(&b, k, n);
237 let mut c_prepacked = vec![0.0_f32; m * n];
238 gemm_blis_parallel_with_prepacked_b(m, n, k, &a, &prepacked, &mut c_prepacked).unwrap();
239
240 for i in 0..m * n {
241 assert!(
242 (c_standard[i] - c_prepacked[i]).abs() < 1e-5,
243 "Mismatch at index {i}: standard={}, prepacked={}",
244 c_standard[i],
245 c_prepacked[i]
246 );
247 }
248 }
249}