use crate::estimate::EstimationError;
use crate::estimate::{FitGeometry, UnifiedFitResult};
use crate::faer_ndarray::FaerArrayView;
use crate::linalg::utils::StableSolver;
use crate::pirls;
use crate::types::LinkFunction;
use faer::Mat as FaerMat;
use faer::linalg::matmul::matmul;
use faer::prelude::ReborrowMut;
use faer::{Accum, Par};
use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder, s};
#[derive(Debug, Clone)]
pub struct AloDiagnostics {
pub eta_tilde: Array1<f64>,
pub se_bayes: Array1<f64>,
pub se_sandwich: Array1<f64>,
pub pred_identity: Array1<f64>,
pub leverage: Array1<f64>,
pub fisherweights: Array1<f64>,
}
#[inline]
fn alo_eta_updatewith_offset(
eta_hat: f64,
z: f64,
offset: f64,
x_hinv_x: f64,
hessian_weight: f64,
score_weight: f64,
) -> f64 {
let denom = 1.0 - hessian_weight * x_hinv_x;
let eta_centered = eta_hat - offset;
let z_centered = z - offset;
let score = score_weight * (eta_centered - z_centered);
offset + eta_centered + x_hinv_x * score / denom
}
#[inline]
fn bayesvar_eta(phi: f64, x_hinv_x: f64) -> f64 {
phi * x_hinv_x
}
#[inline]
fn sandwichvar_eta(phi: f64, x_hinv_x: f64, es_norm2: f64, ridge: f64, s_norm2: f64) -> f64 {
phi * (x_hinv_x - es_norm2 - ridge * s_norm2)
}
#[inline]
fn variance_negative_tolerance(scale: f64) -> f64 {
1e-12 * scale.abs().max(1.0)
}
const LEVERAGE_HIGH_THRESHOLD: f64 = 0.99;
const LEVERAGE_VERY_HIGH_THRESHOLD: f64 = 0.999;
const LEVERAGE_RATE_THRESHOLDS: [f64; 3] = [0.90, 0.95, 0.99];
const LEVERAGE_PERCENTILES: [f64; 3] = [0.50, 0.95, 0.99];
const MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES: usize = 256 * 1024 * 1024;
#[inline]
fn percentile_index(sample_size: usize, quantile: f64) -> usize {
if sample_size <= 1 {
return 0;
}
let max_index = sample_size - 1;
((quantile * max_index as f64).round() as usize).min(max_index)
}
#[inline]
fn percentile_from_sorted(sorted: &[f64], quantile: f64) -> f64 {
if sorted.is_empty() {
0.0
} else {
sorted[percentile_index(sorted.len(), quantile)]
}
}
#[inline]
fn multiblock_col_offsets(block_designs: &[Array2<f64>]) -> Vec<usize> {
let mut offsets = Vec::with_capacity(block_designs.len());
let mut off = 0usize;
for design in block_designs {
offsets.push(off);
off += design.ncols();
}
offsets
}
#[inline]
fn multiblock_alo_parallel_leverage_chunk_size(
p_tot: usize,
n_blocks: usize,
n_obs: usize,
max_workers: usize,
) -> usize {
if p_tot == 0 || n_blocks == 0 || n_obs == 0 {
return 1;
}
let workers = max_workers.max(1);
let per_worker_budget = (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / workers).max(1);
let elem_count_per_obs = p_tot.saturating_mul(n_blocks.saturating_add(1)).max(1);
let bytes_per_obs = elem_count_per_obs
.saturating_mul(std::mem::size_of::<f64>())
.max(1);
let budget_obs = (per_worker_budget / bytes_per_obs).max(1);
budget_obs.min(n_obs)
}
fn compute_alo_diagnostics_from_pirls_impl(
base: &pirls::PirlsResult,
y: ArrayView1<f64>,
link: LinkFunction,
) -> Result<AloDiagnostics, EstimationError> {
let x_dense_arc = base
.x_transformed
.try_to_dense_arc("ALO diagnostics require dense transformed design")
.map_err(EstimationError::InvalidInput)?;
let x_dense = x_dense_arc.as_ref();
let n = x_dense.nrows();
let phi = match link {
LinkFunction::Log => 1.0,
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => 1.0,
LinkFunction::Identity => {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let rss: f64 = (0..n)
.into_par_iter()
.map(|i| {
let r = y[i] - base.finalmu[i];
base.finalweights[i] * r * r
})
.sum();
let dof = (n as f64) - base.edf;
let denom = dof.max(1.0);
rss / denom
}
};
let e = &base.reparam_result.e_transformed;
let ridge = base.ridge_passport.laplacehessianridge().max(0.0);
let h_dense_for_alo = base.dense_stabilizedhessian_transformed(
"ALO diagnostics require exact dense stabilized penalized Hessian",
)?;
let input = AloInput {
design: x_dense,
penalized_hessian: &h_dense_for_alo,
hessian_weights: &base.finalweights,
score_weights: &base.solveweights,
working_response: &base.solveworking_response,
eta: &base.final_eta,
offset: &base.final_offset,
link,
phi,
penalty_root: if e.nrows() > 0 { Some(e) } else { None },
ridge,
};
let result = compute_alo_from_input(&input)?;
log_leverage_diagnostics(&result.leverage, phi);
let has_nan_pred = result.eta_tilde.iter().any(|&x| x.is_nan());
let has_nan_se_bayes = result.se_bayes.iter().any(|&x| x.is_nan());
let has_nan_se_sandwich = result.se_sandwich.iter().any(|&x| x.is_nan());
let has_nan_leverage = result.leverage.iter().any(|&x| x.is_nan());
if has_nan_pred || has_nan_se_bayes || has_nan_se_sandwich || has_nan_leverage {
log::error!("[GAM ALO] NaN values found in ALO diagnostics:");
log::error!(
"[GAM ALO] eta_tilde: {} NaN values",
result.eta_tilde.iter().filter(|&&x| x.is_nan()).count()
);
log::error!(
"[GAM ALO] se_bayes: {} NaN values",
result.se_bayes.iter().filter(|&&x| x.is_nan()).count()
);
log::error!(
"[GAM ALO] se_sandwich: {} NaN values",
result.se_sandwich.iter().filter(|&&x| x.is_nan()).count()
);
log::error!(
"[GAM ALO] leverage: {} NaN values",
result.leverage.iter().filter(|&&x| x.is_nan()).count()
);
return Err(EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
});
}
Ok(result)
}
fn log_leverage_diagnostics(leverage: &Array1<f64>, phi: f64) {
let n = leverage.len();
if n == 0 {
return;
}
let mut invalid_count = 0usize;
let mut high_leverage_count = 0usize;
let mut threshold_counts = [0usize; LEVERAGE_RATE_THRESHOLDS.len()];
let mut finite_leverage = Vec::with_capacity(n);
for (obs, &ai) in leverage.iter().enumerate() {
if ai.is_finite() {
finite_leverage.push(ai);
}
if !(0.0..=1.0).contains(&ai) || !ai.is_finite() {
invalid_count += 1;
log::warn!("[GAM ALO] invalid leverage at i={}, a_ii={:.6e}", obs, ai);
} else if ai > LEVERAGE_HIGH_THRESHOLD {
high_leverage_count += 1;
if ai > LEVERAGE_VERY_HIGH_THRESHOLD {
log::warn!("[GAM ALO] very high leverage at i={}, a_ii={:.6e}", obs, ai);
}
}
for (idx, threshold) in LEVERAGE_RATE_THRESHOLDS.iter().enumerate() {
if ai > *threshold {
threshold_counts[idx] += 1;
}
}
}
if invalid_count > 0 || high_leverage_count > 0 {
log::warn!(
"[GAM ALO] leverage diagnostics: {} invalid values, {} high values (>0.99)",
invalid_count,
high_leverage_count
);
}
finite_leverage.sort_by(f64::total_cmp);
let finite_n = finite_leverage.len();
let a_mean = if finite_n > 0 {
finite_leverage.iter().copied().sum::<f64>() / finite_n as f64
} else {
0.0
};
let a_median = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[0]);
let a_p95 = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[1]);
let a_p99 = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[2]);
let a_max = finite_leverage.last().copied().unwrap_or(0.0);
log::warn!(
"[GAM ALO] leverage: n={}, mean={:.3e}, median={:.3e}, p95={:.3e}, p99={:.3e}, max={:.3e}",
n,
a_mean,
a_median,
a_p95,
a_p99,
a_max
);
log::warn!(
"[GAM ALO] high-leverage: a>0.90: {:.2}%, a>0.95: {:.2}%, a>0.99: {:.2}%, dispersion phi={:.3e}",
100.0 * (threshold_counts[0] as f64) / n as f64,
100.0 * (threshold_counts[1] as f64) / n as f64,
100.0 * (threshold_counts[2] as f64) / n as f64,
phi
);
}
pub struct AloInput<'a> {
pub design: &'a Array2<f64>,
pub penalized_hessian: &'a Array2<f64>,
pub hessian_weights: &'a Array1<f64>,
pub score_weights: &'a Array1<f64>,
pub working_response: &'a Array1<f64>,
pub eta: &'a Array1<f64>,
pub offset: &'a Array1<f64>,
pub link: LinkFunction,
pub phi: f64,
pub penalty_root: Option<&'a Array2<f64>>,
pub ridge: f64,
}
impl<'a> AloInput<'a> {
pub fn from_geometry(
geom: &'a FitGeometry,
design: &'a Array2<f64>,
eta: &'a Array1<f64>,
offset: &'a Array1<f64>,
link: LinkFunction,
phi: f64,
) -> Self {
Self {
design,
penalized_hessian: &geom.penalized_hessian,
hessian_weights: &geom.working_weights,
score_weights: &geom.working_weights,
working_response: &geom.working_response,
eta,
offset,
link,
phi,
penalty_root: None,
ridge: 0.0,
}
}
}
pub fn compute_alo_from_input(input: &AloInput) -> Result<AloDiagnostics, EstimationError> {
let x_dense = input.design;
let n = x_dense.nrows();
let p = x_dense.ncols();
let w_h = input.hessian_weights;
let w_s = input.score_weights;
validate_alo_solve_setup(input, n, p)?;
let factor = StableSolver::new("alo penalized hessian")
.factorize(input.penalized_hessian)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
let xt = x_dense.t();
let phi = input.phi;
let ridge = input.ridge;
let e_rank = input.penalty_root.map(|e| e.nrows()).unwrap_or(0);
let mut aii = Array1::<f64>::zeros(n);
let mut x_hinv_x_diag = Array1::<f64>::zeros(n);
let mut se_bayes = Array1::<f64>::zeros(n);
let mut se_sandwich = Array1::<f64>::zeros(n);
let block_cols = 8192usize;
let mut rhs_chunk_buf = Array2::<f64>::zeros((p, block_cols).f());
let mut es_chunk_storage = if e_rank > 0 {
FaerMat::<f64>::zeros(e_rank, block_cols)
} else {
FaerMat::<f64>::zeros(0, 0)
};
for chunk_start in (0..n).step_by(block_cols) {
let chunk_end = (chunk_start + block_cols).min(n);
let width = chunk_end - chunk_start;
rhs_chunk_buf
.slice_mut(s![.., ..width])
.assign(&xt.slice(s![.., chunk_start..chunk_end]));
let rhs_chunkview = rhs_chunk_buf.slice(s![.., ..width]);
let rhs_chunk = FaerArrayView::new(&rhs_chunkview);
let s_chunk = factor.solve(rhs_chunk.as_ref());
if e_rank > 0 {
if let Some(e) = input.penalty_root {
let eview = FaerArrayView::new(e);
let mut es_target = es_chunk_storage.as_mut().subcols_mut(0, width);
matmul(
es_target.rb_mut(),
Accum::Replace,
eview.as_ref(),
s_chunk.as_ref(),
1.0,
Par::Seq,
);
}
}
let rhs_view = rhs_chunk_buf.slice(s![.., ..width]);
for local_col in 0..width {
let obs = chunk_start + local_col;
let rhs_col = rhs_view.column(local_col);
let rhs_slice = rhs_col.as_slice().expect("column-major col contiguous");
let s_slice = s_chunk.col_as_slice(local_col);
let mut x_hinv_x = 0.0f64;
let mut s_norm2 = 0.0f64;
for k in 0..p {
let sval = s_slice[k];
let xval = rhs_slice[k];
x_hinv_x = sval.mul_add(xval, x_hinv_x);
s_norm2 = sval.mul_add(sval, s_norm2);
}
let ai = w_h[obs].max(0.0) * x_hinv_x;
let mut es_norm2 = 0.0f64;
if e_rank > 0 {
let es_slice = es_chunk_storage.col_as_slice(local_col);
for r in 0..e_rank {
let v = es_slice[r];
es_norm2 = v.mul_add(v, es_norm2);
}
}
aii[obs] = ai;
x_hinv_x_diag[obs] = x_hinv_x;
let var_bayes = bayesvar_eta(phi, x_hinv_x);
let var_sandwich = if e_rank > 0 {
sandwichvar_eta(phi, x_hinv_x, es_norm2, ridge, s_norm2)
} else {
var_bayes
};
if !var_bayes.is_finite() || !var_sandwich.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"ALO variance is not finite at row {obs}: bayes={var_bayes:.6e}, sandwich={var_sandwich:.6e}"
)));
}
let bayes_tol = variance_negative_tolerance(phi * x_hinv_x.abs());
if var_bayes < -bayes_tol {
return Err(EstimationError::InvalidInput(format!(
"ALO Bayesian variance is materially negative at row {obs}: var={var_bayes:.6e}, tol={bayes_tol:.6e}"
)));
}
if e_rank > 0 {
let sandwich_scale =
phi * (x_hinv_x.abs() + es_norm2.abs() + (ridge * s_norm2).abs());
let sandwich_tol = variance_negative_tolerance(sandwich_scale);
if var_sandwich < -sandwich_tol {
return Err(EstimationError::InvalidInput(format!(
"ALO sandwich variance is materially negative at row {obs}: var={var_sandwich:.6e}, tol={sandwich_tol:.6e}"
)));
}
}
se_bayes[obs] = var_bayes.max(0.0).sqrt();
se_sandwich[obs] = var_sandwich.max(0.0).sqrt();
}
}
let eta_hat = input.eta;
let z = input.working_response;
let offset = input.offset;
use rayon::prelude::*;
let eta_tilde_vec: Vec<f64> = (0..n)
.into_par_iter()
.map(|i| {
let denom_raw = 1.0 - aii[i];
if denom_raw <= 0.0 || !denom_raw.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"ALO denominator is non-positive at row {i}: a_ii={:.6e}, 1-a_ii={:.6e}",
aii[i], denom_raw
)));
}
let v = alo_eta_updatewith_offset(
eta_hat[i],
z[i],
offset[i],
x_hinv_x_diag[i],
w_h[i],
w_s[i],
);
if !v.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"ALO eta_tilde is not finite at row {i}: eta_tilde={v}"
)));
}
Ok(v)
})
.collect::<Result<_, _>>()?;
let eta_tilde = Array1::from(eta_tilde_vec);
Ok(AloDiagnostics {
eta_tilde,
se_bayes,
se_sandwich,
pred_identity: eta_hat.clone(),
leverage: aii,
fisherweights: w_h.clone(),
})
}
fn validate_alo_solve_setup(input: &AloInput, n: usize, p: usize) -> Result<(), EstimationError> {
let h = input.penalized_hessian;
if h.nrows() != p || h.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"ALO diagnostics require a dense exact penalized Hessian with shape {p}x{p}; got {}x{}",
h.nrows(),
h.ncols()
)));
}
if h.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::InvalidInput(
"ALO diagnostics require a finite dense exact penalized Hessian".to_string(),
));
}
let sym_tol = 1e-8;
for i in 0..p {
for j in 0..i {
let a = h[[i, j]];
let b = h[[j, i]];
let scale = a.abs().max(b.abs()).max(1.0);
if (a - b).abs() > sym_tol * scale {
return Err(EstimationError::InvalidInput(format!(
"ALO diagnostics require a symmetric dense exact penalized Hessian; entries ({i},{j}) and ({j},{i}) differ by {:.3e}",
(a - b).abs()
)));
}
}
}
let vector_lengths = [
("hessian_weights", input.hessian_weights.len()),
("score_weights", input.score_weights.len()),
("working_response", input.working_response.len()),
("eta", input.eta.len()),
("offset", input.offset.len()),
];
for (name, len) in vector_lengths {
if len != n {
return Err(EstimationError::InvalidInput(format!(
"ALO diagnostics require {name} length {n}; got {len}"
)));
}
}
if input.hessian_weights.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::InvalidInput(
"ALO diagnostics require finite Hessian-side weights".to_string(),
));
}
if input.score_weights.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::InvalidInput(
"ALO diagnostics require finite score-side weights".to_string(),
));
}
if input.working_response.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::InvalidInput(
"ALO diagnostics require finite working responses".to_string(),
));
}
if input.eta.iter().any(|v| !v.is_finite()) || input.offset.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::InvalidInput(
"ALO diagnostics require finite linear predictors and offsets".to_string(),
));
}
if !input.phi.is_finite() || input.phi <= 0.0 {
return Err(EstimationError::InvalidInput(format!(
"ALO diagnostics require positive finite dispersion phi; got {}",
input.phi
)));
}
if !input.ridge.is_finite() || input.ridge < 0.0 {
return Err(EstimationError::InvalidInput(format!(
"ALO diagnostics require a finite non-negative Hessian ridge; got {}",
input.ridge
)));
}
if let Some(e) = input.penalty_root {
if e.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"ALO diagnostics require penalty root to have {p} columns; got {}",
e.ncols()
)));
}
if e.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::InvalidInput(
"ALO diagnostics require finite penalty-root entries".to_string(),
));
}
}
Ok(())
}
pub fn compute_alo_diagnostics_from_fit(
fit: &UnifiedFitResult,
y: ArrayView1<f64>,
link: LinkFunction,
) -> Result<AloDiagnostics, EstimationError> {
let pirls = fit.artifacts.pirls.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"ALO diagnostics require a PIRLS-backed fit; this fit does not expose PIRLS geometry"
.to_string(),
)
})?;
compute_alo_diagnostics_from_pirls_impl(pirls, y, link)
}
pub fn compute_alo_diagnostics_from_unified(
unified: &UnifiedFitResult,
design: &Array2<f64>,
eta: &Array1<f64>,
offset: &Array1<f64>,
link: LinkFunction,
phi: f64,
) -> Result<AloDiagnostics, EstimationError> {
let geom = unified.geometry.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"UnifiedFitResult does not contain working-set geometry; \
ALO diagnostics require geometry at convergence"
.to_string(),
)
})?;
let input = AloInput::from_geometry(geom, design, eta, offset, link, phi);
compute_alo_from_input(&input)
}
pub fn compute_alo_diagnostics_from_pirls(
base: &pirls::PirlsResult,
y: ArrayView1<f64>,
link: LinkFunction,
) -> Result<AloDiagnostics, EstimationError> {
compute_alo_diagnostics_from_pirls_impl(base, y, link)
}
#[derive(Debug, Clone)]
pub struct MultiBlockAloDiagnostics {
pub eta_tilde: Vec<Array1<f64>>,
pub leverage: Array1<f64>,
pub alo_variance: Vec<Array1<f64>>,
pub cook_distance: Array1<f64>,
}
pub struct MultiBlockAloInput<'a> {
pub n_obs: usize,
pub n_blocks: usize,
pub block_designs: &'a [Array2<f64>],
pub penalized_hessian_inv: &'a Array2<f64>,
pub block_weights: Vec<Array2<f64>>,
pub scores: Vec<Array1<f64>>,
pub eta_hat: Vec<Array1<f64>>,
}
pub fn compute_multiblock_alo(
input: &MultiBlockAloInput,
) -> Result<MultiBlockAloDiagnostics, EstimationError> {
use rayon::prelude::*;
let n = input.n_obs;
let b = input.n_blocks;
let p_tot = input.penalized_hessian_inv.nrows();
if input.block_designs.len() != b {
return Err(EstimationError::InvalidInput(format!(
"MultiBlockAloInput: expected {} block designs, got {}",
b,
input.block_designs.len()
)));
}
let col_sum: usize = input.block_designs.iter().map(|d| d.ncols()).sum();
if col_sum != p_tot {
return Err(EstimationError::InvalidInput(format!(
"MultiBlockAloInput: total design columns ({}) != penalized_hessian_inv size ({})",
col_sum, p_tot
)));
}
let col_offsets = multiblock_col_offsets(input.block_designs);
let (chunk_size, max_concurrent_chunks) = multiblock_alo_parallel_plan(p_tot, b, n);
let chunk_starts: Vec<usize> = (0..n).step_by(chunk_size).collect();
let mut chunk_results: Vec<Result<MultiBlockAloChunkDiagnostics, EstimationError>> =
Vec::with_capacity(chunk_starts.len());
for chunk_wave in chunk_starts.chunks(max_concurrent_chunks) {
let mut wave_results: Vec<Result<MultiBlockAloChunkDiagnostics, EstimationError>> =
chunk_wave
.par_iter()
.map_init(
|| MultiBlockAloScratch::new(b),
|scratch, &chunk_start| {
let chunk_end = (chunk_start + chunk_size).min(n);
compute_multiblock_alo_chunk(
input,
&col_offsets,
chunk_start,
chunk_end,
scratch,
)
},
)
.collect();
chunk_results.append(&mut wave_results);
}
let mut eta_tilde = Vec::with_capacity(n);
let mut leverage = Array1::<f64>::zeros(n);
let mut alo_variance = Vec::with_capacity(n);
let mut cook_distance = Array1::<f64>::zeros(n);
let mut chunks = Vec::with_capacity(chunk_results.len());
for result in chunk_results {
chunks.push(result?);
}
chunks.sort_unstable_by_key(|chunk| chunk.chunk_start);
for chunk in chunks {
let chunk_start = chunk.chunk_start;
eta_tilde.extend(chunk.eta_tilde);
alo_variance.extend(chunk.alo_variance);
for (local_i, lev) in chunk.leverage.into_iter().enumerate() {
leverage[chunk_start + local_i] = lev;
}
for (local_i, cook) in chunk.cook_distance.into_iter().enumerate() {
cook_distance[chunk_start + local_i] = cook;
}
}
Ok(MultiBlockAloDiagnostics {
eta_tilde,
leverage,
alo_variance,
cook_distance,
})
}
#[inline]
fn multiblock_alo_parallel_plan(p_tot: usize, n_blocks: usize, n_obs: usize) -> (usize, usize) {
if p_tot == 0 || n_blocks == 0 || n_obs == 0 {
return (1, 1);
}
let bytes_per_obs = (p_tot * n_blocks * std::mem::size_of::<f64>()).max(1);
let workers = rayon::current_num_threads().max(1);
let max_concurrent_chunks = (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / bytes_per_obs)
.max(1)
.min(workers);
let per_worker_budget =
(MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / max_concurrent_chunks).max(bytes_per_obs);
let budget_obs = (per_worker_budget / bytes_per_obs).max(1);
(budget_obs.min(n_obs), max_concurrent_chunks)
}
struct MultiBlockAloScratch {
a_i: Vec<f64>,
wa: Vec<f64>,
aw: Vec<f64>,
imwa: Vec<f64>,
imaw: Vec<f64>,
perm_imwa: Vec<usize>,
perm_imaw: Vec<usize>,
delta_eta: Vec<f64>,
rhs_buf: Vec<f64>,
w_u: Vec<f64>,
var_diag_buf: Vec<f64>,
w_flat: Vec<f64>,
lu_scratch: Vec<f64>,
}
impl MultiBlockAloScratch {
fn new(b: usize) -> Self {
let bb_sz = b * b;
Self {
a_i: vec![0.0f64; bb_sz],
wa: vec![0.0f64; bb_sz],
aw: vec![0.0f64; bb_sz],
imwa: vec![0.0f64; bb_sz],
imaw: vec![0.0f64; bb_sz],
perm_imwa: vec![0usize; b],
perm_imaw: vec![0usize; b],
delta_eta: vec![0.0f64; b],
rhs_buf: vec![0.0f64; b],
w_u: vec![0.0f64; b],
var_diag_buf: vec![0.0f64; b],
w_flat: vec![0.0f64; bb_sz],
lu_scratch: vec![0.0f64; b],
}
}
}
struct MultiBlockAloChunkDiagnostics {
chunk_start: usize,
eta_tilde: Vec<Array1<f64>>,
leverage: Vec<f64>,
alo_variance: Vec<Array1<f64>>,
cook_distance: Vec<f64>,
}
fn compute_multiblock_alo_chunk(
input: &MultiBlockAloInput,
col_offsets: &[usize],
chunk_start: usize,
chunk_end: usize,
scratch: &mut MultiBlockAloScratch,
) -> Result<MultiBlockAloChunkDiagnostics, EstimationError> {
let b = input.n_blocks;
let chunk_len = chunk_end - chunk_start;
let mut q_blocks = Vec::with_capacity(b);
for blk in 0..b {
let x_chunk_t = input.block_designs[blk]
.slice(s![chunk_start..chunk_end, ..])
.t()
.to_owned();
let off_b = col_offsets[blk];
let h_slice = input
.penalized_hessian_inv
.slice(s![.., off_b..off_b + x_chunk_t.nrows()])
.to_owned();
q_blocks.push(h_slice.dot(&x_chunk_t));
}
let mut eta_tilde = Vec::with_capacity(chunk_len);
let mut leverage = vec![0.0f64; chunk_len];
let mut alo_variance = Vec::with_capacity(chunk_len);
let mut cook_distance = vec![0.0f64; chunk_len];
for local_i in 0..chunk_len {
let i = chunk_start + local_i;
let w_i = &input.block_weights[i];
for r in 0..b {
for c in 0..b {
scratch.w_flat[r * b + c] = w_i[(r, c)];
}
}
for a in 0..b {
let x_a = &input.block_designs[a];
let p_a = x_a.ncols();
let off_a = col_offsets[a];
let xa_row = x_a.row(i);
for bb in 0..b {
let q_bb = &q_blocks[bb];
let mut dot = 0.0f64;
for k in 0..p_a {
dot += xa_row[k] * q_bb[(off_a + k, local_i)];
}
scratch.a_i[a * b + bb] = dot;
}
}
mat_mul_flat(&scratch.w_flat, &scratch.a_i, &mut scratch.wa, b);
mat_mul_flat(&scratch.a_i, &scratch.w_flat, &mut scratch.aw, b);
let mut tr = 0.0f64;
for d in 0..b {
tr += scratch.aw[d * b + d];
}
leverage[local_i] = tr;
for r in 0..b {
for c in 0..b {
let idx = r * b + c;
let id = if r == c { 1.0 } else { 0.0 };
scratch.imwa[idx] = id - scratch.wa[idx];
scratch.imaw[idx] = id - scratch.aw[idx];
}
}
if !lu_factor_in_place(&mut scratch.imwa, &mut scratch.perm_imwa, b) {
for r in 0..b {
for c in 0..b {
let idx = r * b + c;
let id = if r == c { 1.0 } else { 0.0 };
scratch.imwa[idx] = id - scratch.wa[idx];
}
}
for d in 0..b {
scratch.imwa[d * b + d] += 1e-6;
}
let _ = lu_factor_in_place(&mut scratch.imwa, &mut scratch.perm_imwa, b);
}
if !lu_factor_in_place(&mut scratch.imaw, &mut scratch.perm_imaw, b) {
for r in 0..b {
for c in 0..b {
let idx = r * b + c;
let id = if r == c { 1.0 } else { 0.0 };
scratch.imaw[idx] = id - scratch.aw[idx];
}
}
for d in 0..b {
scratch.imaw[d * b + d] += 1e-6;
}
let _ = lu_factor_in_place(&mut scratch.imaw, &mut scratch.perm_imaw, b);
}
let s_i = &input.scores[i];
for k in 0..b {
scratch.rhs_buf[k] = s_i[k];
}
lu_solve_in_place(
&scratch.imwa,
&scratch.perm_imwa,
&mut scratch.rhs_buf,
&mut scratch.lu_scratch,
b,
);
for r in 0..b {
let mut acc = 0.0f64;
let row_off = r * b;
for k in 0..b {
acc += scratch.a_i[row_off + k] * scratch.rhs_buf[k];
}
scratch.delta_eta[r] = acc;
}
let eta_i = &input.eta_hat[i];
let mut corrected = Array1::<f64>::zeros(b);
for d in 0..b {
corrected[d] = eta_i[d] + scratch.delta_eta[d];
}
eta_tilde.push(corrected);
let mut cook = 0.0f64;
for r in 0..b {
let mut w_delta_r = 0.0f64;
let row_off = r * b;
for k in 0..b {
w_delta_r += scratch.w_flat[row_off + k] * scratch.delta_eta[k];
}
cook += scratch.delta_eta[r] * w_delta_r;
}
cook_distance[local_i] = cook;
for d in 0..b {
let row_off = d * b;
for k in 0..b {
scratch.rhs_buf[k] = scratch.a_i[row_off + k];
}
lu_solve_in_place(
&scratch.imaw,
&scratch.perm_imaw,
&mut scratch.rhs_buf,
&mut scratch.lu_scratch,
b,
);
for r in 0..b {
let mut acc = 0.0f64;
let wr = r * b;
for k in 0..b {
acc += scratch.w_flat[wr + k] * scratch.rhs_buf[k];
}
scratch.w_u[r] = acc;
}
lu_solve_in_place(
&scratch.imwa,
&scratch.perm_imwa,
&mut scratch.w_u,
&mut scratch.lu_scratch,
b,
);
let mut v_dd = 0.0f64;
for k in 0..b {
v_dd += scratch.a_i[row_off + k] * scratch.w_u[k];
}
scratch.var_diag_buf[d] = v_dd.max(0.0);
}
let mut var_diag = Array1::<f64>::zeros(b);
for d in 0..b {
var_diag[d] = scratch.var_diag_buf[d];
}
alo_variance.push(var_diag);
}
Ok(MultiBlockAloChunkDiagnostics {
chunk_start,
eta_tilde,
leverage,
alo_variance,
cook_distance,
})
}
#[inline]
fn mat_mul_flat(a: &[f64], b_mat: &[f64], out: &mut [f64], b: usize) {
for r in 0..b {
let ar = r * b;
let or = r * b;
for c in 0..b {
let mut acc = 0.0f64;
for k in 0..b {
acc += a[ar + k] * b_mat[k * b + c];
}
out[or + c] = acc;
}
}
}
fn lu_factor_in_place(m: &mut [f64], perm: &mut [usize], b: usize) -> bool {
for i in 0..b {
perm[i] = i;
}
for col in 0..b {
let mut max_val = m[col * b + col].abs();
let mut max_idx = col;
for row in (col + 1)..b {
let v = m[row * b + col].abs();
if v > max_val {
max_val = v;
max_idx = row;
}
}
if max_val < 1e-12 {
return false;
}
if max_idx != col {
for k in 0..b {
m.swap(col * b + k, max_idx * b + k);
}
perm.swap(col, max_idx);
}
let pivot = m[col * b + col];
for row in (col + 1)..b {
let factor = m[row * b + col] / pivot;
m[row * b + col] = factor; for k in (col + 1)..b {
let upd = factor * m[col * b + k];
m[row * b + k] -= upd;
}
}
}
true
}
fn lu_solve_in_place(m: &[f64], perm: &[usize], rhs: &mut [f64], scratch: &mut [f64], b: usize) {
let y = &mut scratch[..b];
for row in 0..b {
let mut s = rhs[perm[row]];
for k in 0..row {
s -= m[row * b + k] * y[k];
}
y[row] = s;
}
for row in (0..b).rev() {
let mut s = y[row];
for k in (row + 1)..b {
s -= m[row * b + k] * rhs[k];
}
rhs[row] = s / m[row * b + row];
}
}
pub fn compute_multiblock_alo_leverages(
n_obs: usize,
n_blocks: usize,
block_designs: &[Array2<f64>],
penalized_hessian_inv: &Array2<f64>,
block_weights: &[Array2<f64>],
) -> Result<Array1<f64>, EstimationError> {
use rayon::prelude::*;
let n = n_obs;
let b = n_blocks;
let p_tot = penalized_hessian_inv.nrows();
let col_offsets = multiblock_col_offsets(block_designs);
let max_workers = rayon::current_num_threads();
let chunk_size = multiblock_alo_parallel_leverage_chunk_size(p_tot, b, n, max_workers);
let mut leverage = Array1::<f64>::zeros(n);
let block_widths: Vec<usize> = block_designs.iter().map(|d| d.ncols()).collect();
let mut h_stripes: Vec<FaerMat<f64>> = block_widths
.iter()
.map(|&p_blk| FaerMat::<f64>::zeros(p_tot, p_blk))
.collect();
for blk in 0..b {
let off_b = col_offsets[blk];
let p_blk = block_widths[blk];
let stripe = &mut h_stripes[blk];
for c in 0..p_blk {
for r in 0..p_tot {
stripe[(r, c)] = penalized_hessian_inv[(r, off_b + c)];
}
}
}
leverage
.as_slice_mut()
.expect("newly allocated Array1 is contiguous")
.par_chunks_mut(chunk_size)
.enumerate()
.for_each(|(chunk_idx, leverage_chunk)| {
let chunk_start = chunk_idx * chunk_size;
let chunk_len = leverage_chunk.len();
let chunk_end = chunk_start + chunk_len;
let bb_sz = b * b;
let mut a_i = vec![0.0f64; bb_sz];
let mut aw = vec![0.0f64; bb_sz];
let mut w_flat = vec![0.0f64; bb_sz];
let mut q_storage: Vec<FaerMat<f64>> = block_widths
.iter()
.map(|_| FaerMat::<f64>::zeros(p_tot, chunk_len))
.collect();
let mut xt_storage: Vec<FaerMat<f64>> = block_widths
.iter()
.map(|&p_blk| FaerMat::<f64>::zeros(p_blk, chunk_len))
.collect();
for blk in 0..b {
let p_blk = block_widths[blk];
let x_chunk = block_designs[blk].slice(s![chunk_start..chunk_end, ..]);
let xt = &mut xt_storage[blk];
for local_i in 0..chunk_len {
let row = x_chunk.row(local_i);
for j in 0..p_blk {
xt[(j, local_i)] = row[j];
}
}
matmul(
q_storage[blk].as_mut(),
Accum::Replace,
h_stripes[blk].as_ref(),
xt_storage[blk].as_ref(),
1.0,
Par::Seq,
);
}
for local_i in 0..chunk_len {
let i = chunk_start + local_i;
let w_i = &block_weights[i];
for r in 0..b {
for c in 0..b {
w_flat[r * b + c] = w_i[(r, c)];
}
}
for r in 0..bb_sz {
a_i[r] = 0.0;
}
for k in 0..b {
let q_k = &q_storage[k];
let q_col = q_k.col_as_slice(local_i);
for a in 0..b {
let p_a = block_widths[a];
let off_a = col_offsets[a];
let xa_row = block_designs[a].row(i);
let mut dot = 0.0f64;
for j in 0..p_a {
dot = xa_row[j].mul_add(q_col[off_a + j], dot);
}
a_i[a * b + k] = dot;
}
}
mat_mul_flat(&a_i, &w_flat, &mut aw, b);
let mut tr = 0.0f64;
for d in 0..b {
tr += aw[d * b + d];
}
leverage_chunk[local_i] = tr;
}
});
Ok(leverage)
}
#[cfg(test)]
mod tests {
use super::{
alo_eta_updatewith_offset, bayesvar_eta, percentile_from_sorted, percentile_index,
sandwichvar_eta,
};
#[test]
fn alo_offset_update_matches_centered_algebra() {
let eta_hat = 11.0;
let z = 13.0;
let offset = 10.0;
let x_hinv_x = 0.2;
let hessian_weight = 1.0;
let score_weight = 1.0;
let leverage = hessian_weight * x_hinv_x;
let expected = offset + ((eta_hat - offset) - leverage * (z - offset)) / (1.0 - leverage);
let got =
alo_eta_updatewith_offset(eta_hat, z, offset, x_hinv_x, hessian_weight, score_weight);
assert!((got - expected).abs() < 1e-12);
}
#[test]
fn alo_offset_update_reduces_to_classicwhen_offsetzero() {
let eta_hat = 1.25;
let z = -0.5;
let x_hinv_x = 0.35;
let hessian_weight = 1.0;
let score_weight = 1.0;
let leverage = hessian_weight * x_hinv_x;
let expected = (eta_hat - leverage * z) / (1.0 - leverage);
let got =
alo_eta_updatewith_offset(eta_hat, z, 0.0, x_hinv_x, hessian_weight, score_weight);
assert!((got - expected).abs() < 1e-12);
}
#[test]
fn alo_offset_update_uses_distinct_score_and_hessian_weights() {
let eta_hat = 1.7;
let z = 0.4;
let offset = -0.2;
let x_hinv_x = 0.15;
let hessian_weight = 3.0;
let score_weight = 5.0;
let expected = offset
+ (eta_hat - offset)
+ x_hinv_x * score_weight * ((eta_hat - offset) - (z - offset))
/ (1.0 - hessian_weight * x_hinv_x);
let got =
alo_eta_updatewith_offset(eta_hat, z, offset, x_hinv_x, hessian_weight, score_weight);
assert!((got - expected).abs() < 1e-12);
}
#[test]
fn alo_offset_update_handles_zero_hessian_weight() {
let eta_hat = 0.8;
let z = -0.3;
let offset = 0.1;
let x_hinv_x = 0.4;
let hessian_weight = 0.0;
let score_weight = 2.5;
let expected = offset
+ (eta_hat - offset)
+ x_hinv_x * score_weight * ((eta_hat - offset) - (z - offset));
let got =
alo_eta_updatewith_offset(eta_hat, z, offset, x_hinv_x, hessian_weight, score_weight);
assert!((got - expected).abs() < 1e-12);
}
#[test]
fn gaussian_unpenalized_sandwich_equals_bayes() {
let phi = 2.5;
let x_hinv_x = 0.3;
let es_norm2 = 0.0;
let ridge = 0.0;
let s_norm2 = 0.0;
let vb = bayesvar_eta(phi, x_hinv_x);
let vs = sandwichvar_eta(phi, x_hinv_x, es_norm2, ridge, s_norm2);
assert!((vb - vs).abs() < 1e-12);
}
#[test]
fn sandwich_matches_direct_linear_gaussian_formula() {
let phi = 1.7;
let x_hinv_x = 0.41;
let es_norm2 = 0.05;
let ridge = 1e-3;
let s_norm2 = 2.0;
let got = sandwichvar_eta(phi, x_hinv_x, es_norm2, ridge, s_norm2);
let expected = phi * (x_hinv_x - es_norm2 - ridge * s_norm2);
assert!((got - expected).abs() < 1e-12);
}
#[test]
fn percentile_index_matches_expected_rounding() {
assert_eq!(percentile_index(0, 0.95), 0);
assert_eq!(percentile_index(1, 0.95), 0);
assert_eq!(percentile_index(10, 0.50), 5);
assert_eq!(percentile_index(10, 0.95), 9);
}
#[test]
fn percentile_from_sorted_returns_order_statistic() {
let values = [1.0, 2.0, 3.0, 4.0, 5.0];
assert_eq!(percentile_from_sorted(&values, 0.50), 3.0);
assert_eq!(percentile_from_sorted(&values, 0.95), 5.0);
assert_eq!(percentile_from_sorted(&[], 0.95), 0.0);
}
use super::{MultiBlockAloInput, compute_multiblock_alo, compute_multiblock_alo_leverages};
use ndarray::{Array1, Array2};
#[test]
fn multiblock_b1_matches_scalar_leverage() {
let n = 3;
let p = 2;
let x = Array2::from_shape_vec((n, p), vec![1.0, 0.5, 0.8, -0.3, 0.2, 1.1]).unwrap();
let w = vec![1.0, 2.0, 0.5];
let mut h = Array2::<f64>::eye(p);
for i in 0..n {
for r in 0..p {
for c in 0..p {
h[(r, c)] += w[i] * x[(i, r)] * x[(i, c)];
}
}
}
let det = h[(0, 0)] * h[(1, 1)] - h[(0, 1)] * h[(1, 0)];
let mut h_inv = Array2::<f64>::zeros((p, p));
h_inv[(0, 0)] = h[(1, 1)] / det;
h_inv[(1, 1)] = h[(0, 0)] / det;
h_inv[(0, 1)] = -h[(0, 1)] / det;
h_inv[(1, 0)] = -h[(1, 0)] / det;
let mut scalar_lev = vec![0.0f64; n];
for i in 0..n {
let mut xhx = 0.0;
for r in 0..p {
for c in 0..p {
xhx += x[(i, r)] * h_inv[(r, c)] * x[(i, c)];
}
}
scalar_lev[i] = w[i] * xhx;
}
let block_designs = vec![x.clone()];
let block_weights: Vec<Array2<f64>> =
w.iter().map(|&wi| Array2::from_elem((1, 1), wi)).collect();
let scores: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.1])).collect();
let eta_hat: Vec<Array1<f64>> = (0..n).map(|i| Array1::from_vec(vec![i as f64])).collect();
let input = MultiBlockAloInput {
n_obs: n,
n_blocks: 1,
block_designs: &block_designs,
penalized_hessian_inv: &h_inv,
block_weights,
scores,
eta_hat,
};
let result = compute_multiblock_alo(&input).unwrap();
for i in 0..n {
assert!(
(result.leverage[i] - scalar_lev[i]).abs() < 1e-10,
"leverage mismatch at i={}: got {}, expected {}",
i,
result.leverage[i],
scalar_lev[i]
);
}
}
#[test]
fn multiblock_leverage_only_matches_full() {
let n = 4;
let p1 = 2;
let p2 = 3;
let x1 = Array2::from_shape_fn((n, p1), |(i, j)| (i + j + 1) as f64 * 0.3);
let x2 = Array2::from_shape_fn((n, p2), |(i, j)| (i * 2 + j) as f64 * 0.2 - 0.1);
let p_tot = p1 + p2;
let h_inv = Array2::<f64>::eye(p_tot); let block_weights: Vec<Array2<f64>> = (0..n)
.map(|i| {
let v = (i + 1) as f64;
Array2::from_shape_vec((2, 2), vec![v, 0.1, 0.1, v * 0.5]).unwrap()
})
.collect();
let scores: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.0, 0.0])).collect();
let eta_hat: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.0, 0.0])).collect();
let block_designs = vec![x1.clone(), x2.clone()];
let input = MultiBlockAloInput {
n_obs: n,
n_blocks: 2,
block_designs: &block_designs,
penalized_hessian_inv: &h_inv,
block_weights: block_weights.clone(),
scores,
eta_hat,
};
let full = compute_multiblock_alo(&input).unwrap();
let lev_only =
compute_multiblock_alo_leverages(n, 2, &block_designs, &h_inv, &block_weights).unwrap();
for i in 0..n {
assert!(
(full.leverage[i] - lev_only[i]).abs() < 1e-12,
"leverage mismatch at i={}: full={}, lev_only={}",
i,
full.leverage[i],
lev_only[i]
);
}
}
#[test]
fn multiblock_singular_weight_still_corrects() {
let n = 1;
let p = 2;
let x = Array2::from_shape_vec((1, p), vec![1.0, 0.5]).unwrap();
let h_inv = Array2::eye(p);
let block_designs = vec![x.clone()];
let block_weights = vec![Array2::from_elem((1, 1), 0.0)]; let scores = vec![Array1::from_vec(vec![1.0])];
let eta_hat = vec![Array1::from_vec(vec![std::f64::consts::PI])];
let input = MultiBlockAloInput {
n_obs: n,
n_blocks: 1,
block_designs: &block_designs,
penalized_hessian_inv: &h_inv,
block_weights,
scores,
eta_hat,
};
let result = compute_multiblock_alo(&input).unwrap();
let expected = std::f64::consts::PI + 1.25;
assert!(
(result.eta_tilde[0][0] - expected).abs() < 1e-12,
"expected {}, got {}",
expected,
result.eta_tilde[0][0]
);
assert!(result.cook_distance[0].abs() < 1e-14);
assert!(result.alo_variance[0][0].abs() < 1e-14);
}
#[test]
fn multiblock_cook_and_variance_basic() {
let n = 1;
let x = Array2::from_elem((1, 1), 1.0);
let h_inv = Array2::from_elem((1, 1), 0.5);
let block_designs = vec![x.clone()];
let w_val = 2.0;
let s_val = 0.4;
let block_weights = vec![Array2::from_elem((1, 1), w_val)];
let scores = vec![Array1::from_vec(vec![s_val])];
let eta_hat = vec![Array1::from_vec(vec![1.0])];
let input = MultiBlockAloInput {
n_obs: n,
n_blocks: 1,
block_designs: &block_designs,
penalized_hessian_inv: &h_inv,
block_weights,
scores,
eta_hat,
};
let result = compute_multiblock_alo(&input).unwrap();
assert!(result.eta_tilde[0][0].is_finite());
assert!(result.cook_distance[0].is_finite());
assert!(result.alo_variance[0][0].is_finite());
}
}