use blas::*;
use lapack::*;
#[inline(always)]
pub fn gemv(
m: i32,
n: i32,
a: &[f64], lda: i32,
x: &[f64], incx: i32,
y: &mut [f64], incy: i32,
alpha: f64,
beta: f64,
trans_a: bool, ) -> Result<(), &'static str> {
if m <= 0 || n <= 0 {
return Err("m and n must be positive");
}
if lda < m {
return Err("lda must be ≥ m");
}
if incx == 0 || incy == 0 {
return Err("incx and incy must be non-zero");
}
if a.len() < (lda * n) as usize {
return Err("A too small");
}
let (len_x, len_y) = if trans_a { (m, n) } else { (n, m) };
if x.len() < (1 + (len_x - 1) * incx.abs()) as usize {
return Err("x too small");
}
if y.len() < (1 + (len_y - 1) * incy.abs()) as usize {
return Err("y too small");
}
unsafe {
dgemv(
if trans_a { b'T' } else { b'N' },
m,
n,
alpha,
a,
lda,
x,
incx,
beta,
y,
incy,
);
}
Ok(())
}
#[inline(always)]
pub fn gemm_4x4_microkernel(
a: &[f64; 16], b: &[f64; 16], c: &mut [f64; 16], alpha: f64,
beta: f64,
) {
unsafe {
dgemm(
b'N', b'N', 4, 4, 4, alpha, a, 4, b, 4, beta, c, 4,
);
}
}
#[inline(always)]
pub fn trisolve_2x2(
upper: bool,
a: &mut [f64; 4], b: &mut [f64; 2], ) -> Result<(), &'static str> {
unsafe {
dtrsv(
if upper { b'U' } else { b'L' }, b'N', b'N', 2, a, 2, b, 1, );
}
Ok(())
}
#[inline(always)]
pub fn householder_apply(
m: i32,
n: i32,
a: &mut [f64], lda: i32,
taus: &mut [f64], ) -> Result<(), &'static str> {
let lwork = (n.max(1) * 64) as i32;
let mut work = vec![0.0_f64; lwork as usize];
let mut info = 0;
unsafe {
dgeqrf(m, n, a, lda, taus, &mut work, lwork, &mut info);
}
if info == 0 {
Ok(())
} else {
Err("LAPACK dgeqrf failed")
}
}
#[inline(always)]
pub fn blocked_gemm(
m: i32,
n: i32,
k: i32,
alpha: f64,
a: &[f64],
lda: i32,
b: &[f64],
ldb: i32,
beta: f64,
c: &mut [f64],
ldc: i32,
trans_a: bool,
trans_b: bool,
) -> Result<(), &'static str> {
let (_rows_a, cols_a) = if trans_a { (k, m) } else { (m, k) };
let (_rows_b, cols_b) = if trans_b { (n, k) } else { (k, n) };
if a.len() < (lda * cols_a) as usize {
return Err("A too small");
}
if b.len() < (ldb * cols_b) as usize {
return Err("B too small");
}
if c.len() < (ldc * n) as usize {
return Err("C too small");
}
unsafe {
dgemm(
if trans_a { b'T' } else { b'N' },
if trans_b { b'T' } else { b'N' },
m,
n,
k,
alpha,
a,
lda,
b,
ldb,
beta,
c,
ldc,
);
}
Ok(())
}
pub fn cholesky_panel(
n: i32,
a: &mut [f64], lda: i32,
) -> Result<(), &'static str> {
let mut info = 0;
unsafe {
dpotrf(b'L', n, a, lda, &mut info);
}
if info == 0 {
Ok(())
} else if info > 0 {
Err("Matrix not positive-definite")
} else {
Err("LAPACK dpotrf argument error")
}
}
#[inline(always)]
pub fn lu_with_piv(
m: i32,
n: i32,
a: &mut [f64],
lda: i32,
ipiv: &mut [i32], ) -> Result<(), &'static str> {
use std::cmp::min;
if a.len() < (lda * n) as usize {
return Err("A too small for GETRF");
}
if ipiv.len() < min(m, n) as usize {
return Err("ipiv too small for GETRF");
}
let mut info = 0;
unsafe {
dgetrf(m, n, a, lda, ipiv, &mut info);
}
if info == 0 {
Ok(())
} else if info > 0 {
Err("Matrix is singular to machine precision")
} else {
Err("LAPACK dgetrf argument error")
}
}
#[inline(always)]
pub fn qr_block(
m: i32,
n: i32,
a: &mut [f64],
lda: i32,
taus: &mut [f64],
) -> Result<(), &'static str> {
if a.len() < (lda * n) as usize {
return Err("A too small for GEQRF");
}
if taus.len() < n as usize {
return Err("taus too small for GEQRF");
}
let mut work_query = [0.0_f64];
let mut info = 0;
unsafe {
dgeqrf(
m,
n,
a,
lda,
taus,
&mut work_query,
-1, &mut info,
);
}
if info != 0 {
return Err("GEQRF workspace query failed");
}
let lwork = work_query[0] as i32;
let mut work = vec![0.0_f64; lwork as usize];
unsafe {
dgeqrf(m, n, a, lda, taus, &mut work, lwork, &mut info);
}
if info == 0 {
Ok(())
} else {
Err("LAPACK dgeqrf failed")
}
}
#[inline(always)]
pub fn syrk_panel(
n: i32, k: i32, alpha: f64,
a: &[f64], lda: i32,
beta: f64,
c: &mut [f64], ldc: i32,
trans_a: bool, ) -> Result<(), &'static str> {
use blas::dsyrk;
if a.len() < (lda * k) as usize {
return Err("A too small for SYRK");
}
if c.len() < (ldc * n) as usize {
return Err("C too small for SYRK");
}
unsafe {
dsyrk(
b'L', if trans_a { b'T' } else { b'N' },
n,
k,
alpha,
a,
lda,
beta,
c,
ldc,
);
}
Ok(())
}
#[inline(always)]
pub fn symeig2x2(
a_in: &[f64; 4], eigvals: &mut [f64; 2],
eigvecs: &mut [f64; 4], ) -> Result<(), &'static str> {
let mut a = *a_in; let mut info = 0;
let mut work = [0.0_f64; 10];
let len = work.len();
unsafe {
dsyev(
b'V', b'U', 2, &mut a, 2, eigvals, &mut work, len as i32, &mut info,
);
}
if info != 0 {
return Err("LAPACK dsyev failed on 2×2");
}
eigvecs.copy_from_slice(&a);
Ok(())
}
#[inline(always)]
pub fn bidiag_reduction(
m: i32,
n: i32,
a: &mut [f64], lda: i32,
d: &mut [f64], e: &mut [f64], tauq: &mut [f64], taup: &mut [f64], ) -> Result<(), &'static str> {
use std::cmp::min;
use lapack::dgebrd;
let k = min(m, n);
if a.len() < (lda * n) as usize {
return Err("A too small for GEBRD");
}
if d.len() < k as usize {
return Err("d too small for GEBRD");
}
if e.len() < (k - 1).max(0) as usize {
return Err("e too small for GEBRD");
}
if tauq.len() < k as usize {
return Err("tauq too small");
}
if taup.len() < k as usize {
return Err("taup too small");
}
let mut work_query = [0.0_f64];
let mut info = 0;
unsafe {
dgebrd(
m,
n,
a,
lda,
d,
e,
tauq,
taup,
&mut work_query,
-1, &mut info,
);
}
if info != 0 {
return Err("GEBRD workspace query failed");
}
let lwork = work_query[0] as i32;
let mut work = vec![0.0_f64; lwork as usize];
unsafe {
dgebrd(m, n, a, lda, d, e, tauq, taup, &mut work, lwork, &mut info);
}
if info == 0 {
Ok(())
} else {
Err("LAPACK dgebrd failed")
}
}
#[inline(always)]
pub fn svd_qr_iter() {
todo!()
}
#[inline(always)]
pub fn svd_block(
jobu: u8, jobvt: u8, m: i32,
n: i32,
a: &mut [f64], lda: i32,
s: &mut [f64], u: &mut [f64], ldu: i32,
vt: &mut [f64], ldvt: i32,
) -> Result<(), &'static str> {
use std::cmp::min;
use lapack::dgesvd;
let k = min(m, n);
if a.len() < (lda * n) as usize {
return Err("A too small for DGESVD");
}
if s.len() < k as usize {
return Err("s too small for DGESVD");
}
let mut wk = [0.0_f64];
let mut info = 0;
unsafe {
dgesvd(
jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, &mut wk, -1, &mut info,
);
}
if info != 0 {
return Err("DGESVD workspace query failed");
}
let lwork = wk[0] as i32;
let mut work = vec![0.0_f64; lwork as usize];
unsafe {
dgesvd(
jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, &mut work, lwork, &mut info,
);
}
if info == 0 {
Ok(())
} else if info > 0 {
Err("DGESVD failed to converge")
} else {
Err("DGESVD argument error")
}
}
#[inline(always)]
pub fn pca_project(
m: i32, f: i32, k: i32, alpha: f64,
x: &[f64],
ldx: i32, w: &[f64],
ldw: i32, beta: f64,
y: &mut [f64],
ldy: i32, ) -> Result<(), &'static str> {
use blas::dgemm;
if x.len() < (ldx * f) as usize {
return Err("X buffer too small");
}
if w.len() < (ldw * k) as usize {
return Err("W buffer too small");
}
if y.len() < (ldy * k) as usize {
return Err("Y buffer too small");
}
unsafe {
dgemm(b'N', b'N', m, k, f, alpha, x, ldx, w, ldw, beta, y, ldy);
}
Ok(())
}
#[inline(always)]
pub fn cachecov_syrk(
n_feat: i32,
obs: i32,
x: &[f64],
ldx: i32,
c: &mut [f64], ) -> Result<(), &'static str> {
use blas::dsyrk;
let need = (n_feat as usize * (n_feat as usize + 1)) / 2;
if c.len() < need {
return Err("C buffer too small");
}
if x.len() < (ldx * obs) as usize {
return Err("X tile too small");
}
let n = n_feat as usize;
let mut full = vec![0.0_f64; n * n];
for col in 0..n {
for row in col..n {
full[row + col * n] = c[(row * (row + 1)) / 2 + col];
}
}
unsafe {
dsyrk(
b'L', b'T', n_feat, obs, 1.0, x, ldx, 1.0, &mut full, n_feat,
);
}
for col in 0..n {
for row in col..n {
c[(row * (row + 1)) / 2 + col] = full[row + col * n];
}
}
Ok(())
}
#[inline(always)]
pub fn lufactor(
m: i32,
n: i32,
a: &mut [f64],
lda: i32,
piv: &mut [i32], ) -> Result<(), &'static str> {
use lapack::dgetrf;
let k = std::cmp::min(m, n);
if a.len() < (lda * n) as usize {
return Err("A too small for GETRF");
}
if piv.len() < k as usize {
return Err("pivot array too small");
}
let mut info = 0;
unsafe {
dgetrf(m, n, a, lda, piv, &mut info);
}
if info == 0 {
Ok(())
} else if info > 0 {
Err("U is singular")
} else {
Err("GETRF argument error")
}
}
#[inline(always)]
pub fn trisolve_upper(
n: i32,
nrhs: i32,
u: &[f64],
ldu: i32, b: &mut [f64],
ldb: i32,
) -> Result<(), &'static str> {
if u.len() < (ldu * n) as usize {
return Err("U buffer too small");
}
if b.len() < (ldb * nrhs) as usize {
return Err("B buffer too small");
}
for j in 0..nrhs {
let col = &mut b[(j * ldb) as usize..][..n as usize];
unsafe {
dtrsv(
b'U', b'N', b'N', n, u, ldu, col, 1,
);
}
}
Ok(())
}
#[inline(always)]
pub fn trisolve_lower(
n: i32,
nrhs: i32,
l: &[f64],
ldl: i32, b: &mut [f64],
ldb: i32,
) -> Result<(), &'static str> {
if l.len() < (ldl * n) as usize {
return Err("L buffer too small");
}
if b.len() < (ldb * nrhs) as usize {
return Err("B buffer too small");
}
for j in 0..nrhs {
let col = &mut b[(j * ldb) as usize..][..n as usize];
unsafe {
dtrsv(
b'L', b'N', b'N', n, l, ldl, col, 1,
);
}
}
Ok(())
}
#[inline(always)]
pub fn tri_inverse(n: i32, t: &mut [f64], ldt: i32, upper: bool) -> Result<(), &'static str> {
if t.len() < (ldt * n) as usize {
return Err("T buffer too small");
}
let mut info = 0;
unsafe {
dtrtri(
if upper { b'U' } else { b'L' },
b'N', n,
t,
ldt,
&mut info,
);
}
if info == 0 {
Ok(())
} else if info > 0 {
Err("T is singular")
} else {
Err("DTRTRI argument error")
}
}
#[inline(always)]
pub fn spd_cholesky(n: i32, a: &mut [f64], lda: i32) -> Result<(), &'static str> {
if a.len() < (lda * n) as usize {
return Err("A buffer too small");
}
let mut info = 0;
unsafe { dpotrf(b'L', n, a, lda, &mut info) };
if info == 0 {
Ok(())
} else if info > 0 {
Err("Matrix is not SPD")
} else {
Err("DPOTRF argument error")
}
}
#[inline(always)]
pub fn spd_solve(
n: i32,
nrhs: i32,
l: &[f64],
ldl: i32, b: &mut [f64],
ldb: i32,
) -> Result<(), &'static str> {
if l.len() < (ldl * n) as usize {
return Err("L buffer too small");
}
if b.len() < (ldb * nrhs) as usize {
return Err("B buffer too small");
}
let mut info = 0;
unsafe {
dpotrs(b'L', n, nrhs, l, ldl, b, ldb, &mut info);
}
if info == 0 {
Ok(())
} else {
Err("DPOTRS failed")
}
}
#[inline(always)]
pub fn spd_inverse(n: i32, l: &mut [f64], ldl: i32) -> Result<(), &'static str> {
if l.len() < (ldl * n) as usize {
return Err("L buffer too small");
}
let mut info = 0;
unsafe { dpotri(b'L', n, l, ldl, &mut info) };
if info == 0 {
Ok(())
} else if info > 0 {
Err("Matrix not SPD (dpotri)")
} else {
Err("DPOTRI argument error")
}
}
#[inline(always)]
pub fn qr_panel(
m: i32,
n: i32,
a: &mut [f64],
lda: i32, taus: &mut [f64], ) -> Result<(), &'static str> {
if a.len() < (lda * n) as usize {
return Err("A buffer too small");
}
if taus.len() < n as usize {
return Err("TAU buffer too small");
}
let mut lwork = -1;
let mut work_q = [0.0_f64];
let mut info = 0;
unsafe { dgeqrf(m, n, a, lda, taus, &mut work_q, lwork, &mut info) };
if info != 0 {
return Err("DGEQRF work-query failed");
}
lwork = work_q[0] as i32;
let mut work = vec![0.0_f64; lwork as usize];
unsafe { dgeqrf(m, n, a, lda, taus, &mut work, lwork, &mut info) };
if info == 0 {
Ok(())
} else {
Err("DGEQRF factorisation failed")
}
}
#[inline(always)]
pub fn qr_form_q(
m: i32,
n: i32,
k: i32, a: &mut [f64],
lda: i32, taus: &[f64], ) -> Result<(), &'static str> {
if a.len() < (lda * n) as usize {
return Err("A buffer too small");
}
if taus.len() < k as usize {
return Err("TAU buffer too small");
}
let mut lwork = -1;
let mut work_q = [0.0_f64];
let mut info = 0;
unsafe { dorgqr(m, n, k, a, lda, taus, &mut work_q, lwork, &mut info) };
if info != 0 {
return Err("DORGQR work-query failed");
}
lwork = work_q[0] as i32;
let mut work = vec![0.0_f64; lwork as usize];
unsafe { dorgqr(m, n, k, a, lda, taus, &mut work, lwork, &mut info) };
if info == 0 {
Ok(())
} else {
Err("DORGQR failed")
}
}
#[inline(always)]
pub fn least_squares_qr(
m: i32,
n: i32,
nrhs: i32,
a: &mut [f64],
lda: i32, b: &mut [f64],
ldb: i32, ) -> Result<(), &'static str> {
if a.len() < (lda * n) as usize {
return Err("A buffer too small");
}
if b.len() < (ldb * nrhs) as usize {
return Err("B buffer too small");
}
let mut lwork = -1;
let mut work_q = [0.0_f64];
let mut info = 0;
unsafe {
dgels(
b'N',
m,
n,
nrhs,
a,
lda,
b,
ldb,
&mut work_q,
lwork,
&mut info,
1, )
};
if info != 0 {
return Err("DGELS work-query failed");
}
lwork = work_q[0] as i32;
let mut work = vec![0.0_f64; lwork as usize];
let len = work.len();
unsafe {
dgels(
b'N', m, n, nrhs, a, lda, b, ldb, &mut work, lwork, &mut info, len,
)
};
if info == 0 {
Ok(())
} else {
Err("DGELS failed")
}
}
#[inline(always)]
pub fn symeig_full(
n: i32,
a: &mut [f64],
lda: i32, w: &mut [f64], ) -> Result<(), &'static str> {
if a.len() < (lda * n) as usize {
return Err("A buffer too small");
}
if w.len() < n as usize {
return Err("W buffer too small");
}
let mut lwork = -1;
let mut work_q = [0.0_f64];
let mut info = 0;
unsafe { dsyev(b'V', b'U', n, a, lda, w, &mut work_q, lwork, &mut info) };
if info != 0 {
return Err("DSYEV work-query failed");
}
lwork = work_q[0] as i32;
let mut work = vec![0.0_f64; lwork as usize];
unsafe { dsyev(b'V', b'U', n, a, lda, w, &mut work, lwork, &mut info) };
if info == 0 {
Ok(())
} else {
Err("DSYEV failed")
}
}
#[inline(always)]
pub fn syrk_fisher_info(
n: i32,
k: i32,
alpha: f64,
a: &[f64],
lda: i32, beta: f64,
c: &mut [f64],
ldc: i32, ) -> Result<(), &'static str> {
if a.len() < (lda * k) as usize {
return Err("A buffer too small");
}
if c.len() < (ldc * n) as usize {
return Err("C buffer too small");
}
unsafe {
dsyrk(
b'L', b'N', n, k, alpha, a, lda, beta, c, ldc,
);
}
Ok(())
}
#[inline(always)]
pub fn sym_rank2k_update(
n: i32,
k: i32,
alpha: f64,
a: &[f64],
lda: i32,
b: &[f64],
ldb: i32,
beta: f64,
c: &mut [f64],
ldc: i32,
) -> Result<(), &'static str> {
if a.len() < (lda * k) as usize {
return Err("A buffer too small");
}
if b.len() < (ldb * k) as usize {
return Err("B buffer too small");
}
if c.len() < (ldc * n) as usize {
return Err("C buffer too small");
}
unsafe {
dsyr2k(
b'L', b'T', n, k, alpha, a, lda, b, ldb, beta, c, ldc,
);
}
Ok(())
}
#[inline(always)]
pub fn getrs(
n: i32,
nrhs: i32,
a: &[f64],
lda: i32,
ipiv: &[i32],
b: &mut [f64],
ldb: i32,
) -> Result<(), &'static str> {
if a.len() < (lda * n) as usize {
return Err("A buffer too small for GETRS");
}
if ipiv.len() < n as usize {
return Err("ipiv too small for GETRS");
}
if b.len() < (ldb * nrhs) as usize {
return Err("B buffer too small for GETRS");
}
let mut info = 0;
unsafe {
dgetrs(b'N', n, nrhs, a, lda, ipiv, b, ldb, &mut info);
}
if info == 0 {
Ok(())
} else {
Err("LAPACK dgetrs failed")
}
}
#[inline(always)]
pub fn getri(n: i32, a: &mut [f64], lda: i32, ipiv: &[i32]) -> Result<(), &'static str> {
if a.len() < (lda * n) as usize {
return Err("A buffer too small for GETRI");
}
if ipiv.len() < n as usize {
return Err("ipiv too small for GETRI");
}
let mut work_query = [0.0_f64];
let mut info = 0;
unsafe {
dgetri(n, a, lda, ipiv, &mut work_query, -1, &mut info);
}
if info != 0 {
return Err("DGETRI workspace query failed");
}
let lwork = work_query[0] as i32;
let mut work = vec![0.0_f64; lwork as usize];
unsafe {
dgetri(n, a, lda, ipiv, &mut work, lwork, &mut info);
}
if info == 0 {
Ok(())
} else if info > 0 {
Err("Matrix is singular (dgetri)")
} else {
Err("DGETRI argument error")
}
}
#[cfg(test)]
#[cfg(feature = "linear_algebra")]
mod tests {
use super::*;
const TOL: f64 = 1e-10;
fn assert_near(a: f64, b: f64, msg: &str) {
assert!(
(a - b).abs() < TOL,
"{msg}: expected {b}, got {a}, diff {}",
(a - b).abs()
);
}
#[test]
fn test_gemv_no_trans() {
let a = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = [1.0, 2.0, 3.0];
let mut y = [0.0; 2];
gemv(2, 3, &a, 2, &x, 1, &mut y, 1, 1.0, 0.0, false).unwrap();
assert_near(y[0], 14.0, "y[0]");
assert_near(y[1], 32.0, "y[1]");
}
#[test]
fn test_gemv_trans() {
let a = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = [1.0, 2.0];
let mut y = [0.0; 3];
gemv(2, 3, &a, 2, &x, 1, &mut y, 1, 1.0, 0.0, true).unwrap();
assert_near(y[0], 9.0, "y[0]");
assert_near(y[1], 12.0, "y[1]");
assert_near(y[2], 15.0, "y[2]");
}
#[test]
fn test_gemv_bad_buffer() {
let a = [1.0; 4]; let x = [1.0; 3];
let mut y = [0.0; 2];
let result = gemv(2, 3, &a, 2, &x, 1, &mut y, 1, 1.0, 0.0, false);
assert!(result.is_err());
}
#[test]
fn test_gemm_4x4_identity() {
let mut a = [0.0_f64; 16];
let mut b = [0.0_f64; 16];
for i in 0..4 {
a[i * 4 + i] = 1.0;
b[i * 4 + i] = 1.0;
}
let mut c = [0.0_f64; 16];
gemm_4x4_microkernel(&a, &b, &mut c, 1.0, 0.0);
for i in 0..4 {
for j in 0..4 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(c[j * 4 + i], expected, &format!("c[{i},{j}]"));
}
}
}
#[test]
fn test_gemm_4x4_accumulation() {
let mut a = [0.0_f64; 16];
let mut b = [0.0_f64; 16];
let mut c = [0.0_f64; 16];
for i in 0..4 {
a[i * 4 + i] = 1.0;
b[i * 4 + i] = 1.0;
c[i * 4 + i] = 2.0;
}
gemm_4x4_microkernel(&a, &b, &mut c, 1.0, 1.0);
for i in 0..4 {
assert_near(c[i * 4 + i], 3.0, &format!("c[{i},{i}]"));
}
}
#[test]
fn test_trisolve_2x2_upper() {
let mut u = [2.0, 0.0, 3.0, 4.0];
let mut b = [11.0, 4.0];
trisolve_2x2(true, &mut u, &mut b).unwrap();
assert_near(b[0], 4.0, "x[0]");
assert_near(b[1], 1.0, "x[1]");
}
#[test]
fn test_trisolve_2x2_lower() {
let mut l = [2.0, 3.0, 0.0, 4.0];
let mut b = [6.0, 19.0];
trisolve_2x2(false, &mut l, &mut b).unwrap();
assert_near(b[0], 3.0, "x[0]");
assert_near(b[1], 2.5, "x[1]");
}
#[test]
fn test_householder_3x2() {
let mut a = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
let mut taus = [0.0; 2];
householder_apply(3, 2, &mut a, 3, &mut taus).unwrap();
let r00 = a[0];
assert!(r00.abs() > 1e-10, "R[0,0] should be nonzero");
assert!(taus[0].abs() > 1e-10, "tau[0] should be nonzero");
}
#[test]
fn test_householder_r_upper_triangular() {
let mut a = [2.0, 0.0, 0.0, 1.0, 3.0, 0.0]; let mut taus = [0.0; 2];
householder_apply(3, 2, &mut a, 3, &mut taus).unwrap();
assert!(a[0].abs() > 1e-10, "R[0,0] should be nonzero");
}
#[test]
fn test_cholesky_panel_2x2() {
let mut a = [4.0, 2.0, 2.0, 3.0];
cholesky_panel(2, &mut a, 2).unwrap();
assert_near(a[0], 2.0, "L[0,0]");
assert_near(a[1], 1.0, "L[1,0]");
assert_near(a[3], 2.0_f64.sqrt(), "L[1,1]");
}
#[test]
fn test_cholesky_panel_not_spd() {
let mut a = [1.0, 2.0, 2.0, 1.0];
let result = cholesky_panel(2, &mut a, 2);
assert!(result.is_err());
}
#[test]
fn test_lu_with_piv_2x2() {
let mut a = [2.0, 6.0, 1.0, 4.0];
let mut ipiv = [0_i32; 2];
lu_with_piv(2, 2, &mut a, 2, &mut ipiv).unwrap();
assert!(ipiv[0] > 0, "pivot should be set");
}
#[test]
fn test_lu_with_piv_singular() {
let mut a = [1.0, 2.0, 2.0, 4.0];
let mut ipiv = [0_i32; 2];
let result = lu_with_piv(2, 2, &mut a, 2, &mut ipiv);
assert!(result.is_err());
}
#[test]
fn test_lu_with_piv_reconstruct() {
let mut a = [2.0, 6.0, 1.0, 4.0];
let a_orig = a;
let mut ipiv = [0_i32; 2];
lu_with_piv(2, 2, &mut a, 2, &mut ipiv).unwrap();
let l00 = 1.0;
let l10 = a[1]; let l01 = 0.0;
let l11 = 1.0;
let u00 = a[0]; let u01 = a[2];
let u10 = 0.0;
let u11 = a[3];
let lu00 = l00 * u00 + l01 * u10;
let lu10 = l10 * u00 + l11 * u10;
let lu01 = l00 * u01 + l01 * u11;
let lu11 = l10 * u01 + l11 * u11;
let mut pa = a_orig;
if ipiv[0] != 1 {
pa.swap(0, (ipiv[0] - 1) as usize);
pa.swap(2, 2 + (ipiv[0] - 1) as usize);
}
assert_near(lu00, pa[0], "PA=LU [0,0]");
assert_near(lu10, pa[1], "PA=LU [1,0]");
assert_near(lu01, pa[2], "PA=LU [0,1]");
assert_near(lu11, pa[3], "PA=LU [1,1]");
}
#[test]
fn test_qr_block_3x2() {
let mut a = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
let mut taus = [0.0; 2];
qr_block(3, 2, &mut a, 3, &mut taus).unwrap();
assert!(taus[0].abs() > 1e-10, "tau[0] should be nonzero");
assert!(taus[1].abs() > 1e-10, "tau[1] should be nonzero");
assert_near(a[0].abs(), 35.0_f64.sqrt(), "|R[0,0]|");
}
#[test]
fn test_qr_block_taus_too_small() {
let mut a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut taus = [0.0; 1]; let result = qr_block(3, 2, &mut a, 3, &mut taus);
assert!(result.is_err());
}
#[test]
fn test_spd_cholesky_3x3() {
let mut a = [4.0, 2.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 6.0];
spd_cholesky(3, &mut a, 3).unwrap();
assert_near(a[0], 2.0, "L[0,0]");
assert_near(a[1], 1.0, "L[1,0]");
assert_near(a[2], 0.5, "L[2,0]");
assert_near(a[4], 2.0, "L[1,1]");
}
#[test]
fn test_spd_cholesky_not_spd() {
let mut a = [1.0, 2.0, 0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let result = spd_cholesky(3, &mut a, 3);
assert!(result.is_err());
}
#[test]
fn test_spd_solve_2x2() {
let mut l = [4.0, 2.0, 2.0, 3.0];
spd_cholesky(2, &mut l, 2).unwrap();
let mut b = [8.0, 7.0];
spd_solve(2, 1, &l, 2, &mut b, 2).unwrap();
assert_near(b[0], 1.25, "x[0]");
assert_near(b[1], 1.5, "x[1]");
}
#[test]
fn test_spd_solve_identity() {
let mut l = [1.0, 0.0, 0.0, 1.0];
spd_cholesky(2, &mut l, 2).unwrap();
let mut b = [3.0, 7.0];
spd_solve(2, 1, &l, 2, &mut b, 2).unwrap();
assert_near(b[0], 3.0, "x[0]");
assert_near(b[1], 7.0, "x[1]");
}
#[test]
fn test_spd_inverse_2x2() {
let mut a = [4.0, 2.0, 2.0, 3.0];
spd_cholesky(2, &mut a, 2).unwrap();
spd_inverse(2, &mut a, 2).unwrap();
assert_near(a[0], 3.0 / 8.0, "A_inv[0,0]");
assert_near(a[1], -2.0 / 8.0, "A_inv[1,0]");
assert_near(a[3], 4.0 / 8.0, "A_inv[1,1]");
}
#[test]
fn test_spd_inverse_product_identity() {
let a_orig = [4.0, 2.0, 2.0, 3.0];
let mut a = a_orig;
spd_cholesky(2, &mut a, 2).unwrap();
spd_inverse(2, &mut a, 2).unwrap();
let a_inv = [a[0], a[1], a[1], a[3]];
let mut prod = [0.0_f64; 4];
blocked_gemm(
2, 2, 2, 1.0, &a_orig, 2, &a_inv, 2, 0.0, &mut prod, 2, false, false,
)
.unwrap();
assert_near(prod[0], 1.0, "I[0,0]");
assert_near(prod[1], 0.0, "I[1,0]");
assert_near(prod[2], 0.0, "I[0,1]");
assert_near(prod[3], 1.0, "I[1,1]");
}
#[test]
fn test_trisolve_upper_3x3() {
let u = [2.0, 0.0, 0.0, 1.0, 4.0, 0.0, 3.0, 2.0, 5.0];
let mut b = [13.0, 18.0, 10.0];
trisolve_upper(3, 1, &u, 3, &mut b, 3).unwrap();
assert_near(b[0], 1.75, "x[0]");
assert_near(b[1], 3.5, "x[1]");
assert_near(b[2], 2.0, "x[2]");
}
#[test]
fn test_trisolve_upper_identity() {
let u = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let mut b = [5.0, 7.0, 9.0];
trisolve_upper(3, 1, &u, 3, &mut b, 3).unwrap();
assert_near(b[0], 5.0, "x[0]");
assert_near(b[1], 7.0, "x[1]");
assert_near(b[2], 9.0, "x[2]");
}
#[test]
fn test_trisolve_lower_3x3() {
let l = [3.0, 1.0, 2.0, 0.0, 4.0, 3.0, 0.0, 0.0, 5.0];
let mut b = [9.0, 17.0, 37.0];
trisolve_lower(3, 1, &l, 3, &mut b, 3).unwrap();
assert_near(b[0], 3.0, "x[0]");
assert_near(b[1], 3.5, "x[1]");
assert_near(b[2], 4.1, "x[2]");
}
#[test]
fn test_trisolve_lower_identity() {
let l = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let mut b = [2.0, 4.0, 6.0];
trisolve_lower(3, 1, &l, 3, &mut b, 3).unwrap();
assert_near(b[0], 2.0, "x[0]");
assert_near(b[1], 4.0, "x[1]");
assert_near(b[2], 6.0, "x[2]");
}
#[test]
fn test_tri_inverse_upper_3x3() {
let mut t = [2.0, 0.0, 0.0, 1.0, 4.0, 0.0, 3.0, 2.0, 5.0];
let t_orig = t;
tri_inverse(3, &mut t, 3, true).unwrap();
let mut prod = [0.0_f64; 9];
blocked_gemm(
3, 3, 3, 1.0, &t_orig, 3, &t, 3, 0.0, &mut prod, 3, false, false,
)
.unwrap();
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(prod[j * 3 + i], expected, &format!("I[{i},{j}]"));
}
}
}
#[test]
fn test_tri_inverse_lower_3x3() {
let mut t = [3.0, 1.0, 2.0, 0.0, 4.0, 3.0, 0.0, 0.0, 5.0];
let t_orig = t;
tri_inverse(3, &mut t, 3, false).unwrap();
let mut prod = [0.0_f64; 9];
blocked_gemm(
3, 3, 3, 1.0, &t_orig, 3, &t, 3, 0.0, &mut prod, 3, false, false,
)
.unwrap();
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(prod[j * 3 + i], expected, &format!("I[{i},{j}]"));
}
}
}
#[test]
fn test_blocked_gemm_3x3_times_3x2() {
let a = [1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0];
let b = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
let mut c = [0.0_f64; 6]; blocked_gemm(3, 2, 3, 1.0, &a, 3, &b, 3, 0.0, &mut c, 3, false, false).unwrap();
assert_near(c[0], 22.0, "c[0,0]");
assert_near(c[1], 49.0, "c[1,0]");
assert_near(c[2], 76.0, "c[2,0]");
assert_near(c[3], 28.0, "c[0,1]");
assert_near(c[4], 64.0, "c[1,1]");
assert_near(c[5], 100.0, "c[2,1]");
}
#[test]
fn test_blocked_gemm_trans_a() {
let a = [1.0, 3.0, 2.0, 4.0];
let mut c = [0.0_f64; 4];
blocked_gemm(2, 2, 2, 1.0, &a, 2, &a, 2, 0.0, &mut c, 2, true, false).unwrap();
assert_near(c[0], 10.0, "c[0,0]");
assert_near(c[1], 14.0, "c[1,0]");
assert_near(c[2], 14.0, "c[0,1]");
assert_near(c[3], 20.0, "c[1,1]");
}
#[test]
fn test_blocked_gemm_trans_b() {
let a = [1.0, 3.0, 2.0, 4.0];
let b = [1.0, 3.0, 2.0, 4.0];
let mut c = [0.0_f64; 4];
blocked_gemm(2, 2, 2, 1.0, &a, 2, &b, 2, 0.0, &mut c, 2, false, true).unwrap();
assert_near(c[0], 5.0, "c[0,0]");
assert_near(c[1], 11.0, "c[1,0]");
assert_near(c[2], 11.0, "c[0,1]");
assert_near(c[3], 25.0, "c[1,1]");
}
#[test]
fn test_blocked_gemm_bad_buffer() {
let a = [1.0; 4]; let b = [1.0; 2]; let mut c = [0.0_f64; 6];
let result = blocked_gemm(2, 3, 2, 1.0, &a, 2, &b, 2, 0.0, &mut c, 2, false, false);
assert!(result.is_err());
}
#[test]
fn test_syrk_panel_aat() {
let a = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let mut c = [0.0_f64; 4]; syrk_panel(2, 3, 1.0, &a, 2, 0.0, &mut c, 2, false).unwrap();
assert_near(c[0], 14.0, "c[0,0]");
assert_near(c[1], 32.0, "c[1,0]");
assert_near(c[3], 77.0, "c[1,1]");
}
#[test]
fn test_syrk_panel_symmetric() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let mut c = [0.0_f64; 9]; syrk_panel(3, 2, 1.0, &a, 3, 0.0, &mut c, 3, false).unwrap();
assert_near(c[0], 17.0, "c[0,0]");
assert_near(c[1], 22.0, "c[1,0]");
assert_near(c[2], 27.0, "c[2,0]");
}
#[test]
fn test_symeig2x2_known() {
let a = [2.0, 1.0, 1.0, 3.0];
let mut eigvals = [0.0_f64; 2];
let mut eigvecs = [0.0_f64; 4];
symeig2x2(&a, &mut eigvals, &mut eigvecs).unwrap();
let expected_0 = (5.0 - 5.0_f64.sqrt()) / 2.0;
let expected_1 = (5.0 + 5.0_f64.sqrt()) / 2.0;
assert_near(eigvals[0], expected_0, "lambda_0");
assert_near(eigvals[1], expected_1, "lambda_1");
}
#[test]
fn test_symeig2x2_orthogonal_eigvecs() {
let a = [2.0, 1.0, 1.0, 3.0];
let mut eigvals = [0.0_f64; 2];
let mut eigvecs = [0.0_f64; 4];
symeig2x2(&a, &mut eigvals, &mut eigvecs).unwrap();
let dot = eigvecs[0] * eigvecs[2] + eigvecs[1] * eigvecs[3];
assert_near(dot, 0.0, "v1.v2 orthogonality");
}
#[test]
fn test_bidiag_reduction_3x3() {
let mut a = [1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0];
let mut d = [0.0_f64; 3]; let mut e = [0.0_f64; 2]; let mut tauq = [0.0_f64; 3];
let mut taup = [0.0_f64; 3];
bidiag_reduction(3, 3, &mut a, 3, &mut d, &mut e, &mut tauq, &mut taup).unwrap();
assert!(d[0].abs() > 1e-10, "d[0] should be nonzero");
assert!(e[0].abs() > 1e-10, "e[0] should be nonzero");
}
#[test]
fn test_bidiag_reduction_lengths() {
let mut a = [1.0; 12]; let mut d = [0.0_f64; 3];
let mut e = [0.0_f64; 2];
let mut tauq = [0.0_f64; 3];
let mut taup = [0.0_f64; 3];
bidiag_reduction(4, 3, &mut a, 4, &mut d, &mut e, &mut tauq, &mut taup).unwrap();
assert!(tauq[0].abs() > 1e-10, "tauq[0] should be nonzero");
}
#[test]
fn test_svd_block_3x2_reconstruct() {
let a_orig = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
let mut a = a_orig;
let mut s = [0.0_f64; 2];
let mut u = [0.0_f64; 6]; let mut vt = [0.0_f64; 4]; svd_block(b'S', b'S', 3, 2, &mut a, 3, &mut s, &mut u, 3, &mut vt, 2).unwrap();
let mut us = [0.0_f64; 6];
for j in 0..2 {
for i in 0..3 {
us[j * 3 + i] = u[j * 3 + i] * s[j];
}
}
let mut a_hat = [0.0_f64; 6];
blocked_gemm(
3, 2, 2, 1.0, &us, 3, &vt, 2, 0.0, &mut a_hat, 3, false, false,
)
.unwrap();
for i in 0..6 {
assert_near(a_hat[i], a_orig[i], &format!("A_hat[{i}]"));
}
}
#[test]
fn test_svd_block_singular_values_descending() {
let mut a = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
let mut s = [0.0_f64; 2];
let mut u = [0.0_f64; 6];
let mut vt = [0.0_f64; 4];
svd_block(b'S', b'S', 3, 2, &mut a, 3, &mut s, &mut u, 3, &mut vt, 2).unwrap();
assert!(s[0] >= s[1], "singular values should be descending");
assert!(s[0] > 0.0, "first singular value should be positive");
}
#[test]
fn test_pca_project_4x3_to_4x2() {
let x = [1.0, 2.0, 3.0, 4.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0];
let w = [1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
let mut y = [0.0_f64; 8]; pca_project(4, 3, 2, 1.0, &x, 4, &w, 3, 0.0, &mut y, 4).unwrap();
assert_near(y[0], 1.0, "Y[0,0]");
assert_near(y[1], 2.0, "Y[1,0]");
assert_near(y[2], 4.0, "Y[2,0]");
assert_near(y[3], 4.0, "Y[3,0]");
assert_near(y[4], 0.0, "Y[0,1]");
assert_near(y[5], 1.0, "Y[1,1]");
assert_near(y[6], 0.0, "Y[2,1]");
assert_near(y[7], 1.0, "Y[3,1]");
}
#[test]
fn test_pca_project_identity_weight() {
let x = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; let w = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]; let mut y = [0.0_f64; 9];
pca_project(3, 3, 3, 1.0, &x, 3, &w, 3, 0.0, &mut y, 3).unwrap();
for i in 0..9 {
assert_near(y[i], x[i], &format!("Y[{i}]"));
}
}
#[test]
fn test_cachecov_syrk_3feat_4obs() {
let x = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 0.0, 0.0, 0.0, 0.0, ];
let mut c = [0.0_f64; 6];
cachecov_syrk(3, 4, &x, 4, &mut c).unwrap();
assert_near(c[0], 30.0, "XtX[0,0]");
assert_near(c[1], 70.0, "XtX[1,0]");
assert_near(c[2], 174.0, "XtX[1,1]");
assert_near(c[3], 110.0, "XtX[2,0]");
assert_near(c[4], 278.0, "XtX[2,1]");
assert_near(c[5], 446.0, "XtX[2,2]");
}
#[test]
fn test_cachecov_syrk_accumulation() {
let x = [1.0, 2.0, 3.0, 4.0]; let mut c = [0.0_f64; 3]; cachecov_syrk(2, 2, &x, 2, &mut c).unwrap();
let first = c;
cachecov_syrk(2, 2, &x, 2, &mut c).unwrap();
for i in 0..3 {
assert_near(c[i], 2.0 * first[i], &format!("c[{i}] doubled"));
}
}
#[test]
fn test_lufactor_3x3() {
let mut a = [2.0, 4.0, 2.0, 1.0, 3.0, 1.0, 1.0, 1.0, 3.0];
let mut piv = [0_i32; 3];
lufactor(3, 3, &mut a, 3, &mut piv).unwrap();
assert!(piv[0] > 0, "piv[0] should be set");
}
#[test]
fn test_lufactor_consistency_with_lu_with_piv() {
let a_orig = [2.0, 6.0, 1.0, 4.0]; let mut a1 = a_orig;
let mut a2 = a_orig;
let mut piv1 = [0_i32; 2];
let mut piv2 = [0_i32; 2];
lu_with_piv(2, 2, &mut a1, 2, &mut piv1).unwrap();
lufactor(2, 2, &mut a2, 2, &mut piv2).unwrap();
for i in 0..4 {
assert_near(a1[i], a2[i], &format!("a[{i}]"));
}
assert_eq!(piv1, piv2, "pivots should match");
}
#[test]
fn test_qr_panel_3x2() {
let mut a = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0]; let mut taus = [0.0_f64; 2];
qr_panel(3, 2, &mut a, 3, &mut taus).unwrap();
assert!(taus[0].abs() > 1e-10, "tau[0] should be nonzero");
assert_near(a[0].abs(), 35.0_f64.sqrt(), "|R[0,0]|");
}
#[test]
fn test_qr_panel_consistent_with_qr_block() {
let a_orig = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0]; let mut a1 = a_orig;
let mut a2 = a_orig;
let mut taus1 = [0.0_f64; 2];
let mut taus2 = [0.0_f64; 2];
qr_block(3, 2, &mut a1, 3, &mut taus1).unwrap();
qr_panel(3, 2, &mut a2, 3, &mut taus2).unwrap();
for i in 0..6 {
assert_near(a1[i], a2[i], &format!("a[{i}]"));
}
for i in 0..2 {
assert_near(taus1[i], taus2[i], &format!("tau[{i}]"));
}
}
#[test]
fn test_qr_form_q_orthogonal() {
let mut a = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
let mut taus = [0.0_f64; 2];
qr_panel(3, 2, &mut a, 3, &mut taus).unwrap();
qr_form_q(3, 2, 2, &mut a, 3, &taus).unwrap();
let mut qtq = [0.0_f64; 4];
blocked_gemm(2, 2, 3, 1.0, &a, 3, &a, 3, 0.0, &mut qtq, 2, true, false).unwrap();
assert_near(qtq[0], 1.0, "QtQ[0,0]");
assert_near(qtq[1], 0.0, "QtQ[1,0]");
assert_near(qtq[2], 0.0, "QtQ[0,1]");
assert_near(qtq[3], 1.0, "QtQ[1,1]");
}
#[test]
fn test_qr_form_q_full_square() {
let mut a = [1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 10.0]; let mut taus = [0.0_f64; 3];
qr_panel(3, 3, &mut a, 3, &mut taus).unwrap();
qr_form_q(3, 3, 3, &mut a, 3, &taus).unwrap();
let mut qtq = [0.0_f64; 9];
blocked_gemm(3, 3, 3, 1.0, &a, 3, &a, 3, 0.0, &mut qtq, 3, true, false).unwrap();
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(qtq[j * 3 + i], expected, &format!("QtQ[{i},{j}]"));
}
}
}
#[test]
fn test_least_squares_qr_overdetermined() {
let mut a = [1.0, 1.0, 1.0, 1.0, 2.0, 3.0]; let mut b = [1.0, 2.0, 3.0]; least_squares_qr(3, 2, 1, &mut a, 3, &mut b, 3).unwrap();
assert_near(b[0], 0.0, "x[0]");
assert_near(b[1], 1.0, "x[1]");
}
#[test]
fn test_least_squares_qr_exact() {
let mut a = [2.0, 1.0, 1.0, 3.0]; let mut b = [5.0, 7.0];
least_squares_qr(2, 2, 1, &mut a, 2, &mut b, 2).unwrap();
assert_near(b[0], 1.6, "x[0]");
assert_near(b[1], 1.8, "x[1]");
}
#[test]
fn test_symeig_full_3x3_trace() {
let mut a = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 4.0];
let mut w = [0.0_f64; 3];
symeig_full(3, &mut a, 3, &mut w).unwrap();
let sum = w[0] + w[1] + w[2];
assert_near(sum, 9.0, "sum(eigenvalues) = trace");
assert!(w[0] <= w[1], "eigenvalues ascending");
assert!(w[1] <= w[2], "eigenvalues ascending");
}
#[test]
fn test_symeig_full_orthogonal_q() {
let mut a = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 4.0];
let mut w = [0.0_f64; 3];
symeig_full(3, &mut a, 3, &mut w).unwrap();
let mut qtq = [0.0_f64; 9];
blocked_gemm(3, 3, 3, 1.0, &a, 3, &a, 3, 0.0, &mut qtq, 3, true, false).unwrap();
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(qtq[j * 3 + i], expected, &format!("QtQ[{i},{j}]"));
}
}
}
#[test]
fn test_syrk_fisher_info_2x3() {
let a = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let mut c = [0.0_f64; 4]; syrk_fisher_info(2, 3, 1.0, &a, 2, 0.0, &mut c, 2).unwrap();
assert_near(c[0], 14.0, "c[0,0]");
assert_near(c[1], 32.0, "c[1,0]");
assert_near(c[3], 77.0, "c[1,1]");
}
#[test]
fn test_syrk_fisher_info_positive_semidefinite() {
let a = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let mut c = [0.0_f64; 4];
syrk_fisher_info(2, 3, 1.0, &a, 2, 0.0, &mut c, 2).unwrap();
assert!(c[0] >= 0.0, "c[0,0] >= 0");
assert!(c[3] >= 0.0, "c[1,1] >= 0");
let det = c[0] * c[3] - c[1] * c[2];
assert!(det >= -TOL, "det(C) >= 0 for PSD");
}
#[test]
fn test_sym_rank2k_update_symmetric() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0]; let b = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 0.0, 0.0, 0.0]; let mut c = [0.0_f64; 4]; sym_rank2k_update(2, 3, 1.0, &a, 3, &b, 3, 0.0, &mut c, 2).unwrap();
assert_near(c[0], 100.0, "C[0,0]");
assert_near(c[1], 190.0, "C[1,0]");
assert_near(c[3], 334.0, "C[1,1]");
}
#[test]
fn test_sym_rank2k_update_a_equals_b() {
let a = [1.0, 2.0, 3.0, 4.0]; let mut c_rank2k = [0.0_f64; 4];
sym_rank2k_update(2, 2, 1.0, &a, 2, &a, 2, 0.0, &mut c_rank2k, 2).unwrap();
let mut c_syrk = [0.0_f64; 4];
syrk_panel(2, 2, 2.0, &a, 2, 0.0, &mut c_syrk, 2, true).unwrap();
assert_near(c_rank2k[0], c_syrk[0], "C[0,0]");
assert_near(c_rank2k[1], c_syrk[1], "C[1,0]");
assert_near(c_rank2k[3], c_syrk[3], "C[1,1]");
}
}