use super::packing::{pack_b_block, packed_b_size};
use super::{KC, NC};
#[derive(Debug, Clone)]
pub struct PrepackedB {
data: Vec<f32>,
pub k: usize,
pub n: usize,
tile_offsets: Vec<usize>,
tile_sizes: Vec<usize>,
num_pc_tiles: usize,
}
impl PrepackedB {
pub fn pack(b: &[f32], k: usize, n: usize) -> Self {
assert_eq!(b.len(), k * n, "B size mismatch: expected {}, got {}", k * n, b.len());
if k == 0 || n == 0 {
return Self {
data: Vec::new(),
k,
n,
tile_offsets: Vec::new(),
tile_sizes: Vec::new(),
num_pc_tiles: 0,
};
}
let num_jc = (n + NC - 1) / NC;
let num_pc = (k + KC - 1) / KC;
let num_tiles = num_jc * num_pc;
let mut tile_offsets = Vec::with_capacity(num_tiles);
let mut tile_sizes = Vec::with_capacity(num_tiles);
let mut total_size = 0;
for jc in (0..n).step_by(NC) {
let nc_block = NC.min(n - jc);
for pc in (0..k).step_by(KC) {
let kc_block = KC.min(k - pc);
let size = packed_b_size(kc_block, nc_block);
tile_offsets.push(total_size);
tile_sizes.push(size);
total_size += size;
}
}
let mut data = vec![0.0_f32; total_size];
let mut tile_idx = 0;
for jc in (0..n).step_by(NC) {
let nc_block = NC.min(n - jc);
for pc in (0..k).step_by(KC) {
let kc_block = KC.min(k - pc);
let offset = tile_offsets[tile_idx];
let size = tile_sizes[tile_idx];
pack_b_block(b, n, pc, jc, kc_block, nc_block, &mut data[offset..offset + size]);
tile_idx += 1;
}
}
Self { data, k, n, tile_offsets, tile_sizes, num_pc_tiles: num_pc }
}
#[inline]
pub fn tile(&self, jc_idx: usize, pc_idx: usize) -> &[f32] {
let idx = jc_idx * self.num_pc_tiles + pc_idx;
let offset = self.tile_offsets[idx];
let size = self.tile_sizes[idx];
&self.data[offset..offset + size]
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
self.data.len() * std::mem::size_of::<f32>()
}
#[must_use]
pub fn num_tiles(&self) -> usize {
self.tile_offsets.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepack_empty() {
let pb = PrepackedB::pack(&[], 0, 0);
assert_eq!(pb.k, 0);
assert_eq!(pb.n, 0);
assert_eq!(pb.num_tiles(), 0);
assert_eq!(pb.memory_bytes(), 0);
}
#[test]
fn test_prepack_small() {
let b: Vec<f32> = (0..16).map(|i| i as f32).collect();
let pb = PrepackedB::pack(&b, 4, 4);
assert_eq!(pb.k, 4);
assert_eq!(pb.n, 4);
assert!(pb.num_tiles() > 0);
assert!(pb.memory_bytes() > 0);
}
#[test]
fn test_prepack_dimensions() {
let k = 384;
let n = 1536;
let b = vec![0.0_f32; k * n];
let pb = PrepackedB::pack(&b, k, n);
assert_eq!(pb.k, k);
assert_eq!(pb.n, n);
let num_jc = (n + NC - 1) / NC;
let num_pc = (k + KC - 1) / KC;
assert_eq!(pb.num_tiles(), num_jc * num_pc);
}
#[test]
fn test_prepack_tile_access() {
let k = 384;
let n = 384;
let b = vec![1.0_f32; k * n];
let pb = PrepackedB::pack(&b, k, n);
let tile = pb.tile(0, 0);
assert!(!tile.is_empty());
}
#[test]
#[should_panic(expected = "B size mismatch")]
fn test_prepack_size_mismatch() {
PrepackedB::pack(&[1.0, 2.0], 4, 4);
}
#[test]
fn test_prepacked_matches_gemm_blis() {
use crate::blis::compute::{gemm_blis, gemm_blis_with_prepacked_b};
let m = 128;
let k = 64;
let n = 96;
let a: Vec<f32> = (0..m * k).map(|i| ((i * 7 + 13) % 97) as f32 / 97.0).collect();
let b: Vec<f32> = (0..k * n).map(|i| ((i * 11 + 3) % 89) as f32 / 89.0).collect();
let mut c_standard = vec![0.0_f32; m * n];
gemm_blis(m, n, k, &a, &b, &mut c_standard, None).unwrap();
let prepacked = PrepackedB::pack(&b, k, n);
let mut c_prepacked = vec![0.0_f32; m * n];
gemm_blis_with_prepacked_b(m, n, k, &a, &prepacked, &mut c_prepacked, None).unwrap();
for i in 0..m * n {
assert!(
(c_standard[i] - c_prepacked[i]).abs() < 1e-5,
"Mismatch at index {i}: standard={}, prepacked={}",
c_standard[i],
c_prepacked[i]
);
}
}
#[test]
fn test_prepacked_parallel_matches_standard() {
use crate::blis::parallel::{gemm_blis_parallel, gemm_blis_parallel_with_prepacked_b};
let m = 256;
let k = 128;
let n = 64;
let a: Vec<f32> = (0..m * k).map(|i| ((i * 7 + 13) % 97) as f32 / 97.0).collect();
let b: Vec<f32> = (0..k * n).map(|i| ((i * 11 + 3) % 89) as f32 / 89.0).collect();
let mut c_standard = vec![0.0_f32; m * n];
gemm_blis_parallel(m, n, k, &a, &b, &mut c_standard).unwrap();
let prepacked = PrepackedB::pack(&b, k, n);
let mut c_prepacked = vec![0.0_f32; m * n];
gemm_blis_parallel_with_prepacked_b(m, n, k, &a, &prepacked, &mut c_prepacked).unwrap();
for i in 0..m * n {
assert!(
(c_standard[i] - c_prepacked[i]).abs() < 1e-5,
"Mismatch at index {i}: standard={}, prepacked={}",
c_standard[i],
c_prepacked[i]
);
}
}
}