1use crate::estimate::EstimationError;
2use crate::estimate::{FitGeometry, UnifiedFitResult};
3use crate::pirls;
4use gam_linalg::faer_ndarray::{FaerArrayView, FaerCholesky};
5use gam_linalg::matrix::{PsdWeightsView, SignedWeightsView};
6use gam_linalg::utils::StableSolver;
7use gam_problem::LinkFunction;
8use faer::Mat as FaerMat;
9use faer::linalg::matmul::matmul;
10use faer::prelude::ReborrowMut;
11use faer::{Accum, Par};
12use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder, s};
13use std::fmt;
14
15#[derive(Debug, Clone)]
24pub enum AloError {
25 InvalidInput { reason: String },
29 WeightInvalid { reason: String },
32 DesignDegenerate { reason: String },
35 InfluenceMatrixFailed { condition_number: f64 },
38 LooComputationFailed { reason: String },
41}
42
43impl fmt::Display for AloError {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 AloError::InvalidInput { reason }
47 | AloError::WeightInvalid { reason }
48 | AloError::DesignDegenerate { reason }
49 | AloError::LooComputationFailed { reason } => f.write_str(reason),
50 AloError::InfluenceMatrixFailed { condition_number } => {
51 write!(
52 f,
53 "ALO influence matrix failed (condition number {condition_number:.3e})"
54 )
55 }
56 }
57 }
58}
59
60impl std::error::Error for AloError {}
61
62impl From<AloError> for EstimationError {
63 fn from(err: AloError) -> EstimationError {
64 match err {
65 AloError::InvalidInput { reason }
66 | AloError::WeightInvalid { reason }
67 | AloError::DesignDegenerate { reason }
68 | AloError::LooComputationFailed { reason } => EstimationError::InvalidInput(reason),
69 AloError::InfluenceMatrixFailed { condition_number } => {
70 EstimationError::ModelIsIllConditioned { condition_number }
71 }
72 }
73 }
74}
75
76impl From<AloError> for String {
77 fn from(err: AloError) -> String {
78 err.to_string()
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct AloDiagnostics {
85 pub eta_tilde: Array1<f64>,
86 pub se_bayes: Array1<f64>,
89 pub se_sandwich: Array1<f64>,
92 pub pred_identity: Array1<f64>,
93 pub leverage: Array1<f64>,
94 pub fisherweights: Array1<f64>,
95}
96
97#[inline]
98fn alo_eta_updatewith_offset(
99 eta_hat: f64,
100 z: f64,
101 offset: f64,
102 x_hinv_x: f64,
103 score_weight: f64,
104 denom: f64,
105) -> f64 {
106 let eta_centered = eta_hat - offset;
109 let z_centered = z - offset;
110 let score = score_weight * (eta_centered - z_centered);
111 offset + eta_centered + x_hinv_x * score / denom
112}
113
114pub type AloScalarScoreCurvature<'a> = dyn Fn(usize, f64) -> (f64, f64) + Sync + 'a;
124
125const ALO_EXACT_SCALAR_MAX_ITERS: usize = 64;
131
132const ALO_EXACT_SCALAR_TOL: f64 = 1e-12;
136
137#[derive(Debug, Clone, Copy, PartialEq)]
158enum AloExactScalarError {
159 NonFiniteScoreCurvature {
160 eta: f64,
161 ell_prime: f64,
162 ell_double: f64,
163 },
164 DegenerateJacobian {
165 eta: f64,
166 jacobian: f64,
167 },
168 NonFiniteStep {
169 eta: f64,
170 residual: f64,
171 jacobian: f64,
172 next: f64,
173 },
174 MaxIterations {
175 iterations: usize,
176 residual: f64,
177 eta: f64,
178 },
179}
180
181impl fmt::Display for AloExactScalarError {
182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 match *self {
184 AloExactScalarError::NonFiniteScoreCurvature {
185 eta,
186 ell_prime,
187 ell_double,
188 } => write!(
189 f,
190 "non-finite score/curvature at eta={eta:.6e}: ell_prime={ell_prime:.6e}, ell_double={ell_double:.6e}"
191 ),
192 AloExactScalarError::DegenerateJacobian { eta, jacobian } => write!(
193 f,
194 "degenerate Newton Jacobian at eta={eta:.6e}: jacobian={jacobian:.6e}, min={ALO_DENOMINATOR_MIN:.1e}"
195 ),
196 AloExactScalarError::NonFiniteStep {
197 eta,
198 residual,
199 jacobian,
200 next,
201 } => write!(
202 f,
203 "non-finite Newton step from eta={eta:.6e}: residual={residual:.6e}, jacobian={jacobian:.6e}, next={next:.6e}"
204 ),
205 AloExactScalarError::MaxIterations {
206 iterations,
207 residual,
208 eta,
209 } => write!(
210 f,
211 "did not converge within {iterations} iterations: residual={residual:.6e}, eta={eta:.6e}, tol={ALO_EXACT_SCALAR_TOL:.1e}"
212 ),
213 }
214 }
215}
216
217const ALO_EXACT_SCALAR_BACKTRACKS: usize = 40;
223
224#[inline]
225fn alo_eta_exact_frozen_curvature(
226 eta_hat: f64,
227 a_ii: f64,
228 score_curvature: &dyn Fn(f64) -> (f64, f64),
229) -> Result<f64, AloExactScalarError> {
230 let residual_and_jac = |eta: f64| -> Result<(f64, f64), AloExactScalarError> {
254 let (ell_prime, ell_double) = score_curvature(eta);
255 if !ell_prime.is_finite() || !ell_double.is_finite() {
256 return Err(AloExactScalarError::NonFiniteScoreCurvature {
257 eta,
258 ell_prime,
259 ell_double,
260 });
261 }
262 Ok((eta - eta_hat - a_ii * ell_prime, 1.0 - a_ii * ell_double))
263 };
264
265 let mut eta = eta_hat;
266 let (mut residual, mut jac) = residual_and_jac(eta)?;
267 for _ in 0..ALO_EXACT_SCALAR_MAX_ITERS {
268 if residual.abs() <= ALO_EXACT_SCALAR_TOL {
269 return Ok(eta);
270 }
271 if jac.abs() <= ALO_DENOMINATOR_MIN || !jac.is_finite() {
272 return Err(AloExactScalarError::DegenerateJacobian { eta, jacobian: jac });
273 }
274 let step = residual / jac;
275 if !step.is_finite() {
276 return Err(AloExactScalarError::NonFiniteStep {
277 eta,
278 residual,
279 jacobian: jac,
280 next: eta - step,
281 });
282 }
283 let mut t = 1.0;
288 let mut advanced = false;
289 for _ in 0..ALO_EXACT_SCALAR_BACKTRACKS {
290 let trial = eta - t * step;
291 if let Ok((r_trial, j_trial)) = residual_and_jac(trial) {
292 if r_trial.abs() < residual.abs() {
293 eta = trial;
294 residual = r_trial;
295 jac = j_trial;
296 advanced = true;
297 break;
298 }
299 }
300 t *= 0.5;
301 }
302 if !advanced {
303 break;
304 }
305 }
306 Err(AloExactScalarError::MaxIterations {
307 iterations: ALO_EXACT_SCALAR_MAX_ITERS,
308 residual,
309 eta,
310 })
311}
312
313#[inline]
314fn bayesvar_eta(phi: f64, x_hinv_x: f64) -> f64 {
315 phi * x_hinv_x
316}
317
318#[inline]
319fn sandwichvar_eta(phi: f64, x_hinv_x: f64, es_norm2: f64, ridge: f64, s_norm2: f64) -> f64 {
320 phi * (x_hinv_x - es_norm2 - ridge * s_norm2)
324}
325
326#[inline]
327fn variance_negative_tolerance(scale: f64) -> f64 {
328 1e-12 * scale.abs().max(1.0)
330}
331
332const LEVERAGE_HIGH_THRESHOLD: f64 = 0.99;
333const LEVERAGE_VERY_HIGH_THRESHOLD: f64 = 0.999;
334const LEVERAGE_RATE_THRESHOLDS: [f64; 3] = [0.90, 0.95, 0.99];
335const LEVERAGE_PERCENTILES: [f64; 3] = [0.50, 0.95, 0.99];
336const ALO_DENOMINATOR_MIN: f64 = 1e-12;
337const MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES: usize = 256 * 1024 * 1024;
338
339const ALO_RHS_BLOCK_COLS: usize = 8192;
344
345const HESSIAN_SYMMETRY_REL_TOL: f64 = 1e-8;
351
352const ALO_LOCAL_BLOCK_RIDGE: f64 = 1e-6;
358
359const LU_PIVOT_SINGULAR_TOL: f64 = 1e-12;
364
365#[inline]
366fn percentile_index(sample_size: usize, quantile: f64) -> usize {
367 if sample_size <= 1 {
368 return 0;
369 }
370 let max_index = sample_size - 1;
371 ((quantile * max_index as f64).round() as usize).min(max_index)
372}
373
374#[inline]
375fn percentile_from_sorted(sorted: &[f64], quantile: f64) -> f64 {
376 if sorted.is_empty() {
377 0.0
378 } else {
379 sorted[percentile_index(sorted.len(), quantile)]
380 }
381}
382
383#[inline]
384fn multiblock_col_offsets(block_designs: &[Array2<f64>]) -> Vec<usize> {
385 let mut offsets = Vec::with_capacity(block_designs.len());
386 let mut off = 0usize;
387 for design in block_designs {
388 offsets.push(off);
389 off += design.ncols();
390 }
391 offsets
392}
393
394#[inline]
395fn multiblock_alo_parallel_leverage_chunk_size(
396 p_tot: usize,
397 n_blocks: usize,
398 n_obs: usize,
399 max_workers: usize,
400) -> usize {
401 if p_tot == 0 || n_blocks == 0 || n_obs == 0 {
402 return 1;
403 }
404
405 let workers = max_workers.max(1);
411 let per_worker_budget = (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / workers).max(1);
412 let elem_count_per_obs = p_tot.saturating_mul(n_blocks.saturating_add(1)).max(1);
413 let bytes_per_obs = elem_count_per_obs
414 .saturating_mul(std::mem::size_of::<f64>())
415 .max(1);
416 let budget_obs = (per_worker_budget / bytes_per_obs).max(1);
417 budget_obs.min(n_obs)
418}
419
420fn compute_alo_diagnostics_from_pirls_impl(
421 base: &pirls::PirlsResult,
422 y: ArrayView1<f64>,
423 link: LinkFunction,
424) -> Result<AloDiagnostics, EstimationError> {
425 compute_alo_diagnostics_from_pirls_inner(base, y, link).map_err(EstimationError::from)
426}
427
428fn alo_link_needs_exact_curvature_refinement(likelihood: &gam_problem::GlmLikelihoodSpec) -> bool {
441 use gam_problem::ResponseFamily;
442 matches!(
443 (&likelihood.spec.response, likelihood.link_function()),
444 (ResponseFamily::Binomial, LinkFunction::Logit)
445 | (ResponseFamily::Poisson, LinkFunction::Log)
446 )
447}
448
449fn compute_alo_diagnostics_from_pirls_inner(
450 base: &pirls::PirlsResult,
451 y: ArrayView1<f64>,
452 link: LinkFunction,
453) -> Result<AloDiagnostics, AloError> {
454 let x_dense_arc = base
455 .x_transformed
456 .try_to_dense_arc("ALO diagnostics require dense transformed design")
457 .map_err(|reason| AloError::DesignDegenerate { reason })?;
458 let x_dense = x_dense_arc.as_ref();
459 let n = x_dense.nrows();
460
461 let phi = match link {
463 LinkFunction::Log => 1.0,
464 LinkFunction::Logit
465 | LinkFunction::Probit
466 | LinkFunction::CLogLog
467 | LinkFunction::Sas
468 | LinkFunction::BetaLogistic => 1.0,
469 LinkFunction::Identity => {
470 use rayon::iter::{IntoParallelIterator, ParallelIterator};
471 let rss: f64 = (0..n)
472 .into_par_iter()
473 .map(|i| {
474 let r = y[i] - base.finalmu[i];
475 base.finalweights[i] * r * r
476 })
477 .sum();
478 let n_pos = (0..n).filter(|&i| base.finalweights[i] > 0.0).count();
485 let dof = (n_pos as f64) - base.edf;
486 let denom = dof.max(1.0);
487 rss / denom
488 }
489 };
490
491 let e = &base.reparam_result.e_transformed;
492 let ridge = base.ridge_passport.laplacehessianridge().max(0.0);
493
494 let h_dense_for_alo = base
498 .dense_stabilizedhessian_transformed(
499 "ALO diagnostics require exact dense stabilized penalized Hessian",
500 )
501 .map_err(|e| match e {
502 EstimationError::InvalidInput(reason) => AloError::InvalidInput { reason },
503 other => AloError::InvalidInput {
504 reason: format!("{other:?}"),
505 },
506 })?;
507
508 let canonical_scale: Option<Array1<f64>> =
527 if alo_link_needs_exact_curvature_refinement(&base.likelihood) {
528 let mut c = Array1::<f64>::zeros(n);
529 for i in 0..n {
530 let dmu = base.solve_dmu_deta[i];
531 let w_h = base.finalweights[i];
532 c[i] = if dmu.abs() <= ALO_DENOMINATOR_MIN || !dmu.is_finite() || !w_h.is_finite() {
533 f64::NAN
534 } else {
535 w_h / dmu
536 };
537 }
538 Some(c)
539 } else {
540 None
541 };
542
543 let inv_link_for_closure = base.likelihood.spec.link.clone();
544 let score_curvature_closure = canonical_scale.as_ref().map(|scale| {
545 move |i: usize, eta: f64| -> (f64, f64) {
546 let (mu, dmu) = crate::mixture_link::inverse_link_mu_d1_for_inverse_link(
547 &inv_link_for_closure,
548 eta,
549 )
550 .unwrap_or((f64::NAN, f64::NAN));
551 let c_i = scale[i];
552 (c_i * (mu - y[i]), c_i * dmu)
553 }
554 });
555 let score_curvature_ref: Option<&AloScalarScoreCurvature> = score_curvature_closure
556 .as_ref()
557 .map(|f| f as &AloScalarScoreCurvature);
558
559 let input = AloInput {
561 design: x_dense,
562 penalized_hessian: &h_dense_for_alo,
563 hessian_weights: base.final_weights_signed(),
564 score_weights: base.solve_weights_psd(),
565 working_response: &base.solveworking_response,
566 eta: &base.final_eta,
567 offset: &base.final_offset,
568 link,
569 phi,
570 penalty_root: if e.nrows() > 0 { Some(e) } else { None },
571 ridge,
572 score_curvature: score_curvature_ref,
573 };
574
575 let result = compute_alo_from_input_inner(&input)?;
576
577 log_leverage_diagnostics(&result.leverage, phi);
579
580 let has_nan_pred = result.eta_tilde.iter().any(|&x| x.is_nan());
582 let has_nan_se_bayes = result.se_bayes.iter().any(|&x| x.is_nan());
583 let has_nan_se_sandwich = result.se_sandwich.iter().any(|&x| x.is_nan());
584 let has_nan_leverage = result.leverage.iter().any(|&x| x.is_nan());
585
586 if has_nan_pred || has_nan_se_bayes || has_nan_se_sandwich || has_nan_leverage {
587 log::error!("[GAM ALO] NaN values found in ALO diagnostics:");
588 log::error!(
589 "[GAM ALO] eta_tilde: {} NaN values",
590 result.eta_tilde.iter().filter(|&&x| x.is_nan()).count()
591 );
592 log::error!(
593 "[GAM ALO] se_bayes: {} NaN values",
594 result.se_bayes.iter().filter(|&&x| x.is_nan()).count()
595 );
596 log::error!(
597 "[GAM ALO] se_sandwich: {} NaN values",
598 result.se_sandwich.iter().filter(|&&x| x.is_nan()).count()
599 );
600 log::error!(
601 "[GAM ALO] leverage: {} NaN values",
602 result.leverage.iter().filter(|&&x| x.is_nan()).count()
603 );
604 return Err(AloError::InfluenceMatrixFailed {
605 condition_number: f64::INFINITY,
606 });
607 }
608
609 Ok(result)
610}
611
612fn log_leverage_diagnostics(leverage: &Array1<f64>, phi: f64) {
614 let n = leverage.len();
615 if n == 0 {
616 return;
617 }
618
619 let mut invalid_count = 0usize;
620 let mut high_leverage_count = 0usize;
621 let mut threshold_counts = [0usize; LEVERAGE_RATE_THRESHOLDS.len()];
622 let mut finite_leverage = Vec::with_capacity(n);
623
624 for (obs, &ai) in leverage.iter().enumerate() {
625 if ai.is_finite() {
626 finite_leverage.push(ai);
627 }
628
629 if !(0.0..=1.0).contains(&ai) || !ai.is_finite() {
630 invalid_count += 1;
631 log::warn!("[GAM ALO] invalid leverage at i={}, a_ii={:.6e}", obs, ai);
632 } else if ai > LEVERAGE_HIGH_THRESHOLD {
633 high_leverage_count += 1;
634 if ai > LEVERAGE_VERY_HIGH_THRESHOLD {
635 log::warn!("[GAM ALO] very high leverage at i={}, a_ii={:.6e}", obs, ai);
636 }
637 }
638
639 for (idx, threshold) in LEVERAGE_RATE_THRESHOLDS.iter().enumerate() {
640 if ai > *threshold {
641 threshold_counts[idx] += 1;
642 }
643 }
644 }
645
646 if invalid_count > 0 || high_leverage_count > 0 {
647 log::warn!(
648 "[GAM ALO] leverage diagnostics: {} invalid values, {} high values (>0.99)",
649 invalid_count,
650 high_leverage_count
651 );
652 }
653
654 finite_leverage.sort_by(f64::total_cmp);
655
656 let finite_n = finite_leverage.len();
657 let a_mean = if finite_n > 0 {
658 finite_leverage.iter().copied().sum::<f64>() / finite_n as f64
659 } else {
660 0.0
661 };
662 let a_median = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[0]);
663 let a_p95 = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[1]);
664 let a_p99 = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[2]);
665 let a_max = finite_leverage.last().copied().unwrap_or(0.0);
666
667 log::info!(
676 "[GAM ALO] leverage: n={}, mean={:.3e}, median={:.3e}, p95={:.3e}, p99={:.3e}, max={:.3e}",
677 n,
678 a_mean,
679 a_median,
680 a_p95,
681 a_p99,
682 a_max
683 );
684 log::info!(
685 "[GAM ALO] high-leverage: a>0.90: {:.2}%, a>0.95: {:.2}%, a>0.99: {:.2}%, dispersion phi={:.3e}",
686 100.0 * (threshold_counts[0] as f64) / n as f64,
687 100.0 * (threshold_counts[1] as f64) / n as f64,
688 100.0 * (threshold_counts[2] as f64) / n as f64,
689 phi
690 );
691}
692
693pub struct AloInput<'a> {
700 pub design: &'a Array2<f64>,
702 pub penalized_hessian: &'a Array2<f64>,
704 pub hessian_weights: SignedWeightsView<'a>,
711 pub score_weights: PsdWeightsView<'a>,
714 pub working_response: &'a Array1<f64>,
716 pub eta: &'a Array1<f64>,
718 pub offset: &'a Array1<f64>,
720 pub link: LinkFunction,
722 pub phi: f64,
724 pub penalty_root: Option<&'a Array2<f64>>,
727 pub ridge: f64,
729 pub score_curvature: Option<&'a AloScalarScoreCurvature<'a>>,
742}
743
744impl<'a> AloInput<'a> {
745 pub fn from_geometry(
747 geom: &'a FitGeometry,
748 design: &'a Array2<f64>,
749 eta: &'a Array1<f64>,
750 offset: &'a Array1<f64>,
751 link: LinkFunction,
752 phi: f64,
753 ) -> Self {
754 let psd_w = PsdWeightsView::from_view_unchecked(geom.working_weights.view());
761 Self {
762 design,
763 penalized_hessian: &geom.penalized_hessian,
764 hessian_weights: psd_w.as_signed(),
765 score_weights: psd_w,
766 working_response: &geom.working_response,
767 eta,
768 offset,
769 link,
770 phi,
771 penalty_root: None,
772 ridge: 0.0,
773 score_curvature: None,
774 }
775 }
776}
777
778pub fn compute_alo_from_input(input: &AloInput) -> Result<AloDiagnostics, EstimationError> {
784 compute_alo_from_input_inner(input).map_err(EstimationError::from)
785}
786
787fn compute_alo_from_input_inner(input: &AloInput) -> Result<AloDiagnostics, AloError> {
788 let x_dense = input.design;
789 let n = x_dense.nrows();
790 let p = x_dense.ncols();
791 let w_h = input.hessian_weights.view();
795 let w_s = input.score_weights.view();
796
797 validate_alo_solve_setup(input, n, p)?;
798
799 let factor = StableSolver::new("alo penalized hessian")
800 .factorize(input.penalized_hessian)
801 .map_err(|_| AloError::InfluenceMatrixFailed {
802 condition_number: f64::INFINITY,
803 })?;
804
805 let xt = x_dense.t();
806 let phi = input.phi;
807 let ridge = input.ridge;
808
809 let e_rank = input.penalty_root.map(|e| e.nrows()).unwrap_or(0);
810
811 let mut aii = Array1::<f64>::zeros(n);
812 let mut x_hinv_x_diag = Array1::<f64>::zeros(n);
813 let mut se_bayes = Array1::<f64>::zeros(n);
814 let mut se_sandwich = Array1::<f64>::zeros(n);
815
816 let block_cols = ALO_RHS_BLOCK_COLS;
817 let mut rhs_chunk_buf = Array2::<f64>::zeros((p, block_cols).f());
822 let mut es_chunk_storage = if e_rank > 0 {
826 FaerMat::<f64>::zeros(e_rank, block_cols)
827 } else {
828 FaerMat::<f64>::zeros(0, 0)
829 };
830
831 for chunk_start in (0..n).step_by(block_cols) {
832 let chunk_end = (chunk_start + block_cols).min(n);
833 let width = chunk_end - chunk_start;
834
835 rhs_chunk_buf
836 .slice_mut(s![.., ..width])
837 .assign(&xt.slice(s![.., chunk_start..chunk_end]));
838
839 let rhs_chunkview = rhs_chunk_buf.slice(s![.., ..width]);
840 let rhs_chunk = FaerArrayView::new(&rhs_chunkview);
841 let s_chunk = factor.solve(rhs_chunk.as_ref());
845
846 if e_rank > 0
847 && let Some(e) = input.penalty_root
848 {
849 let eview = FaerArrayView::new(e);
850 let mut es_target = es_chunk_storage.as_mut().subcols_mut(0, width);
853 matmul(
854 es_target.rb_mut(),
855 Accum::Replace,
856 eview.as_ref(),
857 s_chunk.as_ref(),
858 1.0,
859 Par::Seq,
860 );
861 }
862
863 let rhs_view = rhs_chunk_buf.slice(s![.., ..width]);
864
865 for local_col in 0..width {
866 let obs = chunk_start + local_col;
867 let rhs_col = rhs_view.column(local_col);
871 let rhs_slice = rhs_col.as_slice().expect("column-major col contiguous");
872 let s_slice = s_chunk.col_as_slice(local_col);
873
874 let mut x_hinv_x = 0.0f64;
875 let mut s_norm2 = 0.0f64;
876 for k in 0..p {
878 let sval = s_slice[k];
879 let xval = rhs_slice[k];
880 x_hinv_x = sval.mul_add(xval, x_hinv_x);
881 s_norm2 = sval.mul_add(sval, s_norm2);
882 }
883 let ai = w_h[obs].max(0.0) * x_hinv_x;
884 let mut es_norm2 = 0.0f64;
885 if e_rank > 0 {
886 let es_slice = es_chunk_storage.col_as_slice(local_col);
887 for r in 0..e_rank {
888 let v = es_slice[r];
889 es_norm2 = v.mul_add(v, es_norm2);
890 }
891 }
892 aii[obs] = ai;
893 x_hinv_x_diag[obs] = x_hinv_x;
894
895 let var_bayes = bayesvar_eta(phi, x_hinv_x);
896 let var_sandwich = if e_rank > 0 {
897 sandwichvar_eta(phi, x_hinv_x, es_norm2, ridge, s_norm2)
898 } else {
899 var_bayes
900 };
901
902 if !var_bayes.is_finite() || !var_sandwich.is_finite() {
903 return Err(AloError::LooComputationFailed {
904 reason: format!(
905 "ALO variance is not finite at row {obs}: bayes={var_bayes:.6e}, sandwich={var_sandwich:.6e}"
906 ),
907 });
908 }
909 let bayes_tol = variance_negative_tolerance(phi * x_hinv_x.abs());
910 if var_bayes < -bayes_tol {
911 return Err(AloError::LooComputationFailed {
912 reason: format!(
913 "ALO Bayesian variance is materially negative at row {obs}: var={var_bayes:.6e}, tol={bayes_tol:.6e}"
914 ),
915 });
916 }
917 if e_rank > 0 {
918 let sandwich_scale =
919 phi * (x_hinv_x.abs() + es_norm2.abs() + (ridge * s_norm2).abs());
920 let sandwich_tol = variance_negative_tolerance(sandwich_scale);
921 if var_sandwich < -sandwich_tol {
922 return Err(AloError::LooComputationFailed {
923 reason: format!(
924 "ALO sandwich variance is materially negative at row {obs}: var={var_sandwich:.6e}, tol={sandwich_tol:.6e}"
925 ),
926 });
927 }
928 }
929
930 se_bayes[obs] = var_bayes.max(0.0).sqrt();
931 se_sandwich[obs] = var_sandwich.max(0.0).sqrt();
932 }
933 }
934
935 let eta_hat = input.eta;
936 let z = input.working_response;
937 let offset = input.offset;
938
939 use rayon::prelude::*;
940 let eta_tilde_vec: Vec<f64> = (0..n)
941 .into_par_iter()
942 .map(|i| {
943 let denom_raw = 1.0 - aii[i];
944 if denom_raw <= ALO_DENOMINATOR_MIN || !denom_raw.is_finite() {
945 return Err(AloError::LooComputationFailed {
946 reason: format!(
947 "ALO denominator is too small at row {i}: a_ii={:.6e}, 1-a_ii={:.6e}, min={:.1e}",
948 aii[i], denom_raw, ALO_DENOMINATOR_MIN
949 ),
950 });
951 }
952 let one_step = alo_eta_updatewith_offset(
953 eta_hat[i],
954 z[i],
955 offset[i],
956 x_hinv_x_diag[i],
957 w_s[i],
958 denom_raw,
959 );
960 let v = if let Some(score_curvature) = input.score_curvature {
968 alo_eta_exact_frozen_curvature(
969 eta_hat[i],
970 x_hinv_x_diag[i],
971 &|eta| score_curvature(i, eta),
972 )
973 .map_err(|err| AloError::LooComputationFailed {
974 reason: format!(
975 "ALO exact frozen-curvature solve failed at row {i}: {err}"
976 ),
977 })?
978 } else {
979 one_step
980 };
981 if !v.is_finite() {
982 return Err(AloError::LooComputationFailed {
983 reason: format!("ALO eta_tilde is not finite at row {i}: eta_tilde={v}"),
984 });
985 }
986 Ok(v)
987 })
988 .collect::<Result<_, _>>()?;
989 let eta_tilde = Array1::from(eta_tilde_vec);
990
991 Ok(AloDiagnostics {
992 eta_tilde,
993 se_bayes,
994 se_sandwich,
995 pred_identity: eta_hat.clone(),
996 leverage: aii,
997 fisherweights: w_h.to_owned(),
998 })
999}
1000
1001fn validate_alo_solve_setup(input: &AloInput, n: usize, p: usize) -> Result<(), AloError> {
1002 let h = input.penalized_hessian;
1003 if h.nrows() != p || h.ncols() != p {
1004 return Err(AloError::InvalidInput {
1005 reason: format!(
1006 "ALO diagnostics require a dense exact penalized Hessian with shape {p}x{p}; got {}x{}",
1007 h.nrows(),
1008 h.ncols()
1009 ),
1010 });
1011 }
1012 if h.iter().any(|v| !v.is_finite()) {
1013 return Err(AloError::InvalidInput {
1014 reason: "ALO diagnostics require a finite dense exact penalized Hessian".to_string(),
1015 });
1016 }
1017 for i in 0..p {
1018 for j in 0..i {
1019 let a = h[[i, j]];
1020 let b = h[[j, i]];
1021 let scale = a.abs().max(b.abs()).max(1.0);
1022 if (a - b).abs() > HESSIAN_SYMMETRY_REL_TOL * scale {
1023 return Err(AloError::InvalidInput {
1024 reason: format!(
1025 "ALO diagnostics require a symmetric dense exact penalized Hessian; entries ({i},{j}) and ({j},{i}) differ by {:.3e}",
1026 (a - b).abs()
1027 ),
1028 });
1029 }
1030 }
1031 }
1032
1033 let vector_lengths = [
1034 ("hessian_weights", input.hessian_weights.len()),
1035 ("score_weights", input.score_weights.len()),
1036 ("working_response", input.working_response.len()),
1037 ("eta", input.eta.len()),
1038 ("offset", input.offset.len()),
1039 ];
1040 for (name, len) in vector_lengths {
1041 if len != n {
1042 return Err(AloError::InvalidInput {
1043 reason: format!("ALO diagnostics require {name} length {n}; got {len}"),
1044 });
1045 }
1046 }
1047 if input.hessian_weights.view().iter().any(|v| !v.is_finite()) {
1048 return Err(AloError::WeightInvalid {
1049 reason: "ALO diagnostics require finite Hessian-side weights".to_string(),
1050 });
1051 }
1052 if input.score_weights.view().iter().any(|v| !v.is_finite()) {
1053 return Err(AloError::WeightInvalid {
1054 reason: "ALO diagnostics require finite score-side weights".to_string(),
1055 });
1056 }
1057 if input.working_response.iter().any(|v| !v.is_finite()) {
1058 return Err(AloError::WeightInvalid {
1059 reason: "ALO diagnostics require finite working responses".to_string(),
1060 });
1061 }
1062 if input.eta.iter().any(|v| !v.is_finite()) || input.offset.iter().any(|v| !v.is_finite()) {
1063 return Err(AloError::InvalidInput {
1064 reason: "ALO diagnostics require finite linear predictors and offsets".to_string(),
1065 });
1066 }
1067 if !input.phi.is_finite() || input.phi <= 0.0 {
1068 return Err(AloError::InvalidInput {
1069 reason: format!(
1070 "ALO diagnostics require positive finite dispersion phi; got {}",
1071 input.phi
1072 ),
1073 });
1074 }
1075 if !input.ridge.is_finite() || input.ridge < 0.0 {
1076 return Err(AloError::InvalidInput {
1077 reason: format!(
1078 "ALO diagnostics require a finite non-negative Hessian ridge; got {}",
1079 input.ridge
1080 ),
1081 });
1082 }
1083 if let Some(e) = input.penalty_root {
1084 if e.ncols() != p {
1085 return Err(AloError::InvalidInput {
1086 reason: format!(
1087 "ALO diagnostics require penalty root to have {p} columns; got {}",
1088 e.ncols()
1089 ),
1090 });
1091 }
1092 if e.iter().any(|v| !v.is_finite()) {
1093 return Err(AloError::InvalidInput {
1094 reason: "ALO diagnostics require finite penalty-root entries".to_string(),
1095 });
1096 }
1097 }
1098 Ok(())
1099}
1100
1101pub fn compute_alo_diagnostics_from_fit(
1103 fit: &UnifiedFitResult,
1104 y: ArrayView1<f64>,
1105 link: LinkFunction,
1106) -> Result<AloDiagnostics, EstimationError> {
1107 let pirls = fit
1108 .artifacts
1109 .pirls
1110 .as_ref()
1111 .ok_or_else(|| AloError::InvalidInput {
1112 reason:
1113 "ALO diagnostics require a PIRLS-backed fit; this fit does not expose PIRLS geometry"
1114 .to_string(),
1115 })
1116 .map_err(EstimationError::from)?;
1117 compute_alo_diagnostics_from_pirls_impl(pirls, y, link)
1118}
1119
1120pub fn compute_alo_diagnostics_from_unified(
1126 unified: &UnifiedFitResult,
1127 design: &Array2<f64>,
1128 eta: &Array1<f64>,
1129 offset: &Array1<f64>,
1130 link: LinkFunction,
1131 phi: f64,
1132) -> Result<AloDiagnostics, EstimationError> {
1133 let geom = unified
1134 .geometry
1135 .as_ref()
1136 .ok_or_else(|| AloError::InvalidInput {
1137 reason: "UnifiedFitResult does not contain working-set geometry; \
1138 ALO diagnostics require geometry at convergence"
1139 .to_string(),
1140 })
1141 .map_err(EstimationError::from)?;
1142 let input = AloInput::from_geometry(geom, design, eta, offset, link, phi);
1143 compute_alo_from_input(&input)
1144}
1145
1146pub fn compute_alo_diagnostics_from_pirls(
1148 base: &pirls::PirlsResult,
1149 y: ArrayView1<f64>,
1150 link: LinkFunction,
1151) -> Result<AloDiagnostics, EstimationError> {
1152 compute_alo_diagnostics_from_pirls_impl(base, y, link)
1153}
1154
1155pub fn compute_case_deletion_from_pirls(
1174 base: &pirls::PirlsResult,
1175 y: ArrayView1<f64>,
1176 link: LinkFunction,
1177) -> Result<Option<crate::sensitivity::CaseDeletionInfluence>, EstimationError> {
1178 let x_dense_arc = base
1179 .x_transformed
1180 .try_to_dense_arc("case-deletion diagnostics require dense transformed design")
1181 .map_err(|reason| EstimationError::InvalidInput(reason))?;
1182 let x_dense = x_dense_arc.as_ref();
1183 let n = x_dense.nrows();
1184 let p = x_dense.ncols();
1185 if n == 0 || p == 0 {
1186 return Ok(None);
1187 }
1188
1189 let phi = match link {
1192 LinkFunction::Identity => {
1193 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1194 let rss: f64 = (0..n)
1195 .into_par_iter()
1196 .map(|i| {
1197 let r = y[i] - base.finalmu[i];
1198 base.finalweights[i] * r * r
1199 })
1200 .sum();
1201 let dof = (n as f64) - base.edf;
1202 rss / dof.max(1.0)
1203 }
1204 _ => 1.0,
1205 };
1206 if !(phi.is_finite() && phi > 0.0) {
1207 return Ok(None);
1208 }
1209
1210 let h_dense = base
1213 .dense_stabilizedhessian_transformed(
1214 "case-deletion diagnostics require exact dense stabilized penalized Hessian",
1215 )
1216 .map_err(|e| match e {
1217 EstimationError::InvalidInput(reason) => EstimationError::InvalidInput(reason),
1218 other => EstimationError::InvalidInput(format!("{other:?}")),
1219 })?;
1220
1221 let factor = match h_dense.cholesky(faer::Side::Lower) {
1222 Ok(f) => f,
1223 Err(_) => return Ok(None),
1227 };
1228
1229 let working_weights = base.finalweights.clone();
1233 let working_residual = &base.solveworking_response - &base.final_eta;
1234
1235 let sensitivity = crate::sensitivity::FitSensitivity::from_faer_cholesky(&factor, p);
1236 Ok(sensitivity.case_deletion(
1237 x_dense,
1238 working_weights.view(),
1239 working_residual.view(),
1240 phi,
1241 ))
1242}
1243
1244#[derive(Debug, Clone)]
1248pub struct MultiBlockAloDiagnostics {
1249 pub eta_tilde: Vec<Array1<f64>>,
1252 pub leverage: Array1<f64>,
1254 pub alo_variance: Vec<Array1<f64>>,
1259 pub cook_distance: Array1<f64>,
1262}
1263
1264pub struct MultiBlockAloInput<'a> {
1294 pub n_obs: usize,
1296 pub n_blocks: usize,
1298 pub block_designs: &'a [Array2<f64>],
1301 pub penalized_hessian_inv: &'a Array2<f64>,
1303 pub block_weights: Vec<Array2<f64>>,
1305 pub scores: Vec<Array1<f64>>,
1308 pub eta_hat: Vec<Array1<f64>>,
1311}
1312
1313pub fn compute_multiblock_alo(
1332 input: &MultiBlockAloInput,
1333) -> Result<MultiBlockAloDiagnostics, EstimationError> {
1334 compute_multiblock_alo_inner(input).map_err(EstimationError::from)
1335}
1336
1337fn compute_multiblock_alo_inner(
1338 input: &MultiBlockAloInput,
1339) -> Result<MultiBlockAloDiagnostics, AloError> {
1340 use rayon::prelude::*;
1341
1342 let n = input.n_obs;
1343 let b = input.n_blocks;
1344 let p_tot = input.penalized_hessian_inv.nrows();
1345
1346 if input.block_designs.len() != b {
1348 return Err(AloError::InvalidInput {
1349 reason: format!(
1350 "MultiBlockAloInput: expected {} block designs, got {}",
1351 b,
1352 input.block_designs.len()
1353 ),
1354 });
1355 }
1356
1357 let col_sum: usize = input.block_designs.iter().map(|d| d.ncols()).sum();
1359 if col_sum != p_tot {
1360 return Err(AloError::InvalidInput {
1361 reason: format!(
1362 "MultiBlockAloInput: total design columns ({}) != penalized_hessian_inv size ({})",
1363 col_sum, p_tot
1364 ),
1365 });
1366 }
1367
1368 let col_offsets = multiblock_col_offsets(input.block_designs);
1369 let (chunk_size, max_concurrent_chunks) = multiblock_alo_parallel_plan(p_tot, b, n);
1370 let chunk_starts: Vec<usize> = (0..n).step_by(chunk_size).collect();
1371
1372 let mut chunk_results: Vec<Result<MultiBlockAloChunkDiagnostics, AloError>> =
1378 Vec::with_capacity(chunk_starts.len());
1379 for chunk_wave in chunk_starts.chunks(max_concurrent_chunks) {
1380 let mut wave_results: Vec<Result<MultiBlockAloChunkDiagnostics, AloError>> = chunk_wave
1381 .par_iter()
1382 .map_init(
1383 || MultiBlockAloScratch::new(b),
1384 |scratch, &chunk_start| {
1385 let chunk_end = (chunk_start + chunk_size).min(n);
1386 compute_multiblock_alo_chunk(
1387 input,
1388 &col_offsets,
1389 chunk_start,
1390 chunk_end,
1391 scratch,
1392 )
1393 },
1394 )
1395 .collect();
1396 chunk_results.append(&mut wave_results);
1397 }
1398
1399 let mut eta_tilde = Vec::with_capacity(n);
1400 let mut leverage = Array1::<f64>::zeros(n);
1401 let mut alo_variance = Vec::with_capacity(n);
1402 let mut cook_distance = Array1::<f64>::zeros(n);
1403
1404 let mut chunks = Vec::with_capacity(chunk_results.len());
1405 for result in chunk_results {
1406 chunks.push(result?);
1407 }
1408 chunks.sort_unstable_by_key(|chunk| chunk.chunk_start);
1409
1410 for chunk in chunks {
1411 let chunk_start = chunk.chunk_start;
1412 eta_tilde.extend(chunk.eta_tilde);
1413 alo_variance.extend(chunk.alo_variance);
1414 for (local_i, lev) in chunk.leverage.into_iter().enumerate() {
1415 leverage[chunk_start + local_i] = lev;
1416 }
1417 for (local_i, cook) in chunk.cook_distance.into_iter().enumerate() {
1418 cook_distance[chunk_start + local_i] = cook;
1419 }
1420 }
1421
1422 Ok(MultiBlockAloDiagnostics {
1423 eta_tilde,
1424 leverage,
1425 alo_variance,
1426 cook_distance,
1427 })
1428}
1429
1430#[inline]
1431fn multiblock_alo_parallel_plan(p_tot: usize, n_blocks: usize, n_obs: usize) -> (usize, usize) {
1432 if p_tot == 0 || n_blocks == 0 || n_obs == 0 {
1433 return (1, 1);
1434 }
1435 let bytes_per_obs = (p_tot * n_blocks * std::mem::size_of::<f64>()).max(1);
1436 let workers = rayon::current_num_threads().max(1);
1437 let max_concurrent_chunks = (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / bytes_per_obs)
1438 .max(1)
1439 .min(workers);
1440 let per_worker_budget =
1441 (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / max_concurrent_chunks).max(bytes_per_obs);
1442 let budget_obs = (per_worker_budget / bytes_per_obs).max(1);
1443 (budget_obs.min(n_obs), max_concurrent_chunks)
1444}
1445
1446struct MultiBlockAloScratch {
1447 a_i: Vec<f64>,
1448 wa: Vec<f64>,
1449 aw: Vec<f64>,
1450 imwa: Vec<f64>,
1451 imaw: Vec<f64>,
1452 perm_imwa: Vec<usize>,
1453 perm_imaw: Vec<usize>,
1454 delta_eta: Vec<f64>,
1455 rhs_buf: Vec<f64>,
1456 w_u: Vec<f64>,
1457 var_diag_buf: Vec<f64>,
1458 w_flat: Vec<f64>,
1459 lu_scratch: Vec<f64>,
1460}
1461
1462impl MultiBlockAloScratch {
1463 fn new(b: usize) -> Self {
1464 let bb_sz = b * b;
1465 Self {
1466 a_i: vec![0.0f64; bb_sz],
1467 wa: vec![0.0f64; bb_sz],
1468 aw: vec![0.0f64; bb_sz],
1469 imwa: vec![0.0f64; bb_sz],
1470 imaw: vec![0.0f64; bb_sz],
1471 perm_imwa: vec![0usize; b],
1472 perm_imaw: vec![0usize; b],
1473 delta_eta: vec![0.0f64; b],
1474 rhs_buf: vec![0.0f64; b],
1475 w_u: vec![0.0f64; b],
1476 var_diag_buf: vec![0.0f64; b],
1477 w_flat: vec![0.0f64; bb_sz],
1478 lu_scratch: vec![0.0f64; b],
1479 }
1480 }
1481}
1482
1483struct MultiBlockAloChunkDiagnostics {
1484 chunk_start: usize,
1485 eta_tilde: Vec<Array1<f64>>,
1486 leverage: Vec<f64>,
1487 alo_variance: Vec<Array1<f64>>,
1488 cook_distance: Vec<f64>,
1489}
1490
1491fn compute_multiblock_alo_chunk(
1492 input: &MultiBlockAloInput,
1493 col_offsets: &[usize],
1494 chunk_start: usize,
1495 chunk_end: usize,
1496 scratch: &mut MultiBlockAloScratch,
1497) -> Result<MultiBlockAloChunkDiagnostics, AloError> {
1498 let b = input.n_blocks;
1499 let chunk_len = chunk_end - chunk_start;
1500
1501 let mut q_blocks = Vec::with_capacity(b);
1502 for blk in 0..b {
1503 let x_chunk_t = input.block_designs[blk]
1504 .slice(s![chunk_start..chunk_end, ..])
1505 .t()
1506 .to_owned();
1507 let off_b = col_offsets[blk];
1508 let h_slice = input
1509 .penalized_hessian_inv
1510 .slice(s![.., off_b..off_b + x_chunk_t.nrows()])
1511 .to_owned();
1512 q_blocks.push(h_slice.dot(&x_chunk_t));
1513 }
1514
1515 let mut eta_tilde = Vec::with_capacity(chunk_len);
1516 let mut leverage = vec![0.0f64; chunk_len];
1517 let mut alo_variance = Vec::with_capacity(chunk_len);
1518 let mut cook_distance = vec![0.0f64; chunk_len];
1519
1520 for local_i in 0..chunk_len {
1521 let i = chunk_start + local_i;
1522 let w_i = &input.block_weights[i];
1523
1524 for r in 0..b {
1526 for c in 0..b {
1527 scratch.w_flat[r * b + c] = w_i[(r, c)];
1528 }
1529 }
1530
1531 for a in 0..b {
1533 let x_a = &input.block_designs[a];
1534 let p_a = x_a.ncols();
1535 let off_a = col_offsets[a];
1536 let xa_row = x_a.row(i);
1537 for bb in 0..b {
1538 let q_bb = &q_blocks[bb];
1539 let mut dot = 0.0f64;
1540 for k in 0..p_a {
1541 dot += xa_row[k] * q_bb[(off_a + k, local_i)];
1542 }
1543 scratch.a_i[a * b + bb] = dot;
1544 }
1545 }
1546
1547 mat_mul_flat(&scratch.w_flat, &scratch.a_i, &mut scratch.wa, b);
1549 mat_mul_flat(&scratch.a_i, &scratch.w_flat, &mut scratch.aw, b);
1551
1552 let mut tr = 0.0f64;
1555 for d in 0..b {
1556 tr += scratch.aw[d * b + d];
1557 }
1558 leverage[local_i] = tr;
1559
1560 for r in 0..b {
1562 for c in 0..b {
1563 let idx = r * b + c;
1564 let id = if r == c { 1.0 } else { 0.0 };
1565 scratch.imwa[idx] = id - scratch.wa[idx];
1566 scratch.imaw[idx] = id - scratch.aw[idx];
1567 }
1568 }
1569
1570 if !lu_factor_in_place(&mut scratch.imwa, &mut scratch.perm_imwa, b) {
1576 for r in 0..b {
1577 for c in 0..b {
1578 let idx = r * b + c;
1579 let id = if r == c { 1.0 } else { 0.0 };
1580 scratch.imwa[idx] = id - scratch.wa[idx];
1581 }
1582 }
1583 for d in 0..b {
1584 scratch.imwa[d * b + d] += ALO_LOCAL_BLOCK_RIDGE;
1585 }
1586 let refactored = lu_factor_in_place(&mut scratch.imwa, &mut scratch.perm_imwa, b);
1587 assert!(
1588 refactored,
1589 "ALO local block remained singular after ridge regularization"
1590 );
1591 }
1592 if !lu_factor_in_place(&mut scratch.imaw, &mut scratch.perm_imaw, b) {
1593 for r in 0..b {
1594 for c in 0..b {
1595 let idx = r * b + c;
1596 let id = if r == c { 1.0 } else { 0.0 };
1597 scratch.imaw[idx] = id - scratch.aw[idx];
1598 }
1599 }
1600 for d in 0..b {
1601 scratch.imaw[d * b + d] += ALO_LOCAL_BLOCK_RIDGE;
1602 }
1603 let refactored = lu_factor_in_place(&mut scratch.imaw, &mut scratch.perm_imaw, b);
1604 assert!(
1605 refactored,
1606 "ALO local variance block remained singular after ridge regularization"
1607 );
1608 }
1609
1610 let s_i = &input.scores[i];
1612 for k in 0..b {
1613 scratch.rhs_buf[k] = s_i[k];
1614 }
1615 lu_solve_in_place(
1616 &scratch.imwa,
1617 &scratch.perm_imwa,
1618 &mut scratch.rhs_buf,
1619 &mut scratch.lu_scratch,
1620 b,
1621 );
1622 for r in 0..b {
1624 let mut acc = 0.0f64;
1625 let row_off = r * b;
1626 for k in 0..b {
1627 acc += scratch.a_i[row_off + k] * scratch.rhs_buf[k];
1628 }
1629 scratch.delta_eta[r] = acc;
1630 }
1631
1632 let eta_i = &input.eta_hat[i];
1633 let mut corrected = Array1::<f64>::zeros(b);
1634 for d in 0..b {
1635 corrected[d] = eta_i[d] + scratch.delta_eta[d];
1636 }
1637 eta_tilde.push(corrected);
1638
1639 let mut cook = 0.0f64;
1641 for r in 0..b {
1642 let mut w_delta_r = 0.0f64;
1643 let row_off = r * b;
1644 for k in 0..b {
1645 w_delta_r += scratch.w_flat[row_off + k] * scratch.delta_eta[k];
1646 }
1647 cook += scratch.delta_eta[r] * w_delta_r;
1648 }
1649 cook_distance[local_i] = cook;
1650
1651 for d in 0..b {
1657 let row_off = d * b;
1658 for k in 0..b {
1660 scratch.rhs_buf[k] = scratch.a_i[row_off + k];
1661 }
1662 lu_solve_in_place(
1663 &scratch.imaw,
1664 &scratch.perm_imaw,
1665 &mut scratch.rhs_buf,
1666 &mut scratch.lu_scratch,
1667 b,
1668 );
1669 for r in 0..b {
1671 let mut acc = 0.0f64;
1672 let wr = r * b;
1673 for k in 0..b {
1674 acc += scratch.w_flat[wr + k] * scratch.rhs_buf[k];
1675 }
1676 scratch.w_u[r] = acc;
1677 }
1678 lu_solve_in_place(
1680 &scratch.imwa,
1681 &scratch.perm_imwa,
1682 &mut scratch.w_u,
1683 &mut scratch.lu_scratch,
1684 b,
1685 );
1686 let mut v_dd = 0.0f64;
1688 for k in 0..b {
1689 v_dd += scratch.a_i[row_off + k] * scratch.w_u[k];
1690 }
1691 scratch.var_diag_buf[d] = v_dd.max(0.0);
1692 }
1693 let mut var_diag = Array1::<f64>::zeros(b);
1694 for d in 0..b {
1695 var_diag[d] = scratch.var_diag_buf[d];
1696 }
1697 alo_variance.push(var_diag);
1698 }
1699
1700 Ok(MultiBlockAloChunkDiagnostics {
1701 chunk_start,
1702 eta_tilde,
1703 leverage,
1704 alo_variance,
1705 cook_distance,
1706 })
1707}
1708
1709#[inline]
1711fn mat_mul_flat(a: &[f64], b_mat: &[f64], out: &mut [f64], b: usize) {
1712 for r in 0..b {
1713 let ar = r * b;
1714 let or = r * b;
1715 for c in 0..b {
1716 let mut acc = 0.0f64;
1717 for k in 0..b {
1718 acc += a[ar + k] * b_mat[k * b + c];
1719 }
1720 out[or + c] = acc;
1721 }
1722 }
1723}
1724
1725fn lu_factor_in_place(m: &mut [f64], perm: &mut [usize], b: usize) -> bool {
1732 for i in 0..b {
1733 perm[i] = i;
1734 }
1735 for col in 0..b {
1736 let mut max_val = m[col * b + col].abs();
1738 let mut max_idx = col;
1739 for row in (col + 1)..b {
1740 let v = m[row * b + col].abs();
1741 if v > max_val {
1742 max_val = v;
1743 max_idx = row;
1744 }
1745 }
1746 if max_val < LU_PIVOT_SINGULAR_TOL {
1747 return false;
1748 }
1749 if max_idx != col {
1750 for k in 0..b {
1752 m.swap(col * b + k, max_idx * b + k);
1753 }
1754 perm.swap(col, max_idx);
1755 }
1756 let pivot = m[col * b + col];
1757 for row in (col + 1)..b {
1758 let factor = m[row * b + col] / pivot;
1759 m[row * b + col] = factor; for k in (col + 1)..b {
1761 let upd = factor * m[col * b + k];
1762 m[row * b + k] -= upd;
1763 }
1764 }
1765 }
1766 true
1767}
1768
1769fn lu_solve_in_place(m: &[f64], perm: &[usize], rhs: &mut [f64], scratch: &mut [f64], b: usize) {
1772 let y = &mut scratch[..b];
1774 for row in 0..b {
1775 let mut s = rhs[perm[row]];
1776 for k in 0..row {
1777 s -= m[row * b + k] * y[k];
1778 }
1779 y[row] = s;
1780 }
1781 for row in (0..b).rev() {
1783 let mut s = y[row];
1784 for k in (row + 1)..b {
1785 s -= m[row * b + k] * rhs[k];
1786 }
1787 rhs[row] = s / m[row * b + row];
1788 }
1789}
1790
1791pub fn compute_multiblock_alo_leverages(
1799 n_obs: usize,
1800 n_blocks: usize,
1801 block_designs: &[Array2<f64>],
1802 penalized_hessian_inv: &Array2<f64>,
1803 block_weights: &[Array2<f64>],
1804) -> Result<Array1<f64>, EstimationError> {
1805 use rayon::prelude::*;
1806
1807 let n = n_obs;
1808 let b = n_blocks;
1809 let p_tot = penalized_hessian_inv.nrows();
1810
1811 let col_offsets = multiblock_col_offsets(block_designs);
1812 let max_workers = rayon::current_num_threads();
1813 let chunk_size = multiblock_alo_parallel_leverage_chunk_size(p_tot, b, n, max_workers);
1814
1815 let mut leverage = Array1::<f64>::zeros(n);
1816
1817 let block_widths: Vec<usize> = block_designs.iter().map(|d| d.ncols()).collect();
1821 let mut h_stripes: Vec<FaerMat<f64>> = block_widths
1822 .iter()
1823 .map(|&p_blk| FaerMat::<f64>::zeros(p_tot, p_blk))
1824 .collect();
1825 for blk in 0..b {
1828 let off_b = col_offsets[blk];
1829 let p_blk = block_widths[blk];
1830 let stripe = &mut h_stripes[blk];
1831 for c in 0..p_blk {
1832 for r in 0..p_tot {
1833 stripe[(r, c)] = penalized_hessian_inv[(r, off_b + c)];
1834 }
1835 }
1836 }
1837
1838 leverage
1839 .as_slice_mut()
1840 .expect("newly allocated Array1 is contiguous")
1841 .par_chunks_mut(chunk_size)
1842 .enumerate()
1843 .for_each(|(chunk_idx, leverage_chunk)| {
1844 let chunk_start = chunk_idx * chunk_size;
1845 let chunk_len = leverage_chunk.len();
1846 let chunk_end = chunk_start + chunk_len;
1847
1848 let bb_sz = b * b;
1852 let mut a_i = vec![0.0f64; bb_sz];
1853 let mut aw = vec![0.0f64; bb_sz];
1854 let mut w_flat = vec![0.0f64; bb_sz];
1855
1856 let mut q_storage: Vec<FaerMat<f64>> = block_widths
1860 .iter()
1861 .map(|_| FaerMat::<f64>::zeros(p_tot, chunk_len))
1862 .collect();
1863
1864 let mut xt_storage: Vec<FaerMat<f64>> = block_widths
1868 .iter()
1869 .map(|&p_blk| FaerMat::<f64>::zeros(p_blk, chunk_len))
1870 .collect();
1871
1872 for blk in 0..b {
1877 let p_blk = block_widths[blk];
1878
1879 let x_chunk = block_designs[blk].slice(s![chunk_start..chunk_end, ..]);
1880 let xt = &mut xt_storage[blk];
1881 for local_i in 0..chunk_len {
1882 let row = x_chunk.row(local_i);
1883 for j in 0..p_blk {
1884 xt[(j, local_i)] = row[j];
1885 }
1886 }
1887
1888 matmul(
1889 q_storage[blk].as_mut(),
1890 Accum::Replace,
1891 h_stripes[blk].as_ref(),
1892 xt_storage[blk].as_ref(),
1893 1.0,
1894 Par::Seq,
1895 );
1896 }
1897
1898 for local_i in 0..chunk_len {
1899 let i = chunk_start + local_i;
1900 let w_i = &block_weights[i];
1901
1902 for r in 0..b {
1904 for c in 0..b {
1905 w_flat[r * b + c] = w_i[(r, c)];
1906 }
1907 }
1908
1909 for r in 0..bb_sz {
1913 a_i[r] = 0.0;
1914 }
1915 for k in 0..b {
1916 let q_k = &q_storage[k];
1917 let q_col = q_k.col_as_slice(local_i);
1918 for a in 0..b {
1919 let p_a = block_widths[a];
1920 let off_a = col_offsets[a];
1921 let xa_row = block_designs[a].row(i);
1922 let mut dot = 0.0f64;
1923 for j in 0..p_a {
1924 dot = xa_row[j].mul_add(q_col[off_a + j], dot);
1925 }
1926 a_i[a * b + k] = dot;
1927 }
1928 }
1929
1930 mat_mul_flat(&a_i, &w_flat, &mut aw, b);
1932 let mut tr = 0.0f64;
1933 for d in 0..b {
1934 tr += aw[d * b + d];
1935 }
1936 leverage_chunk[local_i] = tr;
1937 }
1938 });
1939
1940 Ok(leverage)
1941}
1942
1943#[cfg(test)]
1947mod tests {
1948 use super::{
1949 ALO_EXACT_SCALAR_MAX_ITERS, AloExactScalarError, AloInput, alo_eta_exact_frozen_curvature,
1950 alo_eta_updatewith_offset, bayesvar_eta, compute_alo_from_input_inner,
1951 percentile_from_sorted, percentile_index, sandwichvar_eta,
1952 };
1953 use gam_linalg::matrix::{PsdWeightsView, SignedWeightsView};
1954 use gam_problem::LinkFunction;
1955
1956 #[test]
1957 fn alo_offset_update_matches_centered_algebra() {
1958 let eta_hat = 11.0;
1959 let z = 13.0;
1960 let offset = 10.0;
1961 let x_hinv_x = 0.2;
1962 let hessian_weight = 1.0;
1963 let score_weight = 1.0;
1964 let leverage = hessian_weight * x_hinv_x;
1966 let expected = offset + ((eta_hat - offset) - leverage * (z - offset)) / (1.0 - leverage);
1967 let got =
1968 alo_eta_updatewith_offset(eta_hat, z, offset, x_hinv_x, score_weight, 1.0 - leverage);
1969 assert!((got - expected).abs() < 1e-12);
1970 }
1971
1972 #[test]
1973 fn alo_offset_update_reduces_to_classicwhen_offsetzero() {
1974 let eta_hat = 1.25;
1975 let z = -0.5;
1976 let x_hinv_x = 0.35;
1977 let hessian_weight = 1.0;
1978 let score_weight = 1.0;
1979 let leverage = hessian_weight * x_hinv_x;
1980 let expected = (eta_hat - leverage * z) / (1.0 - leverage);
1981 let got =
1982 alo_eta_updatewith_offset(eta_hat, z, 0.0, x_hinv_x, score_weight, 1.0 - leverage);
1983 assert!((got - expected).abs() < 1e-12);
1984 }
1985
1986 #[test]
1987 fn alo_offset_update_uses_distinct_score_and_hessian_weights() {
1988 let eta_hat = 1.7;
1989 let z = 0.4;
1990 let offset = -0.2;
1991 let x_hinv_x = 0.15;
1992 let hessian_weight = 3.0;
1993 let score_weight = 5.0;
1994 let expected = offset
1995 + (eta_hat - offset)
1996 + x_hinv_x * score_weight * ((eta_hat - offset) - (z - offset))
1997 / (1.0 - hessian_weight * x_hinv_x);
1998 let got = alo_eta_updatewith_offset(
1999 eta_hat,
2000 z,
2001 offset,
2002 x_hinv_x,
2003 score_weight,
2004 1.0 - hessian_weight * x_hinv_x,
2005 );
2006 assert!((got - expected).abs() < 1e-12);
2007 }
2008
2009 #[test]
2010 fn alo_offset_update_handles_zero_hessian_weight() {
2011 let eta_hat = 0.8;
2012 let z = -0.3;
2013 let offset = 0.1;
2014 let x_hinv_x = 0.4;
2015 let hessian_weight = 0.0;
2016 let score_weight = 2.5;
2017 let expected = offset
2018 + (eta_hat - offset)
2019 + x_hinv_x * score_weight * ((eta_hat - offset) - (z - offset));
2020 let got = alo_eta_updatewith_offset(
2021 eta_hat,
2022 z,
2023 offset,
2024 x_hinv_x,
2025 score_weight,
2026 1.0 - hessian_weight * x_hinv_x,
2027 );
2028 assert!((got - expected).abs() < 1e-12);
2029 }
2030
2031 #[test]
2032 fn alo_exact_frozen_curvature_converges_to_fixed_point() {
2033 let eta_hat = 1.0;
2034 let a_ii = 0.4;
2035 let got = alo_eta_exact_frozen_curvature(eta_hat, a_ii, &|eta| (0.5 * (eta - 2.0), 0.5))
2036 .expect("linear scalar fixed point should converge in one Newton step");
2037 assert!((got - 0.75).abs() < 1e-12);
2038 }
2039
2040 #[test]
2041 fn alo_exact_frozen_curvature_reports_nonconvergence() {
2042 let err = alo_eta_exact_frozen_curvature(0.0, 1.0, &|eta| (eta + 1.0, 0.0))
2043 .expect_err("constant residual should exhaust the scalar iteration budget");
2044 let AloExactScalarError::MaxIterations { iterations, .. } = err else {
2045 panic!("constant residual must report MaxIterations, got {err:?}");
2046 };
2047 assert_eq!(
2048 iterations, ALO_EXACT_SCALAR_MAX_ITERS,
2049 "non-convergence must report the full scalar iteration budget"
2050 );
2051 }
2052
2053 #[test]
2054 fn alo_input_reports_exact_scalar_nonconvergence_with_row_context() {
2055 let design = Array2::from_elem((1, 1), 1.0);
2056 let penalized_hessian = Array2::from_elem((1, 1), 1.0);
2057 let hessian_weights = Array1::from_vec(vec![0.0]);
2058 let score_weights = Array1::from_vec(vec![0.0]);
2059 let working_response = Array1::from_vec(vec![0.0]);
2060 let eta = Array1::from_vec(vec![0.0]);
2061 let offset = Array1::from_vec(vec![0.0]);
2062 let score_curvature = |_: usize, eta: f64| (eta + 1.0, 0.0);
2063 let input = AloInput {
2064 design: &design,
2065 penalized_hessian: &penalized_hessian,
2066 hessian_weights: SignedWeightsView::from_array(&hessian_weights),
2067 score_weights: PsdWeightsView::try_from_array(&score_weights).expect("psd weights"),
2068 working_response: &working_response,
2069 eta: &eta,
2070 offset: &offset,
2071 link: LinkFunction::Logit,
2072 phi: 1.0,
2073 penalty_root: None,
2074 ridge: 0.0,
2075 score_curvature: Some(&score_curvature),
2076 };
2077
2078 let err =
2079 compute_alo_from_input_inner(&input).expect_err("non-converged exact ALO must error");
2080 let msg = err.to_string();
2081 assert!(
2082 msg.contains("ALO exact frozen-curvature solve failed at row 0"),
2083 "missing row context in exact ALO error: {msg}"
2084 );
2085 assert!(
2086 msg.contains("did not converge within"),
2087 "missing non-convergence cause in exact ALO error: {msg}"
2088 );
2089 }
2090
2091 #[test]
2092 fn gaussian_unpenalized_sandwich_equals_bayes() {
2093 let phi = 2.5;
2096 let x_hinv_x = 0.3;
2097 let es_norm2 = 0.0;
2098 let ridge = 0.0;
2099 let s_norm2 = 0.0;
2100 let vb = bayesvar_eta(phi, x_hinv_x);
2101 let vs = sandwichvar_eta(phi, x_hinv_x, es_norm2, ridge, s_norm2);
2102 assert!((vb - vs).abs() < 1e-12);
2103 }
2104
2105 #[test]
2106 fn sandwich_matches_direct_linear_gaussian_formula() {
2107 let phi = 1.7;
2110 let x_hinv_x = 0.41;
2111 let es_norm2 = 0.05;
2112 let ridge = 1e-3;
2113 let s_norm2 = 2.0;
2114 let got = sandwichvar_eta(phi, x_hinv_x, es_norm2, ridge, s_norm2);
2115 let expected = phi * (x_hinv_x - es_norm2 - ridge * s_norm2);
2116 assert!((got - expected).abs() < 1e-12);
2117 }
2118
2119 #[test]
2120 fn percentile_index_matches_expected_rounding() {
2121 assert_eq!(percentile_index(0, 0.95), 0);
2122 assert_eq!(percentile_index(1, 0.95), 0);
2123 assert_eq!(percentile_index(10, 0.50), 5);
2124 assert_eq!(percentile_index(10, 0.95), 9);
2125 }
2126
2127 #[test]
2128 fn percentile_from_sorted_returns_order_statistic() {
2129 let values = [1.0, 2.0, 3.0, 4.0, 5.0];
2130 assert_eq!(percentile_from_sorted(&values, 0.50), 3.0);
2131 assert_eq!(percentile_from_sorted(&values, 0.95), 5.0);
2132 assert_eq!(percentile_from_sorted(&[], 0.95), 0.0);
2133 }
2134
2135 use super::{MultiBlockAloInput, compute_multiblock_alo, compute_multiblock_alo_leverages};
2138 use ndarray::{Array1, Array2};
2139
2140 #[test]
2141 fn multiblock_b1_matches_scalar_leverage() {
2142 let n = 3;
2145 let p = 2;
2146 let x = Array2::from_shape_vec((n, p), vec![1.0, 0.5, 0.8, -0.3, 0.2, 1.1]).unwrap();
2147 let w = [1.0, 2.0, 0.5];
2149 let mut h = Array2::<f64>::eye(p);
2150 for i in 0..n {
2151 for r in 0..p {
2152 for c in 0..p {
2153 h[(r, c)] += w[i] * x[(i, r)] * x[(i, c)];
2154 }
2155 }
2156 }
2157 let det = h[(0, 0)] * h[(1, 1)] - h[(0, 1)] * h[(1, 0)];
2159 let mut h_inv = Array2::<f64>::zeros((p, p));
2160 h_inv[(0, 0)] = h[(1, 1)] / det;
2161 h_inv[(1, 1)] = h[(0, 0)] / det;
2162 h_inv[(0, 1)] = -h[(0, 1)] / det;
2163 h_inv[(1, 0)] = -h[(1, 0)] / det;
2164
2165 let mut scalar_lev = vec![0.0f64; n];
2167 for i in 0..n {
2168 let mut xhx = 0.0;
2169 for r in 0..p {
2170 for c in 0..p {
2171 xhx += x[(i, r)] * h_inv[(r, c)] * x[(i, c)];
2172 }
2173 }
2174 scalar_lev[i] = w[i] * xhx;
2175 }
2176
2177 let block_designs = vec![x.clone()];
2179 let block_weights: Vec<Array2<f64>> =
2180 w.iter().map(|&wi| Array2::from_elem((1, 1), wi)).collect();
2181 let scores: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.1])).collect();
2182 let eta_hat: Vec<Array1<f64>> = (0..n).map(|i| Array1::from_vec(vec![i as f64])).collect();
2183
2184 let input = MultiBlockAloInput {
2185 n_obs: n,
2186 n_blocks: 1,
2187 block_designs: &block_designs,
2188 penalized_hessian_inv: &h_inv,
2189 block_weights,
2190 scores,
2191 eta_hat,
2192 };
2193
2194 let result = compute_multiblock_alo(&input).unwrap();
2195 for i in 0..n {
2196 assert!(
2197 (result.leverage[i] - scalar_lev[i]).abs() < 1e-10,
2198 "leverage mismatch at i={}: got {}, expected {}",
2199 i,
2200 result.leverage[i],
2201 scalar_lev[i]
2202 );
2203 }
2204 }
2205
2206 #[test]
2207 fn multiblock_leverage_only_matches_full() {
2208 let n = 4;
2211 let p1 = 2;
2212 let p2 = 3;
2213 let x1 = Array2::from_shape_fn((n, p1), |(i, j)| (i + j + 1) as f64 * 0.3);
2214 let x2 = Array2::from_shape_fn((n, p2), |(i, j)| (i * 2 + j) as f64 * 0.2 - 0.1);
2215 let p_tot = p1 + p2;
2216 let h_inv = Array2::<f64>::eye(p_tot); let block_weights: Vec<Array2<f64>> = (0..n)
2218 .map(|i| {
2219 let v = (i + 1) as f64;
2220 Array2::from_shape_vec((2, 2), vec![v, 0.1, 0.1, v * 0.5]).unwrap()
2221 })
2222 .collect();
2223 let scores: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.0, 0.0])).collect();
2224 let eta_hat: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.0, 0.0])).collect();
2225 let block_designs = vec![x1.clone(), x2.clone()];
2226
2227 let input = MultiBlockAloInput {
2228 n_obs: n,
2229 n_blocks: 2,
2230 block_designs: &block_designs,
2231 penalized_hessian_inv: &h_inv,
2232 block_weights: block_weights.clone(),
2233 scores,
2234 eta_hat,
2235 };
2236 let full = compute_multiblock_alo(&input).unwrap();
2237 let lev_only =
2238 compute_multiblock_alo_leverages(n, 2, &block_designs, &h_inv, &block_weights).unwrap();
2239
2240 for i in 0..n {
2241 assert!(
2242 (full.leverage[i] - lev_only[i]).abs() < 1e-12,
2243 "leverage mismatch at i={}: full={}, lev_only={}",
2244 i,
2245 full.leverage[i],
2246 lev_only[i]
2247 );
2248 }
2249 }
2250
2251 #[test]
2252 fn multiblock_singular_weight_still_corrects() {
2253 let n = 1;
2257 let p = 2;
2258 let x = Array2::from_shape_vec((1, p), vec![1.0, 0.5]).unwrap();
2259 let h_inv = Array2::eye(p);
2260 let block_designs = vec![x.clone()];
2261 let block_weights = vec![Array2::from_elem((1, 1), 0.0)]; let scores = vec![Array1::from_vec(vec![1.0])];
2263 let eta_hat = vec![Array1::from_vec(vec![std::f64::consts::PI])];
2264
2265 let input = MultiBlockAloInput {
2266 n_obs: n,
2267 n_blocks: 1,
2268 block_designs: &block_designs,
2269 penalized_hessian_inv: &h_inv,
2270 block_weights,
2271 scores,
2272 eta_hat,
2273 };
2274 let result = compute_multiblock_alo(&input).unwrap();
2275 let expected = std::f64::consts::PI + 1.25;
2277 assert!(
2278 (result.eta_tilde[0][0] - expected).abs() < 1e-12,
2279 "expected {}, got {}",
2280 expected,
2281 result.eta_tilde[0][0]
2282 );
2283 assert!(result.cook_distance[0].abs() < 1e-14);
2285 assert!(result.alo_variance[0][0].abs() < 1e-14);
2287 }
2288
2289 #[test]
2290 fn multiblock_cook_and_variance_basic() {
2291 let n = 1;
2293 let x = Array2::from_elem((1, 1), 1.0);
2294 let h_inv = Array2::from_elem((1, 1), 0.5);
2296 let block_designs = vec![x.clone()];
2297 let w_val = 2.0;
2298 let s_val = 0.4;
2299 let block_weights = vec![Array2::from_elem((1, 1), w_val)];
2300 let scores = vec![Array1::from_vec(vec![s_val])];
2301 let eta_hat = vec![Array1::from_vec(vec![1.0])];
2302
2303 let input = MultiBlockAloInput {
2304 n_obs: n,
2305 n_blocks: 1,
2306 block_designs: &block_designs,
2307 penalized_hessian_inv: &h_inv,
2308 block_weights,
2309 scores,
2310 eta_hat,
2311 };
2312 let result = compute_multiblock_alo(&input).unwrap();
2313
2314 assert!(result.eta_tilde[0][0].is_finite());
2321 assert!(result.cook_distance[0].is_finite());
2322 assert!(result.alo_variance[0][0].is_finite());
2323 }
2324}