use crate::blas3::{gemm_ex_epilogue, Epilogue};
use crate::cholesky::cholesky;
use crate::lu::lu_factorize;
use crate::qr::qr_factorize;
use crate::svd::svd;
use crate::trsm::{trsm, DiagonalType, TriangularSide};
#[test]
fn test_lu_identity() {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let lu = lu_factorize(&a, 2).expect("lu ok");
let x = lu.solve(&[3.0, 7.0]).expect("solve ok");
assert!((x[0] - 3.0).abs() < 1e-5);
assert!((x[1] - 7.0).abs() < 1e-5);
}
#[test]
fn test_lu_2x2() {
let a = [2.0, 1.0, 1.0, 3.0_f32];
let lu = lu_factorize(&a, 2).expect("lu ok");
let x = lu.solve(&[5.0, 7.0]).expect("solve ok");
assert!((x[0] - 1.6).abs() < 1e-5, "x[0]={}", x[0]);
assert!((x[1] - 1.8).abs() < 1e-5, "x[1]={}", x[1]);
}
#[test]
fn test_lu_3x3() {
let a = [2.0, -1.0, 0.0, -1.0, 2.0, -1.0, 0.0, -1.0, 2.0_f32];
let lu = lu_factorize(&a, 3).expect("lu ok");
let b = [1.0, 0.0, 1.0_f32];
let x = lu.solve(&b).expect("solve ok");
for i in 0..3 {
let mut ax_i = 0.0f32;
for j in 0..3 {
ax_i += a[i * 3 + j] * x[j];
}
assert!(
(ax_i - b[i]).abs() < 1e-4,
"Ax[{i}]={ax_i}, b[{i}]={}",
b[i]
);
}
}
#[test]
fn test_lu_backward_error() {
let a = [4.0, 2.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 6.0_f32];
let lu = lu_factorize(&a, 3).expect("lu ok");
let l = lu.extract_l();
let u = lu.extract_u();
let p = lu.extract_p();
let n = 3;
let mut pa = vec![0.0f32; n * n];
let mut lu_product = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..n {
for k in 0..n {
pa[i * n + j] += p[i * n + k] * a[k * n + j];
lu_product[i * n + j] += l[i * n + k] * u[k * n + j];
}
}
}
let mut max_err = 0.0f32;
for i in 0..n * n {
max_err = max_err.max((pa[i] - lu_product[i]).abs());
}
assert!(max_err < 1e-4, "PA ≠ LU, max_err={max_err}");
}
#[test]
fn test_lu_singular_detected() {
let a = [1.0, 2.0, 2.0, 4.0_f32]; assert!(lu_factorize(&a, 2).is_err());
}
#[test]
fn test_qr_identity() {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let qr = qr_factorize(&a, 2, 2).expect("qr ok");
let r = qr.extract_r();
assert!(r[0].abs() > 0.9);
assert!(r[3].abs() > 0.9);
}
#[test]
fn test_qr_orthogonality() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f32]; let qr = qr_factorize(&a, 3, 2).expect("qr ok");
let q = qr.extract_q();
let m = 3;
let mut qtq = vec![0.0f32; m * m];
for i in 0..m {
for j in 0..m {
for k in 0..m {
qtq[i * m + j] += q[k * m + i] * q[k * m + j];
}
}
}
for i in 0..m {
for j in 0..m {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(qtq[i * m + j] - expected).abs() < 1e-4,
"Q^TQ[{i},{j}] = {}, expected {expected}",
qtq[i * m + j]
);
}
}
}
#[test]
fn test_qr_reconstruction() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f32]; let m = 3;
let n = 2;
let qr = qr_factorize(&a, m, n).expect("qr ok");
let q = qr.extract_q();
let r = qr.extract_r();
for i in 0..m {
for j in 0..n {
let mut qr_ij = 0.0f32;
for k in 0..n {
qr_ij += q[i * m + k] * r[k * n + j];
}
assert!(
(qr_ij - a[i * n + j]).abs() < 1e-3,
"QR[{i},{j}]={qr_ij}, A[{i},{j}]={}",
a[i * n + j]
);
}
}
}
#[test]
fn test_qr_solve_least_squares() {
let a = [1.0, 1.0, 1.0, 2.0, 1.0, 3.0_f32];
let b = [1.0, 2.0, 3.0_f32];
let qr = qr_factorize(&a, 3, 2).expect("qr ok");
let x = qr.solve(&b).expect("solve ok");
let mut residual = 0.0f32;
for i in 0..3 {
let mut ax = 0.0f32;
for j in 0..2 {
ax += a[i * 2 + j] * x[j];
}
residual += (ax - b[i]).powi(2);
}
assert!(
residual.sqrt() < 1.0,
"Residual too large: {}",
residual.sqrt()
);
}
#[test]
fn test_svd_2x2_identity() {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let result = svd(&a, 2, 2).expect("svd ok");
assert!(
(result.sigma[0] - 1.0).abs() < 1e-4,
"σ[0]={}",
result.sigma[0]
);
assert!(
(result.sigma[1] - 1.0).abs() < 1e-4,
"σ[1]={}",
result.sigma[1]
);
}
#[test]
fn test_svd_singular_values_nonneg_decreasing() {
let a = [3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0_f32];
let result = svd(&a, 3, 3).expect("svd ok");
for s in &result.sigma {
assert!(*s >= -1e-6, "Negative singular value: {s}");
}
for i in 0..result.sigma.len() - 1 {
assert!(
result.sigma[i] >= result.sigma[i + 1] - 1e-6,
"σ[{}]={} < σ[{}]={}",
i,
result.sigma[i],
i + 1,
result.sigma[i + 1]
);
}
}
#[test]
fn test_svd_reconstruction() {
let a = [1.0, 2.0, 3.0, 4.0_f32]; let m = 2;
let n = 2;
let result = svd(&a, m, n).expect("svd ok");
let min_mn = m.min(n);
let mut recon = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f64;
for k in 0..min_mn {
sum += f64::from(result.u[i * m + k])
* f64::from(result.sigma[k])
* f64::from(result.vt[k * n + j]);
}
recon[i * n + j] = sum as f32;
}
}
for i in 0..m * n {
assert!(
(recon[i] - a[i]).abs() < 1e-3,
"Reconstruction error at {i}: recon={}, orig={}",
recon[i],
a[i]
);
}
}
#[test]
fn test_svd_orthogonality_u() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f32]; let result = svd(&a, 3, 2).expect("svd ok");
let m = result.m;
let min_mn = result.m.min(result.n);
for i in 0..min_mn {
for j in 0..min_mn {
let mut dot = 0.0f64;
for k in 0..m {
dot += f64::from(result.u[k * m + i]) * f64::from(result.u[k * m + j]);
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-3,
"U orthogonality: U^T U[{i},{j}]={dot}, expected {expected}"
);
}
}
}
#[test]
fn test_cholesky_2x2() {
let a = [4.0, 2.0, 2.0, 3.0_f32];
let chol = cholesky(&a, 2).expect("cholesky ok");
let n = 2;
let mut recon = vec![0.0f32; 4];
for i in 0..n {
for j in 0..n {
for k in 0..n {
recon[i * n + j] += chol.l[i * n + k] * chol.l[j * n + k];
}
}
}
for i in 0..4 {
assert!(
(recon[i] - a[i]).abs() < 1e-5,
"LL^T[{i}]={}, A[{i}]={}",
recon[i],
a[i]
);
}
}
#[test]
fn test_cholesky_solve() {
let a = [4.0, 2.0, 2.0, 3.0_f32];
let chol = cholesky(&a, 2).expect("cholesky ok");
let x = chol.solve(&[8.0, 7.0]).expect("solve ok");
let ax0 = 4.0 * x[0] + 2.0 * x[1];
let ax1 = 2.0 * x[0] + 3.0 * x[1];
assert!((ax0 - 8.0).abs() < 1e-4);
assert!((ax1 - 7.0).abs() < 1e-4);
}
#[test]
fn test_cholesky_not_positive_definite() {
let a = [1.0, 3.0, 3.0, 1.0_f32];
assert!(cholesky(&a, 2).is_err());
}
#[test]
fn test_cholesky_3x3() {
let a = [4.0, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0_f32];
let chol = cholesky(&a, 3).expect("cholesky ok");
let n = 3;
let mut recon = vec![0.0f32; 9];
for i in 0..n {
for j in 0..n {
for k in 0..n {
recon[i * n + j] += chol.l[i * n + k] * chol.l[j * n + k];
}
}
}
for i in 0..9 {
assert!(
(recon[i] - a[i]).abs() < 1e-3,
"LL^T[{i}]={}, A[{i}]={}",
recon[i],
a[i]
);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_lu_solve_residual(
a00 in -10.0f32..10.0,
a01 in -10.0f32..10.0,
a10 in -10.0f32..10.0,
a11 in -10.0f32..10.0,
b0 in -10.0f32..10.0,
b1 in -10.0f32..10.0,
) {
let a = [a00, a01, a10, a11];
if let Ok(lu) = lu_factorize(&a, 2) {
if let Ok(x) = lu.solve(&[b0, b1]) {
let r0 = a00 * x[0] + a01 * x[1] - b0;
let r1 = a10 * x[0] + a11 * x[1] - b1;
let residual = (r0 * r0 + r1 * r1).sqrt();
let b_norm = (b0 * b0 + b1 * b1).sqrt() + 1e-10;
prop_assert!(residual / b_norm < 1e-3, "Residual too large: {residual}");
}
}
}
#[test]
fn prop_svd_values_nonneg(
a00 in -5.0f32..5.0,
a01 in -5.0f32..5.0,
a10 in -5.0f32..5.0,
a11 in -5.0f32..5.0,
) {
let a = [a00, a01, a10, a11];
if let Ok(result) = svd(&a, 2, 2) {
for s in &result.sigma {
prop_assert!(*s >= -1e-5, "Negative singular value: {s}");
}
if result.sigma.len() >= 2 {
prop_assert!(result.sigma[0] >= result.sigma[1] - 1e-5);
}
}
}
}
}
#[test]
fn test_trsm_lower_2x2() -> Result<(), Box<dyn std::error::Error>> {
let l = [2.0, 0.0, 3.0, 4.0_f32];
let b = [1.0, 2.0_f32];
let result = trsm(&l, &b, 2, 1, TriangularSide::Lower, DiagonalType::NonUnit)?;
assert!((result.x[0] - 0.5).abs() < 1e-5);
assert!((result.x[1] - 0.125).abs() < 1e-5);
Ok(())
}
#[test]
fn test_trsm_upper_2x2() -> Result<(), Box<dyn std::error::Error>> {
let u = [3.0, 1.0, 0.0, 2.0_f32];
let b = [5.0, 4.0_f32];
let result = trsm(&u, &b, 2, 1, TriangularSide::Upper, DiagonalType::NonUnit)?;
assert!((result.x[0] - 1.0).abs() < 1e-5);
assert!((result.x[1] - 2.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_trsm_unit_diagonal() -> Result<(), Box<dyn std::error::Error>> {
let l = [1.0, 0.0, 3.0, 1.0_f32];
let b = [1.0, 5.0_f32];
let result = trsm(&l, &b, 2, 1, TriangularSide::Lower, DiagonalType::Unit)?;
assert!((result.x[0] - 1.0).abs() < 1e-5);
assert!((result.x[1] - 2.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_trsm_multiple_rhs() -> Result<(), Box<dyn std::error::Error>> {
let l = [2.0, 0.0, 1.0, 3.0_f32];
let b = [2.0, 4.0, 5.0, 7.0_f32];
let result = trsm(&l, &b, 2, 2, TriangularSide::Lower, DiagonalType::NonUnit)?;
assert!((result.x[0] - 1.0).abs() < 1e-5);
assert!((result.x[2] - 4.0 / 3.0).abs() < 1e-4);
assert!((result.x[1] - 2.0).abs() < 1e-5);
assert!((result.x[3] - 5.0 / 3.0).abs() < 1e-4);
Ok(())
}
#[test]
fn test_trsm_backward_error() -> Result<(), Box<dyn std::error::Error>> {
let l = [3.0, 0.0, 0.0, 2.0, 5.0, 0.0, 1.0, 4.0, 6.0_f32];
let b = [1.0, 2.0, 3.0_f32];
let result = trsm(&l, &b, 3, 1, TriangularSide::Lower, DiagonalType::NonUnit)?;
let n = 3;
let mut ax = vec![0.0f32; n];
for i in 0..n {
for j in 0..n {
ax[i] += l[i * n + j] * result.x[j];
}
}
let mut num = 0.0f32;
let mut den = 0.0f32;
for i in 0..n {
num += (ax[i] - b[i]).powi(2);
den += b[i].powi(2);
}
let rel_err = (num / den).sqrt();
assert!(rel_err < 1e-5, "TRSM backward error {rel_err} too large");
Ok(())
}
#[test]
fn test_trsm_singular_detected() {
let a = [1.0, 0.0, 0.0, 0.0_f32]; let b = [1.0, 2.0_f32];
let result = trsm(&a, &b, 2, 1, TriangularSide::Lower, DiagonalType::NonUnit);
assert!(result.is_err());
}
use crate::blas3::{symm, syrk, trmm};
#[test]
fn test_syrk_identity() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let mut c = [0.0_f32; 4];
syrk(&a, &mut c, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 1.0).abs() < 1e-6);
assert!((c[3] - 1.0).abs() < 1e-6);
assert!(c[1].abs() < 1e-6);
assert!(c[2].abs() < 1e-6);
Ok(())
}
#[test]
fn test_syrk_known_value() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 2.0, 3.0, 4.0_f32];
let mut c = [0.0_f32; 4];
syrk(&a, &mut c, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 5.0).abs() < 1e-4);
assert!((c[1] - 11.0).abs() < 1e-4);
assert!((c[2] - 11.0).abs() < 1e-4); assert!((c[3] - 25.0).abs() < 1e-4);
Ok(())
}
#[test]
fn test_syrk_symmetry() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f32]; let mut c = [0.0_f32; 4];
syrk(&a, &mut c, 2, 3, 2.0, 0.0)?;
assert!(
(c[1] - c[2]).abs() < 1e-6,
"Not symmetric: c[1]={}, c[2]={}",
c[1],
c[2]
);
Ok(())
}
#[test]
fn test_syrk_alpha_beta() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let mut c = [10.0, 0.0, 0.0, 10.0_f32];
syrk(&a, &mut c, 2, 2, 2.0, 0.5)?;
assert!((c[0] - 7.0).abs() < 1e-5);
assert!((c[3] - 7.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_trmm_identity() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let mut b = [3.0, 4.0, 5.0, 6.0_f32]; trmm(&a, &mut b, 2, 2, 1.0)?;
assert!((b[0] - 3.0).abs() < 1e-6);
assert!((b[1] - 4.0).abs() < 1e-6);
assert!((b[2] - 5.0).abs() < 1e-6);
assert!((b[3] - 6.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_trmm_lower_triangular() -> Result<(), Box<dyn std::error::Error>> {
let a = [2.0, 0.0, 3.0, 4.0_f32];
let mut b = [1.0, 1.0_f32]; trmm(&a, &mut b, 2, 1, 1.0)?;
assert!((b[0] - 2.0).abs() < 1e-5);
assert!((b[1] - 7.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_trmm_alpha_scaling() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let mut b = [5.0, 10.0_f32];
trmm(&a, &mut b, 2, 1, 3.0)?;
assert!((b[0] - 15.0).abs() < 1e-5);
assert!((b[1] - 30.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_symm_known_product() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 2.0, 2.0, 3.0_f32];
let b = [1.0, 0.0, 0.0, 1.0_f32];
let mut c = [0.0_f32; 4];
symm(&a, &b, &mut c, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 1.0).abs() < 1e-5);
assert!((c[1] - 2.0).abs() < 1e-5);
assert!((c[2] - 2.0).abs() < 1e-5);
assert!((c[3] - 3.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_symm_alpha_beta() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0_f32]; let b = [2.0, 3.0_f32]; let mut c = [10.0, 20.0_f32]; symm(&a, &b, &mut c, 2, 1, 2.0, 0.5)?;
assert!((c[0] - 9.0).abs() < 1e-5);
assert!((c[1] - 16.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_blas3_dimension_mismatch() {
let a = [1.0_f32; 4]; let mut c = [0.0_f32; 3]; assert!(syrk(&a, &mut c, 2, 2, 1.0, 0.0).is_err());
}
use crate::blas3::syr2k;
#[test]
fn test_syr2k_identity() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let b = [1.0, 0.0, 0.0, 1.0_f32];
let mut c = [0.0_f32; 4];
syr2k(&a, &b, &mut c, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 2.0).abs() < 1e-5);
assert!((c[3] - 2.0).abs() < 1e-5);
assert!(c[1].abs() < 1e-5);
assert!(c[2].abs() < 1e-5);
Ok(())
}
#[test]
fn test_syr2k_known_value() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 2.0_f32];
let b = [3.0, 4.0_f32];
let mut c = [0.0_f32; 1];
syr2k(&a, &b, &mut c, 1, 2, 1.0, 0.0)?;
assert!((c[0] - 22.0).abs() < 1e-4, "got {}", c[0]);
Ok(())
}
#[test]
fn test_syr2k_symmetry() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 2.0, 3.0, 4.0_f32]; let b = [5.0, 6.0, 7.0, 8.0_f32]; let mut c = [0.0_f32; 4];
syr2k(&a, &b, &mut c, 2, 2, 1.0, 0.0)?;
assert!(
(c[1] - c[2]).abs() < 1e-5,
"Not symmetric: c[1]={}, c[2]={}",
c[1],
c[2]
);
Ok(())
}
#[test]
fn test_syr2k_alpha_beta() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0_f32];
let b = [1.0, 0.0, 0.0, 1.0_f32];
let mut c = [10.0, 0.0, 0.0, 10.0_f32];
syr2k(&a, &b, &mut c, 2, 2, 2.0, 0.5)?;
assert!((c[0] - 9.0).abs() < 1e-5);
assert!((c[3] - 9.0).abs() < 1e-5);
Ok(())
}
use crate::blas3::{f32_to_f16, gemm_ex, gemm_strided_batched};
#[test]
fn test_gemm_ex_identity() -> Result<(), Box<dyn std::error::Error>> {
let a = [
f32_to_f16(1.0),
f32_to_f16(0.0),
f32_to_f16(0.0),
f32_to_f16(1.0),
];
let b = [
f32_to_f16(1.0),
f32_to_f16(2.0),
f32_to_f16(3.0),
f32_to_f16(4.0),
];
let mut c = [0.0_f32; 4];
gemm_ex(&a, &b, &mut c, 2, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 1.0).abs() < 1e-3);
assert!((c[1] - 2.0).abs() < 1e-3);
assert!((c[2] - 3.0).abs() < 1e-3);
assert!((c[3] - 4.0).abs() < 1e-3);
Ok(())
}
#[test]
fn test_gemm_ex_alpha_beta() -> Result<(), Box<dyn std::error::Error>> {
let i16 = [
f32_to_f16(1.0),
f32_to_f16(0.0),
f32_to_f16(0.0),
f32_to_f16(1.0),
];
let mut c = [10.0, 0.0, 0.0, 10.0_f32];
gemm_ex(&i16, &i16, &mut c, 2, 2, 2, 2.0, 0.5)?;
assert!((c[0] - 7.0).abs() < 1e-3);
assert!((c[3] - 7.0).abs() < 1e-3);
Ok(())
}
#[test]
fn test_gemm_ex_matmul_2x3_3x2() -> Result<(), Box<dyn std::error::Error>> {
let a: Vec<u16> = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let b: Vec<u16> = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let mut c = [0.0_f32; 4];
gemm_ex(&a, &b, &mut c, 2, 2, 3, 1.0, 0.0)?;
assert!((c[0] - 58.0).abs() < 0.5);
assert!((c[1] - 64.0).abs() < 0.5);
assert!((c[2] - 139.0).abs() < 0.5);
assert!((c[3] - 154.0).abs() < 0.5);
Ok(())
}
#[test]
fn test_f16_roundtrip() {
let vals = [0.0_f32, 1.0, -1.0, 0.5, 65504.0];
for &v in &vals {
let h = f32_to_f16(v);
let mut c = [0.0_f32];
gemm_ex(&[h], &[f32_to_f16(1.0)], &mut c, 1, 1, 1, 1.0, 0.0).expect("gemm_ex roundtrip ok");
assert!(
(c[0] - v).abs() < 0.01,
"f16 roundtrip failed for {v}: got {}",
c[0]
);
}
}
#[test]
fn test_gemm_strided_batched_single() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 2.0, 3.0, 4.0_f32]; let b = [5.0, 6.0, 7.0, 8.0_f32]; let mut c = [0.0_f32; 4];
gemm_strided_batched(&a, 4, &b, 4, &mut c, 4, 1, 2, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 19.0).abs() < 1e-5);
assert!((c[1] - 22.0).abs() < 1e-5);
assert!((c[2] - 43.0).abs() < 1e-5);
assert!((c[3] - 50.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_gemm_strided_batched_two_batches() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0_f32];
let b = [1.0, 2.0, 3.0, 4.0, 1.0, 0.0, 0.0, 1.0_f32];
let mut c = [0.0_f32; 8];
gemm_strided_batched(&a, 4, &b, 4, &mut c, 4, 2, 2, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 1.0).abs() < 1e-5);
assert!((c[1] - 2.0).abs() < 1e-5);
assert!((c[2] - 3.0).abs() < 1e-5);
assert!((c[3] - 4.0).abs() < 1e-5);
assert!((c[4] - 2.0).abs() < 1e-5);
assert!((c[5] - 0.0).abs() < 1e-5);
assert!((c[6] - 0.0).abs() < 1e-5);
assert!((c[7] - 2.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_gemm_strided_batched_alpha_beta() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0_f32]; let b = [1.0, 0.0, 0.0, 1.0_f32]; let mut c = [10.0, 0.0, 0.0, 10.0_f32];
gemm_strided_batched(&a, 4, &b, 4, &mut c, 4, 1, 2, 2, 2, 3.0, 0.5)?;
assert!((c[0] - 8.0).abs() < 1e-5);
assert!((c[3] - 8.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_gemm_strided_batched_zero_batch() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0_f32; 4];
let b = [1.0_f32; 4];
let mut c = [999.0_f32; 4];
gemm_strided_batched(&a, 4, &b, 4, &mut c, 4, 0, 2, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 999.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_gemm_ex_buffer_mismatch() {
let a = [f32_to_f16(1.0); 4]; let b = [f32_to_f16(1.0); 2]; let mut c = [0.0_f32; 4];
let result = gemm_ex(&a, &b, &mut c, 2, 2, 2, 1.0, 0.0);
assert!(result.is_err());
}
use crate::Solver;
#[test]
fn test_solver_trait_lu() -> Result<(), Box<dyn std::error::Error>> {
let lu = lu_factorize(&[2.0, 1.0, 1.0, 3.0_f32], 2)?;
let solver: &dyn Solver = &lu;
assert_eq!(solver.dimension(), 2);
let x = solver.solve(&[5.0, 7.0])?;
assert!((x[0] - 1.6).abs() < 1e-5);
assert!((x[1] - 1.8).abs() < 1e-5);
Ok(())
}
#[test]
fn test_solver_trait_cholesky() -> Result<(), Box<dyn std::error::Error>> {
let chol = cholesky(&[4.0, 2.0, 2.0, 3.0_f32], 2)?;
let solver: &dyn Solver = &chol;
assert_eq!(solver.dimension(), 2);
let x = solver.solve(&[8.0, 7.0])?;
assert!((x[0] - 1.25).abs() < 1e-4);
assert!((x[1] - 1.5).abs() < 1e-4);
Ok(())
}
#[test]
fn test_solver_trait_qr() -> Result<(), Box<dyn std::error::Error>> {
let qr = qr_factorize(&[1.0, 0.0, 0.0, 1.0_f32], 2, 2)?;
let solver: &dyn Solver = &qr;
assert_eq!(solver.dimension(), 2);
let x = solver.solve(&[3.0, 7.0])?;
assert!((x[0] - 3.0).abs() < 1e-5);
assert!((x[1] - 7.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_gemm_ex_epilogue_none() -> Result<(), Box<dyn std::error::Error>> {
let a: Vec<u16> = [1.0, 0.0, 0.0, 1.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let b: Vec<u16> = [3.0, 4.0, 5.0, 6.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let mut c = [0.0_f32; 4];
gemm_ex_epilogue(&a, &b, &mut c, 2, 2, 2, 1.0, 0.0, Epilogue::None, None)?;
assert!((c[0] - 3.0).abs() < 0.1);
assert!((c[3] - 6.0).abs() < 0.1);
Ok(())
}
#[test]
fn test_gemm_ex_epilogue_relu() -> Result<(), Box<dyn std::error::Error>> {
let a: Vec<u16> = [1.0, 0.0, 0.0, 1.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let b: Vec<u16> = [-3.0, 4.0, 5.0, -6.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let mut c = [0.0_f32; 4];
gemm_ex_epilogue(&a, &b, &mut c, 2, 2, 2, 1.0, 0.0, Epilogue::Relu, None)?;
assert!((c[0] - 0.0).abs() < 0.1); assert!((c[1] - 4.0).abs() < 0.1); assert!((c[3] - 0.0).abs() < 0.1); Ok(())
}
#[test]
fn test_gemm_ex_epilogue_bias() -> Result<(), Box<dyn std::error::Error>> {
let a: Vec<u16> = [1.0, 0.0, 0.0, 1.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let b: Vec<u16> = [1.0, 2.0, 3.0, 4.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let mut c = [0.0_f32; 4];
let bias = [10.0, 20.0_f32];
gemm_ex_epilogue(
&a,
&b,
&mut c,
2,
2,
2,
1.0,
0.0,
Epilogue::Bias,
Some(&bias),
)?;
assert!((c[0] - 11.0).abs() < 0.1); assert!((c[1] - 22.0).abs() < 0.1); Ok(())
}
#[test]
fn test_gemm_ex_epilogue_gelu() -> Result<(), Box<dyn std::error::Error>> {
let a: Vec<u16> = [1.0, 0.0, 0.0, 1.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let b: Vec<u16> = [2.0, 0.0, 0.0, -2.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let mut c = [0.0_f32; 4];
gemm_ex_epilogue(&a, &b, &mut c, 2, 2, 2, 1.0, 0.0, Epilogue::Gelu, None)?;
assert!(c[0] > 1.9);
assert!(c[3] < 0.0 && c[3] > -0.1);
Ok(())
}
#[test]
fn test_gemm_ex_epilogue_bias_required() {
let a: Vec<u16> = [1.0_f32].iter().map(|&v| f32_to_f16(v)).collect();
let b: Vec<u16> = [1.0_f32].iter().map(|&v| f32_to_f16(v)).collect();
let mut c = [0.0_f32];
let result = gemm_ex_epilogue(&a, &b, &mut c, 1, 1, 1, 1.0, 0.0, Epilogue::Bias, None);
assert!(result.is_err());
}
#[test]
fn test_falsify_lu_1x1() {
let a = [5.0_f32];
let lu = lu_factorize(&a, 1).expect("1x1 LU");
let x = lu.solve(&[15.0]).expect("1x1 solve");
assert!((x[0] - 3.0).abs() < 1e-5);
}
#[test]
fn test_falsify_cholesky_1x1() {
let a = [9.0_f32];
let chol = cholesky(&a, 1).expect("1x1 cholesky");
assert!((chol.l[0] - 3.0).abs() < 1e-5); let x = chol.solve(&[18.0]).expect("1x1 solve");
assert!((x[0] - 2.0).abs() < 1e-5); }
#[test]
fn test_falsify_cholesky_negative_diagonal() {
let a = [-1.0, 0.0, 0.0, 4.0_f32];
assert!(cholesky(&a, 2).is_err());
}
#[test]
fn test_falsify_svd_zero_matrix() {
let a = [0.0_f32; 4]; let result = svd(&a, 2, 2).expect("svd of zero");
for s in &result.sigma {
assert!(
s.abs() < 1e-5,
"Zero matrix should have zero singular values: {s}"
);
}
}
#[test]
fn test_falsify_svd_1x1() {
let a = [7.0_f32];
let result = svd(&a, 1, 1).expect("svd 1x1");
assert!((result.sigma[0] - 7.0).abs() < 1e-4);
}
#[test]
fn test_falsify_svd_rectangular_tall() {
let a = [3.0, 4.0, 0.0_f32];
let result = svd(&a, 3, 1).expect("svd 3x1");
assert!((result.sigma[0] - 5.0).abs() < 0.1, "σ={}", result.sigma[0]);
}
#[test]
fn test_falsify_trsm_1x1() -> Result<(), Box<dyn std::error::Error>> {
let a = [4.0_f32];
let b = [12.0_f32];
let result = trsm(&a, &b, 1, 1, TriangularSide::Lower, DiagonalType::NonUnit)?;
assert!((result.x[0] - 3.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_falsify_syrk_zero_k() -> Result<(), Box<dyn std::error::Error>> {
let a: &[f32] = &[];
let mut c = [10.0, 0.0, 0.0, 10.0_f32];
syrk(a, &mut c, 2, 0, 1.0, 0.5)?;
assert!((c[0] - 5.0).abs() < 1e-5); assert!((c[3] - 5.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_falsify_gemm_ex_1x1() -> Result<(), Box<dyn std::error::Error>> {
let a = [f32_to_f16(3.0)];
let b = [f32_to_f16(7.0)];
let mut c = [0.0_f32];
gemm_ex(&a, &b, &mut c, 1, 1, 1, 1.0, 0.0)?;
assert!((c[0] - 21.0).abs() < 0.5);
Ok(())
}
#[test]
fn test_falsify_epilogue_bias_relu() -> Result<(), Box<dyn std::error::Error>> {
let a: Vec<u16> = [1.0, 0.0, 0.0, 1.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let b: Vec<u16> = [-10.0, 5.0, 3.0, -7.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let mut c = [0.0_f32; 4];
let bias = [100.0, 200.0_f32];
gemm_ex_epilogue(
&a,
&b,
&mut c,
2,
2,
2,
1.0,
0.0,
Epilogue::BiasRelu,
Some(&bias),
)?;
assert!(c[0] > 80.0, "BiasRelu c[0]={}", c[0]);
assert!(c[1] > 190.0, "BiasRelu c[1]={}", c[1]);
Ok(())
}
#[test]
fn test_falsify_epilogue_bias_gelu() -> Result<(), Box<dyn std::error::Error>> {
let a: Vec<u16> = [1.0, 0.0, 0.0, 1.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let b: Vec<u16> = [2.0, -2.0, 0.0, 3.0_f32]
.iter()
.map(|&v| f32_to_f16(v))
.collect();
let mut c = [0.0_f32; 4];
let bias = [1.0, 1.0_f32];
gemm_ex_epilogue(
&a,
&b,
&mut c,
2,
2,
2,
1.0,
0.0,
Epilogue::BiasGelu,
Some(&bias),
)?;
assert!(c[0] > 2.9, "BiasGelu c[0]={}", c[0]);
assert!(c[1] < 0.0 && c[1] > -0.2, "BiasGelu c[1]={}", c[1]);
Ok(())
}
#[test]
fn test_falsify_solver_trait_dimension_check() -> Result<(), Box<dyn std::error::Error>> {
let lu = lu_factorize(&[4.0, 1.0, 1.0, 3.0_f32], 2)?;
let solver: &dyn crate::Solver = &lu;
let result = solver.solve(&[1.0, 2.0, 3.0]); assert!(result.is_err(), "Solver should reject wrong-dimension RHS");
Ok(())
}
#[test]
fn test_falsify_batched_independence() -> Result<(), Box<dyn std::error::Error>> {
let a = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0_f32]; let b = [5.0, 6.0, 7.0, 8.0, 1.0, 1.0, 1.0, 1.0_f32];
let mut c = [0.0_f32; 8];
gemm_strided_batched(&a, 4, &b, 4, &mut c, 4, 2, 2, 2, 2, 1.0, 0.0)?;
assert!((c[0] - 5.0).abs() < 1e-5);
for i in 4..8 {
assert!(
c[i].abs() < 1e-5,
"Batch independence violated at c[{i}]={}",
c[i]
);
}
Ok(())
}