use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign, One};
use std::iter::Sum;
pub trait PolarFloat:
Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static
{
}
impl<F> PolarFloat for F where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static
{
}
fn matmul_nn<F: PolarFloat>(a: &Array2<F>, b: &Array2<F>) -> LinalgResult<Array2<F>> {
let (m, k) = (a.nrows(), a.ncols());
let (k2, n) = (b.nrows(), b.ncols());
if k != k2 {
return Err(LinalgError::ShapeError(format!(
"polar matmul: inner dims {} vs {}", k, k2
)));
}
let mut c = Array2::<F>::zeros((m, n));
for i in 0..m {
for l in 0..k {
let a_il = a[[i, l]];
if a_il == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] = c[[i, j]] + a_il * b[[l, j]];
}
}
}
Ok(c)
}
fn frobenius_norm<F: PolarFloat>(a: &Array2<F>) -> F {
let mut acc = F::zero();
for &v in a.iter() {
acc = acc + v * v;
}
acc.sqrt()
}
fn transpose<F: PolarFloat>(a: &Array2<F>) -> Array2<F> {
let (m, n) = (a.nrows(), a.ncols());
let mut t = Array2::<F>::zeros((n, m));
for i in 0..m {
for j in 0..n {
t[[j, i]] = a[[i, j]];
}
}
t
}
fn lu_factorize<F: PolarFloat>(a: &Array2<F>) -> LinalgResult<(Array2<F>, Vec<usize>)> {
let n = a.nrows();
let mut lu = a.clone();
let mut perm: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = F::zero();
let mut max_row = k;
for i in k..n {
let v = lu[[i, k]].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val < F::epsilon() * F::from(100.0).unwrap_or(F::one()) {
return Err(LinalgError::SingularMatrixError(
"polar: near-singular matrix in LU".into(),
));
}
if max_row != k {
perm.swap(k, max_row);
for j in 0..n {
let tmp = lu[[k, j]];
lu[[k, j]] = lu[[max_row, j]];
lu[[max_row, j]] = tmp;
}
}
for i in (k + 1)..n {
lu[[i, k]] = lu[[i, k]] / lu[[k, k]];
for j in (k + 1)..n {
let l_ik = lu[[i, k]];
let u_kj = lu[[k, j]];
lu[[i, j]] = lu[[i, j]] - l_ik * u_kj;
}
}
}
Ok((lu, perm))
}
fn lu_solve<F: PolarFloat>(lu: &Array2<F>, perm: &[usize], b: &Array2<F>) -> Array2<F> {
let n = lu.nrows();
let nrhs = b.ncols();
let mut x = Array2::<F>::zeros((n, nrhs));
for col in 0..nrhs {
let mut y = vec![F::zero(); n];
for i in 0..n {
y[i] = b[[perm[i], col]];
}
for i in 0..n {
for j in 0..i {
y[i] = y[i] - lu[[i, j]] * y[j];
}
}
let mut z = vec![F::zero(); n];
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum = sum - lu[[i, j]] * z[j];
}
z[i] = sum / lu[[i, i]];
}
for i in 0..n {
x[[i, col]] = z[i];
}
}
x
}
fn mat_inv<F: PolarFloat>(a: &Array2<F>) -> LinalgResult<Array2<F>> {
let n = a.nrows();
let eye = Array2::<F>::eye(n);
let (lu, perm) = lu_factorize(a)?;
Ok(lu_solve(&lu, &perm, &eye))
}
pub fn polar_via_svd<F>(a: &ArrayView2<F>) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float
+ NumAssign
+ Sum
+ One
+ ScalarOperand
+ Send
+ Sync
+ 'static
+ std::fmt::Display,
{
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"polar_via_svd: matrix must be square".into(),
));
}
if n == 0 {
return Ok((Array2::<F>::zeros((0, 0)), Array2::<F>::zeros((0, 0))));
}
use crate::decomposition::svd;
let (p, sigma, qt) = svd(a, true, None)?;
let q = transpose(&qt);
let u = matmul_nn(&p, &qt)?;
let mut sigma_qt = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
sigma_qt[[i, j]] = sigma[i] * qt[[i, j]];
}
}
let h = matmul_nn(&q, &sigma_qt)?;
Ok((u, h))
}
pub fn unitary_factor_halley<F: PolarFloat>(
a: &ArrayView2<F>,
max_iter: Option<usize>,
tol: Option<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"unitary_factor_halley: matrix must be square".into(),
));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
let max_iter = max_iter.unwrap_or(20);
let tol = tol.unwrap_or_else(|| F::from(1e-12).unwrap_or(F::epsilon()));
let three = F::from(3.0).unwrap_or(F::one() + F::one() + F::one());
let norm_a = {
let n_cols = a.ncols();
let mut max_col = F::zero();
for j in 0..n_cols {
let mut col_sum = F::zero();
for i in 0..n {
col_sum = col_sum + a[[i, j]].abs();
}
if col_sum > max_col {
max_col = col_sum;
}
}
max_col
};
if norm_a < F::from(1e-15).unwrap_or(F::epsilon()) {
return Err(LinalgError::SingularMatrixError(
"unitary_factor_halley: zero or near-zero matrix".into(),
));
}
let mut x = a.to_owned();
for v in x.iter_mut() {
*v = *v / norm_a;
}
for _ in 0..max_iter {
let xt = transpose(&x);
let xtx = matmul_nn(&xt, &x)?;
let mut num_inner = xtx.clone();
for i in 0..n {
num_inner[[i, i]] = num_inner[[i, i]] + three;
}
let mut den_inner = xtx;
for i in 0..n {
den_inner[[i, i]] = den_inner[[i, i]] + F::one();
}
for v in den_inner.iter_mut() {
*v = *v * three;
}
let xt2 = transpose(&x);
let xtx2 = matmul_nn(&xt2, &x)?;
let mut denom = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
let diag_one = if i == j { F::one() } else { F::zero() };
denom[[i, j]] = diag_one + three * xtx2[[i, j]];
}
}
let denom_inv = mat_inv(&denom)?;
let x_num = matmul_nn(&x, &num_inner)?;
let x_new = matmul_nn(&x_num, &denom_inv)?;
let mut diff = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
diff[[i, j]] = x_new[[i, j]] - x[[i, j]];
}
}
let rel = frobenius_norm(&diff)
/ (frobenius_norm(&x_new) + F::from(1e-30).unwrap_or(F::epsilon()));
x = x_new;
if rel < tol {
return Ok(x);
}
}
Ok(x)
}
pub fn unitary_factor_newton<F: PolarFloat>(
a: &ArrayView2<F>,
max_iter: Option<usize>,
tol: Option<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"unitary_factor_newton: matrix must be square".into(),
));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
let max_iter = max_iter.unwrap_or(50);
let tol = tol.unwrap_or_else(|| F::from(1e-12).unwrap_or(F::epsilon()));
let two = F::from(2.0).unwrap_or(F::one() + F::one());
let mut x = a.to_owned();
for _ in 0..max_iter {
let xt = transpose(&x);
let xt_inv = mat_inv(&xt)?;
let mut x_new = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
x_new[[i, j]] = (x[[i, j]] + xt_inv[[i, j]]) / two;
}
}
let mut diff = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
diff[[i, j]] = x_new[[i, j]] - x[[i, j]];
}
}
let rel = frobenius_norm(&diff)
/ (frobenius_norm(&x_new) + F::from(1e-30).unwrap_or(F::epsilon()));
x = x_new;
if rel < tol {
return Ok(x);
}
}
Ok(x)
}
pub fn polar_decomposition<F: PolarFloat>(
a: &ArrayView2<F>,
max_iter: Option<usize>,
tol: Option<F>,
) -> LinalgResult<(Array2<F>, Array2<F>)> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"polar_decomposition: matrix must be square".into(),
));
}
if n == 0 {
return Ok((Array2::<F>::zeros((0, 0)), Array2::<F>::zeros((0, 0))));
}
let u = unitary_factor_halley(a, max_iter, tol)?;
let ut = transpose(&u);
let h = matmul_nn(&ut, &a.to_owned())?;
let ht = transpose(&h);
let two = F::from(2.0).unwrap_or(F::one() + F::one());
let mut h_sym = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
h_sym[[i, j]] = (h[[i, j]] + ht[[i, j]]) / two;
}
}
Ok((u, h_sym))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_polar_identity() {
let eye = Array2::<f64>::eye(3);
let (u, h) = polar_decomposition(&eye.view(), None, None).expect("polar identity");
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(u[[i, j]], expected, epsilon = 1e-8);
assert_abs_diff_eq!(h[[i, j]], expected, epsilon = 1e-8);
}
}
}
#[test]
fn test_polar_u_orthogonal() {
let a = array![[3.0_f64, 1.0], [0.0, 2.0]];
let (u, _h) = polar_decomposition(&a.view(), None, None).expect("polar u orthogonal");
let ut = transpose(&u);
let utu = matmul_nn(&ut, &u).expect("u^T u");
assert_abs_diff_eq!(utu[[0, 0]], 1.0, epsilon = 1e-7);
assert_abs_diff_eq!(utu[[1, 1]], 1.0, epsilon = 1e-7);
assert_abs_diff_eq!(utu[[0, 1]], 0.0, epsilon = 1e-7);
assert_abs_diff_eq!(utu[[1, 0]], 0.0, epsilon = 1e-7);
}
#[test]
fn test_polar_reconstruct() {
let a = array![[2.0_f64, 1.0], [1.0, 3.0]];
let (u, h) = polar_decomposition(&a.view(), None, None).expect("polar reconstruct");
let a_rec = matmul_nn(&u, &h).expect("a_rec matmul");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(a_rec[[i, j]], a[[i, j]], epsilon = 1e-7);
}
}
}
#[test]
fn test_polar_h_symmetric() {
let a = array![[3.0_f64, 2.0], [1.0, 4.0]];
let (_u, h) = polar_decomposition(&a.view(), None, None).expect("polar h symmetric");
assert_abs_diff_eq!(h[[0, 1]], h[[1, 0]], epsilon = 1e-8);
}
#[test]
fn test_unitary_factor_newton_identity() {
let eye = Array2::<f64>::eye(2);
let u = unitary_factor_newton(&eye.view(), None, None).expect("newton identity");
for i in 0..2 {
for j in 0..2 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(u[[i, j]], expected, epsilon = 1e-8);
}
}
}
#[test]
fn test_polar_via_svd_reconstruct() {
let a = array![[2.0_f64, 0.5], [0.1, 3.0]];
let (u, h) = polar_via_svd(&a.view()).expect("polar_via_svd reconstruct");
let a_rec = matmul_nn(&u, &h).expect("a_rec svd");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(a_rec[[i, j]], a[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_polar_via_svd_u_orthogonal() {
let a = array![[4.0_f64, 2.0], [1.0, 3.0]];
let (u, _h) = polar_via_svd(&a.view()).expect("polar_via_svd u orthogonal");
let ut = transpose(&u);
let utu = matmul_nn(&ut, &u).expect("utu svd");
for i in 0..2 {
for j in 0..2 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(utu[[i, j]], expected, epsilon = 1e-8);
}
}
}
}