use faer::linalg::matmul::matmul;
use faer::{Accum, Mat, MatRef, Par};
use crate::ProcrustesError;
#[non_exhaustive]
#[derive(Debug, Clone, Copy)]
pub enum InnerAligner {
Orthogonal,
SignedPermutation,
RotationOnly,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy)]
pub enum GpaInit {
FirstMatrix,
Mean,
}
#[derive(Debug, Clone)]
pub struct GpaOptions {
pub inner: InnerAligner,
pub init: GpaInit,
pub tol: f64,
pub max_iters: usize,
pub procrustes_form: bool,
pub weights: Option<Vec<f64>>,
}
impl Default for GpaOptions {
fn default() -> Self {
Self {
inner: InnerAligner::Orthogonal,
init: GpaInit::FirstMatrix,
tol: 1e-10,
max_iters: 100,
procrustes_form: false,
weights: None,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct GpaAlignment {
pub consensus: Mat<f64>,
pub aligned: Vec<Mat<f64>>,
pub n_iters: usize,
pub final_drift: f64,
pub converged: bool,
}
#[allow(clippy::many_single_char_names)]
#[allow(clippy::needless_pass_by_value)]
#[allow(clippy::too_many_lines)]
pub fn generalized(
matrices: &[MatRef<'_, f64>],
opts: GpaOptions,
) -> Result<GpaAlignment, ProcrustesError> {
if matrices.is_empty() {
return Err(ProcrustesError::EmptyInput);
}
let (ref_rows, ref_cols) = (matrices[0].nrows(), matrices[0].ncols());
if ref_rows == 0 || ref_cols == 0 {
return Err(ProcrustesError::EmptyInput);
}
for m in matrices.iter().skip(1) {
if m.nrows() != ref_rows || m.ncols() != ref_cols {
return Err(ProcrustesError::DimensionMismatch {
a_rows: m.nrows(),
a_cols: m.ncols(),
ref_rows,
ref_cols,
});
}
}
if let Some(w) = &opts.weights {
if w.len() != matrices.len() {
return Err(ProcrustesError::InvalidOptions(
"weights length must equal matrices length",
));
}
for &wi in w {
if !wi.is_finite() {
return Err(ProcrustesError::InvalidOptions(
"weights contain non-finite value",
));
}
if wi < 0.0 {
return Err(ProcrustesError::InvalidOptions(
"weights contain negative value",
));
}
}
let sum: f64 = w.iter().sum();
if sum <= 0.0 {
return Err(ProcrustesError::InvalidOptions("weights sum to zero"));
}
}
let scaled_storage: Vec<Mat<f64>> = if opts.procrustes_form {
let mut storage = Vec::with_capacity(matrices.len());
for &m in matrices {
let norm = frobenius(m);
if norm < f64::EPSILON {
return Err(ProcrustesError::InvalidOptions(
"procrustes_form requires non-zero inputs",
));
}
let inv = 1.0 / norm;
storage.push(Mat::<f64>::from_fn(ref_rows, ref_cols, |i, j| {
m[(i, j)] * inv
}));
}
storage
} else {
Vec::new()
};
let scaled_refs: Vec<MatRef<'_, f64>> = scaled_storage.iter().map(Mat::as_ref).collect();
let inputs: &[MatRef<'_, f64>] = if opts.procrustes_form {
&scaled_refs
} else {
matrices
};
let mut consensus = match opts.init {
GpaInit::FirstMatrix => {
let m = inputs[0];
Mat::<f64>::from_fn(m.nrows(), m.ncols(), |i, j| m[(i, j)])
}
GpaInit::Mean => weighted_mean(inputs, opts.weights.as_deref()),
};
let mut last_aligned: Vec<Mat<f64>> = Vec::new();
let mut last_drift = f64::NAN;
for iter in 0..opts.max_iters {
let aligned: Vec<Mat<f64>> = inputs
.iter()
.map(|&m| apply_inner(m, consensus.as_ref(), opts.inner))
.collect();
let aligned_refs: Vec<MatRef<'_, f64>> = aligned.iter().map(Mat::as_ref).collect();
let new_consensus = weighted_mean(&aligned_refs, opts.weights.as_deref());
let drift = frobenius_diff(new_consensus.as_ref(), consensus.as_ref());
consensus = new_consensus;
if drift < opts.tol {
return Ok(GpaAlignment {
consensus,
aligned,
n_iters: iter + 1,
final_drift: drift,
converged: true,
});
}
last_aligned = aligned;
last_drift = drift;
}
Ok(GpaAlignment {
consensus,
aligned: last_aligned,
n_iters: opts.max_iters,
final_drift: last_drift,
converged: false,
})
}
#[allow(clippy::many_single_char_names)]
fn apply_inner(
matrix: MatRef<'_, f64>,
reference: MatRef<'_, f64>,
inner: InnerAligner,
) -> Mat<f64> {
let (m, k) = (matrix.nrows(), matrix.ncols());
match inner {
InnerAligner::Orthogonal => {
let aln = crate::orthogonal(matrix, reference, false).expect("validated upstream");
matmul_apply(matrix, aln.rotation.as_ref())
}
InnerAligner::RotationOnly => {
let aln = crate::rotation_only(matrix, reference, false).expect("validated upstream");
matmul_apply(matrix, aln.rotation.as_ref())
}
InnerAligner::SignedPermutation => {
let aln =
crate::signed_permutation(matrix, reference, false).expect("validated upstream");
let mut out = Mat::<f64>::zeros(m, k);
for kk in 0..k {
let s = aln.signs[kk];
let src = aln.assigned[kk];
for ii in 0..m {
out[(ii, kk)] = s * matrix[(ii, src)];
}
}
out
}
}
}
fn matmul_apply(matrix: MatRef<'_, f64>, rot: MatRef<'_, f64>) -> Mat<f64> {
let (m, k) = (matrix.nrows(), matrix.ncols());
let mut out = Mat::<f64>::zeros(m, k);
matmul(out.as_mut(), Accum::Replace, matrix, rot, 1.0, Par::Seq);
out
}
#[allow(clippy::many_single_char_names)]
#[allow(clippy::cast_precision_loss)]
fn weighted_mean(matrices: &[MatRef<'_, f64>], weights: Option<&[f64]>) -> Mat<f64> {
let n = matrices.len();
let (rows, cols) = (matrices[0].nrows(), matrices[0].ncols());
let mut out = Mat::<f64>::zeros(rows, cols);
match weights {
None => {
let scale = 1.0 / (n as f64);
for &m in matrices {
for j in 0..cols {
for r in 0..rows {
out[(r, j)] += scale * m[(r, j)];
}
}
}
}
Some(w) => {
let inv = 1.0 / w.iter().sum::<f64>();
for (&wi, &m) in w.iter().zip(matrices.iter()) {
if wi == 0.0 {
continue;
}
let scale = wi * inv;
for j in 0..cols {
for r in 0..rows {
out[(r, j)] += scale * m[(r, j)];
}
}
}
}
}
out
}
fn frobenius_diff(a: MatRef<'_, f64>, b: MatRef<'_, f64>) -> f64 {
let mut s = 0.0;
for j in 0..a.ncols() {
for i in 0..a.nrows() {
let d = a[(i, j)] - b[(i, j)];
s += d * d;
}
}
s.sqrt()
}
fn frobenius(x: MatRef<'_, f64>) -> f64 {
let mut s = 0.0;
for j in 0..x.ncols() {
for i in 0..x.nrows() {
let v = x[(i, j)];
s += v * v;
}
}
s.sqrt()
}