1use crate::estimate::EstimationError;
2use crate::estimate::{FitGeometry, UnifiedFitResult};
3use crate::pirls;
4use faer::Mat as FaerMat;
5use faer::linalg::matmul::matmul;
6use faer::prelude::ReborrowMut;
7use faer::{Accum, Par};
8use gam_linalg::faer_ndarray::{FaerArrayView, FaerCholesky};
9use gam_linalg::matrix::{PsdWeightsView, SignedWeightsView};
10use gam_linalg::utils::StableSolver;
11use gam_problem::LinkFunction;
12use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder, s};
13use std::fmt;
14
15#[derive(Debug, Clone)]
24pub enum AloError {
25 InvalidInput { reason: String },
29 WeightInvalid { reason: String },
32 DesignDegenerate { reason: String },
35 InfluenceMatrixFailed { condition_number: f64 },
38 LooComputationFailed { reason: String },
41}
42
43impl fmt::Display for AloError {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 AloError::InvalidInput { reason }
47 | AloError::WeightInvalid { reason }
48 | AloError::DesignDegenerate { reason }
49 | AloError::LooComputationFailed { reason } => f.write_str(reason),
50 AloError::InfluenceMatrixFailed { condition_number } => {
51 write!(
52 f,
53 "ALO influence matrix failed (condition number {condition_number:.3e})"
54 )
55 }
56 }
57 }
58}
59
60impl std::error::Error for AloError {}
61
62impl From<AloError> for EstimationError {
63 fn from(err: AloError) -> EstimationError {
64 match err {
65 AloError::InvalidInput { reason }
66 | AloError::WeightInvalid { reason }
67 | AloError::DesignDegenerate { reason }
68 | AloError::LooComputationFailed { reason } => EstimationError::InvalidInput(reason),
69 AloError::InfluenceMatrixFailed { condition_number } => {
70 EstimationError::ModelIsIllConditioned { condition_number }
71 }
72 }
73 }
74}
75
76impl From<AloError> for String {
77 fn from(err: AloError) -> String {
78 err.to_string()
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct AloDiagnostics {
85 pub eta_tilde: Array1<f64>,
86 pub se_bayes: Array1<f64>,
89 pub se_sandwich: Array1<f64>,
92 pub pred_identity: Array1<f64>,
93 pub leverage: Array1<f64>,
94 pub fisherweights: Array1<f64>,
95}
96
97#[inline]
98fn alo_eta_updatewith_offset(
99 eta_hat: f64,
100 z: f64,
101 offset: f64,
102 x_hinv_x: f64,
103 score_weight: f64,
104 denom: f64,
105) -> f64 {
106 let eta_centered = eta_hat - offset;
109 let z_centered = z - offset;
110 let score = score_weight * (eta_centered - z_centered);
111 offset + eta_centered + x_hinv_x * score / denom
112}
113
114pub type AloScalarScoreCurvature<'a> = dyn Fn(usize, f64) -> (f64, f64) + Sync + 'a;
124
125const ALO_EXACT_SCALAR_MAX_ITERS: usize = 64;
131
132const ALO_EXACT_SCALAR_TOL: f64 = 1e-12;
136
137#[derive(Debug, Clone, Copy, PartialEq)]
158enum AloExactScalarError {
159 NonFiniteScoreCurvature {
160 eta: f64,
161 ell_prime: f64,
162 ell_double: f64,
163 },
164 DegenerateJacobian {
165 eta: f64,
166 jacobian: f64,
167 },
168 NonFiniteStep {
169 eta: f64,
170 residual: f64,
171 jacobian: f64,
172 next: f64,
173 },
174 MaxIterations {
175 iterations: usize,
176 residual: f64,
177 eta: f64,
178 },
179}
180
181impl fmt::Display for AloExactScalarError {
182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 match *self {
184 AloExactScalarError::NonFiniteScoreCurvature {
185 eta,
186 ell_prime,
187 ell_double,
188 } => write!(
189 f,
190 "non-finite score/curvature at eta={eta:.6e}: ell_prime={ell_prime:.6e}, ell_double={ell_double:.6e}"
191 ),
192 AloExactScalarError::DegenerateJacobian { eta, jacobian } => write!(
193 f,
194 "degenerate Newton Jacobian at eta={eta:.6e}: jacobian={jacobian:.6e}, min={ALO_DENOMINATOR_MIN:.1e}"
195 ),
196 AloExactScalarError::NonFiniteStep {
197 eta,
198 residual,
199 jacobian,
200 next,
201 } => write!(
202 f,
203 "non-finite Newton step from eta={eta:.6e}: residual={residual:.6e}, jacobian={jacobian:.6e}, next={next:.6e}"
204 ),
205 AloExactScalarError::MaxIterations {
206 iterations,
207 residual,
208 eta,
209 } => write!(
210 f,
211 "did not converge within {iterations} iterations: residual={residual:.6e}, eta={eta:.6e}, tol={ALO_EXACT_SCALAR_TOL:.1e}"
212 ),
213 }
214 }
215}
216
217const ALO_EXACT_SCALAR_BACKTRACKS: usize = 40;
223
224#[inline]
225fn alo_eta_exact_frozen_curvature(
226 eta_hat: f64,
227 a_ii: f64,
228 score_curvature: &dyn Fn(f64) -> (f64, f64),
229) -> Result<f64, AloExactScalarError> {
230 let residual_and_jac = |eta: f64| -> Result<(f64, f64), AloExactScalarError> {
254 let (ell_prime, ell_double) = score_curvature(eta);
255 if !ell_prime.is_finite() || !ell_double.is_finite() {
256 return Err(AloExactScalarError::NonFiniteScoreCurvature {
257 eta,
258 ell_prime,
259 ell_double,
260 });
261 }
262 Ok((eta - eta_hat - a_ii * ell_prime, 1.0 - a_ii * ell_double))
263 };
264
265 let mut eta = eta_hat;
266 let (mut residual, mut jac) = residual_and_jac(eta)?;
267 for _ in 0..ALO_EXACT_SCALAR_MAX_ITERS {
268 if residual.abs() <= ALO_EXACT_SCALAR_TOL {
269 return Ok(eta);
270 }
271 if jac.abs() <= ALO_DENOMINATOR_MIN || !jac.is_finite() {
272 return Err(AloExactScalarError::DegenerateJacobian { eta, jacobian: jac });
273 }
274 let step = residual / jac;
275 if !step.is_finite() {
276 return Err(AloExactScalarError::NonFiniteStep {
277 eta,
278 residual,
279 jacobian: jac,
280 next: eta - step,
281 });
282 }
283 let mut t = 1.0;
288 let mut advanced = false;
289 for _ in 0..ALO_EXACT_SCALAR_BACKTRACKS {
290 let trial = eta - t * step;
291 if let Ok((r_trial, j_trial)) = residual_and_jac(trial) {
292 if r_trial.abs() < residual.abs() {
293 eta = trial;
294 residual = r_trial;
295 jac = j_trial;
296 advanced = true;
297 break;
298 }
299 }
300 t *= 0.5;
301 }
302 if !advanced {
303 break;
304 }
305 }
306 Err(AloExactScalarError::MaxIterations {
307 iterations: ALO_EXACT_SCALAR_MAX_ITERS,
308 residual,
309 eta,
310 })
311}
312
313#[inline]
314fn bayesvar_eta(phi: f64, x_hinv_x: f64) -> f64 {
315 phi * x_hinv_x
316}
317
318#[inline]
319fn sandwichvar_eta_from_meat(phi: f64, meat_quad: f64) -> f64 {
320 phi * meat_quad
321}
322
323#[inline]
324fn variance_negative_tolerance(scale: f64) -> f64 {
325 1e-12 * scale.abs().max(1.0)
327}
328
329const LEVERAGE_HIGH_THRESHOLD: f64 = 0.99;
330const LEVERAGE_VERY_HIGH_THRESHOLD: f64 = 0.999;
331const LEVERAGE_RATE_THRESHOLDS: [f64; 3] = [0.90, 0.95, 0.99];
332const LEVERAGE_PERCENTILES: [f64; 3] = [0.50, 0.95, 0.99];
333const ALO_DENOMINATOR_MIN: f64 = 1e-12;
334const MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES: usize = 256 * 1024 * 1024;
335
336const ALO_RHS_BLOCK_COLS: usize = 8192;
341
342const HESSIAN_SYMMETRY_REL_TOL: f64 = 1e-8;
348
349const ALO_LOCAL_BLOCK_RIDGE: f64 = 1e-6;
355
356const LU_PIVOT_SINGULAR_TOL: f64 = 1e-12;
361
362#[inline]
363fn percentile_index(sample_size: usize, quantile: f64) -> usize {
364 if sample_size <= 1 {
365 return 0;
366 }
367 let max_index = sample_size - 1;
368 ((quantile * max_index as f64).round() as usize).min(max_index)
369}
370
371#[inline]
372fn percentile_from_sorted(sorted: &[f64], quantile: f64) -> f64 {
373 if sorted.is_empty() {
374 0.0
375 } else {
376 sorted[percentile_index(sorted.len(), quantile)]
377 }
378}
379
380#[inline]
381fn multiblock_col_offsets(block_designs: &[Array2<f64>]) -> Vec<usize> {
382 let mut offsets = Vec::with_capacity(block_designs.len());
383 let mut off = 0usize;
384 for design in block_designs {
385 offsets.push(off);
386 off += design.ncols();
387 }
388 offsets
389}
390
391#[inline]
392fn multiblock_alo_parallel_leverage_chunk_size(
393 p_tot: usize,
394 n_blocks: usize,
395 n_obs: usize,
396 max_workers: usize,
397) -> usize {
398 if p_tot == 0 || n_blocks == 0 || n_obs == 0 {
399 return 1;
400 }
401
402 let workers = max_workers.max(1);
408 let per_worker_budget = (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / workers).max(1);
409 let elem_count_per_obs = p_tot.saturating_mul(n_blocks.saturating_add(1)).max(1);
410 let bytes_per_obs = elem_count_per_obs
411 .saturating_mul(std::mem::size_of::<f64>())
412 .max(1);
413 let budget_obs = (per_worker_budget / bytes_per_obs).max(1);
414 budget_obs.min(n_obs)
415}
416
417fn compute_alo_diagnostics_from_pirls_impl(
418 base: &pirls::PirlsResult,
419 y: ArrayView1<f64>,
420 link: LinkFunction,
421) -> Result<AloDiagnostics, EstimationError> {
422 compute_alo_diagnostics_from_pirls_inner(base, y, link).map_err(EstimationError::from)
423}
424
425fn alo_link_needs_exact_curvature_refinement(likelihood: &gam_problem::GlmLikelihoodSpec) -> bool {
438 use gam_problem::ResponseFamily;
439 matches!(
440 (&likelihood.spec.response, likelihood.link_function()),
441 (ResponseFamily::Binomial, LinkFunction::Logit)
442 | (ResponseFamily::Poisson, LinkFunction::Log)
443 )
444}
445
446fn compute_alo_diagnostics_from_pirls_inner(
447 base: &pirls::PirlsResult,
448 y: ArrayView1<f64>,
449 link: LinkFunction,
450) -> Result<AloDiagnostics, AloError> {
451 let x_dense_arc = base
452 .x_transformed
453 .try_to_dense_arc("ALO diagnostics require dense transformed design")
454 .map_err(|reason| AloError::DesignDegenerate { reason })?;
455 let x_dense = x_dense_arc.as_ref();
456 let n = x_dense.nrows();
457
458 let phi = match link {
460 LinkFunction::Log => 1.0,
461 LinkFunction::Logit
462 | LinkFunction::Probit
463 | LinkFunction::CLogLog
464 | LinkFunction::LogLog
465 | LinkFunction::Cauchit
466 | LinkFunction::Sas
467 | LinkFunction::BetaLogistic => 1.0,
468 LinkFunction::Identity => {
469 use rayon::iter::{IntoParallelIterator, ParallelIterator};
470 let rss: f64 = (0..n)
471 .into_par_iter()
472 .map(|i| {
473 let r = y[i] - base.finalmu[i];
474 base.finalweights[i] * r * r
475 })
476 .sum();
477 let n_pos = (0..n).filter(|&i| base.finalweights[i] > 0.0).count();
484 let dof = (n_pos as f64) - base.edf;
485 let denom = dof.max(1.0);
486 rss / denom
487 }
488 };
489
490 let e = &base.reparam_result.e_transformed;
491 let ridge = base.ridge_passport.laplacehessianridge().max(0.0);
492
493 let h_dense_for_alo = base
497 .dense_stabilizedhessian_transformed(
498 "ALO diagnostics require exact dense stabilized penalized Hessian",
499 )
500 .map_err(|e| match e {
501 EstimationError::InvalidInput(reason) => AloError::InvalidInput { reason },
502 other => AloError::InvalidInput {
503 reason: format!("{other:?}"),
504 },
505 })?;
506
507 let canonical_scale: Option<Array1<f64>> =
526 if alo_link_needs_exact_curvature_refinement(&base.likelihood) {
527 let mut c = Array1::<f64>::zeros(n);
528 for i in 0..n {
529 let dmu = base.solve_dmu_deta[i];
530 let w_h = base.finalweights[i];
531 c[i] = if dmu.abs() <= ALO_DENOMINATOR_MIN || !dmu.is_finite() || !w_h.is_finite() {
532 f64::NAN
533 } else {
534 w_h / dmu
535 };
536 }
537 Some(c)
538 } else {
539 None
540 };
541
542 let inv_link_for_closure = base.likelihood.spec.link.clone();
543 let score_curvature_closure = canonical_scale.as_ref().map(|scale| {
544 move |i: usize, eta: f64| -> (f64, f64) {
545 let (mu, dmu) = crate::mixture_link::inverse_link_mu_d1_for_inverse_link(
546 &inv_link_for_closure,
547 eta,
548 )
549 .unwrap_or((f64::NAN, f64::NAN));
550 let c_i = scale[i];
551 (c_i * (mu - y[i]), c_i * dmu)
552 }
553 });
554 let score_curvature_ref: Option<&AloScalarScoreCurvature> = score_curvature_closure
555 .as_ref()
556 .map(|f| f as &AloScalarScoreCurvature);
557
558 let input = AloInput {
560 design: x_dense,
561 penalized_hessian: &h_dense_for_alo,
562 hessian_weights: base.final_weights_signed(),
563 score_weights: base.solve_weights_psd(),
564 working_response: &base.solveworking_response,
565 eta: &base.final_eta,
566 offset: &base.final_offset,
567 link,
568 phi,
569 penalty_root: if e.nrows() > 0 { Some(e) } else { None },
570 ridge,
571 score_curvature: score_curvature_ref,
572 };
573
574 let result = compute_alo_from_input_inner(&input)?;
575
576 log_leverage_diagnostics(&result.leverage, phi);
578
579 let has_nan_pred = result.eta_tilde.iter().any(|&x| x.is_nan());
581 let has_nan_se_bayes = result.se_bayes.iter().any(|&x| x.is_nan());
582 let has_nan_se_sandwich = result.se_sandwich.iter().any(|&x| x.is_nan());
583 let has_nan_leverage = result.leverage.iter().any(|&x| x.is_nan());
584
585 if has_nan_pred || has_nan_se_bayes || has_nan_se_sandwich || has_nan_leverage {
586 log::error!("[GAM ALO] NaN values found in ALO diagnostics:");
587 log::error!(
588 "[GAM ALO] eta_tilde: {} NaN values",
589 result.eta_tilde.iter().filter(|&&x| x.is_nan()).count()
590 );
591 log::error!(
592 "[GAM ALO] se_bayes: {} NaN values",
593 result.se_bayes.iter().filter(|&&x| x.is_nan()).count()
594 );
595 log::error!(
596 "[GAM ALO] se_sandwich: {} NaN values",
597 result.se_sandwich.iter().filter(|&&x| x.is_nan()).count()
598 );
599 log::error!(
600 "[GAM ALO] leverage: {} NaN values",
601 result.leverage.iter().filter(|&&x| x.is_nan()).count()
602 );
603 return Err(AloError::InfluenceMatrixFailed {
604 condition_number: f64::INFINITY,
605 });
606 }
607
608 Ok(result)
609}
610
611fn log_leverage_diagnostics(leverage: &Array1<f64>, phi: f64) {
613 let n = leverage.len();
614 if n == 0 {
615 return;
616 }
617
618 let mut invalid_count = 0usize;
619 let mut high_leverage_count = 0usize;
620 let mut threshold_counts = [0usize; LEVERAGE_RATE_THRESHOLDS.len()];
621 let mut finite_leverage = Vec::with_capacity(n);
622
623 for (obs, &ai) in leverage.iter().enumerate() {
624 if ai.is_finite() {
625 finite_leverage.push(ai);
626 }
627
628 if !(0.0..=1.0).contains(&ai) || !ai.is_finite() {
629 invalid_count += 1;
630 log::warn!("[GAM ALO] invalid leverage at i={}, a_ii={:.6e}", obs, ai);
631 } else if ai > LEVERAGE_HIGH_THRESHOLD {
632 high_leverage_count += 1;
633 if ai > LEVERAGE_VERY_HIGH_THRESHOLD {
634 log::warn!("[GAM ALO] very high leverage at i={}, a_ii={:.6e}", obs, ai);
635 }
636 }
637
638 for (idx, threshold) in LEVERAGE_RATE_THRESHOLDS.iter().enumerate() {
639 if ai > *threshold {
640 threshold_counts[idx] += 1;
641 }
642 }
643 }
644
645 if invalid_count > 0 || high_leverage_count > 0 {
646 log::warn!(
647 "[GAM ALO] leverage diagnostics: {} invalid values, {} high values (>0.99)",
648 invalid_count,
649 high_leverage_count
650 );
651 }
652
653 finite_leverage.sort_by(f64::total_cmp);
654
655 let finite_n = finite_leverage.len();
656 let a_mean = if finite_n > 0 {
657 finite_leverage.iter().copied().sum::<f64>() / finite_n as f64
658 } else {
659 0.0
660 };
661 let a_median = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[0]);
662 let a_p95 = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[1]);
663 let a_p99 = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[2]);
664 let a_max = finite_leverage.last().copied().unwrap_or(0.0);
665
666 log::info!(
675 "[GAM ALO] leverage: n={}, mean={:.3e}, median={:.3e}, p95={:.3e}, p99={:.3e}, max={:.3e}",
676 n,
677 a_mean,
678 a_median,
679 a_p95,
680 a_p99,
681 a_max
682 );
683 log::info!(
684 "[GAM ALO] high-leverage: a>0.90: {:.2}%, a>0.95: {:.2}%, a>0.99: {:.2}%, dispersion phi={:.3e}",
685 100.0 * (threshold_counts[0] as f64) / n as f64,
686 100.0 * (threshold_counts[1] as f64) / n as f64,
687 100.0 * (threshold_counts[2] as f64) / n as f64,
688 phi
689 );
690}
691
692pub struct AloInput<'a> {
699 pub design: &'a Array2<f64>,
701 pub penalized_hessian: &'a Array2<f64>,
703 pub hessian_weights: SignedWeightsView<'a>,
710 pub score_weights: PsdWeightsView<'a>,
713 pub working_response: &'a Array1<f64>,
715 pub eta: &'a Array1<f64>,
717 pub offset: &'a Array1<f64>,
719 pub link: LinkFunction,
721 pub phi: f64,
723 pub penalty_root: Option<&'a Array2<f64>>,
726 pub ridge: f64,
728 pub score_curvature: Option<&'a AloScalarScoreCurvature<'a>>,
741}
742
743impl<'a> AloInput<'a> {
744 pub fn from_geometry(
746 geom: &'a FitGeometry,
747 design: &'a Array2<f64>,
748 eta: &'a Array1<f64>,
749 offset: &'a Array1<f64>,
750 link: LinkFunction,
751 phi: f64,
752 ) -> Self {
753 let psd_w = PsdWeightsView::from_view_unchecked(geom.working_weights.view());
760 Self {
761 design,
762 penalized_hessian: &geom.penalized_hessian,
763 hessian_weights: psd_w.as_signed(),
764 score_weights: psd_w,
765 working_response: &geom.working_response,
766 eta,
767 offset,
768 link,
769 phi,
770 penalty_root: None,
771 ridge: 0.0,
772 score_curvature: None,
773 }
774 }
775
776 pub fn from_geometry_with_working_state(
796 geom: &'a FitGeometry,
797 design: &'a Array2<f64>,
798 eta: &'a Array1<f64>,
799 offset: &'a Array1<f64>,
800 link: LinkFunction,
801 phi: f64,
802 working_weights: &'a Array1<f64>,
803 working_response: &'a Array1<f64>,
804 ) -> Self {
805 let psd_w = PsdWeightsView::from_view_unchecked(working_weights.view());
806 Self {
807 design,
808 penalized_hessian: &geom.penalized_hessian,
809 hessian_weights: psd_w.as_signed(),
810 score_weights: psd_w,
811 working_response,
812 eta,
813 offset,
814 link,
815 phi,
816 penalty_root: None,
817 ridge: 0.0,
818 score_curvature: None,
819 }
820 }
821}
822
823pub fn compute_alo_from_input(input: &AloInput) -> Result<AloDiagnostics, EstimationError> {
829 compute_alo_from_input_inner(input).map_err(EstimationError::from)
830}
831
832fn compute_alo_from_input_inner(input: &AloInput) -> Result<AloDiagnostics, AloError> {
833 let x_dense = input.design;
834 let n = x_dense.nrows();
835 let p = x_dense.ncols();
836 let w_h = input.hessian_weights.view();
840 let w_s = input.score_weights.view();
841
842 validate_alo_solve_setup(input, n, p)?;
843
844 let factor = StableSolver::new("alo penalized hessian")
845 .factorize(input.penalized_hessian)
846 .map_err(|_| AloError::InfluenceMatrixFailed {
847 condition_number: f64::INFINITY,
848 })?;
849
850 let xt = x_dense.t();
851 let phi = input.phi;
852
853 let mut aii = Array1::<f64>::zeros(n);
854 let mut x_hinv_x_diag = Array1::<f64>::zeros(n);
855 let mut se_bayes = Array1::<f64>::zeros(n);
856 let mut se_sandwich = Array1::<f64>::zeros(n);
857
858 let block_cols = ALO_RHS_BLOCK_COLS;
859 let mut rhs_chunk_buf = Array2::<f64>::zeros((p, block_cols).f());
864 let mut xs_chunk_storage = FaerMat::<f64>::zeros(n, block_cols);
871 let x_dense_view = FaerArrayView::new(x_dense);
872
873 for chunk_start in (0..n).step_by(block_cols) {
874 let chunk_end = (chunk_start + block_cols).min(n);
875 let width = chunk_end - chunk_start;
876
877 rhs_chunk_buf
878 .slice_mut(s![.., ..width])
879 .assign(&xt.slice(s![.., chunk_start..chunk_end]));
880
881 let rhs_chunkview = rhs_chunk_buf.slice(s![.., ..width]);
882 let rhs_chunk = FaerArrayView::new(&rhs_chunkview);
883 let s_chunk = factor.solve(rhs_chunk.as_ref());
887
888 let mut xs_target = xs_chunk_storage.as_mut().subcols_mut(0, width);
889 matmul(
890 xs_target.rb_mut(),
891 Accum::Replace,
892 x_dense_view.as_ref(),
893 s_chunk.as_ref(),
894 1.0,
895 Par::Seq,
896 );
897
898 let rhs_view = rhs_chunk_buf.slice(s![.., ..width]);
899
900 for local_col in 0..width {
901 let obs = chunk_start + local_col;
902 let rhs_col = rhs_view.column(local_col);
906 let rhs_slice = rhs_col.as_slice().expect("column-major col contiguous");
907 let s_slice = s_chunk.col_as_slice(local_col);
908
909 let mut x_hinv_x = 0.0f64;
910 for k in 0..p {
912 let sval = s_slice[k];
913 let xval = rhs_slice[k];
914 x_hinv_x = sval.mul_add(xval, x_hinv_x);
915 }
916 let ai = w_h[obs].max(0.0) * x_hinv_x;
917 aii[obs] = ai;
918 x_hinv_x_diag[obs] = x_hinv_x;
919
920 let var_bayes = bayesvar_eta(phi, x_hinv_x);
921 let xs_slice = xs_chunk_storage.col_as_slice(local_col);
922 let mut meat_quad = 0.0f64;
923 for row in 0..n {
924 let xs = xs_slice[row];
925 meat_quad += w_s[row] * xs * xs;
932 }
933 let var_sandwich = sandwichvar_eta_from_meat(phi, meat_quad);
934
935 if !var_bayes.is_finite() || !var_sandwich.is_finite() {
936 return Err(AloError::LooComputationFailed {
937 reason: format!(
938 "ALO variance is not finite at row {obs}: bayes={var_bayes:.6e}, sandwich={var_sandwich:.6e}"
939 ),
940 });
941 }
942 let bayes_tol = variance_negative_tolerance(phi * x_hinv_x.abs());
943 if var_bayes < -bayes_tol {
944 return Err(AloError::LooComputationFailed {
945 reason: format!(
946 "ALO Bayesian variance is materially negative at row {obs}: var={var_bayes:.6e}, tol={bayes_tol:.6e}"
947 ),
948 });
949 }
950 let sandwich_scale = phi * meat_quad.abs().max(x_hinv_x.abs());
951 let sandwich_tol = variance_negative_tolerance(sandwich_scale);
952 if var_sandwich < -sandwich_tol {
953 return Err(AloError::LooComputationFailed {
954 reason: format!(
955 "ALO sandwich variance is materially negative at row {obs}: var={var_sandwich:.6e}, tol={sandwich_tol:.6e}"
956 ),
957 });
958 }
959
960 se_bayes[obs] = var_bayes.max(0.0).sqrt();
961 se_sandwich[obs] = var_sandwich.max(0.0).sqrt();
962 }
963 }
964
965 let eta_hat = input.eta;
966 let z = input.working_response;
967 let offset = input.offset;
968
969 use rayon::prelude::*;
970 let eta_tilde_vec: Vec<f64> = (0..n)
971 .into_par_iter()
972 .map(|i| {
973 let denom_raw = 1.0 - aii[i];
974 if denom_raw <= ALO_DENOMINATOR_MIN || !denom_raw.is_finite() {
975 return Err(AloError::LooComputationFailed {
976 reason: format!(
977 "ALO denominator is too small at row {i}: a_ii={:.6e}, 1-a_ii={:.6e}, min={:.1e}",
978 aii[i], denom_raw, ALO_DENOMINATOR_MIN
979 ),
980 });
981 }
982 let one_step = alo_eta_updatewith_offset(
983 eta_hat[i],
984 z[i],
985 offset[i],
986 x_hinv_x_diag[i],
987 w_s[i],
988 denom_raw,
989 );
990 let v = if let Some(score_curvature) = input.score_curvature {
998 alo_eta_exact_frozen_curvature(
999 eta_hat[i],
1000 x_hinv_x_diag[i],
1001 &|eta| score_curvature(i, eta),
1002 )
1003 .map_err(|err| AloError::LooComputationFailed {
1004 reason: format!(
1005 "ALO exact frozen-curvature solve failed at row {i}: {err}"
1006 ),
1007 })?
1008 } else {
1009 one_step
1010 };
1011 if !v.is_finite() {
1012 return Err(AloError::LooComputationFailed {
1013 reason: format!("ALO eta_tilde is not finite at row {i}: eta_tilde={v}"),
1014 });
1015 }
1016 Ok(v)
1017 })
1018 .collect::<Result<_, _>>()?;
1019 let eta_tilde = Array1::from(eta_tilde_vec);
1020
1021 Ok(AloDiagnostics {
1022 eta_tilde,
1023 se_bayes,
1024 se_sandwich,
1025 pred_identity: eta_hat.clone(),
1026 leverage: aii,
1027 fisherweights: w_h.to_owned(),
1028 })
1029}
1030
1031fn validate_alo_solve_setup(input: &AloInput, n: usize, p: usize) -> Result<(), AloError> {
1032 let h = input.penalized_hessian;
1033 if h.nrows() != p || h.ncols() != p {
1034 return Err(AloError::InvalidInput {
1035 reason: format!(
1036 "ALO diagnostics require a dense exact penalized Hessian with shape {p}x{p}; got {}x{}",
1037 h.nrows(),
1038 h.ncols()
1039 ),
1040 });
1041 }
1042 if h.iter().any(|v| !v.is_finite()) {
1043 return Err(AloError::InvalidInput {
1044 reason: "ALO diagnostics require a finite dense exact penalized Hessian".to_string(),
1045 });
1046 }
1047 for i in 0..p {
1048 for j in 0..i {
1049 let a = h[[i, j]];
1050 let b = h[[j, i]];
1051 let scale = a.abs().max(b.abs()).max(1.0);
1052 if (a - b).abs() > HESSIAN_SYMMETRY_REL_TOL * scale {
1053 return Err(AloError::InvalidInput {
1054 reason: format!(
1055 "ALO diagnostics require a symmetric dense exact penalized Hessian; entries ({i},{j}) and ({j},{i}) differ by {:.3e}",
1056 (a - b).abs()
1057 ),
1058 });
1059 }
1060 }
1061 }
1062
1063 let vector_lengths = [
1064 ("hessian_weights", input.hessian_weights.len()),
1065 ("score_weights", input.score_weights.len()),
1066 ("working_response", input.working_response.len()),
1067 ("eta", input.eta.len()),
1068 ("offset", input.offset.len()),
1069 ];
1070 for (name, len) in vector_lengths {
1071 if len != n {
1072 return Err(AloError::InvalidInput {
1073 reason: format!("ALO diagnostics require {name} length {n}; got {len}"),
1074 });
1075 }
1076 }
1077 if input.hessian_weights.view().iter().any(|v| !v.is_finite()) {
1078 return Err(AloError::WeightInvalid {
1079 reason: "ALO diagnostics require finite Hessian-side weights".to_string(),
1080 });
1081 }
1082 if input.score_weights.view().iter().any(|v| !v.is_finite()) {
1083 return Err(AloError::WeightInvalid {
1084 reason: "ALO diagnostics require finite score-side weights".to_string(),
1085 });
1086 }
1087 if input.working_response.iter().any(|v| !v.is_finite()) {
1088 return Err(AloError::WeightInvalid {
1089 reason: "ALO diagnostics require finite working responses".to_string(),
1090 });
1091 }
1092 if input.eta.iter().any(|v| !v.is_finite()) || input.offset.iter().any(|v| !v.is_finite()) {
1093 return Err(AloError::InvalidInput {
1094 reason: "ALO diagnostics require finite linear predictors and offsets".to_string(),
1095 });
1096 }
1097 if !input.phi.is_finite() || input.phi <= 0.0 {
1098 return Err(AloError::InvalidInput {
1099 reason: format!(
1100 "ALO diagnostics require positive finite dispersion phi; got {}",
1101 input.phi
1102 ),
1103 });
1104 }
1105 if !input.ridge.is_finite() || input.ridge < 0.0 {
1106 return Err(AloError::InvalidInput {
1107 reason: format!(
1108 "ALO diagnostics require a finite non-negative Hessian ridge; got {}",
1109 input.ridge
1110 ),
1111 });
1112 }
1113 if let Some(e) = input.penalty_root {
1114 if e.ncols() != p {
1115 return Err(AloError::InvalidInput {
1116 reason: format!(
1117 "ALO diagnostics require penalty root to have {p} columns; got {}",
1118 e.ncols()
1119 ),
1120 });
1121 }
1122 if e.iter().any(|v| !v.is_finite()) {
1123 return Err(AloError::InvalidInput {
1124 reason: "ALO diagnostics require finite penalty-root entries".to_string(),
1125 });
1126 }
1127 }
1128 Ok(())
1129}
1130
1131pub fn compute_alo_diagnostics_from_fit(
1133 fit: &UnifiedFitResult,
1134 y: ArrayView1<f64>,
1135 link: LinkFunction,
1136) -> Result<AloDiagnostics, EstimationError> {
1137 let pirls = fit
1138 .artifacts
1139 .pirls
1140 .as_ref()
1141 .ok_or_else(|| AloError::InvalidInput {
1142 reason:
1143 "ALO diagnostics require a PIRLS-backed fit; this fit does not expose PIRLS geometry"
1144 .to_string(),
1145 })
1146 .map_err(EstimationError::from)?;
1147 compute_alo_diagnostics_from_pirls_impl(pirls, y, link)
1148}
1149
1150pub fn compute_alo_diagnostics_from_unified(
1156 unified: &UnifiedFitResult,
1157 design: &Array2<f64>,
1158 eta: &Array1<f64>,
1159 offset: &Array1<f64>,
1160 link: LinkFunction,
1161 phi: f64,
1162) -> Result<AloDiagnostics, EstimationError> {
1163 let geom = unified
1164 .geometry
1165 .as_ref()
1166 .ok_or_else(|| AloError::InvalidInput {
1167 reason: "UnifiedFitResult does not contain working-set geometry; \
1168 ALO diagnostics require geometry at convergence"
1169 .to_string(),
1170 })
1171 .map_err(EstimationError::from)?;
1172 let input = AloInput::from_geometry(geom, design, eta, offset, link, phi);
1173 compute_alo_from_input(&input)
1174}
1175
1176pub fn compute_alo_diagnostics_from_pirls(
1178 base: &pirls::PirlsResult,
1179 y: ArrayView1<f64>,
1180 link: LinkFunction,
1181) -> Result<AloDiagnostics, EstimationError> {
1182 compute_alo_diagnostics_from_pirls_impl(base, y, link)
1183}
1184
1185pub fn compute_case_deletion_from_pirls(
1204 base: &pirls::PirlsResult,
1205 y: ArrayView1<f64>,
1206 link: LinkFunction,
1207) -> Result<Option<crate::sensitivity::CaseDeletionInfluence>, EstimationError> {
1208 let x_dense_arc = base
1209 .x_transformed
1210 .try_to_dense_arc("case-deletion diagnostics require dense transformed design")
1211 .map_err(|reason| EstimationError::InvalidInput(reason))?;
1212 let x_dense = x_dense_arc.as_ref();
1213 let n = x_dense.nrows();
1214 let p = x_dense.ncols();
1215 if n == 0 || p == 0 {
1216 return Ok(None);
1217 }
1218
1219 let phi = match link {
1222 LinkFunction::Identity => {
1223 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1224 let rss: f64 = (0..n)
1225 .into_par_iter()
1226 .map(|i| {
1227 let r = y[i] - base.finalmu[i];
1228 base.finalweights[i] * r * r
1229 })
1230 .sum();
1231 let dof = (n as f64) - base.edf;
1232 rss / dof.max(1.0)
1233 }
1234 _ => 1.0,
1235 };
1236 if !(phi.is_finite() && phi > 0.0) {
1237 return Ok(None);
1238 }
1239
1240 let h_dense = base
1243 .dense_stabilizedhessian_transformed(
1244 "case-deletion diagnostics require exact dense stabilized penalized Hessian",
1245 )
1246 .map_err(|e| match e {
1247 EstimationError::InvalidInput(reason) => EstimationError::InvalidInput(reason),
1248 other => EstimationError::InvalidInput(format!("{other:?}")),
1249 })?;
1250
1251 let factor = match h_dense.cholesky(faer::Side::Lower) {
1252 Ok(f) => f,
1253 Err(_) => return Ok(None),
1257 };
1258
1259 let working_weights = base.finalweights.clone();
1263 let working_residual = &base.solveworking_response - &base.final_eta;
1264
1265 let sensitivity = crate::sensitivity::FitSensitivity::from_faer_cholesky(&factor, p);
1266 Ok(sensitivity.case_deletion(
1267 x_dense,
1268 working_weights.view(),
1269 working_residual.view(),
1270 phi,
1271 ))
1272}
1273
1274#[derive(Debug, Clone)]
1278pub struct MultiBlockAloDiagnostics {
1279 pub eta_tilde: Vec<Array1<f64>>,
1282 pub leverage: Array1<f64>,
1284 pub alo_variance: Vec<Array1<f64>>,
1289 pub cook_distance: Array1<f64>,
1292}
1293
1294pub struct MultiBlockAloInput<'a> {
1324 pub n_obs: usize,
1326 pub n_blocks: usize,
1328 pub block_designs: &'a [Array2<f64>],
1331 pub penalized_hessian_inv: &'a Array2<f64>,
1333 pub block_weights: Vec<Array2<f64>>,
1335 pub scores: Vec<Array1<f64>>,
1338 pub eta_hat: Vec<Array1<f64>>,
1341}
1342
1343pub fn compute_multiblock_alo(
1362 input: &MultiBlockAloInput,
1363) -> Result<MultiBlockAloDiagnostics, EstimationError> {
1364 compute_multiblock_alo_inner(input).map_err(EstimationError::from)
1365}
1366
1367fn compute_multiblock_alo_inner(
1368 input: &MultiBlockAloInput,
1369) -> Result<MultiBlockAloDiagnostics, AloError> {
1370 use rayon::prelude::*;
1371
1372 let n = input.n_obs;
1373 let b = input.n_blocks;
1374 let p_tot = input.penalized_hessian_inv.nrows();
1375
1376 if input.block_designs.len() != b {
1378 return Err(AloError::InvalidInput {
1379 reason: format!(
1380 "MultiBlockAloInput: expected {} block designs, got {}",
1381 b,
1382 input.block_designs.len()
1383 ),
1384 });
1385 }
1386
1387 let col_sum: usize = input.block_designs.iter().map(|d| d.ncols()).sum();
1389 if col_sum != p_tot {
1390 return Err(AloError::InvalidInput {
1391 reason: format!(
1392 "MultiBlockAloInput: total design columns ({}) != penalized_hessian_inv size ({})",
1393 col_sum, p_tot
1394 ),
1395 });
1396 }
1397
1398 let col_offsets = multiblock_col_offsets(input.block_designs);
1399 let (chunk_size, max_concurrent_chunks) = multiblock_alo_parallel_plan(p_tot, b, n);
1400 let chunk_starts: Vec<usize> = (0..n).step_by(chunk_size).collect();
1401
1402 let mut chunk_results: Vec<Result<MultiBlockAloChunkDiagnostics, AloError>> =
1408 Vec::with_capacity(chunk_starts.len());
1409 for chunk_wave in chunk_starts.chunks(max_concurrent_chunks) {
1410 let mut wave_results: Vec<Result<MultiBlockAloChunkDiagnostics, AloError>> = chunk_wave
1411 .par_iter()
1412 .map_init(
1413 || MultiBlockAloScratch::new(b),
1414 |scratch, &chunk_start| {
1415 let chunk_end = (chunk_start + chunk_size).min(n);
1416 compute_multiblock_alo_chunk(
1417 input,
1418 &col_offsets,
1419 chunk_start,
1420 chunk_end,
1421 scratch,
1422 )
1423 },
1424 )
1425 .collect();
1426 chunk_results.append(&mut wave_results);
1427 }
1428
1429 let mut eta_tilde = Vec::with_capacity(n);
1430 let mut leverage = Array1::<f64>::zeros(n);
1431 let mut alo_variance = Vec::with_capacity(n);
1432 let mut cook_distance = Array1::<f64>::zeros(n);
1433
1434 let mut chunks = Vec::with_capacity(chunk_results.len());
1435 for result in chunk_results {
1436 chunks.push(result?);
1437 }
1438 chunks.sort_unstable_by_key(|chunk| chunk.chunk_start);
1439
1440 for chunk in chunks {
1441 let chunk_start = chunk.chunk_start;
1442 eta_tilde.extend(chunk.eta_tilde);
1443 alo_variance.extend(chunk.alo_variance);
1444 for (local_i, lev) in chunk.leverage.into_iter().enumerate() {
1445 leverage[chunk_start + local_i] = lev;
1446 }
1447 for (local_i, cook) in chunk.cook_distance.into_iter().enumerate() {
1448 cook_distance[chunk_start + local_i] = cook;
1449 }
1450 }
1451
1452 Ok(MultiBlockAloDiagnostics {
1453 eta_tilde,
1454 leverage,
1455 alo_variance,
1456 cook_distance,
1457 })
1458}
1459
1460#[inline]
1461fn multiblock_alo_parallel_plan(p_tot: usize, n_blocks: usize, n_obs: usize) -> (usize, usize) {
1462 if p_tot == 0 || n_blocks == 0 || n_obs == 0 {
1463 return (1, 1);
1464 }
1465 let bytes_per_obs = (p_tot * n_blocks * std::mem::size_of::<f64>()).max(1);
1466 let workers = rayon::current_num_threads().max(1);
1467 let max_concurrent_chunks = (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / bytes_per_obs)
1468 .max(1)
1469 .min(workers);
1470 let per_worker_budget =
1471 (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / max_concurrent_chunks).max(bytes_per_obs);
1472 let budget_obs = (per_worker_budget / bytes_per_obs).max(1);
1473 (budget_obs.min(n_obs), max_concurrent_chunks)
1474}
1475
1476struct MultiBlockAloScratch {
1477 a_i: Vec<f64>,
1478 wa: Vec<f64>,
1479 aw: Vec<f64>,
1480 imwa: Vec<f64>,
1481 imaw: Vec<f64>,
1482 perm_imwa: Vec<usize>,
1483 perm_imaw: Vec<usize>,
1484 delta_eta: Vec<f64>,
1485 rhs_buf: Vec<f64>,
1486 w_u: Vec<f64>,
1487 var_diag_buf: Vec<f64>,
1488 w_flat: Vec<f64>,
1489 lu_scratch: Vec<f64>,
1490}
1491
1492impl MultiBlockAloScratch {
1493 fn new(b: usize) -> Self {
1494 let bb_sz = b * b;
1495 Self {
1496 a_i: vec![0.0f64; bb_sz],
1497 wa: vec![0.0f64; bb_sz],
1498 aw: vec![0.0f64; bb_sz],
1499 imwa: vec![0.0f64; bb_sz],
1500 imaw: vec![0.0f64; bb_sz],
1501 perm_imwa: vec![0usize; b],
1502 perm_imaw: vec![0usize; b],
1503 delta_eta: vec![0.0f64; b],
1504 rhs_buf: vec![0.0f64; b],
1505 w_u: vec![0.0f64; b],
1506 var_diag_buf: vec![0.0f64; b],
1507 w_flat: vec![0.0f64; bb_sz],
1508 lu_scratch: vec![0.0f64; b],
1509 }
1510 }
1511}
1512
1513struct MultiBlockAloChunkDiagnostics {
1514 chunk_start: usize,
1515 eta_tilde: Vec<Array1<f64>>,
1516 leverage: Vec<f64>,
1517 alo_variance: Vec<Array1<f64>>,
1518 cook_distance: Vec<f64>,
1519}
1520
1521fn compute_multiblock_alo_chunk(
1522 input: &MultiBlockAloInput,
1523 col_offsets: &[usize],
1524 chunk_start: usize,
1525 chunk_end: usize,
1526 scratch: &mut MultiBlockAloScratch,
1527) -> Result<MultiBlockAloChunkDiagnostics, AloError> {
1528 let b = input.n_blocks;
1529 let chunk_len = chunk_end - chunk_start;
1530
1531 let mut q_blocks = Vec::with_capacity(b);
1532 for blk in 0..b {
1533 let x_chunk_t = input.block_designs[blk]
1534 .slice(s![chunk_start..chunk_end, ..])
1535 .t()
1536 .to_owned();
1537 let off_b = col_offsets[blk];
1538 let h_slice = input
1539 .penalized_hessian_inv
1540 .slice(s![.., off_b..off_b + x_chunk_t.nrows()])
1541 .to_owned();
1542 q_blocks.push(h_slice.dot(&x_chunk_t));
1543 }
1544
1545 let mut eta_tilde = Vec::with_capacity(chunk_len);
1546 let mut leverage = vec![0.0f64; chunk_len];
1547 let mut alo_variance = Vec::with_capacity(chunk_len);
1548 let mut cook_distance = vec![0.0f64; chunk_len];
1549
1550 for local_i in 0..chunk_len {
1551 let i = chunk_start + local_i;
1552 let w_i = &input.block_weights[i];
1553
1554 for r in 0..b {
1556 for c in 0..b {
1557 scratch.w_flat[r * b + c] = w_i[(r, c)];
1558 }
1559 }
1560
1561 for a in 0..b {
1563 let x_a = &input.block_designs[a];
1564 let p_a = x_a.ncols();
1565 let off_a = col_offsets[a];
1566 let xa_row = x_a.row(i);
1567 for bb in 0..b {
1568 let q_bb = &q_blocks[bb];
1569 let mut dot = 0.0f64;
1570 for k in 0..p_a {
1571 dot += xa_row[k] * q_bb[(off_a + k, local_i)];
1572 }
1573 scratch.a_i[a * b + bb] = dot;
1574 }
1575 }
1576
1577 mat_mul_flat(&scratch.w_flat, &scratch.a_i, &mut scratch.wa, b);
1579 mat_mul_flat(&scratch.a_i, &scratch.w_flat, &mut scratch.aw, b);
1581
1582 let mut tr = 0.0f64;
1585 for d in 0..b {
1586 tr += scratch.aw[d * b + d];
1587 }
1588 leverage[local_i] = tr;
1589
1590 for r in 0..b {
1592 for c in 0..b {
1593 let idx = r * b + c;
1594 let id = if r == c { 1.0 } else { 0.0 };
1595 scratch.imwa[idx] = id - scratch.wa[idx];
1596 scratch.imaw[idx] = id - scratch.aw[idx];
1597 }
1598 }
1599
1600 if !lu_factor_in_place(&mut scratch.imwa, &mut scratch.perm_imwa, b) {
1606 for r in 0..b {
1607 for c in 0..b {
1608 let idx = r * b + c;
1609 let id = if r == c { 1.0 } else { 0.0 };
1610 scratch.imwa[idx] = id - scratch.wa[idx];
1611 }
1612 }
1613 for d in 0..b {
1614 scratch.imwa[d * b + d] += ALO_LOCAL_BLOCK_RIDGE;
1615 }
1616 let refactored = lu_factor_in_place(&mut scratch.imwa, &mut scratch.perm_imwa, b);
1617 assert!(
1618 refactored,
1619 "ALO local block remained singular after ridge regularization"
1620 );
1621 }
1622 if !lu_factor_in_place(&mut scratch.imaw, &mut scratch.perm_imaw, b) {
1623 for r in 0..b {
1624 for c in 0..b {
1625 let idx = r * b + c;
1626 let id = if r == c { 1.0 } else { 0.0 };
1627 scratch.imaw[idx] = id - scratch.aw[idx];
1628 }
1629 }
1630 for d in 0..b {
1631 scratch.imaw[d * b + d] += ALO_LOCAL_BLOCK_RIDGE;
1632 }
1633 let refactored = lu_factor_in_place(&mut scratch.imaw, &mut scratch.perm_imaw, b);
1634 assert!(
1635 refactored,
1636 "ALO local variance block remained singular after ridge regularization"
1637 );
1638 }
1639
1640 let s_i = &input.scores[i];
1642 for k in 0..b {
1643 scratch.rhs_buf[k] = s_i[k];
1644 }
1645 lu_solve_in_place(
1646 &scratch.imwa,
1647 &scratch.perm_imwa,
1648 &mut scratch.rhs_buf,
1649 &mut scratch.lu_scratch,
1650 b,
1651 );
1652 for r in 0..b {
1654 let mut acc = 0.0f64;
1655 let row_off = r * b;
1656 for k in 0..b {
1657 acc += scratch.a_i[row_off + k] * scratch.rhs_buf[k];
1658 }
1659 scratch.delta_eta[r] = acc;
1660 }
1661
1662 let eta_i = &input.eta_hat[i];
1663 let mut corrected = Array1::<f64>::zeros(b);
1664 for d in 0..b {
1665 corrected[d] = eta_i[d] + scratch.delta_eta[d];
1666 }
1667 eta_tilde.push(corrected);
1668
1669 let mut cook = 0.0f64;
1671 for r in 0..b {
1672 let mut w_delta_r = 0.0f64;
1673 let row_off = r * b;
1674 for k in 0..b {
1675 w_delta_r += scratch.w_flat[row_off + k] * scratch.delta_eta[k];
1676 }
1677 cook += scratch.delta_eta[r] * w_delta_r;
1678 }
1679 cook_distance[local_i] = cook;
1680
1681 for d in 0..b {
1687 let row_off = d * b;
1688 for k in 0..b {
1690 scratch.rhs_buf[k] = scratch.a_i[row_off + k];
1691 }
1692 lu_solve_in_place(
1693 &scratch.imaw,
1694 &scratch.perm_imaw,
1695 &mut scratch.rhs_buf,
1696 &mut scratch.lu_scratch,
1697 b,
1698 );
1699 for r in 0..b {
1701 let mut acc = 0.0f64;
1702 let wr = r * b;
1703 for k in 0..b {
1704 acc += scratch.w_flat[wr + k] * scratch.rhs_buf[k];
1705 }
1706 scratch.w_u[r] = acc;
1707 }
1708 lu_solve_in_place(
1710 &scratch.imwa,
1711 &scratch.perm_imwa,
1712 &mut scratch.w_u,
1713 &mut scratch.lu_scratch,
1714 b,
1715 );
1716 let mut v_dd = 0.0f64;
1718 for k in 0..b {
1719 v_dd += scratch.a_i[row_off + k] * scratch.w_u[k];
1720 }
1721 scratch.var_diag_buf[d] = v_dd.max(0.0);
1722 }
1723 let mut var_diag = Array1::<f64>::zeros(b);
1724 for d in 0..b {
1725 var_diag[d] = scratch.var_diag_buf[d];
1726 }
1727 alo_variance.push(var_diag);
1728 }
1729
1730 Ok(MultiBlockAloChunkDiagnostics {
1731 chunk_start,
1732 eta_tilde,
1733 leverage,
1734 alo_variance,
1735 cook_distance,
1736 })
1737}
1738
1739#[inline]
1741fn mat_mul_flat(a: &[f64], b_mat: &[f64], out: &mut [f64], b: usize) {
1742 for r in 0..b {
1743 let ar = r * b;
1744 let or = r * b;
1745 for c in 0..b {
1746 let mut acc = 0.0f64;
1747 for k in 0..b {
1748 acc += a[ar + k] * b_mat[k * b + c];
1749 }
1750 out[or + c] = acc;
1751 }
1752 }
1753}
1754
1755fn lu_factor_in_place(m: &mut [f64], perm: &mut [usize], b: usize) -> bool {
1762 for i in 0..b {
1763 perm[i] = i;
1764 }
1765 for col in 0..b {
1766 let mut max_val = m[col * b + col].abs();
1768 let mut max_idx = col;
1769 for row in (col + 1)..b {
1770 let v = m[row * b + col].abs();
1771 if v > max_val {
1772 max_val = v;
1773 max_idx = row;
1774 }
1775 }
1776 if max_val < LU_PIVOT_SINGULAR_TOL {
1777 return false;
1778 }
1779 if max_idx != col {
1780 for k in 0..b {
1782 m.swap(col * b + k, max_idx * b + k);
1783 }
1784 perm.swap(col, max_idx);
1785 }
1786 let pivot = m[col * b + col];
1787 for row in (col + 1)..b {
1788 let factor = m[row * b + col] / pivot;
1789 m[row * b + col] = factor; for k in (col + 1)..b {
1791 let upd = factor * m[col * b + k];
1792 m[row * b + k] -= upd;
1793 }
1794 }
1795 }
1796 true
1797}
1798
1799fn lu_solve_in_place(m: &[f64], perm: &[usize], rhs: &mut [f64], scratch: &mut [f64], b: usize) {
1802 let y = &mut scratch[..b];
1804 for row in 0..b {
1805 let mut s = rhs[perm[row]];
1806 for k in 0..row {
1807 s -= m[row * b + k] * y[k];
1808 }
1809 y[row] = s;
1810 }
1811 for row in (0..b).rev() {
1813 let mut s = y[row];
1814 for k in (row + 1)..b {
1815 s -= m[row * b + k] * rhs[k];
1816 }
1817 rhs[row] = s / m[row * b + row];
1818 }
1819}
1820
1821pub fn compute_multiblock_alo_leverages(
1829 n_obs: usize,
1830 n_blocks: usize,
1831 block_designs: &[Array2<f64>],
1832 penalized_hessian_inv: &Array2<f64>,
1833 block_weights: &[Array2<f64>],
1834) -> Result<Array1<f64>, EstimationError> {
1835 use rayon::prelude::*;
1836
1837 let n = n_obs;
1838 let b = n_blocks;
1839 let p_tot = penalized_hessian_inv.nrows();
1840
1841 let col_offsets = multiblock_col_offsets(block_designs);
1842 let max_workers = rayon::current_num_threads();
1843 let chunk_size = multiblock_alo_parallel_leverage_chunk_size(p_tot, b, n, max_workers);
1844
1845 let mut leverage = Array1::<f64>::zeros(n);
1846
1847 let block_widths: Vec<usize> = block_designs.iter().map(|d| d.ncols()).collect();
1851 let mut h_stripes: Vec<FaerMat<f64>> = block_widths
1852 .iter()
1853 .map(|&p_blk| FaerMat::<f64>::zeros(p_tot, p_blk))
1854 .collect();
1855 for blk in 0..b {
1858 let off_b = col_offsets[blk];
1859 let p_blk = block_widths[blk];
1860 let stripe = &mut h_stripes[blk];
1861 for c in 0..p_blk {
1862 for r in 0..p_tot {
1863 stripe[(r, c)] = penalized_hessian_inv[(r, off_b + c)];
1864 }
1865 }
1866 }
1867
1868 leverage
1869 .as_slice_mut()
1870 .expect("newly allocated Array1 is contiguous")
1871 .par_chunks_mut(chunk_size)
1872 .enumerate()
1873 .for_each(|(chunk_idx, leverage_chunk)| {
1874 let chunk_start = chunk_idx * chunk_size;
1875 let chunk_len = leverage_chunk.len();
1876 let chunk_end = chunk_start + chunk_len;
1877
1878 let bb_sz = b * b;
1882 let mut a_i = vec![0.0f64; bb_sz];
1883 let mut aw = vec![0.0f64; bb_sz];
1884 let mut w_flat = vec![0.0f64; bb_sz];
1885
1886 let mut q_storage: Vec<FaerMat<f64>> = block_widths
1890 .iter()
1891 .map(|_| FaerMat::<f64>::zeros(p_tot, chunk_len))
1892 .collect();
1893
1894 let mut xt_storage: Vec<FaerMat<f64>> = block_widths
1898 .iter()
1899 .map(|&p_blk| FaerMat::<f64>::zeros(p_blk, chunk_len))
1900 .collect();
1901
1902 for blk in 0..b {
1907 let p_blk = block_widths[blk];
1908
1909 let x_chunk = block_designs[blk].slice(s![chunk_start..chunk_end, ..]);
1910 let xt = &mut xt_storage[blk];
1911 for local_i in 0..chunk_len {
1912 let row = x_chunk.row(local_i);
1913 for j in 0..p_blk {
1914 xt[(j, local_i)] = row[j];
1915 }
1916 }
1917
1918 matmul(
1919 q_storage[blk].as_mut(),
1920 Accum::Replace,
1921 h_stripes[blk].as_ref(),
1922 xt_storage[blk].as_ref(),
1923 1.0,
1924 Par::Seq,
1925 );
1926 }
1927
1928 for local_i in 0..chunk_len {
1929 let i = chunk_start + local_i;
1930 let w_i = &block_weights[i];
1931
1932 for r in 0..b {
1934 for c in 0..b {
1935 w_flat[r * b + c] = w_i[(r, c)];
1936 }
1937 }
1938
1939 for r in 0..bb_sz {
1943 a_i[r] = 0.0;
1944 }
1945 for k in 0..b {
1946 let q_k = &q_storage[k];
1947 let q_col = q_k.col_as_slice(local_i);
1948 for a in 0..b {
1949 let p_a = block_widths[a];
1950 let off_a = col_offsets[a];
1951 let xa_row = block_designs[a].row(i);
1952 let mut dot = 0.0f64;
1953 for j in 0..p_a {
1954 dot = xa_row[j].mul_add(q_col[off_a + j], dot);
1955 }
1956 a_i[a * b + k] = dot;
1957 }
1958 }
1959
1960 mat_mul_flat(&a_i, &w_flat, &mut aw, b);
1962 let mut tr = 0.0f64;
1963 for d in 0..b {
1964 tr += aw[d * b + d];
1965 }
1966 leverage_chunk[local_i] = tr;
1967 }
1968 });
1969
1970 Ok(leverage)
1971}
1972
1973#[cfg(test)]
1977mod tests {
1978 use super::{
1979 ALO_EXACT_SCALAR_MAX_ITERS, AloExactScalarError, AloInput, alo_eta_exact_frozen_curvature,
1980 alo_eta_updatewith_offset, bayesvar_eta, compute_alo_from_input_inner,
1981 percentile_from_sorted, percentile_index, sandwichvar_eta_from_meat,
1982 };
1983 use gam_linalg::matrix::{PsdWeightsView, SignedWeightsView};
1984 use gam_problem::LinkFunction;
1985
1986 #[test]
1987 fn alo_offset_update_matches_centered_algebra() {
1988 let eta_hat = 11.0;
1989 let z = 13.0;
1990 let offset = 10.0;
1991 let x_hinv_x = 0.2;
1992 let hessian_weight = 1.0;
1993 let score_weight = 1.0;
1994 let leverage = hessian_weight * x_hinv_x;
1996 let expected = offset + ((eta_hat - offset) - leverage * (z - offset)) / (1.0 - leverage);
1997 let got =
1998 alo_eta_updatewith_offset(eta_hat, z, offset, x_hinv_x, score_weight, 1.0 - leverage);
1999 assert!((got - expected).abs() < 1e-12);
2000 }
2001
2002 #[test]
2003 fn alo_offset_update_reduces_to_classicwhen_offsetzero() {
2004 let eta_hat = 1.25;
2005 let z = -0.5;
2006 let x_hinv_x = 0.35;
2007 let hessian_weight = 1.0;
2008 let score_weight = 1.0;
2009 let leverage = hessian_weight * x_hinv_x;
2010 let expected = (eta_hat - leverage * z) / (1.0 - leverage);
2011 let got =
2012 alo_eta_updatewith_offset(eta_hat, z, 0.0, x_hinv_x, score_weight, 1.0 - leverage);
2013 assert!((got - expected).abs() < 1e-12);
2014 }
2015
2016 #[test]
2017 fn alo_offset_update_uses_distinct_score_and_hessian_weights() {
2018 let eta_hat = 1.7;
2019 let z = 0.4;
2020 let offset = -0.2;
2021 let x_hinv_x = 0.15;
2022 let hessian_weight = 3.0;
2023 let score_weight = 5.0;
2024 let expected = offset
2025 + (eta_hat - offset)
2026 + x_hinv_x * score_weight * ((eta_hat - offset) - (z - offset))
2027 / (1.0 - hessian_weight * x_hinv_x);
2028 let got = alo_eta_updatewith_offset(
2029 eta_hat,
2030 z,
2031 offset,
2032 x_hinv_x,
2033 score_weight,
2034 1.0 - hessian_weight * x_hinv_x,
2035 );
2036 assert!((got - expected).abs() < 1e-12);
2037 }
2038
2039 #[test]
2040 fn alo_offset_update_handles_zero_hessian_weight() {
2041 let eta_hat = 0.8;
2042 let z = -0.3;
2043 let offset = 0.1;
2044 let x_hinv_x = 0.4;
2045 let hessian_weight = 0.0;
2046 let score_weight = 2.5;
2047 let expected = offset
2048 + (eta_hat - offset)
2049 + x_hinv_x * score_weight * ((eta_hat - offset) - (z - offset));
2050 let got = alo_eta_updatewith_offset(
2051 eta_hat,
2052 z,
2053 offset,
2054 x_hinv_x,
2055 score_weight,
2056 1.0 - hessian_weight * x_hinv_x,
2057 );
2058 assert!((got - expected).abs() < 1e-12);
2059 }
2060
2061 #[test]
2062 fn alo_exact_frozen_curvature_converges_to_fixed_point() {
2063 let eta_hat = 1.0;
2064 let a_ii = 0.4;
2065 let got = alo_eta_exact_frozen_curvature(eta_hat, a_ii, &|eta| (0.5 * (eta - 2.0), 0.5))
2066 .expect("linear scalar fixed point should converge in one Newton step");
2067 assert!((got - 0.75).abs() < 1e-12);
2068 }
2069
2070 #[test]
2071 fn alo_exact_frozen_curvature_reports_nonconvergence() {
2072 let err = alo_eta_exact_frozen_curvature(0.0, 1.0, &|eta| (eta + 1.0, 0.0))
2073 .expect_err("constant residual should exhaust the scalar iteration budget");
2074 let AloExactScalarError::MaxIterations { iterations, .. } = err else {
2075 panic!("constant residual must report MaxIterations, got {err:?}");
2076 };
2077 assert_eq!(
2078 iterations, ALO_EXACT_SCALAR_MAX_ITERS,
2079 "non-convergence must report the full scalar iteration budget"
2080 );
2081 }
2082
2083 #[test]
2084 fn alo_input_reports_exact_scalar_nonconvergence_with_row_context() {
2085 let design = Array2::from_elem((1, 1), 1.0);
2086 let penalized_hessian = Array2::from_elem((1, 1), 1.0);
2087 let hessian_weights = Array1::from_vec(vec![0.0]);
2088 let score_weights = Array1::from_vec(vec![0.0]);
2089 let working_response = Array1::from_vec(vec![0.0]);
2090 let eta = Array1::from_vec(vec![0.0]);
2091 let offset = Array1::from_vec(vec![0.0]);
2092 let score_curvature = |_: usize, eta: f64| (eta + 1.0, 0.0);
2093 let input = AloInput {
2094 design: &design,
2095 penalized_hessian: &penalized_hessian,
2096 hessian_weights: SignedWeightsView::from_array(&hessian_weights),
2097 score_weights: PsdWeightsView::try_from_array(&score_weights).expect("psd weights"),
2098 working_response: &working_response,
2099 eta: &eta,
2100 offset: &offset,
2101 link: LinkFunction::Logit,
2102 phi: 1.0,
2103 penalty_root: None,
2104 ridge: 0.0,
2105 score_curvature: Some(&score_curvature),
2106 };
2107
2108 let err =
2109 compute_alo_from_input_inner(&input).expect_err("non-converged exact ALO must error");
2110 let msg = err.to_string();
2111 assert!(
2112 msg.contains("ALO exact frozen-curvature solve failed at row 0"),
2113 "missing row context in exact ALO error: {msg}"
2114 );
2115 assert!(
2116 msg.contains("did not converge within"),
2117 "missing non-convergence cause in exact ALO error: {msg}"
2118 );
2119 }
2120
2121 #[test]
2122 fn gaussian_unpenalized_direct_sandwich_equals_bayes() {
2123 let phi = 2.5;
2126 let x_hinv_x = 0.3;
2127 let vb = bayesvar_eta(phi, x_hinv_x);
2128 let vs = sandwichvar_eta_from_meat(phi, x_hinv_x);
2129 assert!((vb - vs).abs() < 1e-12);
2130 }
2131
2132 #[test]
2133 fn sandwich_from_direct_meat_scales_by_phi() {
2134 let phi = 1.7;
2135 let meat_quad = 0.358;
2136 let got = sandwichvar_eta_from_meat(phi, meat_quad);
2137 let expected = phi * meat_quad;
2138 assert!((got - expected).abs() < 1e-12);
2139 }
2140
2141 #[test]
2142 fn sandwich_meat_uses_score_weights_not_hessian_weights_noncanonical() {
2143 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 1.0, 2.0, 1.0]).unwrap();
2152 let w_h_vec = Array1::from_vec(vec![1.0, -1.0, 1.0, -1.0, 0.5]);
2155 let w_s_vec = Array1::from_vec(vec![1.0, 0.8, 1.2, 0.6, 0.9]);
2157 let phi = 1.3;
2158
2159 let n = x.nrows();
2160 let sum_wh_x2: f64 = (0..n).map(|i| w_h_vec[i] * x[[i, 0]] * x[[i, 0]]).sum();
2161 let sum_ws_x2: f64 = (0..n).map(|i| w_s_vec[i] * x[[i, 0]] * x[[i, 0]]).sum();
2162 assert!(sum_wh_x2 < 0.0, "fixture must exercise a negative W_H meat");
2166 assert!(sum_ws_x2 > 0.0);
2167
2168 let s0 = 8.0_f64;
2170 let h = s0 + sum_wh_x2; assert!(h > 0.0, "penalized Hessian must stay PD");
2172 let penalized_hessian = Array2::from_elem((1, 1), h);
2173
2174 let old_meat_obs1 = x[[1, 0]] * x[[1, 0]] / (h * h) * sum_wh_x2;
2177 assert!(
2178 phi * old_meat_obs1 < -super::variance_negative_tolerance(phi * old_meat_obs1.abs()),
2179 "the pre-fix W_H meat must be materially negative (guard would trip)"
2180 );
2181
2182 let working_response = Array1::from_vec(vec![0.3, -0.2, 0.5, 0.1, -0.4]);
2183 let eta = Array1::from_vec(vec![0.2, 0.1, 0.4, -0.1, 0.05]);
2184 let offset = Array1::zeros(n);
2185 let input = AloInput {
2186 design: &x,
2187 penalized_hessian: &penalized_hessian,
2188 hessian_weights: SignedWeightsView::from_array(&w_h_vec),
2189 score_weights: PsdWeightsView::try_from_array(&w_s_vec).expect("psd weights"),
2190 working_response: &working_response,
2191 eta: &eta,
2192 offset: &offset,
2193 link: LinkFunction::Probit,
2194 phi,
2195 penalty_root: None,
2196 ridge: 0.0,
2197 score_curvature: None,
2198 };
2199
2200 let diag = compute_alo_from_input_inner(&input)
2202 .expect("fixed sandwich meat (W_S) must not trip the negative-variance guard");
2203
2204 for obs in 0..n {
2206 let expected =
2207 (phi * x[[obs, 0]] * x[[obs, 0]] / (h * h) * sum_ws_x2).sqrt();
2208 assert!(
2209 (diag.se_sandwich[obs] - expected).abs() <= 1e-10 * expected.max(1.0),
2210 "row {obs}: se_sandwich={} expected={expected}",
2211 diag.se_sandwich[obs]
2212 );
2213 }
2214 }
2215
2216 #[test]
2217 fn percentile_index_matches_expected_rounding() {
2218 assert_eq!(percentile_index(0, 0.95), 0);
2219 assert_eq!(percentile_index(1, 0.95), 0);
2220 assert_eq!(percentile_index(10, 0.50), 5);
2221 assert_eq!(percentile_index(10, 0.95), 9);
2222 }
2223
2224 #[test]
2225 fn percentile_from_sorted_returns_order_statistic() {
2226 let values = [1.0, 2.0, 3.0, 4.0, 5.0];
2227 assert_eq!(percentile_from_sorted(&values, 0.50), 3.0);
2228 assert_eq!(percentile_from_sorted(&values, 0.95), 5.0);
2229 assert_eq!(percentile_from_sorted(&[], 0.95), 0.0);
2230 }
2231
2232 use super::{MultiBlockAloInput, compute_multiblock_alo, compute_multiblock_alo_leverages};
2235 use ndarray::{Array1, Array2};
2236
2237 #[test]
2238 fn multiblock_b1_matches_scalar_leverage() {
2239 let n = 3;
2242 let p = 2;
2243 let x = Array2::from_shape_vec((n, p), vec![1.0, 0.5, 0.8, -0.3, 0.2, 1.1]).unwrap();
2244 let w = [1.0, 2.0, 0.5];
2246 let mut h = Array2::<f64>::eye(p);
2247 for i in 0..n {
2248 for r in 0..p {
2249 for c in 0..p {
2250 h[(r, c)] += w[i] * x[(i, r)] * x[(i, c)];
2251 }
2252 }
2253 }
2254 let det = h[(0, 0)] * h[(1, 1)] - h[(0, 1)] * h[(1, 0)];
2256 let mut h_inv = Array2::<f64>::zeros((p, p));
2257 h_inv[(0, 0)] = h[(1, 1)] / det;
2258 h_inv[(1, 1)] = h[(0, 0)] / det;
2259 h_inv[(0, 1)] = -h[(0, 1)] / det;
2260 h_inv[(1, 0)] = -h[(1, 0)] / det;
2261
2262 let mut scalar_lev = vec![0.0f64; n];
2264 for i in 0..n {
2265 let mut xhx = 0.0;
2266 for r in 0..p {
2267 for c in 0..p {
2268 xhx += x[(i, r)] * h_inv[(r, c)] * x[(i, c)];
2269 }
2270 }
2271 scalar_lev[i] = w[i] * xhx;
2272 }
2273
2274 let block_designs = vec![x.clone()];
2276 let block_weights: Vec<Array2<f64>> =
2277 w.iter().map(|&wi| Array2::from_elem((1, 1), wi)).collect();
2278 let scores: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.1])).collect();
2279 let eta_hat: Vec<Array1<f64>> = (0..n).map(|i| Array1::from_vec(vec![i as f64])).collect();
2280
2281 let input = MultiBlockAloInput {
2282 n_obs: n,
2283 n_blocks: 1,
2284 block_designs: &block_designs,
2285 penalized_hessian_inv: &h_inv,
2286 block_weights,
2287 scores,
2288 eta_hat,
2289 };
2290
2291 let result = compute_multiblock_alo(&input).unwrap();
2292 for i in 0..n {
2293 assert!(
2294 (result.leverage[i] - scalar_lev[i]).abs() < 1e-10,
2295 "leverage mismatch at i={}: got {}, expected {}",
2296 i,
2297 result.leverage[i],
2298 scalar_lev[i]
2299 );
2300 }
2301 }
2302
2303 #[test]
2304 fn multiblock_leverage_only_matches_full() {
2305 let n = 4;
2308 let p1 = 2;
2309 let p2 = 3;
2310 let x1 = Array2::from_shape_fn((n, p1), |(i, j)| (i + j + 1) as f64 * 0.3);
2311 let x2 = Array2::from_shape_fn((n, p2), |(i, j)| (i * 2 + j) as f64 * 0.2 - 0.1);
2312 let p_tot = p1 + p2;
2313 let h_inv = Array2::<f64>::eye(p_tot); let block_weights: Vec<Array2<f64>> = (0..n)
2315 .map(|i| {
2316 let v = (i + 1) as f64;
2317 Array2::from_shape_vec((2, 2), vec![v, 0.1, 0.1, v * 0.5]).unwrap()
2318 })
2319 .collect();
2320 let scores: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.0, 0.0])).collect();
2321 let eta_hat: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.0, 0.0])).collect();
2322 let block_designs = vec![x1.clone(), x2.clone()];
2323
2324 let input = MultiBlockAloInput {
2325 n_obs: n,
2326 n_blocks: 2,
2327 block_designs: &block_designs,
2328 penalized_hessian_inv: &h_inv,
2329 block_weights: block_weights.clone(),
2330 scores,
2331 eta_hat,
2332 };
2333 let full = compute_multiblock_alo(&input).unwrap();
2334 let lev_only =
2335 compute_multiblock_alo_leverages(n, 2, &block_designs, &h_inv, &block_weights).unwrap();
2336
2337 for i in 0..n {
2338 assert!(
2339 (full.leverage[i] - lev_only[i]).abs() < 1e-12,
2340 "leverage mismatch at i={}: full={}, lev_only={}",
2341 i,
2342 full.leverage[i],
2343 lev_only[i]
2344 );
2345 }
2346 }
2347
2348 #[test]
2349 fn multiblock_singular_weight_still_corrects() {
2350 let n = 1;
2354 let p = 2;
2355 let x = Array2::from_shape_vec((1, p), vec![1.0, 0.5]).unwrap();
2356 let h_inv = Array2::eye(p);
2357 let block_designs = vec![x.clone()];
2358 let block_weights = vec![Array2::from_elem((1, 1), 0.0)]; let scores = vec![Array1::from_vec(vec![1.0])];
2360 let eta_hat = vec![Array1::from_vec(vec![std::f64::consts::PI])];
2361
2362 let input = MultiBlockAloInput {
2363 n_obs: n,
2364 n_blocks: 1,
2365 block_designs: &block_designs,
2366 penalized_hessian_inv: &h_inv,
2367 block_weights,
2368 scores,
2369 eta_hat,
2370 };
2371 let result = compute_multiblock_alo(&input).unwrap();
2372 let expected = std::f64::consts::PI + 1.25;
2374 assert!(
2375 (result.eta_tilde[0][0] - expected).abs() < 1e-12,
2376 "expected {}, got {}",
2377 expected,
2378 result.eta_tilde[0][0]
2379 );
2380 assert!(result.cook_distance[0].abs() < 1e-14);
2382 assert!(result.alo_variance[0][0].abs() < 1e-14);
2384 }
2385
2386 #[test]
2387 fn multiblock_cook_and_variance_basic() {
2388 let n = 1;
2390 let x = Array2::from_elem((1, 1), 1.0);
2391 let h_inv = Array2::from_elem((1, 1), 0.5);
2393 let block_designs = vec![x.clone()];
2394 let w_val = 2.0;
2395 let s_val = 0.4;
2396 let block_weights = vec![Array2::from_elem((1, 1), w_val)];
2397 let scores = vec![Array1::from_vec(vec![s_val])];
2398 let eta_hat = vec![Array1::from_vec(vec![1.0])];
2399
2400 let input = MultiBlockAloInput {
2401 n_obs: n,
2402 n_blocks: 1,
2403 block_designs: &block_designs,
2404 penalized_hessian_inv: &h_inv,
2405 block_weights,
2406 scores,
2407 eta_hat,
2408 };
2409 let result = compute_multiblock_alo(&input).unwrap();
2410
2411 assert!(result.eta_tilde[0][0].is_finite());
2418 assert!(result.cook_distance[0].is_finite());
2419 assert!(result.alo_variance[0][0].is_finite());
2420 }
2421}