use super::config::MatMulConfig;
#[inline]
pub unsafe fn pack_b_f32(kc: usize, nr: usize, b: *const f32, ldb: usize, buffer: *mut f32) {
for p in 0..kc {
for j in 0..nr {
let b_val = *b.add(p * ldb + j);
*buffer.add(p * nr + j) = b_val;
}
}
}
#[inline]
pub unsafe fn pack_b_f64(kc: usize, nr: usize, b: *const f64, ldb: usize, buffer: *mut f64) {
for p in 0..kc {
for j in 0..nr {
let b_val = *b.add(p * ldb + j);
*buffer.add(p * nr + j) = b_val;
}
}
}
#[inline]
pub unsafe fn pack_a_f32(mr: usize, kc: usize, a: *const f32, lda: usize, buffer: *mut f32) {
for i in 0..mr {
for k in 0..kc {
let a_val = *a.add(i * lda + k);
*buffer.add(i * kc + k) = a_val;
}
}
}
#[inline]
pub unsafe fn pack_a_f64(mr: usize, kc: usize, a: *const f64, lda: usize, buffer: *mut f64) {
for i in 0..mr {
for k in 0..kc {
let a_val = *a.add(i * lda + k);
*buffer.add(i * kc + k) = a_val;
}
}
}
#[inline]
pub unsafe fn pack_b_f32_fast(
kc: usize,
nr: usize,
b: *const f32,
ldb: usize,
buffer: *mut f32,
_config: &MatMulConfig,
) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && nr >= 8 {
pack_b_f32_avx2(kc, nr, b, ldb, buffer);
return;
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") && nr >= 4 {
pack_b_f32_neon(kc, nr, b, ldb, buffer);
return;
}
}
pack_b_f32(kc, nr, b, ldb, buffer);
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn pack_b_f32_avx2(kc: usize, nr: usize, b: *const f32, ldb: usize, buffer: *mut f32) {
use std::arch::x86_64::*;
for p in 0..kc {
let src = b.add(p * ldb);
let dst = buffer.add(p * nr);
let mut j = 0;
while j + 8 <= nr {
let values = _mm256_loadu_ps(src.add(j));
_mm256_storeu_ps(dst.add(j), values);
j += 8;
}
while j < nr {
*dst.add(j) = *src.add(j);
j += 1;
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn pack_b_f32_neon(kc: usize, nr: usize, b: *const f32, ldb: usize, buffer: *mut f32) {
use std::arch::aarch64::*;
for p in 0..kc {
let src = b.add(p * ldb);
let dst = buffer.add(p * nr);
let mut j = 0;
while j + 4 <= nr {
let values = vld1q_f32(src.add(j));
vst1q_f32(dst.add(j), values);
j += 4;
}
while j < nr {
*dst.add(j) = *src.add(j);
j += 1;
}
}
}
#[inline]
pub unsafe fn pack_b_f64_fast(
kc: usize,
nr: usize,
b: *const f64,
ldb: usize,
buffer: *mut f64,
_config: &MatMulConfig,
) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && nr >= 4 {
pack_b_f64_avx2(kc, nr, b, ldb, buffer);
return;
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") && nr >= 2 {
pack_b_f64_neon(kc, nr, b, ldb, buffer);
return;
}
}
pack_b_f64(kc, nr, b, ldb, buffer);
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn pack_b_f64_avx2(kc: usize, nr: usize, b: *const f64, ldb: usize, buffer: *mut f64) {
use std::arch::x86_64::*;
for p in 0..kc {
let src = b.add(p * ldb);
let dst = buffer.add(p * nr);
let mut j = 0;
while j + 4 <= nr {
let values = _mm256_loadu_pd(src.add(j));
_mm256_storeu_pd(dst.add(j), values);
j += 4;
}
while j < nr {
*dst.add(j) = *src.add(j);
j += 1;
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn pack_b_f64_neon(kc: usize, nr: usize, b: *const f64, ldb: usize, buffer: *mut f64) {
use std::arch::aarch64::*;
for p in 0..kc {
let src = b.add(p * ldb);
let dst = buffer.add(p * nr);
let mut j = 0;
while j + 2 <= nr {
let values = vld1q_f64(src.add(j));
vst1q_f64(dst.add(j), values);
j += 2;
}
while j < nr {
*dst.add(j) = *src.add(j);
j += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pack_b_f32_small() {
let kc = 4;
let nr = 4;
let ldb = 8;
let b: Vec<f32> = (0..32).map(|i| i as f32).collect();
let mut packed = vec![0.0f32; kc * nr];
unsafe {
pack_b_f32(kc, nr, b.as_ptr(), ldb, packed.as_mut_ptr());
}
for p in 0..kc {
for j in 0..nr {
let expected = b[p * ldb + j];
let actual = packed[p * nr + j];
assert_eq!(
actual, expected,
"Mismatch at p={}, j={}: expected {}, got {}",
p, j, expected, actual
);
}
}
}
#[test]
fn test_pack_a_f32_small() {
let mr = 4;
let kc = 4;
let lda = 8;
let a: Vec<f32> = (0..32).map(|i| i as f32).collect();
let mut packed = vec![0.0f32; mr * kc];
unsafe {
pack_a_f32(mr, kc, a.as_ptr(), lda, packed.as_mut_ptr());
}
for i in 0..mr {
for k in 0..kc {
let expected = a[i * lda + k];
let actual = packed[i * kc + k];
assert_eq!(
actual, expected,
"Mismatch at i={}, k={}: expected {}, got {}",
i, k, expected, actual
);
}
}
}
#[test]
fn test_pack_b_f64() {
let kc = 2;
let nr = 2;
let ldb = 4;
let b: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut packed = vec![0.0f64; kc * nr];
unsafe {
pack_b_f64(kc, nr, b.as_ptr(), ldb, packed.as_mut_ptr());
}
assert_eq!(packed[0], 1.0); assert_eq!(packed[1], 2.0); assert_eq!(packed[2], 5.0); assert_eq!(packed[3], 6.0); }
}