#[cfg(feature = "simd")]
use crate::simd::dot::simd_dot_f32;
#[cfg(feature = "simd")]
use crate::simd::gemm::{blocked_gemm_f32, blocked_gemm_f64, should_use_blocked, MatMulConfig};
pub fn simd_dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"Dot product requires equal-length vectors"
);
#[cfg(feature = "simd")]
{
use crate::ndarray::ArrayView1;
let a_view = ArrayView1::from(a);
let b_view = ArrayView1::from(b);
simd_dot_f32(&a_view, &b_view)
}
#[cfg(not(feature = "simd"))]
{
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
}
pub fn simd_dot_product_f64(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(
a.len(),
b.len(),
"Dot product requires equal-length vectors"
);
#[cfg(feature = "simd")]
{
use crate::ndarray::ArrayView1;
use crate::simd::dot::simd_dot_f64;
let a_view = ArrayView1::from(a);
let b_view = ArrayView1::from(b);
simd_dot_f64(&a_view, &b_view)
}
#[cfg(not(feature = "simd"))]
{
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
}
pub fn simd_matrix_multiply_f32(
m: usize,
k: usize,
n: usize,
alpha: f32,
a: &[f32],
b: &[f32],
beta: f32,
c: &mut [f32],
) {
assert_eq!(
a.len(),
m * k,
"Matrix A must have m*k = {}*{} = {} elements, got {}",
m,
k,
m * k,
a.len()
);
assert_eq!(
b.len(),
k * n,
"Matrix B must have k*n = {}*{} = {} elements, got {}",
k,
n,
k * n,
b.len()
);
assert_eq!(
c.len(),
m * n,
"Matrix C must have m*n = {}*{} = {} elements, got {}",
m,
n,
m * n,
c.len()
);
#[cfg(feature = "simd")]
{
let config = MatMulConfig::for_f32();
unsafe {
blocked_gemm_f32(
m,
k,
n,
alpha,
a.as_ptr(),
k, b.as_ptr(),
n, beta,
c.as_mut_ptr(),
n, &config,
);
}
}
#[cfg(not(feature = "simd"))]
{
gemm_simple_f32(m, k, n, alpha, a, k, b, n, beta, c, n);
}
}
pub fn simd_matrix_multiply_f64(
m: usize,
k: usize,
n: usize,
alpha: f64,
a: &[f64],
b: &[f64],
beta: f64,
c: &mut [f64],
) {
assert_eq!(
a.len(),
m * k,
"Matrix A must have m*k = {}*{} = {} elements, got {}",
m,
k,
m * k,
a.len()
);
assert_eq!(
b.len(),
k * n,
"Matrix B must have k*n = {}*{} = {} elements, got {}",
k,
n,
k * n,
b.len()
);
assert_eq!(
c.len(),
m * n,
"Matrix C must have m*n = {}*{} = {} elements, got {}",
m,
n,
m * n,
c.len()
);
#[cfg(feature = "simd")]
{
let config = MatMulConfig::for_f64();
unsafe {
blocked_gemm_f64(
m,
k,
n,
alpha,
a.as_ptr(),
k, b.as_ptr(),
n, beta,
c.as_mut_ptr(),
n, &config,
);
}
}
#[cfg(not(feature = "simd"))]
{
gemm_simple_f64(m, k, n, alpha, a, k, b, n, beta, c, n);
}
}
fn gemm_simple_f32(
m: usize,
k: usize,
n: usize,
alpha: f32,
a: &[f32],
_lda: usize,
b: &[f32],
_ldb: usize,
beta: f32,
c: &mut [f32],
_ldc: usize,
) {
if beta == 0.0 {
c.fill(0.0);
} else if beta != 1.0 {
for val in c.iter_mut() {
*val *= beta;
}
}
for i in 0..m {
for p in 0..k {
let a_val = alpha * a[i * k + p];
for j in 0..n {
c[i * n + j] += a_val * b[p * n + j];
}
}
}
}
fn gemm_simple_f64(
m: usize,
k: usize,
n: usize,
alpha: f64,
a: &[f64],
_lda: usize,
b: &[f64],
_ldb: usize,
beta: f64,
c: &mut [f64],
_ldc: usize,
) {
if beta == 0.0 {
c.fill(0.0);
} else if beta != 1.0 {
for val in c.iter_mut() {
*val *= beta;
}
}
for i in 0..m {
for p in 0..k {
let a_val = alpha * a[i * k + p];
for j in 0..n {
c[i * n + j] += a_val * b[p * n + j];
}
}
}
}
pub fn should_use_simd_matmul(m: usize, n: usize, k: usize) -> bool {
#[cfg(feature = "simd")]
{
should_use_blocked(m, n, k)
}
#[cfg(not(feature = "simd"))]
{
m >= 128 && n >= 128 && k >= 128
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dot_product_f32() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let result = simd_dot_product_f32(&a, &b);
assert!((result - 70.0).abs() < 1e-5);
}
#[test]
fn test_dot_product_f64() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = simd_dot_product_f64(&a, &b);
assert!((result - 32.0).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "equal-length")]
fn test_dot_product_length_mismatch() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0];
simd_dot_product_f32(&a, &b);
}
#[test]
fn test_matrix_multiply_identity() {
let n = 4;
let mut a = vec![0.0f32; n * n];
for i in 0..n {
a[i * n + i] = 1.0;
}
let b: Vec<f32> = (0..n * n).map(|i| i as f32).collect();
let mut c = vec![0.0f32; n * n];
simd_matrix_multiply_f32(n, n, n, 1.0, &a, &b, 0.0, &mut c);
for i in 0..n * n {
assert!(
(c[i] - b[i]).abs() < 1e-5,
"Mismatch at {}: expected {}, got {}",
i,
b[i],
c[i]
);
}
}
#[test]
fn test_matrix_multiply_alpha_beta() {
let m = 2;
let k = 2;
let n = 2;
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![1.0, 1.0, 1.0, 1.0];
simd_matrix_multiply_f32(m, k, n, 3.0, &a, &b, 2.0, &mut c);
let expected = [59.0, 68.0, 131.0, 152.0];
for i in 0..4 {
assert!(
(c[i] - expected[i]).abs() < 1e-4,
"Mismatch at {}: expected {}, got {}",
i,
expected[i],
c[i]
);
}
}
#[test]
fn test_matrix_multiply_large() {
let n = 128;
let a = vec![1.0f32; n * n];
let b = vec![2.0f32; n * n];
let mut c = vec![0.0f32; n * n];
simd_matrix_multiply_f32(n, n, n, 1.0, &a, &b, 0.0, &mut c);
let expected = 2.0 * n as f32;
for val in &c {
assert!(
(*val - expected).abs() < 1e-2,
"Expected {}, got {}",
expected,
val
);
}
}
#[test]
fn test_rectangular_multiply() {
let m = 3;
let k = 2;
let n = 4;
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0f32; m * n];
simd_matrix_multiply_f32(m, k, n, 1.0, &a, &b, 0.0, &mut c);
assert!(
(c[0] - 11.0).abs() < 1e-5,
"C[0,0] expected 11.0, got {}",
c[0]
);
assert!(
(c[1] - 14.0).abs() < 1e-5,
"C[0,1] expected 14.0, got {}",
c[1]
);
}
#[test]
#[cfg(feature = "simd")]
fn test_should_use_simd_matmul() {
assert!(!should_use_simd_matmul(32, 32, 32));
assert!(should_use_simd_matmul(64, 64, 64));
assert!(should_use_simd_matmul(256, 256, 256));
}
#[test]
fn test_matrix_multiply_identity_f64() {
let n = 4;
let mut a = vec![0.0f64; n * n];
for i in 0..n {
a[i * n + i] = 1.0;
}
let b: Vec<f64> = (0..n * n).map(|i| i as f64).collect();
let mut c = vec![0.0f64; n * n];
simd_matrix_multiply_f64(n, n, n, 1.0, &a, &b, 0.0, &mut c);
for i in 0..n * n {
assert!(
(c[i] - b[i]).abs() < 1e-10,
"Mismatch at {}: expected {}, got {}",
i,
b[i],
c[i]
);
}
}
#[test]
fn test_matrix_multiply_alpha_beta_f64() {
let m = 2;
let k = 2;
let n = 2;
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![1.0, 1.0, 1.0, 1.0];
simd_matrix_multiply_f64(m, k, n, 3.0, &a, &b, 2.0, &mut c);
let expected = [59.0, 68.0, 131.0, 152.0];
for i in 0..4 {
assert!(
(c[i] - expected[i]).abs() < 1e-10,
"Mismatch at {}: expected {}, got {}",
i,
expected[i],
c[i]
);
}
}
#[test]
fn test_matrix_multiply_large_f64() {
let n = 128;
let a = vec![1.0f64; n * n];
let b = vec![2.0f64; n * n];
let mut c = vec![0.0f64; n * n];
simd_matrix_multiply_f64(n, n, n, 1.0, &a, &b, 0.0, &mut c);
let expected = 2.0 * n as f64;
for val in &c {
assert!(
(*val - expected).abs() < 1e-8,
"Expected {}, got {}",
expected,
val
);
}
}
#[test]
fn test_rectangular_multiply_f64() {
let m = 3;
let k = 2;
let n = 4;
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0f64; m * n];
simd_matrix_multiply_f64(m, k, n, 1.0, &a, &b, 0.0, &mut c);
assert!(
(c[0] - 11.0).abs() < 1e-10,
"C[0,0] expected 11.0, got {}",
c[0]
);
assert!(
(c[1] - 14.0).abs() < 1e-10,
"C[0,1] expected 14.0, got {}",
c[1]
);
}
#[test]
fn test_dot_product_large() {
let n = 1024;
let a: Vec<f32> = (1..=n).map(|i| i as f32).collect();
let b: Vec<f32> = (1..=n).map(|i| i as f32).collect();
let result = simd_dot_product_f32(&a, &b);
let n64 = n as f64;
let expected = (n64 * (n64 + 1.0) * (2.0 * n64 + 1.0) / 6.0) as f32;
assert!(
(result - expected).abs() / expected < 1e-5,
"Large dot product: expected {}, got {}",
expected,
result
);
}
#[test]
fn test_dot_product_large_f64() {
let n = 1024;
let a: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let b: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let result = simd_dot_product_f64(&a, &b);
let n64 = n as f64;
let expected = n64 * (n64 + 1.0) * (2.0 * n64 + 1.0) / 6.0;
assert!(
(result - expected).abs() / expected < 1e-10,
"Large dot product: expected {}, got {}",
expected,
result
);
}
}