#![allow(clippy::doc_markdown)]
use faer::{Mat, MatRef};
use crate::error::{PlsKitError, PlsKitResult};
#[derive(Debug, Clone)]
pub enum RotationMethod {
Varimax(VarimaxArgs),
}
#[derive(Debug, Clone, Copy)]
pub struct VarimaxArgs {
pub max_iter: usize,
pub tol: f64,
pub kaiser_normalize: bool,
}
impl Default for VarimaxArgs {
fn default() -> Self {
Self {
max_iter: 50,
tol: 1e-8,
kaiser_normalize: true,
}
}
}
#[derive(Debug)]
pub struct RotateOutput {
pub w_rot: Mat<f64>,
pub r: Mat<f64>,
pub sweeps: usize,
pub v_converged: f64,
}
#[allow(clippy::needless_pass_by_value)]
pub fn rotate(
w: MatRef<'_, f64>,
method: RotationMethod,
l: Option<MatRef<'_, f64>>,
) -> PlsKitResult<RotateOutput> {
let (d_rows, k) = (w.nrows(), w.ncols());
if k == 0 {
return Err(PlsKitError::InvalidInput("W has K=0".into()));
}
if !mat_is_finite(w) {
return Err(PlsKitError::InvalidInput(
"W contains non-finite values".into(),
));
}
if let Some(ll) = l.as_ref() {
if ll.ncols() != k {
return Err(PlsKitError::ShapeMismatch(format!(
"L.ncols={} but W.ncols={}",
ll.ncols(),
k
)));
}
if !mat_is_finite(*ll) {
return Err(PlsKitError::InvalidInput(
"L contains non-finite values".into(),
));
}
}
let _ = d_rows; match method {
RotationMethod::Varimax(args) => varimax_rotate(w, l, args, k),
}
}
#[allow(clippy::unnecessary_wraps)]
#[allow(clippy::many_single_char_names)]
fn varimax_rotate(
w: MatRef<'_, f64>,
l: Option<MatRef<'_, f64>>,
args: VarimaxArgs,
k: usize,
) -> PlsKitResult<RotateOutput> {
if k == 1 {
let r = identity(1);
let w_rot = mat_clone(w);
let target = l.unwrap_or(w);
let v_converged = sum_var_squared_columns(target);
return Ok(RotateOutput {
w_rot,
r,
sweeps: 0,
v_converged,
});
}
let basis = l.unwrap_or(w);
let n_rows = basis.nrows();
let mut t_simp: Mat<f64> = if args.kaiser_normalize {
row_normalize(basis)
} else {
mat_clone(basis)
};
let mut r = identity(k);
let mut v_prev = sum_var_squared_columns(t_simp.as_ref());
let mut sweeps_done = 0_usize;
for sweep in 1..=args.max_iter {
sweeps_done = sweep;
for p in 0..(k - 1) {
for q in (p + 1)..k {
let pair = Mat::<f64>::from_fn(n_rows, 2, |i, j| {
if j == 0 {
t_simp[(i, p)]
} else {
t_simp[(i, q)]
}
});
let theta = varimax_angle_2d(pair.as_ref());
let c = theta.cos();
let s = theta.sin();
rotate_columns_inplace(&mut t_simp, p, q, c, s);
rotate_columns_inplace(&mut r, p, q, c, s);
}
}
let v_new = sum_var_squared_columns(t_simp.as_ref());
if v_new - v_prev < args.tol {
break;
}
v_prev = v_new;
}
let w_rot = matmul(w, r.as_ref());
Ok(RotateOutput {
w_rot,
r,
sweeps: sweeps_done,
v_converged: v_prev,
})
}
fn mat_is_finite(m: MatRef<'_, f64>) -> bool {
for j in 0..m.ncols() {
for i in 0..m.nrows() {
if !m[(i, j)].is_finite() {
return false;
}
}
}
true
}
fn mat_clone(m: MatRef<'_, f64>) -> Mat<f64> {
Mat::<f64>::from_fn(m.nrows(), m.ncols(), |i, j| m[(i, j)])
}
fn identity(k: usize) -> Mat<f64> {
Mat::<f64>::from_fn(k, k, |i, j| if i == j { 1.0 } else { 0.0 })
}
#[allow(clippy::many_single_char_names)]
fn row_normalize(m: MatRef<'_, f64>) -> Mat<f64> {
let n = m.nrows();
let k = m.ncols();
let mut norms = vec![0.0_f64; n];
for i in 0..n {
let mut s = 0.0_f64;
for j in 0..k {
let v = m[(i, j)];
s += v * v;
}
let nrm = s.sqrt();
norms[i] = if nrm > 1e-12 { nrm } else { 1.0 };
}
Mat::<f64>::from_fn(n, k, |i, j| m[(i, j)] / norms[i])
}
fn sum_var_squared_columns(m: MatRef<'_, f64>) -> f64 {
let n = m.nrows();
let k = m.ncols();
let n_f = n as f64;
let mut total = 0.0_f64;
for j in 0..k {
let mut sum_sq = 0.0_f64;
let mut sum = 0.0_f64;
for i in 0..n {
let v = m[(i, j)];
let vv = v * v;
sum += vv;
sum_sq += vv * vv;
}
let mean = sum / n_f;
total += sum_sq / n_f - mean * mean;
}
total
}
#[allow(clippy::many_single_char_names)]
fn rotate_columns_inplace(m: &mut Mat<f64>, p: usize, q: usize, c: f64, s: f64) {
let n = m.nrows();
for i in 0..n {
let mp = m[(i, p)];
let mq = m[(i, q)];
m[(i, p)] = c * mp + s * mq;
m[(i, q)] = -s * mp + c * mq;
}
}
fn matmul(a: MatRef<'_, f64>, b: MatRef<'_, f64>) -> Mat<f64> {
debug_assert_eq!(a.ncols(), b.nrows());
Mat::<f64>::from_fn(a.nrows(), b.ncols(), |i, j| {
let mut s = 0.0_f64;
for k in 0..a.ncols() {
s += a[(i, k)] * b[(k, j)];
}
s
})
}
#[allow(clippy::many_single_char_names)]
fn varimax_angle_2d(l: MatRef<'_, f64>) -> f64 {
debug_assert_eq!(l.ncols(), 2, "varimax_angle_2d requires exactly 2 columns");
let n = l.nrows();
let n_f = n as f64;
let mut u_sum = 0.0_f64;
let mut v_sum = 0.0_f64;
let mut uu = 0.0_f64;
let mut vv = 0.0_f64;
let mut uv = 0.0_f64;
for i in 0..n {
let a = l[(i, 0)];
let b = l[(i, 1)];
let u = a * a - b * b;
let v = 2.0 * a * b;
u_sum += u;
v_sum += v;
uu += u * u;
vv += v * v;
uv += u * v;
}
let big_a = (uu - vv) - (u_sum * u_sum - v_sum * v_sum) / n_f;
let big_b = 2.0 * (uv - u_sum * v_sum / n_f);
big_b.atan2(big_a) / 4.0
}
#[cfg(test)]
mod tests {
use super::*;
use faer::Mat;
#[test]
fn varimax_angle_2d_zero_for_already_simple() {
let l = Mat::<f64>::from_fn(5, 2, |i, j| if j == 0 { (i + 1) as f64 } else { 0.0 });
let theta = varimax_angle_2d(l.as_ref());
assert!(theta.abs() < 1e-12, "expected ~0, got {theta}");
}
#[test]
#[allow(clippy::unnested_or_patterns)] fn varimax_angle_2d_known_value() {
let l = Mat::<f64>::from_fn(5, 2, |i, j| match (i, j) {
(0, 0) | (0, 1) | (1, 0) | (2, 1) => 0.5,
(1, 1) | (2, 0) | (3, 0) | (3, 1) => -0.5,
(4, 0) => 1.0,
_ => 0.0,
});
let theta = varimax_angle_2d(l.as_ref());
let expected = std::f64::consts::PI / 4.0;
assert!(
(theta - expected).abs() < 1e-12,
"expected π/4, got {theta}"
);
}
fn random_w(rng_seed: u64, n: usize, k: usize) -> Mat<f64> {
let mut s = rng_seed;
Mat::<f64>::from_fn(n, k, |_, _| {
s = s
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(f64::from((s >> 33) as u32) / f64::from(u32::MAX)) * 2.0 - 1.0
})
}
fn approx_eq_mat(a: MatRef<'_, f64>, b: MatRef<'_, f64>, tol: f64) -> bool {
if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
return false;
}
for j in 0..a.ncols() {
for i in 0..a.nrows() {
if (a[(i, j)] - b[(i, j)]).abs() > tol {
return false;
}
}
}
true
}
#[test]
fn rotate_k1_is_noop() {
let w = random_w(1, 20, 1);
let out = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
None,
)
.unwrap();
assert_eq!(out.sweeps, 0);
assert_eq!(out.r.nrows(), 1);
assert_eq!(out.r.ncols(), 1);
assert!((out.r[(0, 0)] - 1.0).abs() < 1e-15);
assert!(approx_eq_mat(out.w_rot.as_ref(), w.as_ref(), 1e-15));
}
#[test]
fn rotate_k0_errors() {
let w = Mat::<f64>::from_fn(5, 0, |_, _| 0.0);
let res = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
None,
);
assert!(matches!(res, Err(PlsKitError::InvalidInput(_))));
}
#[test]
fn rotate_l_shape_mismatch_errors() {
let w = random_w(2, 10, 3);
let l = random_w(3, 8, 2); let res = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
Some(l.as_ref()),
);
assert!(matches!(res, Err(PlsKitError::ShapeMismatch(_))));
}
#[test]
fn rotate_non_finite_w_errors() {
let mut w = random_w(4, 5, 2);
w[(0, 0)] = f64::NAN;
let res = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
None,
);
assert!(matches!(res, Err(PlsKitError::InvalidInput(_))));
}
#[test]
fn rotate_non_finite_l_errors() {
let w = random_w(11, 5, 2);
let mut l = random_w(12, 8, 2);
l[(0, 1)] = f64::INFINITY;
let res = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
Some(l.as_ref()),
);
assert!(matches!(res, Err(PlsKitError::InvalidInput(_))));
}
#[test]
fn rotate_r_is_orthogonal() {
let w = random_w(5, 30, 4);
let out = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
None,
)
.unwrap();
let rt_r = matmul(out.r.transpose(), out.r.as_ref());
let eye = identity(4);
assert!(approx_eq_mat(rt_r.as_ref(), eye.as_ref(), 1e-10));
}
#[test]
fn rotate_idempotent_on_converged_solution() {
let tight = VarimaxArgs {
tol: 1e-12,
..VarimaxArgs::default()
};
let w = random_w(6, 40, 3);
let out1 = rotate(w.as_ref(), RotationMethod::Varimax(tight), None).unwrap();
let out2 = rotate(out1.w_rot.as_ref(), RotationMethod::Varimax(tight), None).unwrap();
let eye = identity(3);
assert!(approx_eq_mat(out2.r.as_ref(), eye.as_ref(), 1e-6));
}
#[test]
fn rotate_w_at_r_equals_w_rot() {
let w = random_w(7, 25, 3);
let out = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
None,
)
.unwrap();
let recomputed = matmul(w.as_ref(), out.r.as_ref());
assert!(approx_eq_mat(
out.w_rot.as_ref(),
recomputed.as_ref(),
1e-15
));
}
#[test]
fn rotate_with_l_uses_loading_basis() {
let w = random_w(8, 10, 3);
let l = random_w(9, 40, 3);
let out_none = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
None,
)
.unwrap();
let out_l = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
Some(l.as_ref()),
)
.unwrap();
assert!(
!approx_eq_mat(out_none.r.as_ref(), out_l.r.as_ref(), 1e-6),
"R should differ when L is provided"
);
}
#[test]
fn rotate_kaiser_normalize_off_differs() {
let w = random_w(10, 30, 3);
let on = rotate(
w.as_ref(),
RotationMethod::Varimax(VarimaxArgs::default()),
None,
)
.unwrap();
let off_args = VarimaxArgs {
kaiser_normalize: false,
..VarimaxArgs::default()
};
let off = rotate(w.as_ref(), RotationMethod::Varimax(off_args), None).unwrap();
assert!(
!approx_eq_mat(on.r.as_ref(), off.r.as_ref(), 1e-6),
"R should differ when kaiser_normalize is toggled"
);
}
}