use faer::MatRef;
use crate::ProcrustesError;
#[allow(clippy::many_single_char_names)]
pub fn signed_permutation(
a: MatRef<'_, f64>,
reference: MatRef<'_, f64>,
check_finite: bool,
) -> Result<SignedPermutationAlignment, ProcrustesError> {
fn heap_permute(buf: &mut Vec<usize>, n: usize, on_perm: &mut dyn FnMut(&[usize])) {
if n == 1 {
on_perm(buf);
return;
}
for i in 0..n {
heap_permute(buf, n - 1, on_perm);
if n % 2 == 0 {
buf.swap(i, n - 1);
} else {
buf.swap(0, n - 1);
}
}
}
let (a_rows, a_cols) = (a.nrows(), a.ncols());
let (ref_rows, ref_cols) = (reference.nrows(), reference.ncols());
if a_rows != ref_rows || a_cols != ref_cols {
return Err(ProcrustesError::DimensionMismatch {
a_rows,
a_cols,
ref_rows,
ref_cols,
});
}
if a_rows == 0 || a_cols == 0 {
return Err(ProcrustesError::EmptyInput);
}
if check_finite && (!is_all_finite(a) || !is_all_finite(reference)) {
return Err(ProcrustesError::NonFinite);
}
let d = a_rows;
let k = a_cols;
let mut dot = vec![0.0_f64; k * k];
for i in 0..k {
for j in 0..k {
let mut s = 0.0;
for r in 0..d {
s += a[(r, i)] * reference[(r, j)];
}
dot[i * k + j] = s;
}
}
let mut nb = vec![0.0_f64; k];
let mut nr = vec![0.0_f64; k];
for i in 0..k {
let mut sb = 0.0;
let mut sr = 0.0;
for r in 0..d {
sb += a[(r, i)] * a[(r, i)];
sr += reference[(r, i)] * reference[(r, i)];
}
nb[i] = sb;
nr[i] = sr;
}
let mut perm: Vec<usize> = (0..k).collect();
let mut best_assigned: Vec<usize> = perm.clone();
let mut best_signs: Vec<f64> = vec![1.0; k];
let mut best_cost = f64::INFINITY;
let mut signs_scratch = vec![0.0_f64; k];
let mut on_perm = |p: &[usize]| {
let mut cost = 0.0;
for kk in 0..k {
let d_pk = dot[p[kk] * k + kk];
cost += nb[p[kk]] - 2.0 * d_pk.abs() + nr[kk];
signs_scratch[kk] = if d_pk >= 0.0 { 1.0 } else { -1.0 };
}
if cost < best_cost {
best_cost = cost;
best_assigned.clear();
best_assigned.extend_from_slice(p);
best_signs.clone_from(&signs_scratch);
}
};
heap_permute(&mut perm, k, &mut on_perm);
let residual_frobenius = best_cost.max(0.0).sqrt();
Ok(SignedPermutationAlignment {
assigned: best_assigned,
signs: best_signs,
residual_frobenius,
})
}
fn is_all_finite(x: MatRef<'_, f64>) -> bool {
for j in 0..x.ncols() {
for i in 0..x.nrows() {
if !x[(i, j)].is_finite() {
return false;
}
}
}
true
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct SignedPermutationAlignment {
pub assigned: Vec<usize>,
pub signs: Vec<f64>,
pub residual_frobenius: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ProcrustesError;
use faer::Mat;
#[test]
fn signed_perm_recovers_swap_with_sign_flip() {
let reference = Mat::<f64>::from_fn(3, 2, |i, j| match (i, j) {
(0, 0) | (1, 1) => 1.0,
_ => 0.0,
});
let a = Mat::<f64>::from_fn(3, 2, |i, j| match (i, j) {
(0, 1) => 1.0,
(1, 0) => -1.0,
_ => 0.0,
});
let out = signed_permutation(a.as_ref(), reference.as_ref(), false).unwrap();
assert_eq!(out.assigned, vec![1, 0]);
assert_eq!(out.signs, vec![1.0, -1.0]);
assert!(
out.residual_frobenius < 1e-10,
"got {}",
out.residual_frobenius
);
}
#[test]
fn signed_perm_identity_when_already_aligned() {
let w = Mat::<f64>::from_fn(4, 3, |i, j| if i == j { 1.0 } else { 0.0 });
let out = signed_permutation(w.as_ref(), w.as_ref(), false).unwrap();
assert_eq!(out.assigned, vec![0, 1, 2]);
assert_eq!(out.signs, vec![1.0, 1.0, 1.0]);
assert!(out.residual_frobenius < 1e-12);
}
#[test]
#[allow(clippy::cast_precision_loss)]
fn k_eq_1_trivial() {
let reference = Mat::<f64>::from_fn(5, 1, |i, _| (i as f64) - 2.0);
let a = Mat::<f64>::from_fn(5, 1, |i, _| -((i as f64) - 2.0));
let out = signed_permutation(a.as_ref(), reference.as_ref(), false).unwrap();
assert_eq!(out.assigned, vec![0]);
assert_eq!(out.signs, vec![-1.0]);
assert!(out.residual_frobenius < 1e-12);
}
#[test]
fn k_eq_8_recovers_known_alignment() {
use rand::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(123);
let m = 24;
let k = 8;
let a = Mat::<f64>::from_fn(m, k, |_, _| rand::Rng::gen_range(&mut rng, -1.0..1.0));
let true_perm: Vec<usize> = vec![3, 0, 7, 1, 5, 2, 6, 4];
let true_signs: Vec<f64> = vec![1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0];
let reference = Mat::<f64>::from_fn(m, k, |i, j| true_signs[j] * a[(i, true_perm[j])]);
let out = signed_permutation(a.as_ref(), reference.as_ref(), false).unwrap();
assert_eq!(out.assigned, true_perm);
assert_eq!(out.signs, true_signs);
assert!(
out.residual_frobenius < 1e-10,
"got {}",
out.residual_frobenius
);
}
#[test]
fn empty_input_returns_error() {
let z = Mat::<f64>::zeros(0, 3);
assert!(matches!(
signed_permutation(z.as_ref(), z.as_ref(), false),
Err(ProcrustesError::EmptyInput)
));
let zc = Mat::<f64>::zeros(5, 0);
assert!(matches!(
signed_permutation(zc.as_ref(), zc.as_ref(), false),
Err(ProcrustesError::EmptyInput)
));
}
#[test]
fn dim_mismatch_returns_error() {
let a = Mat::<f64>::zeros(5, 3);
let r1 = Mat::<f64>::zeros(4, 3);
let r2 = Mat::<f64>::zeros(5, 2);
assert!(matches!(
signed_permutation(a.as_ref(), r1.as_ref(), false),
Err(ProcrustesError::DimensionMismatch { .. })
));
assert!(matches!(
signed_permutation(a.as_ref(), r2.as_ref(), false),
Err(ProcrustesError::DimensionMismatch { .. })
));
}
#[test]
fn nan_with_check_finite_true_returns_error() {
let mut a = Mat::<f64>::zeros(3, 2);
a[(1, 0)] = f64::NAN;
let reference = Mat::<f64>::zeros(3, 2);
assert!(matches!(
signed_permutation(a.as_ref(), reference.as_ref(), true),
Err(ProcrustesError::NonFinite)
));
}
#[test]
fn nan_with_check_finite_false_does_not_panic() {
let mut a = Mat::<f64>::zeros(3, 2);
a[(1, 0)] = f64::NAN;
let reference = Mat::<f64>::zeros(3, 2);
let _ = signed_permutation(a.as_ref(), reference.as_ref(), false);
}
#[test]
fn print_runtime_table() {
use rand::SeedableRng;
if std::env::var_os("PROCRUSTES_SKIP_TIMING").is_some() {
return;
}
for &k in &[4_usize, 6, 8, 10] {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0x00C0_FFEE);
let m = 32;
let a = Mat::<f64>::from_fn(m, k, |_, _| rand::Rng::gen_range(&mut rng, -1.0..1.0));
let reference = a.clone();
let start = std::time::Instant::now();
let _ = signed_permutation(a.as_ref(), reference.as_ref(), false).unwrap();
let dur = start.elapsed();
eprintln!("signed_permutation K={k}: {dur:?}");
}
}
}