use nalgebra as na;
use nalgebra::linalg::SymmetricEigen;
use num_traits::float::Float;
pub type SVector3 = na::SVector<f64, 3>;
pub type SVector6 = na::SVector<f64, 6>;
pub type SMatrix3 = na::SMatrix<f64, 3, 3>;
pub type SMatrix6 = na::SMatrix<f64, 6, 6>;
pub fn split_float<T: Float>(num: T) -> (T, T) {
(T::trunc(num), T::fract(num))
}
pub fn vector3_from_array(vec: [f64; 3]) -> na::Vector3<f64> {
na::Vector3::new(vec[0], vec[1], vec[2])
}
pub fn vector6_from_array(vec: [f64; 6]) -> na::SVector<f64, 6> {
na::SVector::<f64, 6>::new(vec[0], vec[1], vec[2], vec[3], vec[4], vec[5])
}
pub fn matrix3_from_array(mat: &[[f64; 3]; 3]) -> na::SMatrix<f64, 3, 3> {
na::SMatrix::<f64, 3, 3>::new(
mat[0][0], mat[0][1], mat[0][2], mat[1][0], mat[1][1], mat[1][2], mat[2][0], mat[2][1],
mat[2][2],
)
}
pub fn kronecker_delta(i: usize, j: usize) -> u8 {
if i == j { 1 } else { 0 }
}
pub fn spd_sqrtm<const N: usize>(
matrix: na::SMatrix<f64, N, N>,
) -> Result<na::SMatrix<f64, N, N>, String>
where
na::Const<N>: na::DimName,
{
let dmatrix = na::DMatrix::from_iterator(N, N, matrix.iter().cloned());
let eigen = SymmetricEigen::new(dmatrix);
for &eigenvalue in eigen.eigenvalues.iter() {
if eigenvalue < 0.0 {
return Err(format!(
"Matrix is not positive-definite: found negative eigenvalue {}",
eigenvalue
));
}
}
let sqrt_eigenvalues = eigen.eigenvalues.map(|x: f64| x.sqrt());
let v = &eigen.eigenvectors;
let sqrt_d = na::DMatrix::<f64>::from_diagonal(&sqrt_eigenvalues);
let result_dmatrix = v * sqrt_d * v.transpose();
let mut result = na::SMatrix::<f64, N, N>::zeros();
for i in 0..N {
for j in 0..N {
result[(i, j)] = result_dmatrix[(i, j)];
}
}
Ok(result)
}
pub fn sqrtm<const N: usize>(
matrix: na::SMatrix<f64, N, N>,
) -> Result<na::SMatrix<f64, N, N>, String>
where
na::Const<N>: na::DimName,
{
let a = na::DMatrix::from_iterator(N, N, matrix.iter().cloned());
let mut y = a.clone();
let mut z = na::DMatrix::<f64>::identity(N, N);
const MAX_ITERATIONS: usize = 50;
const TOLERANCE: f64 = 1e-10;
for _ in 0..MAX_ITERATIONS {
let y_inv = y.clone().try_inverse().ok_or_else(|| {
"Matrix became singular during iteration; cannot compute matrix square root".to_string()
})?;
let z_inv = z.clone().try_inverse().ok_or_else(|| {
"Iteration matrix became singular; cannot compute matrix square root".to_string()
})?;
let y_new = (&y + &z_inv) * 0.5;
let z_new = (&z + &y_inv) * 0.5;
let diff = (&y_new - &y).norm();
if diff < TOLERANCE {
y = y_new;
break;
}
y = y_new;
z = z_new;
}
let check = &y * &y;
let error = (&check - &a).norm();
if error > 1e-8 {
return Err(format!(
"Matrix square root did not converge to sufficient accuracy (error: {})",
error
));
}
let mut result = na::SMatrix::<f64, N, N>::zeros();
for i in 0..N {
for j in 0..N {
result[(i, j)] = y[(i, j)];
}
}
Ok(result)
}
pub fn sqrtm_dmatrix(matrix: &na::DMatrix<f64>) -> Result<na::DMatrix<f64>, String> {
if matrix.nrows() != matrix.ncols() {
return Err(format!(
"Matrix must be square, got {}x{}",
matrix.nrows(),
matrix.ncols()
));
}
let n = matrix.nrows();
let mut y = matrix.clone();
let mut z = na::DMatrix::<f64>::identity(n, n);
const MAX_ITERATIONS: usize = 50;
const TOLERANCE: f64 = 1e-10;
for _ in 0..MAX_ITERATIONS {
let y_inv = y.clone().try_inverse().ok_or_else(|| {
"Matrix became singular during iteration; cannot compute matrix square root".to_string()
})?;
let z_inv = z.clone().try_inverse().ok_or_else(|| {
"Iteration matrix became singular; cannot compute matrix square root".to_string()
})?;
let y_new = (&y + &z_inv) * 0.5;
let z_new = (&z + &y_inv) * 0.5;
let diff = (&y_new - &y).norm();
if diff < TOLERANCE {
y = y_new;
break;
}
y = y_new;
z = z_new;
}
let check = &y * &y;
let error = (&check - matrix).norm();
if error > 1e-8 {
return Err(format!(
"Matrix square root did not converge to sufficient accuracy (error: {})",
error
));
}
Ok(y)
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_split_float_f32() {
assert_eq!(split_float(1.5_f32), (1.0, 0.5));
assert_eq!(split_float(-1.5_f32), (-1.0, -0.5));
assert_eq!(split_float(0.0_f32), (0.0, 0.0));
assert_eq!(split_float(1.0_f32), (1.0, 0.0));
assert_eq!(split_float(-1.0_f32), (-1.0, 0.0));
}
#[test]
fn test_split_float_f64() {
assert_eq!(split_float(1.5_f64), (1.0, 0.5));
assert_eq!(split_float(-1.5_f64), (-1.0, -0.5));
assert_eq!(split_float(0.0_f64), (0.0, 0.0));
assert_eq!(split_float(1.0_f64), (1.0, 0.0));
assert_eq!(split_float(-1.0_f64), (-1.0, 0.0));
}
#[test]
fn test_vector3_from_array() {
let vec = [1.0, 2.0, 3.0];
let v = vector3_from_array(vec);
assert_eq!(v, na::Vector3::new(1.0, 2.0, 3.0));
}
#[test]
fn test_vector6_from_array() {
let vec = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let v = vector6_from_array(vec);
assert_eq!(v, na::SVector::<f64, 6>::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0));
}
#[test]
fn test_matrix3_from_array() {
let mat = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let m = matrix3_from_array(&mat);
assert_eq!(
m,
na::SMatrix::<f64, 3, 3>::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)
);
assert_eq!(m[(0, 0)], 1.0);
assert_eq!(m[(0, 1)], 2.0);
assert_eq!(m[(0, 2)], 3.0);
assert_eq!(m[(1, 0)], 4.0);
assert_eq!(m[(1, 1)], 5.0);
assert_eq!(m[(1, 2)], 6.0);
assert_eq!(m[(2, 0)], 7.0);
assert_eq!(m[(2, 1)], 8.0);
assert_eq!(m[(2, 2)], 9.0);
}
#[test]
fn test_kronecker_delta() {
assert_eq!(kronecker_delta(0, 0), 1);
assert_eq!(kronecker_delta(0, 1), 0);
assert_eq!(kronecker_delta(1, 0), 0);
assert_eq!(kronecker_delta(1, 1), 1);
}
#[test]
fn test_spd_sqrtm_identity() {
let identity = na::SMatrix::<f64, 3, 3>::identity();
let sqrt_identity = spd_sqrtm(identity).unwrap();
assert!((sqrt_identity - identity).norm() < 1e-10);
}
#[test]
fn test_spd_sqrtm_diagonal() {
let diag = na::SMatrix::<f64, 3, 3>::new(4.0, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0, 0.0, 16.0);
let sqrt_diag = spd_sqrtm(diag).unwrap();
let expected = na::SMatrix::<f64, 3, 3>::new(2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0);
assert!((sqrt_diag - expected).norm() < 1e-10);
let reconstructed = sqrt_diag * sqrt_diag;
assert!((reconstructed - diag).norm() < 1e-10);
}
#[test]
fn test_spd_sqrtm_covariance() {
let mut cov = na::SMatrix::<f64, 6, 6>::identity() * 100.0;
cov[(0, 1)] = 10.0;
cov[(1, 0)] = 10.0;
cov[(2, 3)] = 5.0;
cov[(3, 2)] = 5.0;
let sqrt_cov = spd_sqrtm(cov).unwrap();
let reconstructed = sqrt_cov * sqrt_cov;
assert!((reconstructed - cov).norm() < 1e-8);
let sqrt_cov_t = sqrt_cov.transpose();
assert!((sqrt_cov - sqrt_cov_t).norm() < 1e-10);
}
#[test]
fn test_spd_sqrtm_error_negative_eigenvalue() {
let mat = na::SMatrix::<f64, 2, 2>::new(1.0, 0.0, 0.0, -1.0);
let result = spd_sqrtm(mat);
assert!(result.is_err());
assert!(result.unwrap_err().contains("negative eigenvalue"));
}
#[test]
fn test_sqrtm_wiki_test_case() {
let a = na::SMatrix::<f64, 2, 2>::new(33.0, 24.0, 48.0, 57.0);
let sqrt_a = sqrtm(a).unwrap();
let expected = na::SMatrix::<f64, 2, 2>::new(5.0, 2.0, 4.0, 7.0);
assert!((sqrt_a - expected).norm() < 1e-10);
let reconstructed = sqrt_a * sqrt_a;
assert!((reconstructed - a).norm() < 1e-10);
}
#[test]
fn test_sqrtm_3x3_general() {
let mat = na::SMatrix::<f64, 3, 3>::new(5.0, 2.0, 1.0, 0.0, 3.0, 1.0, 0.0, 0.0, 2.0);
let sqrt_mat = sqrtm(mat).unwrap();
let reconstructed = sqrt_mat * sqrt_mat;
assert!((reconstructed - mat).norm() < 1e-10);
}
#[test]
fn test_sqrtm_symmetric_matches_spd() {
let mat = na::SMatrix::<f64, 3, 3>::new(4.0, 2.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0, 5.0);
let sqrt_spd = spd_sqrtm(mat).unwrap();
let sqrt_gen = sqrtm(mat).unwrap();
assert!((sqrt_spd - sqrt_gen).norm() < 1e-8);
}
#[test]
fn test_sqrtm_error_negative_eigenvalue() {
let mat = na::SMatrix::<f64, 2, 2>::new(-1.0, 0.0, 0.0, 4.0);
let result = sqrtm(mat);
assert!(result.is_err());
let err_msg = result.unwrap_err();
assert!(
err_msg.contains("singular")
|| err_msg.contains("converge")
|| err_msg.contains("accuracy")
);
}
#[test]
fn test_sqrtm_dmatrix_wiki_test_case() {
let a = na::DMatrix::from_row_slice(2, 2, &[33.0, 24.0, 48.0, 57.0]);
let sqrt_a = sqrtm_dmatrix(&a).unwrap();
let expected = na::DMatrix::from_row_slice(2, 2, &[5.0, 2.0, 4.0, 7.0]);
assert!((&sqrt_a - &expected).norm() < 1e-10);
let reconstructed = &sqrt_a * &sqrt_a;
assert!((reconstructed - &a).norm() < 1e-10);
}
#[test]
fn test_sqrtm_dmatrix_identity() {
let identity = na::DMatrix::<f64>::identity(4, 4);
let sqrt_identity = sqrtm_dmatrix(&identity).unwrap();
assert!((sqrt_identity - &identity).norm() < 1e-10);
}
#[test]
fn test_sqrtm_dmatrix_diagonal() {
let diag = na::DMatrix::from_row_slice(2, 2, &[4.0, 0.0, 0.0, 9.0]);
let sqrt_diag = sqrtm_dmatrix(&diag).unwrap();
let expected = na::DMatrix::from_row_slice(2, 2, &[2.0, 0.0, 0.0, 3.0]);
assert!((&sqrt_diag - &expected).norm() < 1e-10);
}
#[test]
fn test_sqrtm_dmatrix_non_square_error() {
let mat = na::DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let result = sqrtm_dmatrix(&mat);
assert!(result.is_err());
assert!(result.unwrap_err().contains("square"));
}
}