use super::*;
use crate::TruenoError;
#[test]
fn test_matmul_basic() {
let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 2);
assert_eq!(c.get(0, 0), Some(&19.0));
assert_eq!(c.get(0, 1), Some(&22.0));
assert_eq!(c.get(1, 0), Some(&43.0));
assert_eq!(c.get(1, 1), Some(&50.0));
}
#[test]
fn test_matmul_identity() {
let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let identity = Matrix::identity(2);
let result = a.matmul(&identity).unwrap();
assert_eq!(result.get(0, 0), Some(&1.0));
assert_eq!(result.get(0, 1), Some(&2.0));
assert_eq!(result.get(1, 0), Some(&3.0));
assert_eq!(result.get(1, 1), Some(&4.0));
}
#[test]
fn test_matmul_zeros() {
let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let zeros = Matrix::zeros(2, 2);
let result = a.matmul(&zeros).unwrap();
for &val in result.as_slice() {
assert_eq!(val, 0.0);
}
}
#[test]
fn test_matmul_dimension_mismatch() {
let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let b = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let result = a.matmul(&b);
assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
}
#[test]
fn test_matmul_non_square() {
let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let b = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 2);
assert_eq!(c.get(0, 0), Some(&58.0));
assert_eq!(c.get(0, 1), Some(&64.0));
assert_eq!(c.get(1, 0), Some(&139.0));
assert_eq!(c.get(1, 1), Some(&154.0));
}
#[test]
fn test_matmul_single_element() {
let a = Matrix::from_vec(1, 1, vec![3.0]).unwrap();
let b = Matrix::from_vec(1, 1, vec![4.0]).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.rows(), 1);
assert_eq!(c.cols(), 1);
assert_eq!(c.get(0, 0), Some(&12.0));
}
#[test]
fn test_matmul_remainder_rows() {
let a = Matrix::from_vec(5, 8, (0..40).map(|i| (i + 1) as f32).collect()).unwrap();
let b = Matrix::from_vec(8, 6, (0..48).map(|i| (i + 1) as f32).collect()).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.rows(), 5);
assert_eq!(c.cols(), 6);
let expected_00 = (1..=8)
.zip((0..48).step_by(6).map(|i| (i + 1) as f32))
.map(|(a, b)| a as f32 * b)
.sum::<f32>();
assert!((c.get(0, 0).unwrap() - expected_00).abs() < 1.0);
}
#[test]
fn test_matmul_remainder_rows_7() {
let a = Matrix::from_vec(7, 8, (0..56).map(|_| 1.0f32).collect()).unwrap();
let b = Matrix::from_vec(8, 5, (0..40).map(|_| 1.0f32).collect()).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.rows(), 7);
assert_eq!(c.cols(), 5);
for &val in c.as_slice() {
assert!((val - 8.0).abs() < 1e-5);
}
}
#[test]
fn test_matmul_public_api_small() {
let a = Matrix::from_vec(8, 8, (0..64).map(|i| i as f32).collect()).unwrap();
let b = Matrix::identity(8);
let result = a.matmul(&b).unwrap();
assert_eq!(result.as_slice(), a.as_slice());
}
#[test]
fn test_matmul_public_api_large() {
let size = 128;
let a = Matrix::identity(size);
let b =
Matrix::from_vec(size, size, (0..size * size).map(|i| ((i * 2) % 100) as f32).collect())
.unwrap();
let result = a.matmul(&b).unwrap();
assert_eq!(result.as_slice(), b.as_slice());
}
#[test]
fn test_matmul_public_api_rectangular() {
let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let b = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let result = a.matmul(&b).unwrap();
assert_eq!(result.rows(), 2);
assert_eq!(result.cols(), 2);
assert!((result.get(0, 0).unwrap() - 58.0).abs() < 1e-5);
assert!((result.get(0, 1).unwrap() - 64.0).abs() < 1e-5);
assert!((result.get(1, 0).unwrap() - 139.0).abs() < 1e-5);
assert!((result.get(1, 1).unwrap() - 154.0).abs() < 1e-5);
}
#[test]
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
fn test_gpu_availability() {
use crate::backends::gpu::GpuBackend;
let _available = GpuBackend::is_available();
}
#[test]
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
#[ignore] fn test_gpu_matmul_basic() {
use crate::backends::gpu::GpuBackend;
if !GpuBackend::is_available() {
eprintln!("GPU not available, skipping test");
return;
}
let a = Matrix::from_vec(
4,
4,
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0],
)
.unwrap();
let b = Matrix::from_vec(
4,
4,
vec![16.0, 15.0, 14.0, 13.0, 12.0, 11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
)
.unwrap();
let c = a.matmul(&b).expect("matmul should succeed");
assert_eq!(c.rows(), 4);
assert_eq!(c.cols(), 4);
assert!((c.get(0, 0).unwrap() - 80.0).abs() < 1e-4);
}