pub mod blocked;
pub mod config;
pub mod micro_kernels;
pub mod packing;
pub use blocked::{blocked_gemm_f32, blocked_gemm_f64, should_use_blocked};
pub use config::MatMulConfig;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gemm_config_f32() {
let config = MatMulConfig::for_f32();
assert!(config.mc > 0);
assert!(config.kc > 0);
assert!(config.nc > 0);
assert!(config.mr > 0);
assert!(config.nr > 0);
assert_eq!(config.mr % 4, 0);
assert_eq!(config.nr % 4, 0);
}
#[test]
fn test_gemm_config_f64() {
let config = MatMulConfig::for_f64();
assert!(config.mc > 0);
assert!(config.kc > 0);
assert!(config.nc > 0);
assert!(config.mr > 0);
assert!(config.nr > 0);
let config_f32 = MatMulConfig::for_f32();
assert!(config.mc <= config_f32.mc);
assert!(config.kc <= config_f32.kc);
}
#[test]
fn test_identity_multiply() {
let n = 64;
let a: Vec<f32> = (0..n * n).map(|i| i as f32).collect();
let mut b = vec![0.0f32; n * n];
for i in 0..n {
b[i * n + i] = 1.0;
}
let mut c = vec![0.0f32; n * n];
let config = MatMulConfig::for_f32();
unsafe {
blocked_gemm_f32(
n,
n,
n,
1.0,
a.as_ptr(),
n,
b.as_ptr(),
n,
0.0,
c.as_mut_ptr(),
n,
&config,
);
}
for i in 0..n * n {
assert!(
(c[i] - a[i]).abs() < 1e-4,
"Mismatch at index {}: expected {}, got {}",
i,
a[i],
c[i]
);
}
}
#[test]
fn test_rectangular_multiply() {
let m = 3;
let k = 4;
let n = 5;
let a: Vec<f32> = (1..=12).map(|i| i as f32).collect();
let b: Vec<f32> = (1..=20).map(|i| i as f32).collect();
let mut c = vec![0.0f32; m * n];
let config = MatMulConfig::for_f32();
unsafe {
blocked_gemm_f32(
m,
k,
n,
1.0,
a.as_ptr(),
k,
b.as_ptr(),
n,
0.0,
c.as_mut_ptr(),
n,
&config,
);
}
assert!(
(c[0] - 110.0).abs() < 1e-4,
"C[0,0] expected 110.0, got {}",
c[0]
);
assert!(
(c[1] - 120.0).abs() < 1e-4,
"C[0,1] expected 120.0, got {}",
c[1]
);
}
#[test]
fn test_gemm_with_strided_access() {
let m = 4;
let k = 4;
let n = 4;
let lda = 8; let ldb = 8; let ldc = 8;
let mut a = vec![0.0f32; m * lda];
let mut b = vec![0.0f32; k * ldb];
let mut c = vec![0.0f32; m * ldc];
for i in 0..m {
a[i * lda + i] = 1.0;
}
for i in 0..k {
b[i * ldb + i] = 1.0;
}
let config = MatMulConfig::for_f32();
unsafe {
blocked_gemm_f32(
m,
k,
n,
1.0,
a.as_ptr(),
lda,
b.as_ptr(),
ldb,
0.0,
c.as_mut_ptr(),
ldc,
&config,
);
}
for i in 0..m {
for j in 0..n {
let expected = if i == j { 1.0 } else { 0.0 };
let actual = c[i * ldc + j];
assert!(
(actual - expected).abs() < 1e-5,
"Mismatch at ({},{}): expected {}, got {}",
i,
j,
expected,
actual
);
}
}
}
#[test]
fn test_large_matrix_correctness() {
let n = 200;
let a = vec![1.0f32; n * n];
let b = vec![2.0f32; n * n];
let mut c = vec![0.0f32; n * n];
let config = MatMulConfig::for_f32();
unsafe {
blocked_gemm_f32(
n,
n,
n,
1.0,
a.as_ptr(),
n,
b.as_ptr(),
n,
0.0,
c.as_mut_ptr(),
n,
&config,
);
}
let expected = 2.0 * n as f32;
for i in 0..n * n {
assert!(
(c[i] - expected).abs() < 1e-2,
"Mismatch at index {}: expected {}, got {}",
i,
expected,
c[i]
);
}
}
#[test]
fn test_identity_multiply_f64() {
let n = 64;
let a: Vec<f64> = (0..n * n).map(|i| i as f64).collect();
let mut b = vec![0.0f64; n * n];
for i in 0..n {
b[i * n + i] = 1.0;
}
let mut c = vec![0.0f64; n * n];
let config = MatMulConfig::for_f64();
unsafe {
blocked_gemm_f64(
n,
n,
n,
1.0,
a.as_ptr(),
n,
b.as_ptr(),
n,
0.0,
c.as_mut_ptr(),
n,
&config,
);
}
for i in 0..n * n {
assert!(
(c[i] - a[i]).abs() < 1e-10,
"Mismatch at index {}: expected {}, got {}",
i,
a[i],
c[i]
);
}
}
#[test]
fn test_rectangular_multiply_f64() {
let m = 3;
let k = 4;
let n = 5;
let a: Vec<f64> = (1..=12).map(|i| i as f64).collect();
let b: Vec<f64> = (1..=20).map(|i| i as f64).collect();
let mut c = vec![0.0f64; m * n];
let config = MatMulConfig::for_f64();
unsafe {
blocked_gemm_f64(
m,
k,
n,
1.0,
a.as_ptr(),
k,
b.as_ptr(),
n,
0.0,
c.as_mut_ptr(),
n,
&config,
);
}
assert!(
(c[0] - 110.0).abs() < 1e-10,
"C[0,0] expected 110.0, got {}",
c[0]
);
assert!(
(c[1] - 120.0).abs() < 1e-10,
"C[0,1] expected 120.0, got {}",
c[1]
);
}
#[test]
fn test_gemm_with_alpha_beta_f64() {
let m = 2;
let k = 2;
let n = 2;
let a: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
let b: Vec<f64> = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![1.0, 1.0, 1.0, 1.0];
let config = MatMulConfig::for_f64();
unsafe {
blocked_gemm_f64(
m,
k,
n,
2.0, a.as_ptr(),
k,
b.as_ptr(),
n,
3.0, c.as_mut_ptr(),
n,
&config,
);
}
let expected = [41.0, 47.0, 89.0, 103.0];
for i in 0..4 {
assert!(
(c[i] - expected[i]).abs() < 1e-10,
"Mismatch at index {}: expected {}, got {}",
i,
expected[i],
c[i]
);
}
}
#[test]
fn test_large_matrix_correctness_f64() {
let n = 128;
let a = vec![1.0f64; n * n];
let b = vec![2.0f64; n * n];
let mut c = vec![0.0f64; n * n];
let config = MatMulConfig::for_f64();
unsafe {
blocked_gemm_f64(
n,
n,
n,
1.0,
a.as_ptr(),
n,
b.as_ptr(),
n,
0.0,
c.as_mut_ptr(),
n,
&config,
);
}
let expected = 2.0 * n as f64;
for i in 0..n * n {
assert!(
(c[i] - expected).abs() < 1e-8,
"Mismatch at index {}: expected {}, got {}",
i,
expected,
c[i]
);
}
}
#[test]
fn test_gemm_with_strided_access_f64() {
let m = 4;
let k = 4;
let n = 4;
let lda = 8; let ldb = 8; let ldc = 8;
let mut a = vec![0.0f64; m * lda];
let mut b = vec![0.0f64; k * ldb];
let mut c = vec![0.0f64; m * ldc];
for i in 0..m {
a[i * lda + i] = 1.0;
}
for i in 0..k {
b[i * ldb + i] = 1.0;
}
let config = MatMulConfig::for_f64();
unsafe {
blocked_gemm_f64(
m,
k,
n,
1.0,
a.as_ptr(),
lda,
b.as_ptr(),
ldb,
0.0,
c.as_mut_ptr(),
ldc,
&config,
);
}
for i in 0..m {
for j in 0..n {
let expected = if i == j { 1.0 } else { 0.0 };
let actual = c[i * ldc + j];
assert!(
(actual - expected).abs() < 1e-10,
"Mismatch at ({},{}): expected {}, got {}",
i,
j,
expected,
actual
);
}
}
}
}