use faer::{ColRef, MatRef};
use crate::error::{PlsKitError, PlsKitResult};
#[derive(Debug, Clone, Copy)]
#[allow(clippy::struct_excessive_bools)]
pub struct PermNullOpts {
pub n_perm: usize,
pub return_perm_matrix: bool,
pub pre_standardized: bool,
pub disable_parallelism: bool,
pub verbose: bool,
}
impl PermNullOpts {
pub fn validate(&self, k: usize) -> PlsKitResult<()> {
if self.n_perm < 100 {
return Err(PlsKitError::InvalidArgument(format!(
"n_perm must be ≥ 100, got {}",
self.n_perm
)));
}
if k < 1 {
return Err(PlsKitError::InvalidArgument(format!(
"k must be ≥ 1, got {k}"
)));
}
Ok(())
}
}
#[derive(Debug, Clone)]
#[allow(clippy::doc_markdown)]
pub struct PermNullOutput {
pub n_perm: usize,
pub k: usize,
pub seed: u64,
pub n_eff: f64,
pub beta_ref: Vec<f64>,
pub beta_perm_mean: Vec<f64>,
pub beta_perm_sd: Vec<f64>,
pub beta_perm_z: Vec<f64>,
pub beta_perm_matrix: Option<Vec<f64>>,
}
#[allow(clippy::needless_pass_by_value, clippy::many_single_char_names)]
pub fn pls1_perm_null(
x: MatRef<'_, f64>,
y: ColRef<'_, f64>,
k: usize,
weights: Option<ColRef<'_, f64>>,
opts: PermNullOpts,
seed: Option<u64>,
) -> PlsKitResult<PermNullOutput> {
use crate::fit::{pls1_fit, validate_and_normalize_weights, FitOpts, KSpec};
use crate::linalg::{standardize, standardize1};
use faer::{Col, Mat};
opts.validate(k)?;
let n = x.nrows();
let d = x.ncols();
if y.nrows() != n {
return Err(PlsKitError::DimensionMismatch {
x: (n, d),
y: y.nrows(),
});
}
if k > d {
return Err(PlsKitError::KExceedsMax { k, k_max: d });
}
let (w_norm, n_eff_val, _all_uniform) = validate_and_normalize_weights(weights, n, k)?;
let wref = w_norm.as_ref().map(Col::as_ref);
let (xs_owned, ys_owned) = if opts.pre_standardized {
(
Mat::<f64>::from_fn(n, d, |i, j| x[(i, j)]),
Col::<f64>::from_fn(n, |i| y[i]),
)
} else {
let (xs, _, _) = standardize(x);
let (ys, _, _) = standardize1(y);
(xs, ys)
};
let xs = xs_owned.as_ref();
let ys = ys_owned.as_ref();
let fit_ref = pls1_fit(
xs,
ys,
KSpec::Fixed(k),
wref,
FitOpts {
pre_standardized: true,
..FitOpts::default()
},
)?;
let beta_ref: Vec<f64> = (0..d).map(|j| fit_ref.beta[j]).collect();
let (seed_used, mut rng) = crate::rng::resolve_seed(seed);
if opts.return_perm_matrix {
run_engine_retained(
xs, ys, k, wref, beta_ref, n_eff_val, opts, seed_used, &mut rng,
)
} else {
run_engine_streaming(
xs, ys, k, wref, beta_ref, n_eff_val, opts, seed_used, &mut rng,
)
}
}
#[allow(clippy::too_many_arguments, clippy::unnecessary_wraps)]
fn run_engine_retained(
xs: MatRef<'_, f64>,
ys: ColRef<'_, f64>,
k: usize,
wref: Option<faer::ColRef<'_, f64>>,
beta_ref: Vec<f64>,
n_eff_val: f64,
opts: PermNullOpts,
seed_used: u64,
rng: &mut crate::rng::Rng,
) -> PlsKitResult<PermNullOutput> {
let d = xs.ncols();
let b = opts.n_perm;
let beta_rows: Vec<Vec<f64>> =
crate::resample::parallel_for_each_seeded(rng, b, opts.disable_parallelism, |_, child| {
run_one_perm(xs, ys, k, wref, child).unwrap_or_else(|_| vec![f64::NAN; d])
});
let mut flat = vec![0.0_f64; b * d];
for (bi, row) in beta_rows.iter().enumerate() {
let off = bi * d;
flat[off..off + d].copy_from_slice(row);
}
let (beta_perm_mean, beta_perm_sd) = reduce_two_pass(&flat, b, d);
let beta_perm_z = signed_z(&beta_ref, &beta_perm_sd);
Ok(PermNullOutput {
n_perm: b,
k,
seed: seed_used,
n_eff: n_eff_val,
beta_ref,
beta_perm_mean,
beta_perm_sd,
beta_perm_z,
beta_perm_matrix: Some(flat),
})
}
#[allow(clippy::too_many_arguments, clippy::unnecessary_wraps)]
fn run_engine_streaming(
xs: MatRef<'_, f64>,
ys: ColRef<'_, f64>,
k: usize,
wref: Option<faer::ColRef<'_, f64>>,
beta_ref: Vec<f64>,
n_eff_val: f64,
opts: PermNullOpts,
seed_used: u64,
rng: &mut crate::rng::Rng,
) -> PlsKitResult<PermNullOutput> {
let d = xs.ncols();
let b = opts.n_perm;
let beta_rows: Vec<Vec<f64>> =
crate::resample::parallel_for_each_seeded(rng, b, opts.disable_parallelism, |_, child| {
run_one_perm(xs, ys, k, wref, child).unwrap_or_else(|_| vec![f64::NAN; d])
});
let mut flat = vec![0.0_f64; b * d];
for (bi, row) in beta_rows.iter().enumerate() {
let off = bi * d;
flat[off..off + d].copy_from_slice(row);
}
let (beta_perm_mean, beta_perm_sd) = reduce_two_pass(&flat, b, d);
let beta_perm_z = signed_z(&beta_ref, &beta_perm_sd);
Ok(PermNullOutput {
n_perm: b,
k,
seed: seed_used,
n_eff: n_eff_val,
beta_ref,
beta_perm_mean,
beta_perm_sd,
beta_perm_z,
beta_perm_matrix: None,
})
}
fn reduce_two_pass(flat: &[f64], b: usize, d: usize) -> (Vec<f64>, Vec<f64>) {
let mut mean = vec![0.0_f64; d];
let mut count = vec![0_usize; d];
for bi in 0..b {
let off = bi * d;
for j in 0..d {
let v = flat[off + j];
if v.is_finite() {
mean[j] += v;
count[j] += 1;
}
}
}
for j in 0..d {
if count[j] > 0 {
mean[j] /= count[j] as f64;
}
}
let mut m2 = vec![0.0_f64; d];
for bi in 0..b {
let off = bi * d;
for j in 0..d {
let v = flat[off + j];
if v.is_finite() {
let dv = v - mean[j];
m2[j] += dv * dv;
}
}
}
let sd: Vec<f64> = (0..d)
.map(|j| {
if count[j] > 1 {
(m2[j] / (count[j] - 1) as f64).sqrt()
} else {
0.0
}
})
.collect();
(mean, sd)
}
#[allow(clippy::doc_markdown)]
fn signed_z(beta_ref: &[f64], sd: &[f64]) -> Vec<f64> {
let eps = f64::EPSILON.sqrt();
beta_ref
.iter()
.zip(sd.iter())
.map(|(b, s)| if *s > eps { b / s } else { f64::NAN })
.collect()
}
fn run_one_perm(
xs: MatRef<'_, f64>,
ys_std: ColRef<'_, f64>,
k: usize,
wref: Option<faer::ColRef<'_, f64>>,
rng: &mut crate::rng::Rng,
) -> PlsKitResult<Vec<f64>> {
use crate::fit::{pls1_fit, FitOpts, KSpec};
let n = xs.nrows();
let d = xs.ncols();
let perm = crate::resample::permute_indices(n, rng);
let y_perm = faer::Col::<f64>::from_fn(n, |i| ys_std[perm[i]]);
let fit = pls1_fit(
xs,
y_perm.as_ref(),
KSpec::Fixed(k),
wref,
FitOpts {
pre_standardized: true,
..FitOpts::default()
},
)?;
let mut out = vec![0.0_f64; d];
#[allow(clippy::needless_range_loop)]
for j in 0..d {
out[j] = fit.beta[j];
}
Ok(out)
}
#[cfg(test)]
#[allow(clippy::many_single_char_names)]
pub(crate) fn run_one_perm_for_test(
x: MatRef<'_, f64>,
y: ColRef<'_, f64>,
k: usize,
pre_standardized_x: bool,
rng: &mut crate::rng::Rng,
) -> PlsKitResult<Vec<f64>> {
use crate::linalg::{standardize, standardize1};
use faer::{Col, Mat};
let n = x.nrows();
let d = x.ncols();
let (xs_owned, ys_owned) = if pre_standardized_x {
(
Mat::<f64>::from_fn(n, d, |i, j| x[(i, j)]),
Col::<f64>::from_fn(n, |i| y[i]),
)
} else {
let (xs, _, _) = standardize(x);
let (ys, _, _) = standardize1(y);
(xs, ys)
};
run_one_perm(xs_owned.as_ref(), ys_owned.as_ref(), k, None, rng)
}
#[cfg(test)]
mod tests_worker {
use super::*;
use faer::{Col, Mat};
use rand::RngExt;
use rand::SeedableRng;
fn synth(n: usize, d: usize, snr: f64, seed: u64) -> (Mat<f64>, Col<f64>) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let x = Mat::<f64>::from_fn(n, d, |_, _| rng.random_range(-1.0..1.0));
let beta = Col::<f64>::from_fn(d, |j| if j < 2 { 1.0 } else { 0.0 });
let signal: Col<f64> = &x * β
let noise = Col::<f64>::from_fn(n, |_| rng.random_range(-1.0..1.0));
let y = Col::<f64>::from_fn(n, |i| signal[i] * snr + noise[i]);
(x, y)
}
#[test]
fn worker_returns_finite_beta_with_correct_length() {
let (x, y) = synth(80, 5, 4.0, 1);
let (_, mut rng) = crate::rng::resolve_seed(Some(11));
let beta = run_one_perm_for_test(x.as_ref(), y.as_ref(), 2, false, &mut rng).unwrap();
assert_eq!(beta.len(), 5);
for v in &beta {
assert!(v.is_finite());
}
}
}
#[cfg(test)]
mod tests_engine_retained {
use super::*;
use faer::{Col, Mat};
use rand::RngExt;
use rand::SeedableRng;
fn synth(n: usize, d: usize, snr: f64, seed: u64) -> (Mat<f64>, Col<f64>) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let x = Mat::<f64>::from_fn(n, d, |_, _| rng.random_range(-1.0..1.0));
let beta = Col::<f64>::from_fn(d, |j| if j < 2 { 1.0 } else { 0.0 });
let signal: Col<f64> = &x * β
let noise = Col::<f64>::from_fn(n, |_| rng.random_range(-1.0..1.0));
let y = Col::<f64>::from_fn(n, |i| signal[i] * snr + noise[i]);
(x, y)
}
fn opts_retained() -> PermNullOpts {
PermNullOpts {
n_perm: 200,
return_perm_matrix: true,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
}
}
#[test]
fn retained_path_runs_end_to_end() {
let (x, y) = synth(100, 5, 4.0, 42);
let out =
pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts_retained(), Some(7)).unwrap();
assert_eq!(out.n_perm, 200);
assert_eq!(out.k, 2);
assert_eq!(out.beta_ref.len(), 5);
assert_eq!(out.beta_perm_mean.len(), 5);
assert_eq!(out.beta_perm_sd.len(), 5);
assert_eq!(out.beta_perm_z.len(), 5);
let m = out
.beta_perm_matrix
.as_ref()
.expect("matrix should be retained");
assert_eq!(m.len(), 200 * 5);
for &z in &out.beta_perm_z {
assert!(z.is_finite() || z.is_nan());
}
}
#[test]
fn retained_path_signal_voxels_have_higher_abs_z() {
let (x, y) = synth(150, 8, 6.0, 11);
let out =
pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts_retained(), Some(13)).unwrap();
let signal: f64 = out.beta_perm_z[..2].iter().map(|z| z.abs()).sum::<f64>() / 2.0;
let noise: f64 = out.beta_perm_z[2..].iter().map(|z| z.abs()).sum::<f64>() / 6.0;
assert!(
signal > noise,
"signal mean |z|={signal}, noise mean |z|={noise}"
);
}
}
#[cfg(test)]
mod tests_engine_streaming {
use super::*;
use faer::{Col, Mat};
use rand::RngExt;
use rand::SeedableRng;
fn synth(n: usize, d: usize, snr: f64, seed: u64) -> (Mat<f64>, Col<f64>) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let x = Mat::<f64>::from_fn(n, d, |_, _| rng.random_range(-1.0..1.0));
let beta = Col::<f64>::from_fn(d, |j| if j < 2 { 1.0 } else { 0.0 });
let signal: Col<f64> = &x * β
let noise = Col::<f64>::from_fn(n, |_| rng.random_range(-1.0..1.0));
let y = Col::<f64>::from_fn(n, |i| signal[i] * snr + noise[i]);
(x, y)
}
#[test]
fn streaming_path_runs_end_to_end() {
let (x, y) = synth(100, 5, 4.0, 42);
let opts = PermNullOpts {
n_perm: 200,
return_perm_matrix: false,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let out = pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts, Some(7)).unwrap();
assert!(out.beta_perm_matrix.is_none());
assert_eq!(out.beta_perm_sd.len(), 5);
assert_eq!(out.beta_perm_z.len(), 5);
}
#[test]
fn streaming_matches_retained_byte_exact() {
let (x, y) = synth(100, 5, 4.0, 42);
let opts_retained = PermNullOpts {
n_perm: 200,
return_perm_matrix: true,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let opts_streaming = PermNullOpts {
return_perm_matrix: false,
..opts_retained
};
let r1 = pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts_retained, Some(99)).unwrap();
let r2 = pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts_streaming, Some(99)).unwrap();
assert_eq!(
r1.beta_perm_mean, r2.beta_perm_mean,
"beta_perm_mean must be byte-exact between retained and streaming"
);
assert_eq!(
r1.beta_perm_sd, r2.beta_perm_sd,
"beta_perm_sd must be byte-exact between retained and streaming"
);
assert_eq!(
r1.beta_perm_z, r2.beta_perm_z,
"beta_perm_z must be byte-exact between retained and streaming"
);
}
}
#[cfg(test)]
mod tests_validate {
use super::*;
fn opts_default() -> PermNullOpts {
PermNullOpts {
n_perm: 1000,
return_perm_matrix: false,
pre_standardized: false,
disable_parallelism: false,
verbose: false,
}
}
#[test]
fn validate_accepts_defaults() {
opts_default().validate(2).unwrap();
}
#[test]
fn validate_rejects_low_n_perm() {
let mut o = opts_default();
o.n_perm = 50;
let err = o.validate(2).unwrap_err();
assert_eq!(err.code(), "invalid_argument");
assert!(format!("{err}").contains("n_perm"));
}
#[test]
fn validate_rejects_zero_k() {
let err = opts_default().validate(0).unwrap_err();
assert_eq!(err.code(), "invalid_argument");
assert!(format!("{err}").contains('k'));
}
}
#[cfg(test)]
mod tests_calibration {
use super::*;
use faer::{Col, Mat};
use rand::RngExt;
use rand::SeedableRng;
fn synth_h0(n: usize, d: usize, seed: u64) -> (Mat<f64>, Col<f64>) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let x = Mat::<f64>::from_fn(n, d, |_, _| rng.random_range(-1.0..1.0));
let y = Col::<f64>::from_fn(n, |_| rng.random_range(-1.0..1.0));
(x, y)
}
fn synth_h1_signed(n: usize, d: usize, seed: u64) -> (Mat<f64>, Col<f64>) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let x = Mat::<f64>::from_fn(n, d, |_, _| rng.random_range(-1.0..1.0));
let y = Col::<f64>::from_fn(n, |i| 2.0 * x[(i, 0)] + 0.3 * rng.random_range(-1.0..1.0));
(x, y)
}
#[test]
fn h0_mean_perm_close_to_zero() {
let (x, y) = synth_h0(80, 5, 1);
let opts = PermNullOpts {
n_perm: 1000,
return_perm_matrix: false,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let out = pls1_perm_null(x.as_ref(), y.as_ref(), 1, None, opts, Some(7)).unwrap();
for j in 0..5 {
let band = 5.0 * out.beta_perm_sd[j] / (1000.0_f64).sqrt();
assert!(
out.beta_perm_mean[j].abs() < band,
"mean_perm[{j}] = {} exceeds band {}",
out.beta_perm_mean[j],
band,
);
}
}
#[test]
fn h0_uncorrected_fpr_close_to_alpha() {
let mut total = 0_usize;
let mut rejects = 0_usize;
for seed in 0..3_u64 {
let (x, y) = synth_h0(80, 30, seed * 17 + 3);
let opts = PermNullOpts {
n_perm: 500,
return_perm_matrix: false,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let out =
pls1_perm_null(x.as_ref(), y.as_ref(), 1, None, opts, Some(seed * 11 + 7)).unwrap();
for &z in &out.beta_perm_z {
if z.is_finite() {
total += 1;
if z.abs() > 1.96 {
rejects += 1;
}
}
}
}
let fpr = rejects as f64 / total as f64;
assert!(
(0.01..=0.15).contains(&fpr),
"FPR={fpr} (rejects={rejects} / total={total}) outside [0.01, 0.15]",
);
}
#[test]
fn h1_signed_signal_recovered() {
let (x, y) = synth_h1_signed(150, 8, 23);
let opts = PermNullOpts {
n_perm: 500,
return_perm_matrix: false,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let out = pls1_perm_null(x.as_ref(), y.as_ref(), 1, None, opts, Some(31)).unwrap();
assert!(
out.beta_perm_z[0] > 0.0,
"z[0] = {} not positive",
out.beta_perm_z[0]
);
let abs_z: Vec<f64> = out.beta_perm_z.iter().map(|z| z.abs()).collect();
let max_other = abs_z[1..].iter().copied().fold(0.0_f64, f64::max);
assert!(
abs_z[0] > max_other,
"|z[0]|={} not larger than max |z[1..]|={}",
abs_z[0],
max_other,
);
}
#[test]
fn parallelism_determinism_disable_vs_enable() {
let (x, y) = synth_h0(60, 5, 41);
let opts_serial = PermNullOpts {
n_perm: 200,
return_perm_matrix: false,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let opts_parallel = PermNullOpts {
disable_parallelism: false,
..opts_serial
};
let r1 = pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts_serial, Some(2026)).unwrap();
let r2 =
pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts_parallel, Some(2026)).unwrap();
assert_eq!(
r1.beta_perm_mean, r2.beta_perm_mean,
"beta_perm_mean must be byte-exact across serial/parallel"
);
assert_eq!(
r1.beta_perm_sd, r2.beta_perm_sd,
"beta_perm_sd must be byte-exact across serial/parallel"
);
assert_eq!(
r1.beta_perm_z, r2.beta_perm_z,
"beta_perm_z must be byte-exact across serial/parallel"
);
}
#[test]
fn parallelism_determinism_retained_matrix_byte_exact() {
let (x, y) = synth_h0(60, 5, 41);
let opts_serial = PermNullOpts {
n_perm: 200,
return_perm_matrix: true,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let opts_parallel = PermNullOpts {
disable_parallelism: false,
..opts_serial
};
let r1 = pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts_serial, Some(2026)).unwrap();
let r2 =
pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts_parallel, Some(2026)).unwrap();
let m1 = r1.beta_perm_matrix.as_ref().unwrap();
let m2 = r2.beta_perm_matrix.as_ref().unwrap();
assert_eq!(
m1, m2,
"retained matrices diverge between serial and parallel"
);
}
}
#[cfg(test)]
mod tests_validation {
use super::*;
use faer::{Col, Mat};
#[test]
fn rejects_dim_mismatch() {
let x = Mat::<f64>::zeros(10, 5);
let y = Col::<f64>::zeros(9);
let opts = PermNullOpts {
n_perm: 200,
return_perm_matrix: false,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let err = pls1_perm_null(x.as_ref(), y.as_ref(), 2, None, opts, Some(1)).unwrap_err();
assert_eq!(err.code(), "dimension_mismatch");
}
#[test]
fn rejects_k_exceeds_max() {
let x = Mat::<f64>::zeros(20, 4);
let y = Col::<f64>::zeros(20);
let opts = PermNullOpts {
n_perm: 200,
return_perm_matrix: false,
pre_standardized: false,
disable_parallelism: true,
verbose: false,
};
let err = pls1_perm_null(x.as_ref(), y.as_ref(), 5, None, opts, Some(1)).unwrap_err();
assert_eq!(err.code(), "k_exceeds_max");
}
}