#[must_use]
pub fn simd_bf16_matmul(
input: &[f32],
weight_bf16: &[u8],
in_dim: usize,
out_dim: usize,
) -> Vec<f32> {
let weight_f32 = simd_bf16_to_f32(weight_bf16);
simd_matmul(input, &weight_f32, in_dim, out_dim)
}
#[must_use]
pub(super) fn simd_bf16_matmul_streaming(
input: &[f32],
weight_bf16: &[u8],
in_dim: usize,
out_dim: usize,
) -> Vec<f32> {
const TILE_SIZE: usize = 64;
let mut output = vec![0.0f32; out_dim];
let input_vec = Vector::from_slice(input);
for tile_start in (0..out_dim).step_by(TILE_SIZE) {
let tile_end = (tile_start + TILE_SIZE).min(out_dim);
for row in tile_start..tile_end {
let row_byte_start = row * in_dim * 2;
let row_byte_end = row_byte_start + in_dim * 2;
let row_bf16 = &weight_bf16[row_byte_start..row_byte_end];
let row_f32 = simd_bf16_to_f32(row_bf16);
let row_vec = Vector::from_slice(&row_f32);
output[row] = input_vec.dot(&row_vec).expect("dot product failed");
}
}
output
}