#![allow(clippy::doc_markdown)]
use faer::{Col, ColRef, Mat, MatRef};
use rand::{RngExt, SeedableRng};
use crate::error::{PlsKitError, PlsKitResult};
use crate::rotate::{RotationMethod, VarimaxArgs};
use crate::subsample::CIScalar;
const N_BOOT_PAIRED: usize = 1000;
const BOOT_SEED_OFFSET: u64 = 0xB007_u64;
const DEGENERATE_BOOTSTRAP_SKIP_THRESHOLD: f64 = 0.05;
#[derive(Debug, Clone, Copy)]
pub struct RotationStabilityOpts {
pub n_boot: usize,
pub m_rate: f64,
pub level: f64,
pub pre_standardized: bool,
pub seed: Option<u64>,
pub disable_parallelism: bool,
pub verbose: bool,
pub max_skip_rate: f64,
}
impl Default for RotationStabilityOpts {
fn default() -> Self {
Self {
n_boot: 1000,
m_rate: 0.7,
level: 0.95,
pre_standardized: false,
seed: None,
disable_parallelism: false,
verbose: false,
max_skip_rate: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub enum RotationStabilityMethod {
Varimax(VarimaxArgs),
}
#[derive(Debug, Clone)]
pub struct RotationStabilityOutput {
pub method: String,
pub n_boot: usize,
pub m: usize,
pub m_rate: f64,
pub level: f64,
pub seed: u64,
pub variance_ratio: CIScalar,
pub variance_ratio_per_axis: Vec<CIScalar>,
pub variance_unrot: f64,
pub variance_rot: f64,
pub variance_unrot_per_axis: Vec<f64>,
pub variance_rot_per_axis: Vec<f64>,
pub degenerate_baseline: bool,
pub n_boot_finite: usize,
pub n_eff: f64,
}
#[allow(clippy::many_single_char_names)]
#[allow(clippy::too_many_lines)]
#[allow(clippy::needless_pass_by_value)]
pub fn pls1_rotation_stability(
x: MatRef<'_, f64>,
y: ColRef<'_, f64>,
k: usize,
method: RotationStabilityMethod,
l: Option<MatRef<'_, f64>>,
weights: Option<ColRef<'_, f64>>,
opts: RotationStabilityOpts,
) -> PlsKitResult<RotationStabilityOutput> {
let n = x.nrows();
let d = x.ncols();
if y.nrows() != n {
return Err(PlsKitError::DimensionMismatch {
x: (n, d),
y: y.nrows(),
});
}
if k == 0 || k > d {
return Err(PlsKitError::KExceedsMax { k, k_max: d });
}
if k == 1 {
return Err(PlsKitError::InvalidArgument(
"k=1: rotation indeterminacy diagnostic is meaningless on a 1-D subspace".into(),
));
}
if k > 7 {
return Err(PlsKitError::InvalidArgument(format!(
"k>7 ({k}): signed-permutation enumeration is not tractable for k > 7 \
(2^k * k! candidates per replicate). Use k <= 7, or reduce the \
subspace dimension before calling this function."
)));
}
if let Some(l_ref) = l {
if l_ref.ncols() != k {
return Err(PlsKitError::ShapeMismatch(format!(
"L.ncols={} but k={}",
l_ref.ncols(),
k
)));
}
}
let sub_opts = crate::subsample::SubsampleOpts {
n_boot: opts.n_boot,
m_rate: opts.m_rate,
level: opts.level,
pre_standardized: opts.pre_standardized,
disable_parallelism: opts.disable_parallelism,
max_failure_rate: 1.0,
max_skip_rate: 1.0,
};
sub_opts.validate()?;
let (w_norm, n_eff_val, _all_uniform) =
crate::fit::validate_and_normalize_weights(weights, n, k)?;
let (seed_used, mut rng) = crate::rng::resolve_seed(opts.seed);
let fit_ref = {
use crate::fit::{pls1_fit, FitOpts, KSpec};
pls1_fit(
x,
y,
KSpec::Fixed(k),
w_norm.as_ref().map(faer::Col::as_ref),
FitOpts {
pre_standardized: opts.pre_standardized,
..FitOpts::default()
},
)?
};
let RotationStabilityMethod::Varimax(varimax_args) = &method;
let varimax_args = *varimax_args;
let rot_ref = crate::rotate::rotate(
fit_ref.w_star.as_ref(),
RotationMethod::Varimax(varimax_args),
l,
)?;
let w_rot_ref = rot_ref.w_rot;
let m = crate::subsample::resolve_m(n, opts.m_rate);
if m < k + 2 {
return Err(PlsKitError::InvalidArgument(format!(
"resolved m = {m} (from n={n}, m_rate={}) is too small for k={k}; need m ≥ k+2",
opts.m_rate
)));
}
let pre_std = opts.pre_standardized;
let rows: Vec<RotationStabilityWorkerRow> = crate::resample::parallel_for_each_seeded(
&mut rng,
opts.n_boot,
opts.disable_parallelism,
|_, child| {
run_one_rotation_stability(
x,
y,
k,
m,
pre_std,
fit_ref.w_star.as_ref(),
w_rot_ref.as_ref(),
varimax_args,
l,
w_norm.as_ref().map(faer::Col::as_ref),
child,
)
.unwrap_or_else(|_| RotationStabilityWorkerRow::nan(k))
},
);
reduce_variance_ratio(
&rows,
k,
opts.n_boot,
m,
opts.m_rate,
opts.level,
seed_used,
n_eff_val,
opts.max_skip_rate,
)
}
#[derive(Debug, Clone)]
struct RotationStabilityWorkerRow {
sq_unrot_per_axis: Vec<f64>,
sq_rot_per_axis: Vec<f64>,
}
impl RotationStabilityWorkerRow {
fn nan(k: usize) -> Self {
Self {
sq_unrot_per_axis: vec![f64::NAN; k],
sq_rot_per_axis: vec![f64::NAN; k],
}
}
fn is_finite(&self) -> bool {
self.sq_unrot_per_axis.iter().all(|v| v.is_finite())
&& self.sq_rot_per_axis.iter().all(|v| v.is_finite())
}
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::many_single_char_names)]
fn run_one_rotation_stability(
x: MatRef<'_, f64>,
y: ColRef<'_, f64>,
k: usize,
m: usize,
pre_standardized_x: bool,
w_unrot_ref: MatRef<'_, f64>,
w_rot_ref: MatRef<'_, f64>,
varimax_args: VarimaxArgs,
l: Option<MatRef<'_, f64>>,
weights: Option<ColRef<'_, f64>>,
rng: &mut crate::rng::Rng,
) -> PlsKitResult<RotationStabilityWorkerRow> {
use crate::fit::{pls1_fit, validate_and_normalize_weights, FitOpts, KSpec};
use crate::linalg::{col_row_subset, row_subset, standardize, standardize1};
let n = x.nrows();
let d = x.ncols();
let (sample_idx, _holdout_idx) = crate::subsample::subsample_indices(n, m, rng);
let x_sub = row_subset(x, &sample_idx);
let y_sub = col_row_subset(y, &sample_idx);
let w_sub_norm: Option<Col<f64>> = match weights {
Some(w_full) => {
let w_sub = col_row_subset(w_full, &sample_idx);
let (w_norm_sub, _, _) = validate_and_normalize_weights(Some(w_sub.as_ref()), m, k)?;
w_norm_sub
}
None => None,
};
let (xs, ys) = if pre_standardized_x {
(
Mat::<f64>::from_fn(x_sub.nrows(), d, |i, j| x_sub[(i, j)]),
Col::<f64>::from_fn(y_sub.nrows(), |i| y_sub[i]),
)
} else {
let (xs, _, _) = standardize(x_sub.as_ref());
let (ys, _, _) = standardize1(y_sub.as_ref());
(xs, ys)
};
let fit_b = pls1_fit(
xs.as_ref(),
ys.as_ref(),
KSpec::Fixed(k),
w_sub_norm.as_ref().map(Col::as_ref),
FitOpts {
pre_standardized: true,
..FitOpts::default()
},
)?;
let w_b = fit_b.w_star;
let aln_unrot = procrustes::signed_permutation(w_b.as_ref(), w_unrot_ref, false)
.expect("procrustes invariants pre-validated by plskit");
let sq_unrot_per_axis: Vec<f64> = (0..k)
.map(|kk| {
let src = aln_unrot.assigned[kk];
let s = aln_unrot.signs[kk];
let mut acc = 0.0_f64;
for j in 0..d {
let diff = s * w_b[(j, src)] - w_unrot_ref[(j, kk)];
acc += diff * diff;
}
acc
})
.collect();
let r_orth = procrustes::orthogonal(w_b.as_ref(), w_unrot_ref, false)
.expect("procrustes invariants pre-validated by plskit")
.rotation;
let mut w_b_rot_input = Mat::<f64>::zeros(d, k);
faer::linalg::matmul::matmul(
w_b_rot_input.as_mut(),
faer::Accum::Replace,
w_b.as_ref(),
r_orth.as_ref(),
1.0,
faer::Par::Seq,
);
let rot_b = crate::rotate::rotate(
w_b_rot_input.as_ref(),
RotationMethod::Varimax(varimax_args),
l,
)?;
let w_b_rot = rot_b.w_rot;
let aln_rot = procrustes::signed_permutation(w_b_rot.as_ref(), w_rot_ref, false)
.expect("procrustes invariants pre-validated by plskit");
let sq_rot_per_axis: Vec<f64> = (0..k)
.map(|kk| {
let src = aln_rot.assigned[kk];
let s = aln_rot.signs[kk];
let mut acc = 0.0_f64;
for j in 0..d {
let diff = s * w_b_rot[(j, src)] - w_rot_ref[(j, kk)];
acc += diff * diff;
}
acc
})
.collect();
Ok(RotationStabilityWorkerRow {
sq_unrot_per_axis,
sq_rot_per_axis,
})
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::many_single_char_names)]
#[allow(clippy::too_many_lines)]
fn reduce_variance_ratio(
rows: &[RotationStabilityWorkerRow],
k: usize,
n_boot: usize,
m: usize,
m_rate: f64,
level: f64,
seed: u64,
n_eff: f64,
max_skip_rate: f64,
) -> PlsKitResult<RotationStabilityOutput> {
let finite: Vec<&RotationStabilityWorkerRow> = rows.iter().filter(|r| r.is_finite()).collect();
let n_boot_finite = finite.len();
let total = rows.len();
let skipped = total - n_boot_finite;
#[allow(clippy::cast_precision_loss)]
{
let skip_rate = skipped as f64 / total.max(1) as f64;
if skip_rate > max_skip_rate {
return Err(PlsKitError::ResamplingDegenerate {
skipped,
total,
skip_rate,
threshold: max_skip_rate,
});
}
}
let b = n_boot_finite;
#[allow(clippy::cast_precision_loss)]
let b_f = b as f64;
let mut v_unrot_per_axis = vec![0.0_f64; k];
let mut v_rot_per_axis = vec![0.0_f64; k];
for row in &finite {
for kk in 0..k {
v_unrot_per_axis[kk] += row.sq_unrot_per_axis[kk];
v_rot_per_axis[kk] += row.sq_rot_per_axis[kk];
}
}
for kk in 0..k {
v_unrot_per_axis[kk] /= b_f;
v_rot_per_axis[kk] /= b_f;
}
let v_unrot: f64 = v_unrot_per_axis.iter().sum();
let v_rot: f64 = v_rot_per_axis.iter().sum();
let primary_degenerate = v_unrot == 0.0;
let rho_point = if primary_degenerate {
f64::NAN
} else {
v_rot / v_unrot
};
let rho_per_axis_point: Vec<f64> = (0..k)
.map(|kk| {
if v_unrot_per_axis[kk] == 0.0 {
f64::NAN
} else {
v_rot_per_axis[kk] / v_unrot_per_axis[kk]
}
})
.collect();
let mut boot_rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed.wrapping_add(BOOT_SEED_OFFSET));
let mut rho_star: Vec<f64> = Vec::with_capacity(N_BOOT_PAIRED);
let mut rho_per_axis_star: Vec<Vec<f64>> = vec![Vec::with_capacity(N_BOOT_PAIRED); k];
let mut bootstrap_skipped = 0usize;
if b > 0 && !primary_degenerate {
let mut idx_buf: Vec<usize> = vec![0; b];
for _ in 0..N_BOOT_PAIRED {
for slot in &mut idx_buf {
*slot = boot_rng.random_range(0..b);
}
let mut v_unrot_b_per_axis = vec![0.0_f64; k];
let mut v_rot_b_per_axis = vec![0.0_f64; k];
for &i in &idx_buf {
let row = finite[i];
for kk in 0..k {
v_unrot_b_per_axis[kk] += row.sq_unrot_per_axis[kk];
v_rot_b_per_axis[kk] += row.sq_rot_per_axis[kk];
}
}
for kk in 0..k {
v_unrot_b_per_axis[kk] /= b_f;
v_rot_b_per_axis[kk] /= b_f;
}
let v_unrot_b: f64 = v_unrot_b_per_axis.iter().sum();
let v_rot_b: f64 = v_rot_b_per_axis.iter().sum();
if v_unrot_b > 0.0 {
rho_star.push(v_rot_b / v_unrot_b);
} else {
bootstrap_skipped += 1;
}
for kk in 0..k {
if v_unrot_b_per_axis[kk] > 0.0 {
rho_per_axis_star[kk].push(v_rot_b_per_axis[kk] / v_unrot_b_per_axis[kk]);
}
}
}
}
#[allow(clippy::cast_precision_loss)]
let bootstrap_skip_rate = bootstrap_skipped as f64 / N_BOOT_PAIRED as f64;
let degenerate_baseline =
primary_degenerate || bootstrap_skip_rate > DEGENERATE_BOOTSTRAP_SKIP_THRESHOLD;
let alpha = 1.0 - level;
let variance_ratio = build_ciscalar_from_bootstrap(rho_point, &mut rho_star[..], alpha);
let variance_ratio_per_axis: Vec<CIScalar> = (0..k)
.map(|kk| {
build_ciscalar_from_bootstrap(
rho_per_axis_point[kk],
&mut rho_per_axis_star[kk][..],
alpha,
)
})
.collect();
Ok(RotationStabilityOutput {
method: "varimax".to_owned(),
n_boot,
m,
m_rate,
level,
seed,
variance_ratio,
variance_ratio_per_axis,
variance_unrot: v_unrot,
variance_rot: v_rot,
variance_unrot_per_axis: v_unrot_per_axis,
variance_rot_per_axis: v_rot_per_axis,
degenerate_baseline,
n_boot_finite,
n_eff,
})
}
fn build_ciscalar_from_bootstrap(point: f64, samples: &mut [f64], alpha: f64) -> CIScalar {
if samples.is_empty() {
return CIScalar {
point,
lower: f64::NAN,
upper: f64::NAN,
sd: f64::NAN,
};
}
samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less));
let lower = crate::linalg::empirical_quantile(samples, alpha / 2.0);
let upper = crate::linalg::empirical_quantile(samples, 1.0 - alpha / 2.0);
#[allow(clippy::cast_precision_loss)]
let n = samples.len() as f64;
let mean: f64 = samples.iter().sum::<f64>() / n;
let sd: f64 = if samples.len() > 1 {
let var = samples.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (n - 1.0);
var.sqrt()
} else {
0.0
};
CIScalar {
point,
lower,
upper,
sd,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rotate::VarimaxArgs;
use faer::Mat;
use rand::RngExt;
use rand::SeedableRng;
fn synth(n: usize, d: usize, snr: f64, seed: u64) -> (Mat<f64>, faer::Col<f64>) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let x = Mat::<f64>::from_fn(n, d, |_, _| rng.random_range(-1.0..1.0));
let beta = faer::Col::<f64>::from_fn(d, |j| if j < 2 { 1.0 } else { 0.0 });
let signal: faer::Col<f64> = &x * β
let noise = faer::Col::<f64>::from_fn(n, |_| rng.random_range(-1.0..1.0));
let y = faer::Col::<f64>::from_fn(n, |i| signal[i] * snr + noise[i]);
(x, y)
}
#[allow(clippy::many_single_char_names)]
fn synth_factor_model(n: usize, seed: u64) -> (Mat<f64>, faer::Col<f64>) {
let d = 8usize;
let rho = 0.05_f64;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let f1 = faer::Col::<f64>::from_fn(n, |_| rng.random_range(-1.0..1.0));
let z = faer::Col::<f64>::from_fn(n, |_| rng.random_range(-1.0..1.0));
let f2 = faer::Col::<f64>::from_fn(n, |i| rho * f1[i] + (1.0 - rho * rho).sqrt() * z[i]);
let mut x = Mat::<f64>::zeros(n, d);
for i in 0..n {
for j in 0..d {
let base = if j < 4 { f1[i] } else { f2[i] };
let noise = rng.random_range(-0.05..0.05);
x[(i, j)] = base + noise;
}
}
let y = faer::Col::<f64>::from_fn(n, |i| f1[i] + f2[i] + 0.1 * rng.random_range(-1.0..1.0));
(x, y)
}
fn run_one(
x: &Mat<f64>,
y: &faer::Col<f64>,
k: usize,
n_boot: usize,
seed: u64,
) -> RotationStabilityOutput {
let opts = RotationStabilityOpts {
n_boot,
m_rate: 0.7,
level: 0.95,
seed: Some(seed),
disable_parallelism: true,
..Default::default()
};
pls1_rotation_stability(
x.as_ref(),
y.as_ref(),
k,
RotationStabilityMethod::Varimax(VarimaxArgs::default()),
None,
None,
opts,
)
.unwrap()
}
#[test]
fn rotation_stability_runs_end_to_end() {
let (x, y) = synth(100, 6, 4.0, 7);
let r = run_one(&x, &y, 2, 200, 13);
assert_eq!(r.method, "varimax");
assert_eq!(r.n_boot, 200);
assert_eq!(r.m, 26);
assert_eq!(r.variance_ratio_per_axis.len(), 2);
assert_eq!(r.variance_unrot_per_axis.len(), 2);
assert_eq!(r.variance_rot_per_axis.len(), 2);
assert!(r.variance_unrot >= 0.0);
assert!(r.variance_rot >= 0.0);
}
#[test]
fn variance_ratio_factor_model_below_one() {
let (x, y) = synth_factor_model(300, 17);
let r = run_one(&x, &y, 2, 500, 23);
assert!(
!r.degenerate_baseline,
"factor-model design should not flag degenerate baseline"
);
assert!(
r.variance_ratio.point.is_finite(),
"ratio must be finite on factor-model design (point={})",
r.variance_ratio.point
);
assert!(
r.variance_ratio.upper < 1.5,
"rotation should not blow up the variance ratio on a factor-model \
design, got upper={} (point={})",
r.variance_ratio.upper,
r.variance_ratio.point,
);
}
#[test]
fn variance_ratio_one_under_pure_noise() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let x = Mat::<f64>::from_fn(200, 6, |_, _| rng.random_range(-1.0..1.0));
let y = faer::Col::<f64>::from_fn(200, |_| rng.random_range(-1.0..1.0));
let r = run_one(&x, &y, 2, 200, 31);
assert!(
r.variance_ratio.point.is_finite(),
"ratio must be finite on noise design"
);
assert!(
(0.5..=1.6).contains(&r.variance_ratio.point),
"expected ratio ≈ 1 under pure noise, got {}",
r.variance_ratio.point
);
}
#[test]
fn per_axis_decomposition_sums_to_aggregate() {
let (x, y) = synth(150, 6, 4.0, 11);
let r = run_one(&x, &y, 2, 200, 5);
let sum_unrot: f64 = r.variance_unrot_per_axis.iter().sum();
let sum_rot: f64 = r.variance_rot_per_axis.iter().sum();
assert!(
(sum_unrot - r.variance_unrot).abs() < 1e-10,
"sum_unrot={} aggregate={}",
sum_unrot,
r.variance_unrot,
);
assert!(
(sum_rot - r.variance_rot).abs() < 1e-10,
"sum_rot={} aggregate={}",
sum_rot,
r.variance_rot,
);
if r.variance_unrot > 0.0 {
let expected = r.variance_rot / r.variance_unrot;
assert!(
(r.variance_ratio.point - expected).abs() < 1e-10,
"ratio={} V_rot/V_unrot={}",
r.variance_ratio.point,
expected,
);
}
}
#[test]
fn paired_bootstrap_ci_contains_point_estimate() {
let (x, y) = synth(120, 6, 3.0, 9);
let r = run_one(&x, &y, 2, 200, 14);
assert!(
r.variance_ratio.lower <= r.variance_ratio.point + 1e-10,
"lower={} > point={}",
r.variance_ratio.lower,
r.variance_ratio.point,
);
assert!(
r.variance_ratio.point <= r.variance_ratio.upper + 1e-10,
"point={} > upper={}",
r.variance_ratio.point,
r.variance_ratio.upper,
);
for (kk, ci) in r.variance_ratio_per_axis.iter().enumerate() {
if ci.point.is_finite() {
assert!(
ci.lower <= ci.point + 1e-10 && ci.point <= ci.upper + 1e-10,
"axis {kk}: lower={} point={} upper={}",
ci.lower,
ci.point,
ci.upper,
);
}
}
}
#[test]
fn degenerate_baseline_flagged() {
let n = 200;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(101);
let x = Mat::<f64>::from_fn(n, 2, |i, j| {
let sign = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
if (j == 0 && i % 2 == 0) || (j == 1 && i % 2 == 1) {
sign * (1.0 + 0.01 * rng.random_range(-1.0..1.0))
} else {
0.0
}
});
let y = faer::Col::<f64>::from_fn(n, |i| if i % 2 == 0 { 1.0 } else { -1.0 });
let r = run_one(&x, &y, 2, 200, 7);
if r.degenerate_baseline {
assert!(
r.variance_ratio.point.is_nan(),
"degenerate_baseline=true requires variance_ratio.point=NaN, got {}",
r.variance_ratio.point,
);
}
}
#[test]
#[allow(clippy::float_cmp)]
fn parallel_matches_sequential() {
let (x, y) = synth(120, 6, 4.0, 19);
let opts_seq = RotationStabilityOpts {
n_boot: 200,
m_rate: 0.7,
level: 0.95,
seed: Some(33),
disable_parallelism: true,
..Default::default()
};
let opts_par = RotationStabilityOpts {
disable_parallelism: false,
..opts_seq
};
let r_seq = pls1_rotation_stability(
x.as_ref(),
y.as_ref(),
2,
RotationStabilityMethod::Varimax(VarimaxArgs::default()),
None,
None,
opts_seq,
)
.unwrap();
let r_par = pls1_rotation_stability(
x.as_ref(),
y.as_ref(),
2,
RotationStabilityMethod::Varimax(VarimaxArgs::default()),
None,
None,
opts_par,
)
.unwrap();
assert_eq!(r_seq.variance_ratio.point, r_par.variance_ratio.point);
assert_eq!(r_seq.variance_ratio.lower, r_par.variance_ratio.lower);
assert_eq!(r_seq.variance_ratio.upper, r_par.variance_ratio.upper);
assert_eq!(r_seq.variance_unrot, r_par.variance_unrot);
assert_eq!(r_seq.variance_rot, r_par.variance_rot);
}
#[test]
#[allow(clippy::float_cmp)]
fn reproducibility_under_fixed_seed() {
let (x, y) = synth(120, 6, 4.0, 21);
let a = run_one(&x, &y, 2, 200, 99);
let b = run_one(&x, &y, 2, 200, 99);
assert_eq!(a.variance_ratio.point, b.variance_ratio.point);
assert_eq!(a.variance_ratio.lower, b.variance_ratio.lower);
assert_eq!(a.variance_ratio.upper, b.variance_ratio.upper);
assert_eq!(a.variance_unrot, b.variance_unrot);
assert_eq!(a.variance_rot, b.variance_rot);
for kk in 0..a.variance_ratio_per_axis.len() {
assert_eq!(
a.variance_ratio_per_axis[kk].point,
b.variance_ratio_per_axis[kk].point,
);
}
}
#[test]
fn signed_perm_alignment_used_on_both_sides() {
let (x, y) = synth(150, 6, 4.0, 25);
let r = run_one(&x, &y, 2, 200, 41);
let sum_unrot: f64 = r.variance_unrot_per_axis.iter().sum();
assert!(
(sum_unrot - r.variance_unrot).abs() < 1e-10,
"alignment payload identity broken: sum_per_axis={}, aggregate={}",
sum_unrot,
r.variance_unrot,
);
assert!(
r.variance_unrot > 1e-6,
"V_unrot collapsed to {} — likely using orthogonal (not signed-perm) alignment",
r.variance_unrot,
);
}
#[test]
fn rotation_stability_rejects_k_eq_1() {
let (x, y) = synth(80, 5, 3.0, 1);
let err = pls1_rotation_stability(
x.as_ref(),
y.as_ref(),
1,
RotationStabilityMethod::Varimax(VarimaxArgs::default()),
None,
None,
RotationStabilityOpts::default(),
)
.unwrap_err();
assert_eq!(err.code(), "invalid_argument");
assert!(format!("{err}").contains("k=1"));
}
#[test]
fn rotation_stability_rejects_k_gt_7() {
let (x, y) = synth(80, 10, 3.0, 1);
let err = pls1_rotation_stability(
x.as_ref(),
y.as_ref(),
8,
RotationStabilityMethod::Varimax(VarimaxArgs::default()),
None,
None,
RotationStabilityOpts::default(),
)
.unwrap_err();
assert_eq!(err.code(), "invalid_argument");
assert!(format!("{err}").contains("k>7") || format!("{err}").contains("k > 7"));
}
#[test]
fn rotation_stability_rejects_l_shape_mismatch() {
let (x, y) = synth(80, 6, 3.0, 1);
let l_bad = Mat::<f64>::zeros(4, 3);
let err = pls1_rotation_stability(
x.as_ref(),
y.as_ref(),
2,
RotationStabilityMethod::Varimax(VarimaxArgs::default()),
Some(l_bad.as_ref()),
None,
RotationStabilityOpts::default(),
)
.unwrap_err();
assert_eq!(err.code(), "shape_mismatch");
}
}