use faer::{Col, ColRef, Mat, MatRef};
use crate::error::PlsKitResult;
use crate::signal_test::ConfirmatoryMethod;
#[derive(Debug, Clone, Copy)]
pub(crate) enum SequentialArgs {
RawPerm {
n_perm: usize,
},
SplitNb {
n_splits: usize,
},
SplitPerm {
n_perm: usize,
n_splits: usize,
},
E,
}
impl SequentialArgs {
#[must_use]
pub(crate) fn method(&self) -> ConfirmatoryMethod {
match self {
SequentialArgs::RawPerm { .. } => ConfirmatoryMethod::RawPerm,
SequentialArgs::SplitNb { .. } => ConfirmatoryMethod::SplitNb,
SequentialArgs::SplitPerm { .. } => ConfirmatoryMethod::SplitPerm,
SequentialArgs::E => ConfirmatoryMethod::E,
}
}
#[must_use]
pub(crate) fn defaults_for(method: ConfirmatoryMethod) -> Option<Self> {
Some(match method {
ConfirmatoryMethod::RawPerm => SequentialArgs::RawPerm { n_perm: 1000 },
ConfirmatoryMethod::SplitNb => SequentialArgs::SplitNb { n_splits: 50 },
ConfirmatoryMethod::SplitPerm => SequentialArgs::SplitPerm {
n_perm: 1000,
n_splits: 50,
},
ConfirmatoryMethod::E => SequentialArgs::E,
ConfirmatoryMethod::Score => return None,
})
}
#[must_use]
pub(crate) fn to_confirmatory_args(self) -> crate::signal_test::ConfirmatoryArgs {
use crate::signal_test::ConfirmatoryArgs;
match self {
SequentialArgs::RawPerm { n_perm } => ConfirmatoryArgs::RawPerm { n_perm, n_folds: 5 },
SequentialArgs::SplitNb { n_splits } => ConfirmatoryArgs::SplitNb { n_splits },
SequentialArgs::SplitPerm { n_perm, n_splits } => {
ConfirmatoryArgs::SplitPerm { n_perm, n_splits }
}
SequentialArgs::E => ConfirmatoryArgs::E,
}
}
}
#[derive(Debug, Clone, Copy)]
#[allow(clippy::struct_excessive_bools)]
pub(crate) struct IncrementalSequenceOpts {
pub(crate) args: SequentialArgs,
pub(crate) alpha: f64,
pub(crate) stop_early_override: bool,
pub(crate) pre_standardized: bool,
pub(crate) seed: Option<u64>,
pub(crate) disable_parallelism: bool,
pub(crate) verbose: bool,
}
#[derive(Debug, Clone)]
pub(crate) struct IncrementalSequenceOutput {
pub(crate) pvalues: Col<f64>,
pub(crate) last_significant_k: Option<usize>,
#[allow(dead_code)]
pub(crate) method: String,
#[allow(dead_code)]
pub(crate) alpha: f64,
pub(crate) seed: u64,
}
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn run_incremental_sequence(
x: MatRef<'_, f64>,
y: ColRef<'_, f64>,
k_max: usize,
weights: Option<ColRef<'_, f64>>,
opts: IncrementalSequenceOpts,
) -> PlsKitResult<IncrementalSequenceOutput> {
let max_allowed = x.ncols();
if k_max == 0 || k_max > max_allowed {
return Err(crate::error::PlsKitError::KExceedsMax {
k: k_max,
k_max: max_allowed,
});
}
let (seed_used, mut rng) = crate::rng::resolve_seed(opts.seed);
let mut pvalues_vec: Vec<f64> = vec![f64::NAN; k_max];
let mut last_sig: Option<usize> = None;
for h in 1..=k_max {
let p = p_for_incremental(x, y, h, weights, &opts, &mut rng)?;
pvalues_vec[h - 1] = p;
if p < opts.alpha {
last_sig = Some(h);
}
if !opts.stop_early_override && p >= opts.alpha {
break;
}
}
let pvalues = Col::<f64>::from_fn(k_max, |i| pvalues_vec[i]);
Ok(IncrementalSequenceOutput {
pvalues,
last_significant_k: last_sig,
method: opts.args.method().as_str().to_owned(),
alpha: opts.alpha,
seed: seed_used,
})
}
fn p_for_confirmatory_at_k(
x: MatRef<'_, f64>,
y: ColRef<'_, f64>,
k: usize,
weights: Option<ColRef<'_, f64>>,
opts: &IncrementalSequenceOpts,
rng: &mut crate::rng::Rng,
) -> PlsKitResult<f64> {
use crate::signal_test::{pls1_confirmatory_test, ConfirmatoryTestInput, ConfirmatoryTestOpts};
let _: u64 = {
use rand::Rng;
rng.next_u64()
};
let r = pls1_confirmatory_test(
ConfirmatoryTestInput::Raw { x, y, k, weights },
ConfirmatoryTestOpts {
args: opts.args.to_confirmatory_args(),
pre_standardized: opts.pre_standardized,
seed: Some({
use rand::Rng;
rng.next_u64()
}),
disable_parallelism: opts.disable_parallelism,
verbose: opts.verbose,
ci: None,
max_skip_rate: 0.01,
},
)?;
Ok(r.pvalue)
}
fn p_for_incremental(
x: MatRef<'_, f64>,
y: ColRef<'_, f64>,
h: usize,
weights: Option<ColRef<'_, f64>>,
opts: &IncrementalSequenceOpts,
rng: &mut crate::rng::Rng,
) -> PlsKitResult<f64> {
use crate::fit::{pls1_fit, FitOpts, KSpec};
use crate::linalg::{standardize, standardize1};
let (xs_full, _, _) = standardize(x);
let (ys_full, _, _) = standardize1(y);
let (xs_def, ys_def) = if h == 1 {
(xs_full, ys_full)
} else {
let prev = pls1_fit(
xs_full.as_ref(),
ys_full.as_ref(),
KSpec::Fixed(h - 1),
None,
FitOpts {
pre_standardized: true,
..FitOpts::default()
},
)?;
let tp: Mat<f64> = prev.t_scores.as_ref() * prev.p_loadings.transpose();
let xs_d = Mat::<f64>::from_fn(xs_full.nrows(), xs_full.ncols(), |i, j| {
xs_full[(i, j)] - tp[(i, j)]
});
let tq: Col<f64> = prev.t_scores.as_ref() * prev.q_loadings.as_ref();
let ys_d = Col::<f64>::from_fn(ys_full.nrows(), |i| ys_full[i] - tq[i]);
(xs_d, ys_d)
};
let mut sub_opts = *opts;
sub_opts.pre_standardized = true;
p_for_confirmatory_at_k(xs_def.as_ref(), ys_def.as_ref(), 1, weights, &sub_opts, rng)
}
#[cfg(test)]
mod tests {
use super::*;
fn synth(
n: usize,
d: usize,
k_signal: usize,
snr: f64,
seed: u64,
) -> (faer::Mat<f64>, Col<f64>) {
use rand::RngExt;
use rand::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let x = faer::Mat::<f64>::from_fn(n, d, |_, _| rng.random_range(-1.0..1.0));
let beta = Col::<f64>::from_fn(d, |j| if j < k_signal { 1.0 } else { 0.0 });
let signal: Col<f64> = &x * β
let y = Col::<f64>::from_fn(n, |i| signal[i] * snr + rng.random_range(-1.0..1.0));
(x, y)
}
#[test]
fn score_unrepresentable() {
assert!(SequentialArgs::defaults_for(ConfirmatoryMethod::Score).is_none());
}
#[test]
fn incremental_stops_early_at_first_nonrejection() {
let (x, y) = synth(60, 5, 1, 4.0, 2);
let r = run_incremental_sequence(
x.as_ref(),
y.as_ref(),
5,
None,
IncrementalSequenceOpts {
args: SequentialArgs::SplitNb { n_splits: 30 },
alpha: 0.05,
stop_early_override: false,
pre_standardized: false,
seed: Some(11),
disable_parallelism: false,
verbose: false,
},
)
.unwrap();
let n_filled = (0..r.pvalues.nrows())
.filter(|i| !r.pvalues[*i].is_nan())
.count();
assert!(
n_filled < 5,
"stop-early did not trigger; pvalues={:?}",
r.pvalues
);
assert!(r.pvalues[0] < 0.05);
}
#[test]
fn override_runs_all_k() {
let (x, y) = synth(60, 5, 1, 4.0, 1);
let r = run_incremental_sequence(
x.as_ref(),
y.as_ref(),
3,
None,
IncrementalSequenceOpts {
args: SequentialArgs::SplitNb { n_splits: 30 },
alpha: 0.05,
stop_early_override: true,
pre_standardized: false,
seed: Some(7),
disable_parallelism: false,
verbose: false,
},
)
.unwrap();
assert_eq!(r.pvalues.nrows(), 3);
assert!((0..3).all(|i| !r.pvalues[i].is_nan()));
}
}