use faer::linalg::matmul::matmul;
use faer::{Accum, Mat, MatRef, Par};
use crate::ProcrustesError;
#[allow(clippy::many_single_char_names)]
pub fn orthogonal(
a: MatRef<'_, f64>,
reference: MatRef<'_, f64>,
check_finite: bool,
) -> Result<OrthogonalAlignment, ProcrustesError> {
let (a_rows, a_cols) = (a.nrows(), a.ncols());
let (ref_rows, ref_cols) = (reference.nrows(), reference.ncols());
if a_rows != ref_rows || a_cols != ref_cols {
return Err(ProcrustesError::DimensionMismatch {
a_rows,
a_cols,
ref_rows,
ref_cols,
});
}
if a_rows == 0 || a_cols == 0 {
return Err(ProcrustesError::EmptyInput);
}
if check_finite && (!is_all_finite(a) || !is_all_finite(reference)) {
return Err(ProcrustesError::NonFinite);
}
let k = a_cols;
let mut m_buf = Mat::<f64>::zeros(k, k);
matmul(
m_buf.as_mut(),
Accum::Replace,
a.transpose(),
reference,
1.0,
Par::Seq,
);
let Ok(svd) = m_buf.as_ref().svd() else {
return Ok(OrthogonalAlignment {
rotation: Mat::<f64>::from_fn(k, k, |_, _| f64::NAN),
scale: f64::NAN,
});
};
let u = svd.U();
let v = svd.V();
let mut rotation = Mat::<f64>::zeros(k, k);
matmul(
rotation.as_mut(),
Accum::Replace,
u,
v.transpose(),
1.0,
Par::Seq,
);
let mut scale = 0.0;
for i in 0..k {
for j in 0..k {
scale += m_buf[(i, j)] * rotation[(i, j)];
}
}
Ok(OrthogonalAlignment { rotation, scale })
}
fn is_all_finite(x: MatRef<'_, f64>) -> bool {
for j in 0..x.ncols() {
for i in 0..x.nrows() {
if !x[(i, j)].is_finite() {
return false;
}
}
}
true
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct OrthogonalAlignment {
pub rotation: Mat<f64>,
pub scale: f64,
}
impl OrthogonalAlignment {
#[must_use]
pub fn residual_frobenius(&self, a: MatRef<'_, f64>, reference: MatRef<'_, f64>) -> f64 {
let a_sq = frobenius_sq(a);
let r_sq = frobenius_sq(reference);
(a_sq + r_sq - 2.0 * self.scale).max(0.0).sqrt()
}
}
fn frobenius_sq(x: MatRef<'_, f64>) -> f64 {
let mut s = 0.0;
for j in 0..x.ncols() {
for i in 0..x.nrows() {
let v = x[(i, j)];
s += v * v;
}
}
s
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ProcrustesError;
use faer::linalg::matmul::matmul;
use faer::{Accum, Mat, Par};
#[test]
fn procrustes_recovers_known_rotation() {
let reference = Mat::<f64>::from_fn(4, 2, |i, j| match (i, j) {
(0, 0) | (1, 1) => 1.0,
(2, 0) | (3, 1) => 0.5,
_ => 0.0,
});
let theta = std::f64::consts::PI / 6.0;
let r0 = Mat::<f64>::from_fn(2, 2, |i, j| match (i, j) {
(0, 0) | (1, 1) => theta.cos(),
(0, 1) => -theta.sin(),
(1, 0) => theta.sin(),
_ => unreachable!(),
});
let a: Mat<f64> = &reference * &r0;
let aln = orthogonal(a.as_ref(), reference.as_ref(), false).unwrap();
let recovered: Mat<f64> = &a * &aln.rotation;
for i in 0..4 {
for j in 0..2 {
assert!(
(recovered[(i, j)] - reference[(i, j)]).abs() < 1e-10,
"i={i} j={j} got {} want {}",
recovered[(i, j)],
reference[(i, j)]
);
}
}
}
#[test]
#[allow(non_snake_case)]
fn orthogonality_of_R() {
use rand::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let a = Mat::<f64>::from_fn(8, 4, |_, _| rand::Rng::gen_range(&mut rng, -1.0..1.0));
let reference = Mat::<f64>::from_fn(8, 4, |_, _| rand::Rng::gen_range(&mut rng, -1.0..1.0));
let aln = orthogonal(a.as_ref(), reference.as_ref(), false).unwrap();
let mut rtr = Mat::<f64>::zeros(4, 4);
matmul(
rtr.as_mut(),
Accum::Replace,
aln.rotation.transpose(),
aln.rotation.as_ref(),
1.0,
Par::Seq,
);
for i in 0..4 {
for j in 0..4 {
let want = if i == j { 1.0 } else { 0.0 };
assert!(
(rtr[(i, j)] - want).abs() < 1e-12,
"RᵀR[{i},{j}] = {}",
rtr[(i, j)]
);
}
}
}
#[test]
fn procrustes_zero_input_returns_orthogonal() {
let w = Mat::<f64>::zeros(5, 3);
let aln = orthogonal(w.as_ref(), w.as_ref(), false).unwrap();
let mut rtr = Mat::<f64>::zeros(3, 3);
matmul(
rtr.as_mut(),
Accum::Replace,
aln.rotation.transpose(),
aln.rotation.as_ref(),
1.0,
Par::Seq,
);
for i in 0..3 {
for j in 0..3 {
let want = if i == j { 1.0 } else { 0.0 };
assert!((rtr[(i, j)] - want).abs() < 1e-10);
}
}
}
#[test]
fn scale_matches_nuclear_norm() {
let a = Mat::<f64>::from_fn(6, 3, |i, j| if i == j { 1.0 } else { 0.0 });
let aln = orthogonal(a.as_ref(), a.as_ref(), false).unwrap();
assert!(
(aln.scale - 3.0).abs() < 1e-12,
"scale = {} want 3",
aln.scale
);
}
#[test]
fn residual_frobenius_method_matches_direct_computation() {
use rand::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(7);
let a = Mat::<f64>::from_fn(10, 3, |_, _| rand::Rng::gen_range(&mut rng, -1.0..1.0));
let reference =
Mat::<f64>::from_fn(10, 3, |_, _| rand::Rng::gen_range(&mut rng, -1.0..1.0));
let aln = orthogonal(a.as_ref(), reference.as_ref(), false).unwrap();
let mut ar = Mat::<f64>::zeros(10, 3);
matmul(
ar.as_mut(),
Accum::Replace,
a.as_ref(),
aln.rotation.as_ref(),
1.0,
Par::Seq,
);
let mut direct_sq = 0.0;
for i in 0..10 {
for j in 0..3 {
let d = ar[(i, j)] - reference[(i, j)];
direct_sq += d * d;
}
}
let direct = direct_sq.sqrt();
let via_method = aln.residual_frobenius(a.as_ref(), reference.as_ref());
assert!(
(via_method - direct).abs() < 1e-12,
"method {via_method} direct {direct}"
);
}
#[test]
#[allow(clippy::cast_precision_loss)]
fn k_eq_1() {
let a = Mat::<f64>::from_fn(5, 1, |i, _| (i as f64) - 2.0);
let reference = Mat::<f64>::from_fn(5, 1, |i, _| -((i as f64) - 2.0));
let aln = orthogonal(a.as_ref(), reference.as_ref(), false).unwrap();
assert!((aln.rotation[(0, 0)] + 1.0).abs() < 1e-12);
}
#[test]
fn empty_input_returns_error() {
let zero_rows = Mat::<f64>::zeros(0, 3);
let zero_cols = Mat::<f64>::zeros(5, 0);
assert!(matches!(
orthogonal(zero_rows.as_ref(), zero_rows.as_ref(), false),
Err(ProcrustesError::EmptyInput)
));
assert!(matches!(
orthogonal(zero_cols.as_ref(), zero_cols.as_ref(), false),
Err(ProcrustesError::EmptyInput)
));
}
#[test]
fn dim_mismatch_returns_error() {
let a = Mat::<f64>::zeros(5, 3);
let ref_rows = Mat::<f64>::zeros(4, 3);
let ref_cols = Mat::<f64>::zeros(5, 2);
assert!(matches!(
orthogonal(a.as_ref(), ref_rows.as_ref(), false),
Err(ProcrustesError::DimensionMismatch { .. })
));
assert!(matches!(
orthogonal(a.as_ref(), ref_cols.as_ref(), false),
Err(ProcrustesError::DimensionMismatch { .. })
));
}
#[test]
fn nan_with_check_finite_true_returns_error() {
let mut a = Mat::<f64>::zeros(3, 2);
a[(1, 1)] = f64::NAN;
let reference = Mat::<f64>::zeros(3, 2);
assert!(matches!(
orthogonal(a.as_ref(), reference.as_ref(), true),
Err(ProcrustesError::NonFinite)
));
}
#[test]
fn nan_with_check_finite_false_does_not_panic() {
let mut a = Mat::<f64>::zeros(3, 2);
a[(1, 1)] = f64::NAN;
let reference = Mat::<f64>::zeros(3, 2);
let _ = orthogonal(a.as_ref(), reference.as_ref(), false);
}
}