use faer::{Col, Mat};
use plskit::{
pls1_confirmatory_test, pls1_fit, CIOpts, ConfirmatoryArgs, ConfirmatoryTestInput,
ConfirmatoryTestOpts, FitOpts, KSpec,
};
use rand::{RngExt, SeedableRng};
use rand_chacha::ChaCha8Rng;
const N_DATASETS: usize = 200;
const N_BOOT: usize = 300;
const LEVEL: f64 = 0.95;
const BAND_HALF_WIDTH: f64 = 0.05;
const ORACLE_N: usize = 50_000;
const CELL_NS: &[usize] = &[100, 200, 500];
const CELL_DS: &[usize] = &[6, 20];
const CELL_KS: &[usize] = &[1, 2, 3];
const CELL_SNRS: &[f64] = &[1.0, 4.0];
fn synth(rng: &mut ChaCha8Rng, n: usize, d: usize, snr: f64) -> (Mat<f64>, Col<f64>) {
let x = Mat::<f64>::from_fn(n, d, |_, _| rng.random_range(-1.0..1.0));
let beta_signal: Vec<f64> = (0..d).map(|j| if j < 2 { 1.0 } else { 0.0 }).collect();
let signal: Col<f64> =
Col::<f64>::from_fn(n, |i| (0..d).map(|j| x[(i, j)] * beta_signal[j]).sum());
let noise: Col<f64> = Col::<f64>::from_fn(n, |_| rng.random_range(-1.0..1.0));
let y = Col::<f64>::from_fn(n, |i| signal[i] * snr + noise[i]);
(x, y)
}
#[allow(clippy::similar_names, clippy::many_single_char_names)]
fn oracle_leverage(x: faer::MatRef<f64>, y: faer::ColRef<f64>, k: usize) -> Vec<f64> {
let fit = pls1_fit(x, y, KSpec::Fixed(k), None, FitOpts::default())
.expect("oracle pls1_fit must succeed");
let mut wtw = faer::Mat::<f64>::zeros(k, k);
faer::linalg::matmul::matmul(
wtw.as_mut(),
faer::Accum::Replace,
fit.w_star.transpose(),
fit.w_star.as_ref(),
1.0,
faer::Par::Seq,
);
let lu = faer::linalg::solvers::PartialPivLu::new(wtw.as_ref());
let mut m_inv = faer::Mat::<f64>::zeros(k, k);
for i in 0..k {
m_inv[(i, i)] = 1.0;
}
{
use faer::prelude::Solve;
lu.solve_in_place(m_inv.as_mut());
}
let d = x.ncols();
let mut leverage = vec![0.0_f64; d];
let mut tmp = vec![0.0_f64; k];
#[allow(clippy::needless_range_loop)]
for j in 0..d {
for kk in 0..k {
let mut s = 0.0;
for ll in 0..k {
s += fit.w_star[(j, ll)] * m_inv[(ll, kk)];
}
tmp[kk] = s;
}
let mut q = 0.0;
for kk in 0..k {
q += tmp[kk] * fit.w_star[(j, kk)];
}
leverage[j] = q;
}
leverage
}
#[allow(clippy::cast_possible_truncation)]
fn cell_base_seed(n: usize, d: usize, k: usize, snr_idx: usize) -> u64 {
let salt: u64 = 0x5EED_C0FF_BAAD_F00D;
let mut s = salt;
s ^= (n as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15);
s ^= (d as u64).wrapping_mul(0xBF58_476D_1CE4_E5B9);
s ^= (k as u64).wrapping_mul(0x94D0_49BB_1331_11EB);
s ^= (snr_idx as u64).wrapping_mul(0xD6E8_FEB8_6659_FD93);
s
}
#[test]
#[ignore = "slow MC coverage test; run before release with --ignored"]
#[allow(clippy::too_many_lines, clippy::similar_names)]
fn coverage_mc_two_sided_grid() {
let band_lo = LEVEL - BAND_HALF_WIDTH;
let band_hi = LEVEL + BAND_HALF_WIDTH;
println!(
"coverage_mc grid: cells = {} × {} × {} × {} = {}, datasets/cell = {}",
CELL_NS.len(),
CELL_DS.len(),
CELL_KS.len(),
CELL_SNRS.len(),
CELL_NS.len() * CELL_DS.len() * CELL_KS.len() * CELL_SNRS.len(),
N_DATASETS,
);
println!(
"level = {LEVEL:.2}, two-sided band = [{band_lo:.2}, {band_hi:.2}], n_boot = {N_BOOT}",
);
let mut failures: Vec<String> = Vec::new();
for &n in CELL_NS {
for &d in CELL_DS {
for &k in CELL_KS {
for (snr_idx, &snr) in CELL_SNRS.iter().enumerate() {
let base = cell_base_seed(n, d, k, snr_idx);
let mut oracle_rng = ChaCha8Rng::seed_from_u64(base ^ 0xDEAD_BEEF_DEAD_BEEF);
let (x_oracle, y_oracle) = synth(&mut oracle_rng, ORACLE_N, d, snr);
let oracle_opts = ConfirmatoryTestOpts {
args: ConfirmatoryArgs::SplitNb { n_splits: 50 },
ci: Some(CIOpts {
n_boot: N_BOOT,
m_rate: 0.7,
level: LEVEL,
max_failure_rate: 0.0,
}),
seed: Some(base ^ 0xC0FF_EE00_C0FF_EE00),
disable_parallelism: false,
..Default::default()
};
let oracle_r = pls1_confirmatory_test(
ConfirmatoryTestInput::Raw {
x: x_oracle.as_ref(),
y: y_oracle.as_ref(),
k,
weights: None,
},
oracle_opts,
)
.expect("oracle confirmatory_test must succeed");
let oracle_holdout_corr = oracle_r
.ci
.expect("oracle CI must be Some")
.holdout_corr
.point;
let oracle_lev = oracle_leverage(x_oracle.as_ref(), y_oracle.as_ref(), k);
let mut covered_holdout = 0_usize;
let mut covered_lev = vec![0_usize; d];
for d_idx in 0..N_DATASETS {
let dataset_seed = base.wrapping_add(d_idx as u64);
let mut rng = ChaCha8Rng::seed_from_u64(dataset_seed);
let (x, y) = synth(&mut rng, n, d, snr);
let opts = ConfirmatoryTestOpts {
args: ConfirmatoryArgs::SplitNb { n_splits: 30 },
ci: Some(CIOpts {
n_boot: N_BOOT,
m_rate: 0.7,
level: LEVEL,
max_failure_rate: 0.0,
}),
seed: Some(dataset_seed ^ 0xBEEF_BEEF_BEEF_BEEF),
disable_parallelism: false,
..Default::default()
};
let r = pls1_confirmatory_test(
ConfirmatoryTestInput::Raw {
x: x.as_ref(),
y: y.as_ref(),
k,
weights: None,
},
opts,
)
.expect("strict-mode MC dataset must not fail");
let ci = r.ci.expect("ci=Some must produce ci.is_some");
if ci.holdout_corr.lower <= oracle_holdout_corr
&& oracle_holdout_corr <= ci.holdout_corr.upper
{
covered_holdout += 1;
}
for j in 0..d {
if ci.leverage_ci_lower[j] <= oracle_lev[j]
&& oracle_lev[j] <= ci.leverage_ci_upper[j]
{
covered_lev[j] += 1;
}
}
}
#[allow(clippy::cast_precision_loss)]
let denom = N_DATASETS as f64;
#[allow(clippy::cast_precision_loss)]
let cov_holdout = covered_holdout as f64 / denom;
let cov_lev: Vec<f64> = covered_lev
.iter()
.map(|&c| {
#[allow(clippy::cast_precision_loss)]
let v = c as f64 / denom;
v
})
.collect();
let signal_str: Vec<String> =
(0..2).map(|j| format!("{:.3}", cov_lev[j])).collect();
let noise_str: Vec<String> =
(2..d).map(|j| format!("{:.3}", cov_lev[j])).collect();
println!(
"[cell n={:>3} d={:>2} k={} snr={:.0}] holdout_corr={:.3} \
leverage_signal=[{}] leverage_noise=[{}]",
n,
d,
k,
snr,
cov_holdout,
signal_str.join(", "),
noise_str.join(", "),
);
if !(cov_holdout >= band_lo && cov_holdout <= band_hi) {
failures.push(format!(
"cell (n={n}, d={d}, k={k}, snr={snr}): holdout_corr \
coverage {cov_holdout:.3} outside band \
[{band_lo:.2}, {band_hi:.2}]",
));
}
for (j, &c) in cov_lev.iter().take(2).enumerate() {
if !(c >= band_lo && c <= band_hi) {
failures.push(format!(
"cell (n={n}, d={d}, k={k}, snr={snr}): leverage[{j}] \
(signal) coverage {c:.3} outside band [{band_lo:.2}, {band_hi:.2}]",
));
}
}
}
}
}
}
if !failures.is_empty() {
let n_fail = failures.len();
let joined = failures.join("\n ");
panic!(
"coverage_mc grid: {n_fail} cell-metric failures outside two-sided \
band [{band_lo:.2}, {band_hi:.2}]:\n {joined}",
);
}
println!("coverage_mc grid: all cells within band [{band_lo:.2}, {band_hi:.2}]");
}