pub mod blocked;
pub mod kernels;
pub mod matrix;
pub mod threaded;
pub use matrix::naive_ijk::matmul_naive_ijk;
pub use matrix::naive_ikj::matmul_naive_ikj;
pub use matrix::transpose::transpose;
pub fn multiply(a: &[f64], b: &[f64], c: &mut [f64], m: usize, n: usize, k: usize) {
assert_eq!(a.len(), m * k, "A: expected {}x{}={} elements", m, k, m * k);
assert_eq!(b.len(), k * n, "B: expected {}x{}={} elements", k, n, k * n);
assert_eq!(c.len(), m * n, "C: expected {}x{}={} elements", m, n, m * n);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma") {
unsafe { blocked::gemm_8x8::matmul_blocked_8x8(a, b, c, m, n, k, None, None) };
return;
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { blocked::gemm_12x4::matmul_blocked_12x4(a, b, c, m, n, k, None, None) };
return;
}
}
matrix::naive_ikj::matmul_naive_ikj(a, b, c, m, n, k);
}
pub fn multiply_parallel(
a: &[f64],
b: &[f64],
c: &mut [f64],
m: usize,
n: usize,
k: usize,
num_threads: usize,
) {
assert_eq!(a.len(), m * k, "A: expected {}x{}={} elements", m, k, m * k);
assert_eq!(b.len(), k * n, "B: expected {}x{}={} elements", k, n, k * n);
assert_eq!(c.len(), m * n, "C: expected {}x{}={} elements", m, n, m * n);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma") {
threaded::gemm_8x8_mt::matmul_blocked_8x8_mt(a, b, c, m, n, k, num_threads);
return;
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
threaded::gemm_12x4_mt::matmul_blocked_12x4_mt(a, b, c, m, n, k, num_threads);
return;
}
}
matrix::naive_ikj::matmul_naive_ikj(a, b, c, m, n, k);
}