use crate::TruenoError;
#[cfg(feature = "tracing")]
use tracing::instrument;
use super::super::Matrix;
impl Matrix<f32> {
#[cfg_attr(feature = "tracing", instrument(skip(self, other), fields(dims = %format!("{}x{} @ {}x{}", self.rows, self.cols, other.rows, other.cols))))]
pub fn matmul(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
if self.cols != other.rows {
return Err(TruenoError::InvalidInput(format!(
"Matrix dimension mismatch for multiplication: {}×{} × {}×{} (inner dimensions {} and {} must match)",
self.rows, self.cols, other.rows, other.cols, self.cols, other.rows
)));
}
if self.rows == 1 {
return self.matmul_vector_matrix(other);
}
let mut result = Matrix::zeros_with_backend(self.rows, other.cols, self.backend);
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
const GPU_THRESHOLD: usize = 500;
const SIMD_THRESHOLD: usize = 64;
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
{
if self.rows >= GPU_THRESHOLD
&& self.cols >= GPU_THRESHOLD
&& other.cols >= GPU_THRESHOLD
{
if let Ok(gpu_result) = self.matmul_gpu(other) {
return Ok(gpu_result);
}
}
}
if self.rows >= SIMD_THRESHOLD
|| self.cols >= SIMD_THRESHOLD
|| other.cols >= SIMD_THRESHOLD
{
#[cfg(target_arch = "wasm32")]
{
self.matmul_wasm_tiled(other, &mut result)?;
}
#[cfg(not(target_arch = "wasm32"))]
{
crate::blis::parallel::gemm_blis_parallel(
self.rows,
other.cols,
self.cols,
&self.data,
&other.data,
&mut result.data,
)?;
}
} else {
self.matmul_naive(other, &mut result)?;
}
Ok(result)
}
#[cfg_attr(feature = "tracing", instrument(skip(a_data, b_data), fields(batch, m, k, n)))]
pub fn batched_matmul(
a_data: &[f32],
b_data: &[f32],
batch: usize,
m: usize,
k: usize,
n: usize,
) -> Result<Vec<f32>, TruenoError> {
let a_stride = m * k;
let b_stride = k * n;
let out_stride = m * n;
if a_data.len() != batch * a_stride {
return Err(TruenoError::InvalidInput(format!(
"A data size mismatch: expected {} ({}×{}×{}), got {}",
batch * a_stride,
batch,
m,
k,
a_data.len()
)));
}
if b_data.len() != batch * b_stride {
return Err(TruenoError::InvalidInput(format!(
"B data size mismatch: expected {} ({}×{}×{}), got {}",
batch * b_stride,
batch,
k,
n,
b_data.len()
)));
}
let mut output = vec![0.0f32; batch * out_stride];
for ba in 0..batch {
let a_offset = ba * a_stride;
let b_offset = ba * b_stride;
let out_offset = ba * out_stride;
let a_slice = &a_data[a_offset..a_offset + a_stride];
let b_slice = &b_data[b_offset..b_offset + b_stride];
let c_slice = &mut output[out_offset..out_offset + out_stride];
#[cfg(not(target_arch = "wasm32"))]
{
crate::blis::gemm_blis(m, n, k, a_slice, b_slice, c_slice, None)?;
}
#[cfg(target_arch = "wasm32")]
{
let a_mat = Matrix::from_slice(m, k, a_slice)?;
let b_mat = Matrix::from_slice(k, n, b_slice)?;
let result = a_mat.matmul(&b_mat)?;
c_slice.copy_from_slice(result.as_slice());
}
}
Ok(output)
}
#[cfg_attr(
feature = "tracing",
instrument(skip(a_data, b_data), fields(batch, heads, m, k, n))
)]
pub fn batched_matmul_4d(
a_data: &[f32],
b_data: &[f32],
batch: usize,
heads: usize,
m: usize,
k: usize,
n: usize,
) -> Result<Vec<f32>, TruenoError> {
let a_head_stride = m * k;
let b_head_stride = k * n;
let out_head_stride = m * n;
let total_heads = batch * heads;
let expected_a = total_heads * a_head_stride;
let expected_b = total_heads * b_head_stride;
if a_data.len() != expected_a {
return Err(TruenoError::InvalidInput(format!(
"A data size mismatch: expected {} ({}×{}×{}×{}), got {}",
expected_a,
batch,
heads,
m,
k,
a_data.len()
)));
}
if b_data.len() != expected_b {
return Err(TruenoError::InvalidInput(format!(
"B data size mismatch: expected {} ({}×{}×{}×{}), got {}",
expected_b,
batch,
heads,
k,
n,
b_data.len()
)));
}
let mut output = vec![0.0f32; total_heads * out_head_stride];
for bh in 0..total_heads {
let a_offset = bh * a_head_stride;
let b_offset = bh * b_head_stride;
let out_offset = bh * out_head_stride;
let a_slice = &a_data[a_offset..a_offset + a_head_stride];
let b_slice = &b_data[b_offset..b_offset + b_head_stride];
let c_slice = &mut output[out_offset..out_offset + out_head_stride];
#[cfg(not(target_arch = "wasm32"))]
{
crate::blis::gemm_blis(m, n, k, a_slice, b_slice, c_slice, None)?;
}
#[cfg(target_arch = "wasm32")]
{
let a_mat = Matrix::from_slice(m, k, a_slice)?;
let b_mat = Matrix::from_slice(k, n, b_slice)?;
let result = a_mat.matmul(&b_mat)?;
c_slice.copy_from_slice(result.as_slice());
}
}
Ok(output)
}
#[cfg_attr(feature = "tracing", instrument(skip(self, other), fields(k = self.cols, n = other.cols)))]
fn matmul_vector_matrix(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
debug_assert_eq!(self.rows, 1);
let k = self.cols;
let n = other.cols;
let mut c = vec![0.0f32; n];
crate::blis::gemv::gemv(k, n, &self.data, &other.data, &mut c);
Ok(Matrix::from_vec(1, n, c)?)
}
fn matmul_naive(
&self,
other: &Matrix<f32>,
result: &mut Matrix<f32>,
) -> Result<(), TruenoError> {
let m = self.rows;
let k = self.cols;
let n = other.cols;
let a = &self.data;
let b = &other.data;
let c = &mut result.data;
for i in 0..m {
let a_row = i * k;
let c_row = i * n;
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[a_row + kk] * b[kk * n + j];
}
c[c_row + j] = sum;
}
}
Ok(())
}
#[allow(dead_code)]
fn matmul_wasm_tiled(
&self,
other: &Matrix<f32>,
result: &mut Matrix<f32>,
) -> Result<(), TruenoError> {
let m = self.rows;
let k = self.cols;
let n = other.cols;
for i in 0..m {
let a_row_start = i * k;
let result_row_start = i * n;
let simd_width = 8;
let n_simd = (n / simd_width) * simd_width;
#[allow(clippy::needless_range_loop)]
for j0 in (0..n_simd).step_by(simd_width) {
let mut acc = [0.0f32; 8];
for kk in 0..k {
let a_val = self.data[a_row_start + kk];
let b_row_start = kk * n + j0;
for jj in 0..simd_width {
acc[jj] += a_val * other.data[b_row_start + jj];
}
}
for jj in 0..simd_width {
result.data[result_row_start + j0 + jj] = acc[jj];
}
}
for j in n_simd..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += self.data[a_row_start + kk] * other.data[kk * n + j];
}
result.data[result_row_start + j] = sum;
}
}
Ok(())
}
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
fn matmul_gpu(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
use crate::backends::gpu::GpuBackend;
if !GpuBackend::is_available() {
return Err(TruenoError::InvalidInput("GPU not available".to_string()));
}
let mut gpu = GpuBackend::new();
let result_data = gpu
.matmul(&self.data, &other.data, self.rows, self.cols, other.cols)
.map_err(|e| TruenoError::InvalidInput(format!("GPU matmul failed: {}", e)))?;
let mut result = Matrix::zeros(self.rows, other.cols);
result.data = result_data;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matmul_basic() {
let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.get(0, 0), Some(&19.0));
assert_eq!(c.get(0, 1), Some(&22.0));
assert_eq!(c.get(1, 0), Some(&43.0));
assert_eq!(c.get(1, 1), Some(&50.0));
}
#[test]
fn test_matmul_dimension_mismatch() {
let a = Matrix::from_vec(2, 3, vec![1.0; 6]).unwrap();
let b = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
assert!(a.matmul(&b).is_err());
}
#[test]
fn test_matmul_identity() {
let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let i = Matrix::identity(2);
let result = a.matmul(&i).unwrap();
assert_eq!(result.as_slice(), a.as_slice());
}
#[test]
fn test_batched_matmul() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let b = vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]; let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2).unwrap();
assert_eq!(result, a); }
#[test]
fn test_batched_matmul_a_size_mismatch() {
let a = vec![1.0, 2.0, 3.0]; let b = vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2);
assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
}
#[test]
fn test_batched_matmul_b_size_mismatch() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![1.0, 0.0]; let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2);
assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
}
#[test]
fn test_batched_matmul_single_batch() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = vec![1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]; let result = Matrix::batched_matmul(&a, &b, 1, 3, 2, 4).unwrap();
assert_eq!(result.len(), 12); }
#[test]
fn test_batched_matmul_4d_basic() {
let a = vec![1.0, 2.0, 3.0, 4.0]; let b = vec![1.0, 0.0, 0.0, 1.0]; let result = Matrix::batched_matmul_4d(&a, &b, 1, 1, 2, 2, 2).unwrap();
assert_eq!(result, a);
}
#[test]
fn test_batched_matmul_4d_a_size_mismatch() {
let a = vec![1.0]; let b: Vec<f32> = (0..80).map(|x| x as f32 * 0.1).collect();
let result = Matrix::batched_matmul_4d(&a, &b, 2, 2, 3, 4, 5);
assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
}
#[test]
fn test_batched_matmul_4d_b_size_mismatch() {
let a: Vec<f32> = (0..48).map(|x| x as f32 * 0.1).collect();
let b = vec![1.0]; let result = Matrix::batched_matmul_4d(&a, &b, 2, 2, 3, 4, 5);
assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
}
#[test]
fn test_batched_matmul_4d_multi_head() {
let total = 4 * 2 * 2; let a: Vec<f32> = (0..total).map(|_| 1.0).collect();
let b: Vec<f32> = (0..total).map(|_| 1.0).collect();
let result = Matrix::batched_matmul_4d(&a, &b, 1, 4, 2, 2, 2).unwrap();
assert_eq!(result.len(), total);
for val in &result {
assert!((*val - 2.0).abs() < 1e-5);
}
}
#[test]
fn test_matmul_vector_matrix_path() {
let a = Matrix::from_vec(1, 4, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = Matrix::from_vec(
4,
3,
vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
)
.unwrap();
let result = a.matmul(&b).unwrap();
assert_eq!(result.rows(), 1);
assert_eq!(result.cols(), 3);
assert!((result.get(0, 0).unwrap() - 5.0).abs() < 1e-5);
assert!((result.get(0, 1).unwrap() - 6.0).abs() < 1e-5);
assert!((result.get(0, 2).unwrap() - 7.0).abs() < 1e-5);
}
#[test]
fn test_matmul_vector_matrix_with_zeros() {
let a = Matrix::from_vec(1, 3, vec![0.0, 2.0, 0.0]).unwrap();
let b = Matrix::from_vec(3, 2, vec![100.0, 200.0, 3.0, 4.0, 500.0, 600.0]).unwrap();
let result = a.matmul(&b).unwrap();
assert!((result.get(0, 0).unwrap() - 6.0).abs() < 1e-5);
assert!((result.get(0, 1).unwrap() - 8.0).abs() < 1e-5);
}
#[test]
fn test_matmul_wasm_tiled_small_no_simd() {
let a = Matrix::from_vec(2, 4, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
let b = Matrix::from_vec(
4,
3,
vec![1.0, 0.0, 2.0, 0.0, 1.0, 0.0, 2.0, 0.0, 1.0, 0.0, 2.0, 0.0],
)
.unwrap();
let mut result = Matrix::zeros(2, 3);
a.matmul_wasm_tiled(&b, &mut result).unwrap();
assert!((result.get(0, 0).unwrap() - 7.0).abs() < 1e-5);
assert!((result.get(0, 1).unwrap() - 10.0).abs() < 1e-5);
assert!((result.get(0, 2).unwrap() - 5.0).abs() < 1e-5);
assert!((result.get(1, 0).unwrap() - 19.0).abs() < 1e-5);
assert!((result.get(1, 1).unwrap() - 22.0).abs() < 1e-5);
assert!((result.get(1, 2).unwrap() - 17.0).abs() < 1e-5);
}
#[test]
fn test_matmul_wasm_tiled_exact_simd_width() {
let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let b_data: Vec<f32> = (1..=24).map(|x| x as f32).collect(); let b = Matrix::from_vec(3, 8, b_data).unwrap();
let mut result = Matrix::zeros(2, 8);
a.matmul_wasm_tiled(&b, &mut result).unwrap();
let mut expected = Matrix::zeros(2, 8);
a.matmul_naive(&b, &mut expected).unwrap();
for i in 0..2 {
for j in 0..8 {
assert!(
(result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-4,
"Mismatch at ({}, {}): wasm_tiled={}, naive={}",
i,
j,
result.get(i, j).unwrap(),
expected.get(i, j).unwrap()
);
}
}
}
#[test]
fn test_matmul_wasm_tiled_simd_plus_remainder() {
let a_data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
let a = Matrix::from_vec(3, 4, a_data).unwrap();
let b_data: Vec<f32> = (1..=44).map(|x| x as f32 * 0.1).collect();
let b = Matrix::from_vec(4, 11, b_data).unwrap();
let mut result = Matrix::zeros(3, 11);
a.matmul_wasm_tiled(&b, &mut result).unwrap();
let mut expected = Matrix::zeros(3, 11);
a.matmul_naive(&b, &mut expected).unwrap();
for i in 0..3 {
for j in 0..11 {
assert!(
(result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-3,
"Mismatch at ({}, {}): wasm_tiled={}, naive={}",
i,
j,
result.get(i, j).unwrap(),
expected.get(i, j).unwrap()
);
}
}
}
#[test]
fn test_matmul_wasm_tiled_multiple_simd_blocks() {
let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b_data: Vec<f32> = (1..=32).map(|x| x as f32).collect();
let b = Matrix::from_vec(2, 16, b_data).unwrap();
let mut result = Matrix::zeros(2, 16);
a.matmul_wasm_tiled(&b, &mut result).unwrap();
let mut expected = Matrix::zeros(2, 16);
a.matmul_naive(&b, &mut expected).unwrap();
for i in 0..2 {
for j in 0..16 {
assert!(
(result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-3,
"Mismatch at ({}, {})",
i,
j,
);
}
}
}
#[test]
fn test_matmul_wasm_tiled_single_row() {
let a = Matrix::from_vec(1, 5, vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let b_data: Vec<f32> = (1..=50).map(|x| x as f32 * 0.1).collect();
let b = Matrix::from_vec(5, 10, b_data).unwrap();
let mut result = Matrix::zeros(1, 10);
a.matmul_wasm_tiled(&b, &mut result).unwrap();
let mut expected = Matrix::zeros(1, 10);
a.matmul_naive(&b, &mut expected).unwrap();
for j in 0..10 {
assert!(
(result.get(0, j).unwrap() - expected.get(0, j).unwrap()).abs() < 1e-3,
"Mismatch at col {}: wasm_tiled={}, naive={}",
j,
result.get(0, j).unwrap(),
expected.get(0, j).unwrap()
);
}
}
#[test]
fn test_matmul_wasm_tiled_identity() {
let a = Matrix::from_vec(
4,
4,
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0,
],
)
.unwrap();
let identity = Matrix::identity(4);
let mut result = Matrix::zeros(4, 4);
a.matmul_wasm_tiled(&identity, &mut result).unwrap();
assert_eq!(result.as_slice(), a.as_slice());
}
#[test]
fn test_matmul_wasm_tiled_large_mixed() {
let a_data: Vec<f32> = (0..50).map(|x| (x as f32) * 0.1).collect();
let a = Matrix::from_vec(5, 10, a_data).unwrap();
let b_data: Vec<f32> = (0..190).map(|x| (x as f32) * 0.01).collect();
let b = Matrix::from_vec(10, 19, b_data).unwrap();
let mut result = Matrix::zeros(5, 19);
a.matmul_wasm_tiled(&b, &mut result).unwrap();
let mut expected = Matrix::zeros(5, 19);
a.matmul_naive(&b, &mut expected).unwrap();
for i in 0..5 {
for j in 0..19 {
assert!(
(result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-2,
"Mismatch at ({}, {}): wasm_tiled={}, naive={}",
i,
j,
result.get(i, j).unwrap(),
expected.get(i, j).unwrap()
);
}
}
}
#[test]
fn falsify_mm_001_shape_correctness() {
for &(m, p, n) in &[(1, 1, 1), (2, 3, 4), (16, 32, 8), (1, 100, 1), (64, 1, 64)] {
let a = Matrix::from_vec(m, p, vec![1.0; m * p]).unwrap();
let b = Matrix::from_vec(p, n, vec![1.0; p * n]).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(
(c.rows(), c.cols()),
(m, n),
"FALSIFIED MM-001: matmul([{m},{p}], [{p},{n}]) shape = [{},{}], expected [{m},{n}]",
c.rows(),
c.cols()
);
}
}
#[test]
fn falsify_mm_005_identity_matrix() {
let a = Matrix::from_vec(3, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
let eye =
Matrix::from_vec(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap();
let ai = a.matmul(&eye).unwrap();
let ia = eye.matmul(&a).unwrap();
for i in 0..3 {
for j in 0..3 {
let expected = a.get(i, j).unwrap();
assert!(
(*ai.get(i, j).unwrap() - expected).abs() < 1e-6,
"FALSIFIED MM-005: (A*I)[{i},{j}] = {}, expected {expected}",
ai.get(i, j).unwrap()
);
assert!(
(*ia.get(i, j).unwrap() - expected).abs() < 1e-6,
"FALSIFIED MM-005: (I*A)[{i},{j}] = {}, expected {expected}",
ia.get(i, j).unwrap()
);
}
}
}
#[test]
fn falsify_mm_002_numerical_accuracy() {
let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let c = a.matmul(&b).unwrap();
let expected = [19.0, 22.0, 43.0, 50.0];
for (i, &exp) in expected.iter().enumerate() {
let row = i / 2;
let col = i % 2;
let val = *c.get(row, col).unwrap();
assert!(
(val - exp).abs() < 1e-5,
"FALSIFIED MM-002: C[{row},{col}] = {val}, expected {exp}"
);
}
}
#[test]
fn falsify_mm_002b_zero_annihilation() {
let zero = Matrix::from_vec(3, 4, vec![0.0; 12]).unwrap();
let b = Matrix::from_vec(4, 2, vec![1.0; 8]).unwrap();
let c = zero.matmul(&b).unwrap();
for i in 0..3 {
for j in 0..2 {
let val = *c.get(i, j).unwrap();
assert!(
val.abs() < 1e-10,
"FALSIFIED MM-002b: zeros*B [{i},{j}] = {val}, expected 0"
);
}
}
}
}
#[cfg(test)]
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
mod gpu_tests {
use super::*;
#[test]
fn test_matmul_gpu_identity() {
use crate::backends::gpu::GpuBackend;
if !GpuBackend::is_available() {
eprintln!("GPU not available, skipping test_matmul_gpu_identity");
return;
}
let n = 500;
let a_data: Vec<f32> = (0..n * n).map(|i| (i % 100) as f32 * 0.01).collect();
let mut i_data = vec![0.0f32; n * n];
for i in 0..n {
i_data[i * n + i] = 1.0;
}
let a = Matrix::from_vec(n, n, a_data.clone()).expect("valid matrix A");
let identity = Matrix::from_vec(n, n, i_data).expect("valid identity matrix");
let result = a.matmul(&identity).expect("matmul should succeed");
assert_eq!(result.rows(), n);
assert_eq!(result.cols(), n);
let check_indices = [(0, 0), (0, n - 1), (n - 1, 0), (n - 1, n - 1), (n / 2, n / 2)];
for &(r, c) in &check_indices {
let expected = a_data[r * n + c];
let actual = *result.get(r, c).unwrap();
assert!(
(actual - expected).abs() < 1e-2,
"A*I mismatch at ({},{}): gpu={}, expected={}",
r,
c,
actual,
expected
);
}
}
#[test]
fn test_matmul_gpu_ones() {
use crate::backends::gpu::GpuBackend;
if !GpuBackend::is_available() {
eprintln!("GPU not available, skipping test_matmul_gpu_ones");
return;
}
let m = 500;
let k = 500;
let n = 500;
let a = Matrix::from_vec(m, k, vec![1.0f32; m * k]).expect("valid matrix A");
let b = Matrix::from_vec(k, n, vec![1.0f32; k * n]).expect("valid matrix B");
let result = a.matmul(&b).expect("matmul should succeed");
assert_eq!(result.rows(), m);
assert_eq!(result.cols(), n);
let expected = k as f32;
for i in 0..10 {
for j in 0..10 {
assert!(
(result.get(i, j).unwrap() - expected).abs() < 1.0,
"C[{},{}] = {}, expected {}",
i,
j,
result.get(i, j).unwrap(),
expected
);
}
}
}
#[test]
fn test_matmul_gpu_direct() {
use crate::backends::gpu::GpuBackend;
if !GpuBackend::is_available() {
eprintln!("GPU not available, skipping test_matmul_gpu_direct");
return;
}
let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid A");
let b = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("valid B");
let result = a.matmul_gpu(&b).expect("matmul_gpu should succeed");
assert_eq!(result.rows(), 2);
assert_eq!(result.cols(), 2);
assert!(
(result.get(0, 0).unwrap() - 58.0).abs() < 1e-2,
"Expected 58.0, got {}",
result.get(0, 0).unwrap()
);
assert!(
(result.get(0, 1).unwrap() - 64.0).abs() < 1e-2,
"Expected 64.0, got {}",
result.get(0, 1).unwrap()
);
assert!(
(result.get(1, 0).unwrap() - 139.0).abs() < 1e-2,
"Expected 139.0, got {}",
result.get(1, 0).unwrap()
);
assert!(
(result.get(1, 1).unwrap() - 154.0).abs() < 1e-2,
"Expected 154.0, got {}",
result.get(1, 1).unwrap()
);
}
#[test]
fn test_matmul_gpu_not_available_path() {
use crate::backends::gpu::GpuBackend;
if !GpuBackend::is_available() {
let a = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
let b = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
let result = a.matmul_gpu(&b);
assert!(result.is_err(), "matmul_gpu should fail without GPU");
}
}
}