1use crate::estimate::{BlockRole, EstimationError, FittedLinkState, UnifiedFitResult};
2use crate::families::bernoulli_marginal_slope::{
3 EmpiricalZGrid, LatentMeasureKind, bernoulli_marginal_link_map,
4 empirical_intercept_from_marginal,
5};
6use crate::families::lognormal_kernel::FrailtySpec;
7use crate::families::marginal_slope_shared::{
8 ObservedDenestedCellPartials, eval_coeff4_at,
9 probit_frailty_scale as marginal_slope_probit_frailty_scale, scale_coeff4,
10};
11use crate::families::strategy::{FamilyStrategy, strategy_for_family, strategy_from_fit};
12use crate::inference::model::{
13 SavedAnchoredDeviationRuntime, SavedLatentZNormalization, SavedLinkWiggleRuntime,
14};
15use crate::inference::prediction_linalg::{
16 PredictionCovarianceBackend, design_row_chunk, prediction_chunk_rows,
17 rowwise_local_covariances_parallel,
18};
19use crate::linalg::utils::predict_gam_dimension_mismatch_message;
20use crate::matrix::{DesignMatrix, SymmetricMatrix};
21use crate::mixture_link::{
22 InverseLinkJet, beta_logistic_inverse_link_jetwith_param_partials,
23 mixture_inverse_link_jetwith_rho_partials_into, sas_inverse_link_jetwith_param_partials,
24};
25use crate::probability::{normal_cdf, normal_pdf, standard_normal_quantile};
26use crate::quadrature::QuadratureContext;
27use crate::types::{InverseLink, LikelihoodFamily};
28use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
29use rayon::iter::{IntoParallelIterator, ParallelIterator};
30
31thread_local! {
32 static PREDICT_QUADRATURE_CONTEXT: QuadratureContext = QuadratureContext::new();
33}
34
35pub fn se_from_covariance(cov: &Array2<f64>) -> Array1<f64> {
37 Array1::from_iter(cov.diag().iter().map(|&v| v.max(0.0).sqrt()))
38}
39
40fn apply_family_inverse_link(
41 eta: &Array1<f64>,
42 family: crate::types::LikelihoodFamily,
43 link_kind: Option<&InverseLink>,
44) -> Result<Array1<f64>, EstimationError> {
45 strategy_for_family(family, link_kind).inverse_link_array(eta.view())
46}
47
48fn local_covariances_with_backend<F>(
49 backend: &PredictionCovarianceBackend<'_>,
50 n_rows: usize,
51 local_dim: usize,
52 build_chunk: F,
53) -> Result<Vec<Vec<Array1<f64>>>, EstimationError>
54where
55 F: Fn(std::ops::Range<usize>) -> Result<Vec<Array2<f64>>, String> + Sync,
56{
57 rowwise_local_covariances_parallel(backend, n_rows, local_dim, build_chunk)
58 .map_err(EstimationError::InvalidInput)
59}
60
61fn usable_penalized_hessian<'a>(
62 fit: &'a UnifiedFitResult,
63 expected_dim: usize,
64 label: &str,
65) -> Option<&'a Array2<f64>> {
66 let hessian = fit.penalized_hessian()?;
67 if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
68 log::warn!(
69 "{label}: ignoring penalized Hessian with shape {}x{}; expected {}x{}",
70 hessian.nrows(),
71 hessian.ncols(),
72 expected_dim,
73 expected_dim
74 );
75 return None;
76 }
77 if !hessian.iter().any(|value| value.abs() > 0.0) {
78 log::warn!("{label}: ignoring zero penalized Hessian placeholder");
79 return None;
80 }
81 Some(hessian)
82}
83
84fn conditional_prediction_backend<'a>(
85 fit: &'a UnifiedFitResult,
86 expected_dim: usize,
87 label: &str,
88) -> Option<PredictionCovarianceBackend<'a>> {
89 if let Some(covariance) = fit.beta_covariance() {
104 if covariance.nrows() == expected_dim && covariance.ncols() == expected_dim {
105 return Some(PredictionCovarianceBackend::from_dense(covariance.view()));
106 }
107 log::warn!(
108 "{label}: ignoring conditional covariance with shape {}x{}; expected {}x{}",
109 covariance.nrows(),
110 covariance.ncols(),
111 expected_dim,
112 expected_dim
113 );
114 }
115 if let Some(hessian) = usable_penalized_hessian(fit, expected_dim, label) {
116 match PredictionCovarianceBackend::from_factorized_hessian(SymmetricMatrix::Dense(
117 hessian.clone(),
118 )) {
119 Ok(backend) => return Some(backend),
120 Err(err) => {
121 log::warn!(
122 "{label}: failed to build factorized prediction precision backend: {err}"
123 );
124 }
125 }
126 }
127 None
128}
129
130fn selected_uncertainty_backend<'a>(
131 fit: &'a UnifiedFitResult,
132 expected_dim: usize,
133 requested_mode: InferenceCovarianceMode,
134 label: &str,
135) -> Result<(PredictionCovarianceBackend<'a>, bool), EstimationError> {
136 match requested_mode {
137 InferenceCovarianceMode::Conditional => {
138 conditional_prediction_backend(fit, expected_dim, label)
139 .map(|backend| (backend, false))
140 .ok_or_else(|| {
141 EstimationError::InvalidInput(
142 "fit result does not contain conditional covariance or a usable penalized Hessian"
143 .to_string(),
144 )
145 })
146 }
147 InferenceCovarianceMode::ConditionalPlusSmoothingPreferred => {
148 if let Some(covariance) = fit.beta_covariance_corrected() {
149 if covariance.nrows() != expected_dim || covariance.ncols() != expected_dim {
150 return Err(EstimationError::InvalidInput(format!(
151 "{label}: corrected covariance dimension mismatch: expected {}x{}, got {}x{}",
152 expected_dim,
153 expected_dim,
154 covariance.nrows(),
155 covariance.ncols()
156 )));
157 }
158 Ok((
159 PredictionCovarianceBackend::from_dense(covariance.view()),
160 true,
161 ))
162 } else {
163 selected_uncertainty_backend(
164 fit,
165 expected_dim,
166 InferenceCovarianceMode::Conditional,
167 label,
168 )
169 }
170 }
171 InferenceCovarianceMode::ConditionalPlusSmoothingRequired => {
172 let covariance = fit.beta_covariance_corrected().ok_or_else(|| {
173 EstimationError::InvalidInput(
174 "fit result does not contain smoothing-corrected covariance".to_string(),
175 )
176 })?;
177 if covariance.nrows() != expected_dim || covariance.ncols() != expected_dim {
178 return Err(EstimationError::InvalidInput(format!(
179 "{label}: corrected covariance dimension mismatch: expected {}x{}, got {}x{}",
180 expected_dim,
181 expected_dim,
182 covariance.nrows(),
183 covariance.ncols()
184 )));
185 }
186 Ok((
187 PredictionCovarianceBackend::from_dense(covariance.view()),
188 true,
189 ))
190 }
191 }
192}
193
194#[inline]
202fn quadratic_form(cov: &Array2<f64>, grad: &[f64]) -> Result<f64, EstimationError> {
203 let m = grad.len();
204 if cov.nrows() != m || cov.ncols() != m {
205 return Err(EstimationError::InvalidInput(format!(
206 "covariance/gradient dimension mismatch: covariance is {}x{}, gradient length is {}",
207 cov.nrows(),
208 cov.ncols(),
209 m
210 )));
211 }
212 let mut diag_acc = 0.0_f64;
213 let mut off_acc = 0.0_f64;
214 for i in 0..m {
215 let row = cov.row(i);
216 let row_slice = row.as_slice().expect("Array2 row is contiguous");
217 let gi = grad[i];
218 diag_acc += gi * gi * row_slice[i];
220 let mut row_off = 0.0_f64;
222 for j in (i + 1)..m {
223 row_off += grad[j] * row_slice[j];
224 }
225 off_acc += gi * row_off;
226 }
227 Ok((diag_acc + 2.0 * off_acc).max(0.0))
228}
229
230#[inline]
234fn quadratic_form_from_jetmu(
235 cov: &Array2<f64>,
236 partials: &[InverseLinkJet],
237) -> Result<f64, EstimationError> {
238 let m = partials.len();
239 if cov.nrows() != m || cov.ncols() != m {
240 return Err(EstimationError::InvalidInput(format!(
241 "covariance/mixture-gradient dimension mismatch: covariance is {}x{}, mixture gradient length is {}",
242 cov.nrows(),
243 cov.ncols(),
244 m
245 )));
246 }
247 let mut diag_acc = 0.0_f64;
248 let mut off_acc = 0.0_f64;
249 for i in 0..m {
250 let row = cov.row(i);
251 let row_slice = row.as_slice().expect("Array2 row is contiguous");
252 let gi = partials[i].mu;
253 diag_acc += gi * gi * row_slice[i];
254 let mut row_off = 0.0_f64;
255 for j in (i + 1)..m {
256 row_off += partials[j].mu * row_slice[j];
257 }
258 off_acc += gi * row_off;
259 }
260 Ok((diag_acc + 2.0 * off_acc).max(0.0))
261}
262
263fn linear_predictorvariance_from_backend(
264 x: &DesignMatrix,
265 backend: &PredictionCovarianceBackend<'_>,
266) -> Result<Array1<f64>, EstimationError> {
267 let local = local_covariances_with_backend(backend, x.nrows(), 1, |rows| {
268 Ok(vec![design_row_chunk(x, rows)?])
269 })?;
270 Ok(local[0][0].mapv(|v| v.max(0.0)))
271}
272
273const POSTERIOR_MEAN_VARIANCE_TOL: f64 = 1e-10;
274const POSTERIOR_MEAN_CROSS_TOL: f64 = 1e-10;
275
276fn posterior_mean_backend_or_warn<'a>(
277 fit: &'a UnifiedFitResult,
278 fallback: Option<&'a Array2<f64>>,
279 expected_dim: usize,
280 label: &str,
281) -> Option<PredictionCovarianceBackend<'a>> {
282 for (source, covariance) in [
283 ("fit result", fit.beta_covariance()),
284 ("predictor state", fallback),
285 ] {
286 let Some(covariance) = covariance else {
287 continue;
288 };
289 if covariance.nrows() == expected_dim && covariance.ncols() == expected_dim {
290 return Some(PredictionCovarianceBackend::from_dense(covariance.view()));
291 }
292 log::warn!(
293 "{label}: ignoring {source} covariance with shape {}x{}; expected {}x{}",
294 covariance.nrows(),
295 covariance.ncols(),
296 expected_dim,
297 expected_dim
298 );
299 }
300 if let Some(backend) = conditional_prediction_backend(fit, expected_dim, label) {
301 return Some(backend);
302 }
303 log::warn!(
304 "{label}: covariance/precision unavailable; falling back to plug-in point prediction"
305 );
306 None
307}
308
309fn require_posterior_mean_backend<'a>(
310 fit: &'a UnifiedFitResult,
311 fallback: Option<&'a Array2<f64>>,
312 expected_dim: usize,
313 label: &str,
314) -> Result<PredictionCovarianceBackend<'a>, EstimationError> {
315 posterior_mean_backend_or_warn(fit, fallback, expected_dim, label).ok_or_else(|| {
316 EstimationError::InvalidInput(format!(
317 "{label} requires covariance or penalized Hessian for posterior-mean prediction"
318 ))
319 })
320}
321
322fn project_two_block_linear_predictor_covariance(
323 design_first: &DesignMatrix,
324 design_second: &DesignMatrix,
325 backend: &PredictionCovarianceBackend<'_>,
326 p_first: usize,
327 p_second: usize,
328 label: &str,
329) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
330 let p_total = p_first + p_second;
331 if backend.nrows() != p_total {
332 return Err(EstimationError::InvalidInput(format!(
333 "{label} covariance dimension mismatch: expected parameter dimension {}, got {}",
334 p_total,
335 backend.nrows()
336 )));
337 }
338 if design_first.ncols() != p_first || design_second.ncols() != p_second {
339 return Err(EstimationError::InvalidInput(format!(
340 "{label} design dimension mismatch: threshold/location design has {} columns (expected {}), scale design has {} columns (expected {})",
341 design_first.ncols(),
342 p_first,
343 design_second.ncols(),
344 p_second
345 )));
346 }
347 let local = local_covariances_with_backend(backend, design_first.nrows(), 2, |rows| {
348 let x_first = design_row_chunk(design_first, rows.clone())?;
349 let x_second = design_row_chunk(design_second, rows.clone())?;
350 let rows_in_chunk = rows.end - rows.start;
351 let mut first = Array2::<f64>::zeros((rows_in_chunk, p_total));
352 let mut second = Array2::<f64>::zeros((rows_in_chunk, p_total));
353 first
354 .slice_mut(ndarray::s![.., 0..p_first])
355 .assign(&x_first);
356 second
357 .slice_mut(ndarray::s![.., p_first..p_total])
358 .assign(&x_second);
359 Ok(vec![first, second])
360 })?;
361 Ok((
362 local[0][0].mapv(|v| v.max(0.0)),
363 local[1][1].mapv(|v| v.max(0.0)),
364 local[0][1].clone(),
365 ))
366}
367
368fn linear_predictor_se_from_backend<F>(
369 backend: &PredictionCovarianceBackend<'_>,
370 n_rows: usize,
371 build_chunk: F,
372) -> Result<Array1<f64>, EstimationError>
373where
374 F: Fn(std::ops::Range<usize>) -> Result<Vec<Array2<f64>>, String> + Sync,
375{
376 let local = local_covariances_with_backend(backend, n_rows, 1, build_chunk)?;
377 Ok(local[0][0].mapv(|v| v.max(0.0).sqrt()))
378}
379
380fn padded_design_standard_errors_from_backend(
381 design: &DesignMatrix,
382 backend: &PredictionCovarianceBackend<'_>,
383 leading_zeros: usize,
384 trailing_zeros: usize,
385 label: &str,
386) -> Result<Array1<f64>, EstimationError> {
387 let p_design = design.ncols();
388 let p_total = leading_zeros + p_design + trailing_zeros;
389 if backend.nrows() != p_total {
390 return Err(EstimationError::InvalidInput(format!(
391 "{label} covariance dimension mismatch: expected parameter dimension {p_total}, got {}",
392 backend.nrows()
393 )));
394 }
395 linear_predictor_se_from_backend(backend, design.nrows(), |rows| {
396 let x = design_row_chunk(design, rows)?;
397 let rows_in_chunk = x.nrows();
398 let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
399 grad.slice_mut(ndarray::s![.., leading_zeros..leading_zeros + p_design])
400 .assign(&x);
401 Ok(vec![grad])
402 })
403}
404
405fn projected_bivariate_posterior_mean_result<F>(
406 quadctx: &crate::quadrature::QuadratureContext,
407 mu: [f64; 2],
408 cov: [[f64; 2]; 2],
409 integrand: F,
410) -> Result<f64, EstimationError>
411where
412 F: Fn(f64, f64) -> Result<f64, EstimationError>,
413{
414 let var0 = cov[0][0].max(0.0);
415 let var1 = cov[1][1].max(0.0);
416 let cov01 = cov[0][1];
417
418 if var0 <= POSTERIOR_MEAN_VARIANCE_TOL && var1 <= POSTERIOR_MEAN_VARIANCE_TOL {
419 return integrand(mu[0], mu[1]);
420 }
421 if var0 <= POSTERIOR_MEAN_VARIANCE_TOL && cov01.abs() <= POSTERIOR_MEAN_CROSS_TOL {
422 return crate::quadrature::normal_expectation_nd_adaptive_result::<1, _, _, EstimationError>(
423 quadctx,
424 [mu[1]],
425 [[var1]],
426 21,
427 |x| integrand(mu[0], x[0]),
428 );
429 }
430 if var1 <= POSTERIOR_MEAN_VARIANCE_TOL && cov01.abs() <= POSTERIOR_MEAN_CROSS_TOL {
431 return crate::quadrature::normal_expectation_nd_adaptive_result::<1, _, _, EstimationError>(
432 quadctx,
433 [mu[0]],
434 [[var0]],
435 21,
436 |x| integrand(x[0], mu[1]),
437 );
438 }
439 crate::quadrature::normal_expectation_2d_adaptive_result(quadctx, mu, cov, integrand)
440}
441
442pub struct PredictResult {
443 pub eta: Array1<f64>,
444 pub mean: Array1<f64>,
445}
446
447pub struct PredictInput {
454 pub design: DesignMatrix,
456 pub offset: Array1<f64>,
458 pub design_noise: Option<DesignMatrix>,
460 pub offset_noise: Option<Array1<f64>>,
462 pub auxiliary_scalar: Option<Array1<f64>>,
464 pub auxiliary_matrix: Option<Array2<f64>>,
466}
467
468fn slice_predict_input(
469 input: &PredictInput,
470 rows: std::ops::Range<usize>,
471) -> Result<PredictInput, EstimationError> {
472 Ok(PredictInput {
473 design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
474 design_row_chunk(&input.design, rows.clone()).map_err(EstimationError::InvalidInput)?,
475 )),
476 offset: input.offset.slice(ndarray::s![rows.clone()]).to_owned(),
477 design_noise: input
478 .design_noise
479 .as_ref()
480 .map(|design| {
481 design_row_chunk(design, rows.clone())
482 .map(|d| DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(d)))
483 .map_err(EstimationError::InvalidInput)
484 })
485 .transpose()?,
486 offset_noise: input
487 .offset_noise
488 .as_ref()
489 .map(|offset| offset.slice(ndarray::s![rows.clone()]).to_owned()),
490 auxiliary_scalar: input
491 .auxiliary_scalar
492 .as_ref()
493 .map(|values| values.slice(ndarray::s![rows.clone()]).to_owned()),
494 auxiliary_matrix: input
495 .auxiliary_matrix
496 .as_ref()
497 .map(|values| values.slice(ndarray::s![rows, ..]).to_owned()),
498 })
499}
500
501pub struct PredictionWithSE {
503 pub eta: Array1<f64>,
505 pub mean: Array1<f64>,
507 pub eta_se: Option<Array1<f64>>,
509 pub mean_se: Option<Array1<f64>>,
511}
512
513pub trait PredictableModel {
519 fn predict_plugin_response(
521 &self,
522 input: &PredictInput,
523 ) -> Result<PredictResult, EstimationError>;
524
525 fn predict_linear_predictor(
527 &self,
528 input: &PredictInput,
529 ) -> Result<Array1<f64>, EstimationError> {
530 self.predict_plugin_response(input).map(|pred| pred.eta)
531 }
532
533 fn predict_with_uncertainty(
535 &self,
536 input: &PredictInput,
537 ) -> Result<PredictionWithSE, EstimationError>;
538
539 fn predict_noise_scale(
546 &self,
547 input: &PredictInput,
548 ) -> Result<Option<Array1<f64>>, EstimationError>;
549
550 fn predict_full_uncertainty(
556 &self,
557 input: &PredictInput,
558 fit: &UnifiedFitResult,
559 options: &PredictUncertaintyOptions,
560 ) -> Result<PredictUncertaintyResult, EstimationError>;
561
562 fn predict_posterior_mean(
573 &self,
574 input: &PredictInput,
575 fit: &UnifiedFitResult,
576 confidence_level: Option<f64>,
577 ) -> Result<PredictPosteriorMeanResult, EstimationError>;
578
579 fn n_blocks(&self) -> usize;
581
582 fn block_roles(&self) -> Vec<BlockRole>;
584}
585
586pub struct StandardPredictor {
588 pub beta: Array1<f64>,
589 pub family: crate::types::LikelihoodFamily,
590 pub link_kind: Option<InverseLink>,
591 pub covariance: Option<Array2<f64>>,
592 pub link_wiggle: Option<SavedLinkWiggleRuntime>,
593}
594
595impl StandardPredictor {
596 pub(crate) fn from_unified(
599 unified: &UnifiedFitResult,
600 family: crate::types::LikelihoodFamily,
601 link_kind: Option<InverseLink>,
602 link_wiggle: Option<SavedLinkWiggleRuntime>,
603 ) -> Result<Self, String> {
604 let expected_linkwiggle = link_wiggle.is_some();
605 if !expected_linkwiggle
606 && (unified.n_blocks() != 1 || unified.block_by_role(BlockRole::LinkWiggle).is_some())
607 {
608 return Err(
609 "StandardPredictor only supports single-block standard fits without link wiggles"
610 .to_string(),
611 );
612 }
613 let beta = if expected_linkwiggle {
614 unified
615 .block_by_role(BlockRole::Mean)
616 .map(|b| b.beta.clone())
617 .ok_or_else(|| {
618 "standard link-wiggle unified fit is missing Mean coefficient block".to_string()
619 })?
620 } else {
621 unified
622 .blocks
623 .first()
624 .map(|b| b.beta.clone())
625 .ok_or_else(|| {
626 "standard unified fit is missing its sole coefficient block".to_string()
627 })?
628 };
629 let covariance = unified.covariance_conditional.clone();
630 Ok(Self {
631 beta,
632 family,
633 link_kind,
634 covariance,
635 link_wiggle,
636 })
637 }
638}
639
640impl PredictableModel for StandardPredictor {
641 fn predict_plugin_response(
642 &self,
643 input: &PredictInput,
644 ) -> Result<PredictResult, EstimationError> {
645 let eta_base = input.design.dot(&self.beta) + &input.offset;
646 let eta = if let Some(runtime) = self.link_wiggle.as_ref() {
647 runtime
648 .apply(&eta_base)
649 .map_err(EstimationError::InvalidInput)?
650 } else {
651 eta_base
652 };
653 let strategy = strategy_for_family(self.family, self.link_kind.as_ref());
654 let mean = strategy.inverse_link_array(eta.view())?;
655 Ok(PredictResult { eta, mean })
656 }
657
658 fn predict_with_uncertainty(
659 &self,
660 input: &PredictInput,
661 ) -> Result<PredictionWithSE, EstimationError> {
662 let result = self.predict_plugin_response(input)?;
663 let eta_base = input.design.dot(&self.beta) + &input.offset;
664 let (eta_se, mean_se) = if let Some(ref cov) = self.covariance {
665 let backend = PredictionCovarianceBackend::from_dense(cov.view());
666 let se = if let Some(runtime) = self.link_wiggle.as_ref() {
667 let p_main = self.beta.len();
668 let p_w = runtime.beta.len();
669 let p_total = p_main + p_w;
670 if backend.nrows() != p_total {
671 return Err(EstimationError::InvalidInput(format!(
672 "standard link-wiggle covariance dimension mismatch: expected parameter dimension {}, got {}",
673 p_total,
674 backend.nrows()
675 )));
676 }
677 linear_predictor_se_from_backend(&backend, result.eta.len(), |rows| {
678 let q0_chunk = eta_base.slice(ndarray::s![rows.clone()]).to_owned();
679 let x_main = design_row_chunk(&input.design, rows.clone())?;
680 let wiggle_design = runtime.design(&q0_chunk)?;
681 let dq_dq0 = runtime.derivative_q0(&q0_chunk)?;
682 let rows_in_chunk = q0_chunk.len();
683 let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
684 for i in 0..rows_in_chunk {
685 for j in 0..p_main {
686 grad[[i, j]] = dq_dq0[i] * x_main[[i, j]];
687 }
688 }
689 grad.slice_mut(ndarray::s![.., p_main..p_total])
690 .assign(&wiggle_design);
691 Ok(vec![grad])
692 })?
693 } else {
694 eta_standard_errors_from_backend(&input.design, &backend)?
695 };
696 let strategy = strategy_for_family(self.family, self.link_kind.as_ref());
697 let mean_se = delta_method_mean_se(&result.eta, &se, &strategy)?;
698 (Some(se), Some(mean_se))
699 } else {
700 (None, None)
701 };
702 Ok(PredictionWithSE {
703 eta: result.eta,
704 mean: result.mean,
705 eta_se,
706 mean_se,
707 })
708 }
709
710 fn predict_noise_scale(
711 &self,
712 _: &PredictInput,
713 ) -> Result<Option<Array1<f64>>, EstimationError> {
714 Ok(None)
715 }
716
717 fn predict_full_uncertainty(
718 &self,
719 input: &PredictInput,
720 fit: &UnifiedFitResult,
721 options: &PredictUncertaintyOptions,
722 ) -> Result<PredictUncertaintyResult, EstimationError> {
723 if self.link_wiggle.is_none() {
724 return predict_gamwith_uncertainty(
725 input.design.clone(),
726 self.beta.view(),
727 input.offset.view(),
728 self.family,
729 fit,
730 options,
731 );
732 }
733 let pred = self.predict_with_uncertainty(input)?;
734 let eta_se = pred.eta_se.clone().ok_or_else(|| {
735 EstimationError::InvalidInput(
736 "standard link-wiggle uncertainty requires covariance".to_string(),
737 )
738 })?;
739 let mean_se = pred.mean_se.clone().ok_or_else(|| {
740 EstimationError::InvalidInput(
741 "standard link-wiggle uncertainty requires covariance".to_string(),
742 )
743 })?;
744 let z = crate::probability::standard_normal_quantile(0.5 + options.confidence_level * 0.5)
745 .map_err(EstimationError::InvalidInput)?;
746 let eta_lower = &pred.eta - &eta_se.mapv(|s| z * s);
747 let eta_upper = &pred.eta + &eta_se.mapv(|s| z * s);
748 let mut mean_lower = &pred.mean - &mean_se.mapv(|s| z * s);
749 let mut mean_upper = &pred.mean + &mean_se.mapv(|s| z * s);
750 let (lo, hi) = match self.family {
751 crate::types::LikelihoodFamily::GaussianIdentity => (f64::NEG_INFINITY, f64::INFINITY),
752 crate::types::LikelihoodFamily::PoissonLog
753 | crate::types::LikelihoodFamily::GammaLog => (0.0, f64::INFINITY),
754 _ => (1e-10, 1.0 - 1e-10),
755 };
756 mean_lower.mapv_inplace(|v| v.clamp(lo, hi));
757 mean_upper.mapv_inplace(|v| v.clamp(lo, hi));
758 Ok(PredictUncertaintyResult {
759 eta: pred.eta,
760 mean: pred.mean,
761 eta_standard_error: eta_se,
762 mean_standard_error: mean_se,
763 eta_lower,
764 eta_upper,
765 mean_lower,
766 mean_upper,
767 observation_lower: None,
768 observation_upper: None,
769 covariance_mode_requested: options.covariance_mode,
770 covariance_corrected_used: false,
771 })
772 }
773
774 fn predict_posterior_mean(
775 &self,
776 input: &PredictInput,
777 fit: &UnifiedFitResult,
778 confidence_level: Option<f64>,
779 ) -> Result<PredictPosteriorMeanResult, EstimationError> {
780 let mut result = if self.link_wiggle.is_none() {
781 let backend = posterior_mean_backend_or_warn(
782 fit,
783 self.covariance.as_ref(),
784 self.beta.len(),
785 "standard posterior mean",
786 )
787 .ok_or_else(|| {
788 EstimationError::InvalidInput(
789 "posterior-mean prediction requires beta covariance or penalized Hessian"
790 .to_string(),
791 )
792 })?;
793 let strategy = strategy_from_fit(self.family, fit)?;
794 predict_gam_posterior_mean_from_backendwith_bc(
795 input.design.clone(),
796 self.beta.view(),
797 input.offset.view(),
798 &backend,
799 &strategy,
800 "standard posterior mean",
801 fit.bias_correction_beta().map(|b| b.view()),
802 )?
803 } else {
804 let runtime = self.link_wiggle.as_ref().expect("checked above");
805 let plugin = self.predict_plugin_response(input)?;
806 let eta_base = input.design.dot(&self.beta) + &input.offset;
807 let backend = posterior_mean_backend_or_warn(
808 fit,
809 self.covariance.as_ref(),
810 self.beta.len() + runtime.beta.len(),
811 "standard link-wiggle posterior mean",
812 )
813 .ok_or_else(|| {
814 EstimationError::InvalidInput(
815 "posterior-mean prediction requires beta covariance or penalized Hessian"
816 .to_string(),
817 )
818 })?;
819 let p_main = self.beta.len();
820 let p_w = runtime.beta.len();
821 let p_total = p_main + p_w;
822 if backend.nrows() != p_total {
823 return Err(EstimationError::InvalidInput(format!(
824 "standard link-wiggle posterior mean covariance mismatch: expected parameter dimension {}, got {}",
825 p_total,
826 backend.nrows()
827 )));
828 }
829 let eta_se = linear_predictor_se_from_backend(&backend, plugin.eta.len(), |rows| {
830 let q0_chunk = eta_base.slice(ndarray::s![rows.clone()]).to_owned();
831 let x_main = design_row_chunk(&input.design, rows.clone())?;
832 let wiggle_design = runtime.design(&q0_chunk)?;
833 let dq_dq0 = runtime.derivative_q0(&q0_chunk)?;
834 let rows_in_chunk = q0_chunk.len();
835 let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
836 for i in 0..rows_in_chunk {
837 for j in 0..p_main {
838 grad[[i, j]] = dq_dq0[i] * x_main[[i, j]];
839 }
840 }
841 grad.slice_mut(ndarray::s![.., p_main..p_total])
842 .assign(&wiggle_design);
843 Ok(vec![grad])
844 })?;
845 let strategy = strategy_for_family(self.family, self.link_kind.as_ref());
846 let quadctx = crate::quadrature::QuadratureContext::new();
847 let mean = plugin
848 .eta
849 .iter()
850 .zip(eta_se.iter())
851 .map(|(&e, &se)| strategy.posterior_mean(&quadctx, e, se))
852 .collect::<Result<Array1<f64>, _>>()?;
853 PredictPosteriorMeanResult {
854 eta: plugin.eta,
855 eta_standard_error: eta_se,
856 mean,
857 mean_lower: None,
858 mean_upper: None,
859 }
860 };
861 if let Some(level) = confidence_level {
862 enrich_posterior_mean_bounds(&mut result, level, self.family, self.link_kind.as_ref())?;
863 }
864 Ok(result)
865 }
866
867 fn n_blocks(&self) -> usize {
868 if self.link_wiggle.is_some() { 2 } else { 1 }
869 }
870
871 fn block_roles(&self) -> Vec<BlockRole> {
872 if self.link_wiggle.is_some() {
873 vec![BlockRole::Mean, BlockRole::LinkWiggle]
874 } else {
875 vec![BlockRole::Mean]
876 }
877 }
878}
879
880pub struct BernoulliMarginalSlopePredictor {
881 pub beta_marginal: Array1<f64>,
882 pub beta_logslope: Array1<f64>,
883 pub beta_score_warp: Option<Array1<f64>>,
884 pub beta_link_dev: Option<Array1<f64>>,
885 pub base_link: InverseLink,
886 pub z_column: String,
887 pub latent_z_normalization: SavedLatentZNormalization,
888 pub latent_measure: LatentMeasureKind,
889 pub baseline_marginal: f64,
890 pub baseline_logslope: f64,
891 pub covariance: Option<Array2<f64>>,
892 pub score_warp_runtime: Option<SavedAnchoredDeviationRuntime>,
893 pub link_deviation_runtime: Option<SavedAnchoredDeviationRuntime>,
894 pub gaussian_frailty_sd: Option<f64>,
895 pub(crate) latent_z_calibration:
906 Option<crate::families::bernoulli_marginal_slope::LatentZRankIntCalibration>,
907}
908
909#[derive(Default)]
926struct BmsAnchorCorrections {
927 n_anchor_rows: Option<Array2<f64>>,
928 score_warp: Option<Array2<f64>>,
929 link_dev: Option<Array2<f64>>,
930}
931
932impl BmsAnchorCorrections {
933 fn score_warp_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
934 self.score_warp.as_ref().map(|m| m.row(row))
935 }
936
937 fn link_dev_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
938 self.link_dev.as_ref().map(|m| m.row(row))
939 }
940
941 fn n_anchor_rows_view(&self) -> Option<ndarray::ArrayView2<'_, f64>> {
942 self.n_anchor_rows.as_ref().map(|m| m.view())
943 }
944}
945
946impl BernoulliMarginalSlopePredictor {
947 fn build_anchor_correction_matrices(
956 &self,
957 input: &PredictInput,
958 design_logslope: &DesignMatrix,
959 ) -> Result<BmsAnchorCorrections, EstimationError> {
960 let needs_score = self
961 .score_warp_runtime
962 .as_ref()
963 .map_or(false, |r| r.anchor_residual_coefficients.is_some());
964 let needs_link = self
965 .link_deviation_runtime
966 .as_ref()
967 .map_or(false, |r| r.anchor_residual_coefficients.is_some());
968 if !needs_score && !needs_link {
969 return Ok(BmsAnchorCorrections::default());
970 }
971 let marginal_dense = input
976 .design
977 .try_to_dense_arc(
978 "bernoulli marginal-slope predict-time marginal anchor materialisation",
979 )
980 .map_err(EstimationError::InvalidInput)?;
981 let logslope_dense = design_logslope
982 .try_to_dense_arc(
983 "bernoulli marginal-slope predict-time logslope anchor materialisation",
984 )
985 .map_err(EstimationError::InvalidInput)?;
986 let n_rows = marginal_dense.nrows();
987 if logslope_dense.nrows() != n_rows {
988 return Err(EstimationError::InvalidInput(format!(
989 "bernoulli marginal-slope predict anchor materialisation row mismatch: marginal {} vs logslope {}",
990 n_rows,
991 logslope_dense.nrows()
992 )));
993 }
994 let p_marginal = marginal_dense.ncols();
995 let p_logslope = logslope_dense.ncols();
996 let d = p_marginal + p_logslope;
997 let mut n_anchor_rows = Array2::<f64>::zeros((n_rows, d));
998 n_anchor_rows
999 .slice_mut(ndarray::s![.., 0..p_marginal])
1000 .assign(&marginal_dense.view());
1001 n_anchor_rows
1002 .slice_mut(ndarray::s![.., p_marginal..d])
1003 .assign(&logslope_dense.view());
1004 let score_warp = if needs_score {
1005 self.score_warp_runtime
1006 .as_ref()
1007 .unwrap()
1008 .anchor_correction_matrix(n_anchor_rows.view())
1009 .map_err(EstimationError::InvalidInput)?
1010 } else {
1011 None
1012 };
1013 let link_dev = if needs_link {
1014 self.link_deviation_runtime
1015 .as_ref()
1016 .unwrap()
1017 .anchor_correction_matrix(n_anchor_rows.view())
1018 .map_err(EstimationError::InvalidInput)?
1019 } else {
1020 None
1021 };
1022 Ok(BmsAnchorCorrections {
1023 n_anchor_rows: Some(n_anchor_rows),
1024 score_warp,
1025 link_dev,
1026 })
1027 }
1028
1029 fn likelihood_family(&self) -> LikelihoodFamily {
1030 LikelihoodFamily::BinomialProbit
1031 }
1032
1033 fn mean_from_eta(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
1034 Ok(eta.mapv(normal_cdf))
1035 }
1036
1037 fn mean_derivative_from_eta(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
1038 Ok(eta.mapv(normal_pdf))
1039 }
1040
1041 fn probit_frailty_scale(&self) -> f64 {
1042 marginal_slope_probit_frailty_scale(self.gaussian_frailty_sd)
1043 }
1044
1045 fn apply_latent_z_calibration(&self, z: &Array1<f64>) -> Array1<f64> {
1063 match &self.latent_z_calibration {
1064 Some(cal) => Array1::from_iter(z.iter().map(|&zi| cal.apply_at_predict(zi))),
1065 None => z.clone(),
1066 }
1067 }
1068
1069 fn rigid_intercept_from_marginal(&self, marginal_eta: f64, slope: f64) -> f64 {
1070 let probit_scale = self.probit_frailty_scale();
1071 marginal_eta * (1.0 + (probit_scale * slope).powi(2)).sqrt() / probit_scale
1072 }
1073
1074 fn empirical_rigid_intercept_and_gradient(
1075 &self,
1076 marginal_eta: f64,
1077 slope: f64,
1078 nodes: &[f64],
1079 weights: &[f64],
1080 ) -> Result<(f64, f64, f64), EstimationError> {
1081 let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1082 .map_err(EstimationError::InvalidInput)?;
1083 let scale = self.probit_frailty_scale();
1084 let intercept = empirical_intercept_from_marginal(
1085 marginal.mu,
1086 marginal.q,
1087 slope,
1088 scale,
1089 nodes,
1090 weights,
1091 None,
1092 )
1093 .map_err(EstimationError::InvalidInput)?;
1094 let observed_slope = scale * slope;
1095 let mut f_a = 0.0;
1096 let mut f_b = 0.0;
1097 for (&node, &weight) in nodes.iter().zip(weights.iter()) {
1098 let eta = intercept + observed_slope * node;
1099 let pdf = normal_pdf(eta);
1100 f_a += weight * pdf;
1101 f_b += weight * pdf * scale * node;
1102 }
1103 if !(f_a.is_finite() && f_a > 0.0 && f_b.is_finite()) {
1104 return Err(EstimationError::InvalidInput(format!(
1105 "empirical latent prediction calibration derivative is invalid: F_a={f_a}, F_b={f_b}"
1106 )));
1107 }
1108 let a_marginal_eta = marginal.mu1 / f_a;
1109 let a_slope = -f_b / f_a;
1110 Ok((intercept, a_marginal_eta, a_slope))
1111 }
1112
1113 fn local_empirical_mixture_for_point(
1114 point: &[f64],
1115 centers: &[Vec<f64>],
1116 top_k: usize,
1117 bandwidth: f64,
1118 ) -> Result<Vec<(usize, f64)>, EstimationError> {
1119 if centers.is_empty() {
1120 return Err(EstimationError::InvalidInput(
1121 "local empirical latent prediction has no centers".to_string(),
1122 ));
1123 }
1124 if top_k == 0 {
1125 return Err(EstimationError::InvalidInput(
1126 "local empirical latent prediction top_k must be positive".to_string(),
1127 ));
1128 }
1129 if !(bandwidth.is_finite() && bandwidth > 0.0) {
1130 return Err(EstimationError::InvalidInput(format!(
1131 "local empirical latent prediction bandwidth must be finite and positive, got {bandwidth}"
1132 )));
1133 }
1134 let bw2 = bandwidth * bandwidth;
1135 let mut distances = Vec::<(usize, f64)>::with_capacity(centers.len());
1136 for (idx, center) in centers.iter().enumerate() {
1137 if center.len() != point.len() {
1138 return Err(EstimationError::InvalidInput(format!(
1139 "local empirical latent prediction center {idx} dimension mismatch: center={}, point={}",
1140 center.len(),
1141 point.len()
1142 )));
1143 }
1144 let d2 = center
1145 .iter()
1146 .zip(point.iter())
1147 .map(|(&c, &x)| {
1148 let delta = x - c;
1149 delta * delta
1150 })
1151 .sum::<f64>();
1152 if !d2.is_finite() {
1153 return Err(EstimationError::InvalidInput(
1154 "local empirical latent prediction distance is non-finite".to_string(),
1155 ));
1156 }
1157 distances.push((idx, d2));
1158 }
1159 distances.sort_by(|left, right| {
1160 left.1
1161 .partial_cmp(&right.1)
1162 .expect("validated local empirical distances are finite")
1163 });
1164 let k = top_k.min(distances.len());
1165 let mut mixture = Vec::with_capacity(k);
1166 let mut total = 0.0;
1167 for &(idx, d2) in distances.iter().take(k) {
1168 let weight = (-0.5 * d2 / bw2).exp().max(1e-300);
1169 mixture.push((idx, weight));
1170 total += weight;
1171 }
1172 if !(total.is_finite() && total > 0.0) {
1173 return Err(EstimationError::InvalidInput(
1174 "local empirical latent prediction mixture has non-positive total weight"
1175 .to_string(),
1176 ));
1177 }
1178 for (_, weight) in &mut mixture {
1179 *weight /= total;
1180 }
1181 Ok(mixture)
1182 }
1183
1184 fn combine_empirical_grids(
1185 grids: &[EmpiricalZGrid],
1186 mixture: &[(usize, f64)],
1187 ) -> Result<EmpiricalZGrid, EstimationError> {
1188 let total_len = mixture
1189 .iter()
1190 .map(|&(idx, _)| grids.get(idx).map_or(0, |grid| grid.nodes.len()))
1191 .sum::<usize>();
1192 let mut nodes = Vec::with_capacity(total_len);
1193 let mut weights = Vec::with_capacity(total_len);
1194 let mut total_weight = 0.0;
1195 for &(grid_idx, grid_weight) in mixture {
1196 if !(grid_weight.is_finite() && grid_weight >= 0.0) {
1197 return Err(EstimationError::InvalidInput(format!(
1198 "local empirical latent prediction mixture weight must be finite and non-negative, got {grid_weight}"
1199 )));
1200 }
1201 let grid = grids.get(grid_idx).ok_or_else(|| {
1202 EstimationError::InvalidInput(format!(
1203 "local empirical latent prediction grid index {grid_idx} is out of bounds for {} grids",
1204 grids.len()
1205 ))
1206 })?;
1207 if grid.nodes.len() != grid.weights.len() || grid.nodes.is_empty() {
1208 return Err(EstimationError::InvalidInput(format!(
1209 "local empirical latent prediction grid {grid_idx} is invalid: nodes={}, weights={}",
1210 grid.nodes.len(),
1211 grid.weights.len()
1212 )));
1213 }
1214 for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
1215 let combined_weight = grid_weight * weight;
1216 if !(node.is_finite() && combined_weight.is_finite() && combined_weight >= 0.0) {
1217 return Err(EstimationError::InvalidInput(
1218 "local empirical latent prediction grid contains invalid node/weight"
1219 .to_string(),
1220 ));
1221 }
1222 nodes.push(node);
1223 weights.push(combined_weight);
1224 total_weight += combined_weight;
1225 }
1226 }
1227 if !(total_weight.is_finite() && total_weight > 0.0) {
1228 return Err(EstimationError::InvalidInput(
1229 "local empirical latent prediction combined grid has non-positive total weight"
1230 .to_string(),
1231 ));
1232 }
1233 for weight in &mut weights {
1234 *weight /= total_weight;
1235 }
1236 Ok(EmpiricalZGrid { nodes, weights })
1237 }
1238
1239 fn empirical_grid_for_prediction_row(
1240 &self,
1241 input: &PredictInput,
1242 row: usize,
1243 ) -> Result<Option<EmpiricalZGrid>, EstimationError> {
1244 match &self.latent_measure {
1245 LatentMeasureKind::StandardNormal => Ok(None),
1246 LatentMeasureKind::GlobalEmpirical { nodes, weights } => Ok(Some(EmpiricalZGrid {
1247 nodes: nodes.clone(),
1248 weights: weights.clone(),
1249 })),
1250 LatentMeasureKind::LocalEmpirical {
1251 centers,
1252 grids,
1253 top_k,
1254 bandwidth,
1255 ..
1256 } => {
1257 let conditioning = input.auxiliary_matrix.as_ref().ok_or_else(|| {
1258 EstimationError::InvalidInput(
1259 "bernoulli marginal-slope local empirical prediction requires auxiliary conditioning matrix"
1260 .to_string(),
1261 )
1262 })?;
1263 if row >= conditioning.nrows() {
1264 return Err(EstimationError::InvalidInput(format!(
1265 "local empirical latent prediction row {row} is out of bounds for {} conditioning rows",
1266 conditioning.nrows()
1267 )));
1268 }
1269 let expected_dim = centers.first().map_or(0, Vec::len);
1270 if conditioning.ncols() != expected_dim {
1271 return Err(EstimationError::InvalidInput(format!(
1272 "local empirical latent prediction conditioning dimension mismatch: got {}, expected {expected_dim}",
1273 conditioning.ncols()
1274 )));
1275 }
1276 let point = conditioning.row(row).to_vec();
1277 let mixture =
1278 Self::local_empirical_mixture_for_point(&point, centers, *top_k, *bandwidth)?;
1279 Self::combine_empirical_grids(grids, &mixture).map(Some)
1280 }
1281 }
1282 }
1283
1284 fn transform_internal_eta_to_base_scale(
1285 &self,
1286 internal_eta: Array1<f64>,
1287 internal_grad: Option<Array2<f64>>,
1288 ) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
1289 Ok((internal_eta, internal_grad))
1290 }
1291
1292 fn link_terms_value_d1(
1293 &self,
1294 eta0: &Array1<f64>,
1295 beta_link_dev: Option<&Array1<f64>>,
1296 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1297 ) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
1298 if let (Some(runtime), Some(beta)) = (&self.link_deviation_runtime, beta_link_dev) {
1299 let basis = runtime
1307 .design_uncorrected(eta0)
1308 .map_err(EstimationError::InvalidInput)?;
1309 let mut value = &basis.dot(beta) + eta0;
1310 if let Some(corr) = link_dev_correction_for_row {
1311 let offset = corr.dot(beta);
1312 for v in value.iter_mut() {
1313 *v -= offset;
1314 }
1315 } else if runtime.anchor_residual_coefficients.is_some() {
1316 return Err(EstimationError::InvalidInput(
1317 "bernoulli marginal-slope link-deviation runtime has an anchor residual but \
1318 no per-row correction was supplied to link_terms_value_d1"
1319 .to_string(),
1320 ));
1321 }
1322 let d1 = runtime
1323 .first_derivative_design(eta0)
1324 .map_err(EstimationError::InvalidInput)?;
1325 Ok((value, d1.dot(beta) + 1.0))
1326 } else {
1327 Ok((eta0.clone(), Array1::ones(eta0.len())))
1328 }
1329 }
1330
1331 fn denested_partition_cells(
1332 &self,
1333 a: f64,
1334 b: f64,
1335 beta_score_warp: Option<&Array1<f64>>,
1336 beta_link_dev: Option<&Array1<f64>>,
1337 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1338 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1339 ) -> Result<
1340 Vec<crate::families::bernoulli_marginal_slope::exact_kernel::DenestedPartitionCell>,
1341 EstimationError,
1342 > {
1343 let score_breaks = if let Some(runtime) = self.score_warp_runtime.as_ref() {
1344 runtime
1345 .breakpoints()
1346 .map_err(EstimationError::InvalidInput)?
1347 } else {
1348 Vec::new()
1349 };
1350 let link_breaks = if let Some(runtime) = self.link_deviation_runtime.as_ref() {
1351 runtime
1352 .breakpoints()
1353 .map_err(EstimationError::InvalidInput)?
1354 } else {
1355 Vec::new()
1356 };
1357 let mut cells = crate::families::bernoulli_marginal_slope::exact_kernel::build_denested_partition_cells_with_tails(
1358 a,
1359 b,
1360 &score_breaks,
1361 &link_breaks,
1362 |z| {
1363 if let (Some(runtime), Some(beta)) =
1364 (self.score_warp_runtime.as_ref(), beta_score_warp)
1365 {
1366 let mut span = runtime.local_cubic_at(beta, z)?;
1367 if let Some(corr) = score_warp_correction_for_row {
1374 span.c0 -= corr.dot(beta);
1375 }
1376 Ok(span)
1377 } else {
1378 Ok(crate::families::bernoulli_marginal_slope::exact_kernel::LocalSpanCubic {
1379 left: 0.0,
1380 right: 1.0,
1381 c0: 0.0,
1382 c1: 0.0,
1383 c2: 0.0,
1384 c3: 0.0,
1385 })
1386 }
1387 },
1388 |u| {
1389 if let (Some(runtime), Some(beta)) =
1390 (self.link_deviation_runtime.as_ref(), beta_link_dev)
1391 {
1392 let mut span = runtime.local_cubic_at(beta, u)?;
1393 if let Some(corr) = link_dev_correction_for_row {
1394 span.c0 -= corr.dot(beta);
1395 }
1396 Ok(span)
1397 } else {
1398 Ok(crate::families::bernoulli_marginal_slope::exact_kernel::LocalSpanCubic {
1399 left: 0.0,
1400 right: 1.0,
1401 c0: 0.0,
1402 c1: 0.0,
1403 c2: 0.0,
1404 c3: 0.0,
1405 })
1406 }
1407 },
1408 )
1409 .map_err(EstimationError::InvalidInput)?;
1410 let scale = self.probit_frailty_scale();
1411 if scale != 1.0 {
1412 for partition_cell in &mut cells {
1413 partition_cell.cell.c0 *= scale;
1414 partition_cell.cell.c1 *= scale;
1415 partition_cell.cell.c2 *= scale;
1416 partition_cell.cell.c3 *= scale;
1417 }
1418 }
1419 Ok(cells)
1420 }
1421
1422 fn evaluate_denested_calibration(
1423 &self,
1424 a: f64,
1425 marginal_eta: f64,
1426 slope: f64,
1427 beta_score_warp: Option<&Array1<f64>>,
1428 beta_link_dev: Option<&Array1<f64>>,
1429 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1430 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1431 ) -> Result<(f64, f64, f64), EstimationError> {
1432 let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1433 .map_err(EstimationError::InvalidInput)?;
1434 let cells = self.denested_partition_cells(
1435 a,
1436 slope,
1437 beta_score_warp,
1438 beta_link_dev,
1439 score_warp_correction_for_row,
1440 link_dev_correction_for_row,
1441 )?;
1442 let scale = self.probit_frailty_scale();
1443 let mut f = -marginal.mu;
1444 let mut f_a = 0.0;
1445 let mut f_aa = 0.0;
1446 for partition_cell in cells {
1447 let cell = partition_cell.cell;
1448 let state =
1449 crate::families::bernoulli_marginal_slope::exact_kernel::evaluate_cell_moments(
1450 cell, 7,
1451 )
1452 .map_err(EstimationError::InvalidInput)?;
1453 f += state.value;
1454 let (dc_da_raw, _) = crate::families::bernoulli_marginal_slope::exact_kernel::denested_cell_coefficient_partials(
1455 partition_cell.score_span,
1456 partition_cell.link_span,
1457 a,
1458 slope,
1459 );
1460 let (d2c_da2_raw, _, _) = crate::families::bernoulli_marginal_slope::exact_kernel::denested_cell_second_partials(
1461 partition_cell.score_span,
1462 partition_cell.link_span,
1463 a,
1464 slope,
1465 );
1466 let dc_da = scale_coeff4(dc_da_raw, scale);
1467 let d2c_da2 = scale_coeff4(d2c_da2_raw, scale);
1468 f_a += crate::families::bernoulli_marginal_slope::exact_kernel::cell_first_derivative_from_moments(
1469 &dc_da,
1470 &state.moments,
1471 )
1472 .map_err(EstimationError::InvalidInput)?;
1473 f_aa += crate::families::bernoulli_marginal_slope::exact_kernel::cell_second_derivative_from_moments(
1474 cell,
1475 &dc_da,
1476 &dc_da,
1477 &d2c_da2,
1478 &state.moments,
1479 )
1480 .map_err(EstimationError::InvalidInput)?;
1481 }
1482 Ok((f, f_a, f_aa))
1483 }
1484
1485 fn observed_denested_cell_partials_at_z(
1486 &self,
1487 z_value: f64,
1488 a: f64,
1489 b: f64,
1490 beta_score_warp: Option<&Array1<f64>>,
1491 beta_link_dev: Option<&Array1<f64>>,
1492 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1493 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1494 ) -> Result<ObservedDenestedCellPartials, EstimationError> {
1495 use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
1496
1497 let zero_span = exact::LocalSpanCubic {
1498 left: 0.0,
1499 right: 1.0,
1500 c0: 0.0,
1501 c1: 0.0,
1502 c2: 0.0,
1503 c3: 0.0,
1504 };
1505 let u_value = a + b * z_value;
1506 let score_span = if let (Some(runtime), Some(beta)) =
1507 (self.score_warp_runtime.as_ref(), beta_score_warp)
1508 {
1509 let mut span = runtime
1510 .local_cubic_at(beta, z_value)
1511 .map_err(EstimationError::InvalidInput)?;
1512 if let Some(corr) = score_warp_correction_for_row {
1513 span.c0 -= corr.dot(beta);
1514 }
1515 span
1516 } else {
1517 zero_span
1518 };
1519 let link_span = if let (Some(runtime), Some(beta)) =
1520 (self.link_deviation_runtime.as_ref(), beta_link_dev)
1521 {
1522 let mut span = runtime
1523 .local_cubic_at(beta, u_value)
1524 .map_err(EstimationError::InvalidInput)?;
1525 if let Some(corr) = link_dev_correction_for_row {
1526 span.c0 -= corr.dot(beta);
1527 }
1528 span
1529 } else {
1530 zero_span
1531 };
1532 let scale = self.probit_frailty_scale();
1533 let coeff = scale_coeff4(
1534 exact::denested_cell_coefficients(score_span, link_span, a, b),
1535 scale,
1536 );
1537 let (dc_da_raw, dc_db_raw) =
1538 exact::denested_cell_coefficient_partials(score_span, link_span, a, b);
1539 let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) =
1540 exact::denested_cell_second_partials(score_span, link_span, a, b);
1541 let (dc_daaa, dc_daab, dc_dabb, dc_dbbb) = exact::denested_cell_third_partials(link_span);
1542 Ok(ObservedDenestedCellPartials {
1543 coeff,
1544 dc_da: scale_coeff4(dc_da_raw, scale),
1545 dc_db: scale_coeff4(dc_db_raw, scale),
1546 dc_daa: scale_coeff4(dc_daa_raw, scale),
1547 dc_dab: scale_coeff4(dc_dab_raw, scale),
1548 dc_dbb: scale_coeff4(dc_dbb_raw, scale),
1549 dc_daaa: scale_coeff4(dc_daaa, scale),
1550 dc_daab: scale_coeff4(dc_daab, scale),
1551 dc_dabb: scale_coeff4(dc_dabb, scale),
1552 dc_dbbb: scale_coeff4(dc_dbbb, scale),
1553 })
1554 }
1555
1556 fn evaluate_empirical_denested_calibration(
1557 &self,
1558 a: f64,
1559 marginal_eta: f64,
1560 slope: f64,
1561 beta_score_warp: Option<&Array1<f64>>,
1562 beta_link_dev: Option<&Array1<f64>>,
1563 grid: &EmpiricalZGrid,
1564 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1565 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1566 ) -> Result<(f64, f64, f64), EstimationError> {
1567 let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1568 .map_err(EstimationError::InvalidInput)?;
1569 let mut f = -marginal.mu;
1570 let mut f_a = 0.0;
1571 let mut f_aa = 0.0;
1572 for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
1573 let obs = self.observed_denested_cell_partials_at_z(
1574 node,
1575 a,
1576 slope,
1577 beta_score_warp,
1578 beta_link_dev,
1579 score_warp_correction_for_row,
1580 link_dev_correction_for_row,
1581 )?;
1582 let eta = eval_coeff4_at(&obs.coeff, node);
1583 let eta_a = eval_coeff4_at(&obs.dc_da, node);
1584 let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
1585 let pdf = normal_pdf(eta);
1586 f += weight * normal_cdf(eta);
1587 f_a += weight * pdf * eta_a;
1588 f_aa += weight * pdf * (eta_aa - eta * eta_a * eta_a);
1589 }
1590 Ok((f, f_a, f_aa))
1591 }
1592
1593 fn evaluate_prediction_calibration(
1594 &self,
1595 a: f64,
1596 marginal_eta: f64,
1597 slope: f64,
1598 beta_score_warp: Option<&Array1<f64>>,
1599 beta_link_dev: Option<&Array1<f64>>,
1600 empirical_grid: Option<&EmpiricalZGrid>,
1601 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1602 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1603 ) -> Result<(f64, f64, f64), EstimationError> {
1604 if let Some(grid) = empirical_grid {
1605 self.evaluate_empirical_denested_calibration(
1606 a,
1607 marginal_eta,
1608 slope,
1609 beta_score_warp,
1610 beta_link_dev,
1611 grid,
1612 score_warp_correction_for_row,
1613 link_dev_correction_for_row,
1614 )
1615 } else {
1616 self.evaluate_denested_calibration(
1617 a,
1618 marginal_eta,
1619 slope,
1620 beta_score_warp,
1621 beta_link_dev,
1622 score_warp_correction_for_row,
1623 link_dev_correction_for_row,
1624 )
1625 }
1626 }
1627
1628 pub fn from_unified(
1629 unified: &UnifiedFitResult,
1630 z_column: String,
1631 latent_z_normalization: SavedLatentZNormalization,
1632 latent_measure: LatentMeasureKind,
1633 baseline_marginal: f64,
1634 baseline_logslope: f64,
1635 base_link: InverseLink,
1636 frailty: FrailtySpec,
1637 score_warp_runtime: Option<SavedAnchoredDeviationRuntime>,
1638 link_deviation_runtime: Option<SavedAnchoredDeviationRuntime>,
1639 latent_z_calibration: Option<
1640 crate::families::bernoulli_marginal_slope::LatentZRankIntCalibration,
1641 >,
1642 ) -> Result<Self, String> {
1643 let gaussian_frailty_sd = match frailty {
1644 FrailtySpec::None => None,
1645 FrailtySpec::GaussianShift {
1646 sigma_fixed: Some(sigma),
1647 } => Some(sigma),
1648 FrailtySpec::GaussianShift { sigma_fixed: None } => {
1649 return Err(
1650 "bernoulli marginal-slope predictor requires a fixed GaussianShift sigma"
1651 .to_string(),
1652 );
1653 }
1654 FrailtySpec::HazardMultiplier { .. } => {
1655 return Err(
1656 "bernoulli marginal-slope predictor does not support HazardMultiplier frailty"
1657 .to_string(),
1658 );
1659 }
1660 };
1661 if !matches!(
1662 base_link,
1663 InverseLink::Standard(crate::types::LinkFunction::Probit)
1664 ) {
1665 return Err(
1666 "bernoulli marginal-slope predictor requires link(type=probit); saved non-probit marginal-slope models must be refit"
1667 .to_string(),
1668 );
1669 }
1670 if let Some(runtime) = score_warp_runtime.as_ref() {
1671 runtime.validate_exact_replay_contract().map_err(|e| {
1672 format!("bernoulli marginal-slope score-warp runtime is invalid: {e}")
1673 })?;
1674 }
1675 if let Some(runtime) = link_deviation_runtime.as_ref() {
1676 runtime.validate_exact_replay_contract().map_err(|e| {
1677 format!("bernoulli marginal-slope link-deviation runtime is invalid: {e}")
1678 })?;
1679 }
1680 latent_z_normalization
1684 .validate("bernoulli marginal-slope predictor")
1685 .map_err(|e| {
1686 format!("bernoulli marginal-slope predictor latent z normalization is invalid: {e}")
1687 })?;
1688 latent_measure
1689 .validate("bernoulli marginal-slope predictor latent measure")
1690 .map_err(|e| {
1691 format!("bernoulli marginal-slope predictor latent measure is invalid: {e}")
1692 })?;
1693 let blocks = &unified.blocks;
1694 let expected_blocks = 2
1695 + usize::from(score_warp_runtime.is_some())
1696 + usize::from(link_deviation_runtime.is_some());
1697 if blocks.len() != expected_blocks {
1698 return Err(format!(
1699 "bernoulli marginal-slope predictor requires exactly {expected_blocks} coefficient blocks under the current exact de-nested semantics, got {}",
1700 blocks.len()
1701 ));
1702 }
1703 let mut cursor = 2usize;
1704 let beta_score_warp = if score_warp_runtime.is_some() {
1705 let beta = blocks
1706 .get(cursor)
1707 .ok_or_else(|| "missing score-warp coefficient block".to_string())?
1708 .beta
1709 .clone();
1710 cursor += 1;
1711 Some(beta)
1712 } else {
1713 None
1714 };
1715 let beta_link_dev = if link_deviation_runtime.is_some() {
1716 Some(
1717 blocks
1718 .get(cursor)
1719 .ok_or_else(|| "missing link-deviation coefficient block".to_string())?
1720 .beta
1721 .clone(),
1722 )
1723 } else {
1724 None
1725 };
1726 Ok(Self {
1727 beta_marginal: blocks[0].beta.clone(),
1728 beta_logslope: blocks[1].beta.clone(),
1729 beta_score_warp,
1730 beta_link_dev,
1731 base_link,
1732 z_column,
1733 latent_z_normalization,
1734 latent_measure,
1735 baseline_marginal,
1736 baseline_logslope,
1737 covariance: unified.beta_covariance().cloned(),
1738 score_warp_runtime,
1739 link_deviation_runtime,
1740 gaussian_frailty_sd,
1741 latent_z_calibration,
1742 })
1743 }
1744
1745 fn theta(&self) -> Array1<f64> {
1746 let total = self.beta_marginal.len()
1747 + self.beta_logslope.len()
1748 + self.beta_score_warp.as_ref().map_or(0, |b| b.len())
1749 + self.beta_link_dev.as_ref().map_or(0, |b| b.len());
1750 let mut theta = Array1::<f64>::zeros(total);
1751 let mut cursor = 0usize;
1752 theta
1753 .slice_mut(ndarray::s![cursor..cursor + self.beta_marginal.len()])
1754 .assign(&self.beta_marginal);
1755 cursor += self.beta_marginal.len();
1756 theta
1757 .slice_mut(ndarray::s![cursor..cursor + self.beta_logslope.len()])
1758 .assign(&self.beta_logslope);
1759 cursor += self.beta_logslope.len();
1760 if let Some(beta) = self.beta_score_warp.as_ref() {
1761 theta
1762 .slice_mut(ndarray::s![cursor..cursor + beta.len()])
1763 .assign(beta);
1764 cursor += beta.len();
1765 }
1766 if let Some(beta) = self.beta_link_dev.as_ref() {
1767 theta
1768 .slice_mut(ndarray::s![cursor..cursor + beta.len()])
1769 .assign(beta);
1770 }
1771 theta
1772 }
1773
1774 fn split_theta<'a>(
1775 &'a self,
1776 theta: &'a Array1<f64>,
1777 ) -> Result<
1778 (
1779 ArrayView1<'a, f64>,
1780 ArrayView1<'a, f64>,
1781 Option<ArrayView1<'a, f64>>,
1782 Option<ArrayView1<'a, f64>>,
1783 ),
1784 EstimationError,
1785 > {
1786 let expected = self.theta().len();
1787 if theta.len() != expected {
1788 return Err(EstimationError::InvalidInput(format!(
1789 "bernoulli marginal-slope theta length mismatch: expected {expected}, got {}",
1790 theta.len()
1791 )));
1792 }
1793 let mut cursor = 0usize;
1794 let marginal = theta.slice(ndarray::s![cursor..cursor + self.beta_marginal.len()]);
1795 cursor += self.beta_marginal.len();
1796 let logslope = theta.slice(ndarray::s![cursor..cursor + self.beta_logslope.len()]);
1797 cursor += self.beta_logslope.len();
1798 let score_warp = self.beta_score_warp.as_ref().map(|beta| {
1799 let view = theta.slice(ndarray::s![cursor..cursor + beta.len()]);
1800 cursor += beta.len();
1801 view
1802 });
1803 let link_dev = self
1804 .beta_link_dev
1805 .as_ref()
1806 .map(|beta| theta.slice(ndarray::s![cursor..cursor + beta.len()]));
1807 Ok((marginal, logslope, score_warp, link_dev))
1808 }
1809
1810 fn solve_intercept_scalar(
1814 &self,
1815 marginal_eta: f64,
1816 slope: f64,
1817 link_dev_beta: Option<&Array1<f64>>,
1818 score_warp_beta: Option<&Array1<f64>>,
1819 empirical_grid: Option<&EmpiricalZGrid>,
1820 warm_start_buf: &mut Array1<f64>,
1821 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1822 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1823 ) -> Result<f64, EstimationError> {
1824 let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1825 .map_err(EstimationError::InvalidInput)?;
1826 let eval = |a: f64| -> Result<(f64, f64, f64), String> {
1827 self.evaluate_prediction_calibration(
1828 a,
1829 marginal_eta,
1830 slope,
1831 score_warp_beta,
1832 link_dev_beta,
1833 empirical_grid,
1834 score_warp_correction_for_row,
1835 link_dev_correction_for_row,
1836 )
1837 .map_err(|err| err.to_string())
1838 };
1839
1840 let probit_scale = self.probit_frailty_scale();
1841 let a_rigid = self.rigid_intercept_from_marginal(marginal.q, slope);
1842 let mut intercept = a_rigid;
1843 if let (Some(_), Some(beta)) = (self.link_deviation_runtime.as_ref(), link_dev_beta) {
1844 warm_start_buf[0] = a_rigid;
1845 let one_pt = warm_start_buf.slice(ndarray::s![0..1]).to_owned();
1846 let (l_val, l_d1) =
1847 self.link_terms_value_d1(&one_pt, Some(beta), link_dev_correction_for_row)?;
1848 let ell1 = l_d1[0];
1849 if ell1 > 1e-8 {
1850 let ell0 = l_val[0] - ell1 * a_rigid;
1851 let observed_logslope = probit_scale * ell1 * slope;
1852 intercept = (marginal.q * (1.0 + observed_logslope * observed_logslope).sqrt()
1853 / probit_scale
1854 - ell0)
1855 / ell1;
1856 }
1857 }
1858
1859 let target = marginal.mu;
1862 let abs_tol = 1e-8_f64.max(1e-4 * target.abs());
1863
1864 let (root, _, f_best) = crate::families::monotone_root::solve_monotone_root(
1865 eval,
1866 intercept,
1867 "saved bernoulli intercept",
1868 abs_tol,
1869 64,
1870 48,
1871 )
1872 .map_err(EstimationError::InvalidInput)?;
1873
1874 if f_best.abs() > abs_tol {
1875 return Err(EstimationError::InvalidInput(format!(
1876 "saved bernoulli marginal-slope intercept solve failed: residual={f_best:.3e} at a={root:.6}, target mu={target:.6}"
1877 )));
1878 }
1879 Ok(root)
1880 }
1881
1882 fn final_eta_and_gradient_from_theta(
1883 &self,
1884 input: &PredictInput,
1885 theta: &Array1<f64>,
1886 need_gradient: bool,
1887 ) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
1888 let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
1889 EstimationError::InvalidInput(format!(
1890 "bernoulli marginal-slope prediction requires auxiliary z column '{}'",
1891 self.z_column
1892 ))
1893 })?;
1894 let z_normalized = self
1895 .latent_z_normalization
1896 .apply(z_raw, "bernoulli marginal-slope prediction")
1897 .map_err(EstimationError::InvalidInput)?;
1898 let z = self.apply_latent_z_calibration(&z_normalized);
1908 let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
1909 EstimationError::InvalidInput(
1910 "bernoulli marginal-slope prediction requires logslope design".to_string(),
1911 )
1912 })?;
1913 let (beta_marginal, beta_logslope, beta_score_warp, beta_link_dev) =
1914 self.split_theta(theta)?;
1915 if self.score_warp_runtime.is_some() != beta_score_warp.is_some() {
1916 return Err(EstimationError::InvalidInput(
1917 "bernoulli marginal-slope saved score-warp runtime/coefficients are inconsistent"
1918 .to_string(),
1919 ));
1920 }
1921 if self.link_deviation_runtime.is_some() != beta_link_dev.is_some() {
1922 return Err(EstimationError::InvalidInput(
1923 "bernoulli marginal-slope saved link-deviation runtime/coefficients are inconsistent"
1924 .to_string(),
1925 ));
1926 }
1927 let n = z.len();
1928 if input.offset.len() != n {
1929 return Err(EstimationError::InvalidInput(format!(
1930 "bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
1931 input.offset.len()
1932 )));
1933 }
1934 let logslope_offset = input
1935 .offset_noise
1936 .as_ref()
1937 .map_or_else(|| Array1::zeros(n), Clone::clone);
1938 if logslope_offset.len() != n {
1939 return Err(EstimationError::InvalidInput(format!(
1940 "bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
1941 logslope_offset.len()
1942 )));
1943 }
1944 let marginal_eta = input
1945 .design
1946 .dot(&beta_marginal.to_owned())
1947 .mapv(|v| v + self.baseline_marginal)
1948 + &input.offset;
1949 let logslope_eta = design_logslope
1950 .dot(&beta_logslope.to_owned())
1951 .mapv(|v| v + self.baseline_logslope)
1952 + &logslope_offset;
1953 let flex_active =
1954 self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
1955 let marginal_dim = self.beta_marginal.len();
1956 let logslope_dim = self.beta_logslope.len();
1957 let score_warp_dim = self.beta_score_warp.as_ref().map_or(0, Array1::len);
1958 let link_dev_dim = self.beta_link_dev.as_ref().map_or(0, Array1::len);
1959 let logslope_offset = marginal_dim;
1960 let score_warp_offset = logslope_offset + logslope_dim;
1961 let link_dev_offset = score_warp_offset + score_warp_dim;
1962 let chunk_size = prediction_chunk_rows(theta.len(), 1, n);
1963 let num_chunks = (n + chunk_size - 1) / chunk_size;
1964 let scale = self.probit_frailty_scale();
1965 let anchor_corrections = self.build_anchor_correction_matrices(input, design_logslope)?;
1972 let marginal_map = marginal_eta
1973 .iter()
1974 .map(|&eta| {
1975 bernoulli_marginal_link_map(&self.base_link, eta)
1976 .map_err(EstimationError::InvalidInput)
1977 })
1978 .collect::<Result<Vec<_>, _>>()?;
1979
1980 if !flex_active {
1981 let (final_eta_internal, marginal_scales, logslope_scales) = match &self.latent_measure
1982 {
1983 LatentMeasureKind::StandardNormal => {
1984 let sb_vec = logslope_eta.mapv(|b| scale * b);
1985 let c_vec = sb_vec.mapv(|sb| (1.0 + sb * sb).sqrt());
1986 let final_eta_internal = Array1::from_iter(
1987 (0..n).map(|i| c_vec[i] * marginal_eta[i] + sb_vec[i] * z[i]),
1988 );
1989 let marginal_scales = c_vec;
1990 let logslope_scales = Array1::from_iter((0..n).map(|i| {
1991 marginal_eta[i] * (scale * scale) * logslope_eta[i] / marginal_scales[i]
1992 + scale * z[i]
1993 }));
1994 (final_eta_internal, marginal_scales, logslope_scales)
1995 }
1996 LatentMeasureKind::GlobalEmpirical { nodes, weights } => {
1997 let mut final_eta = Array1::<f64>::zeros(n);
1998 let mut marginal_scales = Array1::<f64>::zeros(n);
1999 let mut logslope_scales = Array1::<f64>::zeros(n);
2000 for i in 0..n {
2001 let (intercept, a_marginal, a_slope) = self
2002 .empirical_rigid_intercept_and_gradient(
2003 marginal_eta[i],
2004 logslope_eta[i],
2005 nodes,
2006 weights,
2007 )?;
2008 final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
2009 marginal_scales[i] = a_marginal;
2010 logslope_scales[i] = a_slope + scale * z[i];
2011 }
2012 (final_eta, marginal_scales, logslope_scales)
2013 }
2014 LatentMeasureKind::LocalEmpirical { .. } => {
2015 let mut final_eta = Array1::<f64>::zeros(n);
2016 let mut marginal_scales = Array1::<f64>::zeros(n);
2017 let mut logslope_scales = Array1::<f64>::zeros(n);
2018 for i in 0..n {
2019 let grid = self
2020 .empirical_grid_for_prediction_row(input, i)?
2021 .ok_or_else(|| {
2022 EstimationError::InvalidInput(
2023 "local empirical latent prediction did not produce a row grid"
2024 .to_string(),
2025 )
2026 })?;
2027 let (intercept, a_marginal, a_slope) = self
2028 .empirical_rigid_intercept_and_gradient(
2029 marginal_eta[i],
2030 logslope_eta[i],
2031 &grid.nodes,
2032 &grid.weights,
2033 )?;
2034 final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
2035 marginal_scales[i] = a_marginal;
2036 logslope_scales[i] = a_slope + scale * z[i];
2037 }
2038 (final_eta, marginal_scales, logslope_scales)
2039 }
2040 };
2041
2042 if !need_gradient {
2043 return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
2044 }
2045
2046 let mut grad_internal = Array2::<f64>::zeros((n, theta.len()));
2048 let mut start = 0usize;
2049 while start < n {
2050 let end = (start + chunk_size).min(n);
2051 let mc = input
2052 .design
2053 .try_row_chunk(start..end)
2054 .map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
2055 let lc = design_logslope
2056 .try_row_chunk(start..end)
2057 .map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
2058
2059 for li in 0..(end - start) {
2060 let i = start + li;
2061 let c = marginal_scales[i];
2062 let g_scale = logslope_scales[i];
2063 let mut row = grad_internal.row_mut(i);
2064 for j in 0..marginal_dim {
2065 row[j] = c * mc[[li, j]];
2066 }
2067 for j in 0..logslope_dim {
2068 row[logslope_offset + j] = g_scale * lc[[li, j]];
2069 }
2070 }
2071
2072 start = end;
2073 }
2074 return self
2075 .transform_internal_eta_to_base_scale(final_eta_internal, Some(grad_internal));
2076 }
2077
2078 let score_warp_obs_design = self
2080 .score_warp_runtime
2081 .as_ref()
2082 .map(|runtime| {
2083 if runtime.anchor_residual_coefficients.is_some() {
2084 let anchor_rows = anchor_corrections.n_anchor_rows_view().ok_or_else(|| {
2085 EstimationError::InvalidInput(
2086 "bernoulli marginal-slope score-warp anchor residual present but \
2087 anchor_corrections bundle is missing the parametric anchor rows"
2088 .to_string(),
2089 )
2090 })?;
2091 runtime
2092 .design_with_anchor_rows(&z, anchor_rows)
2093 .map_err(EstimationError::InvalidInput)
2094 } else {
2095 runtime.design(&z).map_err(EstimationError::InvalidInput)
2096 }
2097 })
2098 .transpose()?;
2099 let score_dev_obs = if let (Some(design), Some(beta)) =
2100 (score_warp_obs_design.as_ref(), beta_score_warp.clone())
2101 {
2102 design.dot(&beta.to_owned())
2103 } else {
2104 Array1::zeros(n)
2105 };
2106
2107 let score_warp_beta_owned = beta_score_warp.as_ref().map(|v| v.to_owned());
2109 let link_dev_beta_owned = beta_link_dev.as_ref().map(|v| v.to_owned());
2110 struct FlexSolveChunk {
2111 start: usize,
2112 end: usize,
2113 intercepts: Array1<f64>,
2114 a_q: Option<Array1<f64>>,
2115 a_b: Option<Array1<f64>>,
2116 a_h: Option<Array2<f64>>,
2117 a_w: Option<Array2<f64>>,
2118 }
2119 let solve_chunks = (0..num_chunks)
2120 .into_par_iter()
2121 .map(|chunk_idx| -> Result<FlexSolveChunk, EstimationError> {
2122 let start = chunk_idx * chunk_size;
2123 let end = (start + chunk_size).min(n);
2124 let rows = end - start;
2125 let mut intercepts = Array1::<f64>::zeros(rows);
2126 let mut a_q = need_gradient.then(|| Array1::<f64>::zeros(rows));
2127 let mut a_b = need_gradient.then(|| Array1::<f64>::zeros(rows));
2128 let mut a_h = if need_gradient && score_warp_dim > 0 {
2129 Some(Array2::<f64>::zeros((rows, score_warp_dim)))
2130 } else {
2131 None
2132 };
2133 let mut a_w = if need_gradient && link_dev_dim > 0 {
2134 Some(Array2::<f64>::zeros((rows, link_dev_dim)))
2135 } else {
2136 None
2137 };
2138 let mut warm_start_buf = Array1::<f64>::zeros(1);
2139
2140 for local_row in 0..rows {
2141 let i = start + local_row;
2142 let slope = logslope_eta[i];
2143 let q = marginal_eta[i];
2144 let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
2145 let score_corr_row = anchor_corrections.score_warp_row(i);
2146 let link_corr_row = anchor_corrections.link_dev_row(i);
2147 intercepts[local_row] = self.solve_intercept_scalar(
2148 q,
2149 slope,
2150 link_dev_beta_owned.as_ref(),
2151 score_warp_beta_owned.as_ref(),
2152 empirical_grid.as_ref(),
2153 &mut warm_start_buf,
2154 score_corr_row,
2155 link_corr_row,
2156 )?;
2157
2158 if !need_gradient {
2159 continue;
2160 }
2161
2162 let intercept = intercepts[local_row];
2163 let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
2164 intercept,
2165 q,
2166 slope,
2167 score_warp_beta_owned.as_ref(),
2168 link_dev_beta_owned.as_ref(),
2169 empirical_grid.as_ref(),
2170 score_corr_row,
2171 link_corr_row,
2172 )?;
2173 let m_a = m_a_raw.max(1e-12);
2174 a_q.as_mut().unwrap()[local_row] = marginal_map[i].mu1 / m_a;
2175 let mut f_b = 0.0;
2176 let mut f_h_row = vec![0.0; score_warp_dim];
2177 let mut f_w_row = vec![0.0; link_dev_dim];
2178 if let Some(grid) = empirical_grid.as_ref() {
2179 for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
2180 let obs = self.observed_denested_cell_partials_at_z(
2181 node,
2182 intercept,
2183 slope,
2184 score_warp_beta_owned.as_ref(),
2185 link_dev_beta_owned.as_ref(),
2186 score_corr_row,
2187 link_corr_row,
2188 )?;
2189 let eta = eval_coeff4_at(&obs.coeff, node);
2190 let pdf = normal_pdf(eta);
2191 f_b += weight * pdf * eval_coeff4_at(&obs.dc_db, node);
2192
2193 if let Some(runtime) = self.score_warp_runtime.as_ref() {
2194 for j in 0..score_warp_dim {
2195 let mut basis_span = runtime
2196 .basis_cubic_at(j, node)
2197 .map_err(EstimationError::InvalidInput)?;
2198 if let Some(corr) = score_corr_row {
2205 basis_span.c0 -= corr[j];
2206 }
2207 let coeffs = crate::families::bernoulli_marginal_slope::exact_kernel::score_basis_cell_coefficients(
2208 basis_span,
2209 slope,
2210 );
2211 let coeffs = scale_coeff4(coeffs, scale);
2212 f_h_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
2213 }
2214 }
2215
2216 if let Some(runtime) = self.link_deviation_runtime.as_ref() {
2217 for j in 0..link_dev_dim {
2218 let mut basis_span = runtime
2219 .basis_cubic_at(j, intercept + slope * node)
2220 .map_err(EstimationError::InvalidInput)?;
2221 if let Some(corr) = link_corr_row {
2222 basis_span.c0 -= corr[j];
2223 }
2224 let coeffs = crate::families::bernoulli_marginal_slope::exact_kernel::link_basis_cell_coefficients(
2225 basis_span,
2226 intercept,
2227 slope,
2228 );
2229 let coeffs = scale_coeff4(coeffs, scale);
2230 f_w_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
2231 }
2232 }
2233 }
2234 } else {
2235 let cells = self.denested_partition_cells(
2236 intercept,
2237 slope,
2238 score_warp_beta_owned.as_ref(),
2239 link_dev_beta_owned.as_ref(),
2240 score_corr_row,
2241 link_corr_row,
2242 )?;
2243 for partition_cell in cells {
2244 let cell = partition_cell.cell;
2245 let state =
2246 crate::families::bernoulli_marginal_slope::exact_kernel::evaluate_cell_moments(
2247 cell, 9,
2248 )
2249 .map_err(EstimationError::InvalidInput)?;
2250 let (_, dc_db_raw) = crate::families::bernoulli_marginal_slope::exact_kernel::denested_cell_coefficient_partials(
2251 partition_cell.score_span,
2252 partition_cell.link_span,
2253 intercept,
2254 slope,
2255 );
2256 let dc_db = scale_coeff4(dc_db_raw, scale);
2260 f_b += crate::families::bernoulli_marginal_slope::exact_kernel::cell_first_derivative_from_moments(
2261 &dc_db,
2262 &state.moments,
2263 )
2264 .map_err(EstimationError::InvalidInput)?;
2265
2266 let mid = 0.5 * (cell.left + cell.right);
2267 if let Some(runtime) = self.score_warp_runtime.as_ref() {
2268 for j in 0..score_warp_dim {
2269 let mut basis_span = runtime
2270 .basis_cubic_at(j, mid)
2271 .map_err(EstimationError::InvalidInput)?;
2272 if let Some(corr) = score_corr_row {
2273 basis_span.c0 -= corr[j];
2274 }
2275 let coeffs = crate::families::bernoulli_marginal_slope::exact_kernel::score_basis_cell_coefficients(
2276 basis_span, slope,
2277 );
2278 let coeffs = scale_coeff4(coeffs, scale);
2279 f_h_row[j] += crate::families::bernoulli_marginal_slope::exact_kernel::cell_first_derivative_from_moments(
2280 &coeffs,
2281 &state.moments,
2282 )
2283 .map_err(EstimationError::InvalidInput)?;
2284 }
2285 }
2286
2287 if let Some(runtime) = self.link_deviation_runtime.as_ref() {
2288 for j in 0..link_dev_dim {
2289 let mut basis_span = runtime
2290 .basis_cubic_at(j, intercept + slope * mid)
2291 .map_err(EstimationError::InvalidInput)?;
2292 if let Some(corr) = link_corr_row {
2293 basis_span.c0 -= corr[j];
2294 }
2295 let coeffs = crate::families::bernoulli_marginal_slope::exact_kernel::link_basis_cell_coefficients(
2296 basis_span,
2297 intercept,
2298 slope,
2299 );
2300 let coeffs = scale_coeff4(coeffs, scale);
2301 f_w_row[j] += crate::families::bernoulli_marginal_slope::exact_kernel::cell_first_derivative_from_moments(
2302 &coeffs,
2303 &state.moments,
2304 )
2305 .map_err(EstimationError::InvalidInput)?;
2306 }
2307 }
2308 }
2309 }
2310 if let Some(a_h) = a_h.as_mut() {
2311 let factor = -1.0 / m_a;
2312 for j in 0..score_warp_dim {
2313 a_h[[local_row, j]] = factor * f_h_row[j];
2314 }
2315 }
2316 if let Some(a_w) = a_w.as_mut() {
2317 let factor = -1.0 / m_a;
2318 for j in 0..link_dev_dim {
2319 a_w[[local_row, j]] = factor * f_w_row[j];
2320 }
2321 }
2322 a_b.as_mut().unwrap()[local_row] = -f_b / m_a;
2323 }
2324
2325 Ok(FlexSolveChunk {
2326 start,
2327 end,
2328 intercepts,
2329 a_q,
2330 a_b,
2331 a_h,
2332 a_w,
2333 })
2334 })
2335 .collect::<Vec<_>>();
2336
2337 let mut intercepts = Array1::<f64>::zeros(n);
2338 let mut a_q_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
2339 let mut a_b_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
2340 let mut a_h_rows = if need_gradient && score_warp_dim > 0 {
2341 Some(Array2::<f64>::zeros((n, score_warp_dim)))
2342 } else {
2343 None
2344 };
2345 let mut a_w_rows = if need_gradient && link_dev_dim > 0 {
2346 Some(Array2::<f64>::zeros((n, link_dev_dim)))
2347 } else {
2348 None
2349 };
2350
2351 for solve_chunk in solve_chunks {
2352 let chunk = solve_chunk?;
2353 intercepts
2354 .slice_mut(ndarray::s![chunk.start..chunk.end])
2355 .assign(&chunk.intercepts);
2356 if let (Some(target), Some(source)) = (a_q_vec.as_mut(), chunk.a_q.as_ref()) {
2357 target
2358 .slice_mut(ndarray::s![chunk.start..chunk.end])
2359 .assign(source);
2360 }
2361 if let (Some(target), Some(source)) = (a_b_vec.as_mut(), chunk.a_b.as_ref()) {
2362 target
2363 .slice_mut(ndarray::s![chunk.start..chunk.end])
2364 .assign(source);
2365 }
2366 if let (Some(target), Some(source)) = (a_h_rows.as_mut(), chunk.a_h.as_ref()) {
2367 target
2368 .slice_mut(ndarray::s![chunk.start..chunk.end, ..])
2369 .assign(source);
2370 }
2371 if let (Some(target), Some(source)) = (a_w_rows.as_mut(), chunk.a_w.as_ref()) {
2372 target
2373 .slice_mut(ndarray::s![chunk.start..chunk.end, ..])
2374 .assign(source);
2375 }
2376 }
2377
2378 let eta_base = &intercepts + &(&logslope_eta * &z);
2379
2380 let mut link_c_obs: Option<Array1<f64>> = None;
2381 let mut link_basis_obs: Option<Array2<f64>> = None;
2382 let link_dev_obs = if let (Some(runtime), Some(beta_owned)) = (
2383 self.link_deviation_runtime.as_ref(),
2384 link_dev_beta_owned.as_ref(),
2385 ) {
2386 let basis = if runtime.anchor_residual_coefficients.is_some() {
2387 let anchor_rows = anchor_corrections.n_anchor_rows_view().ok_or_else(|| {
2388 EstimationError::InvalidInput(
2389 "bernoulli marginal-slope link-deviation anchor residual present but \
2390 anchor_corrections bundle is missing the parametric anchor rows"
2391 .to_string(),
2392 )
2393 })?;
2394 runtime
2395 .design_with_anchor_rows(&eta_base, anchor_rows)
2396 .map_err(EstimationError::InvalidInput)?
2397 } else {
2398 runtime
2399 .design(&eta_base)
2400 .map_err(EstimationError::InvalidInput)?
2401 };
2402 let dev = basis.dot(beta_owned);
2403 if need_gradient {
2404 let d1 = runtime
2405 .first_derivative_design(&eta_base)
2406 .map_err(EstimationError::InvalidInput)?;
2407 let mut c_obs = d1.dot(beta_owned);
2408 c_obs.mapv_inplace(|v| v + 1.0);
2409 link_c_obs = Some(c_obs);
2410 link_basis_obs = Some(basis);
2411 }
2412 dev
2413 } else {
2414 Array1::zeros(n)
2415 };
2416 let final_eta_internal =
2417 (&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
2418
2419 if !need_gradient {
2420 return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
2421 }
2422
2423 let a_q_vec = a_q_vec.unwrap();
2424 let a_b_vec = a_b_vec.unwrap();
2425
2426 struct FlexGradientChunk {
2428 start: usize,
2429 end: usize,
2430 grad: Array2<f64>,
2431 }
2432 let grad_chunks = (0..num_chunks)
2433 .into_par_iter()
2434 .map(|chunk_idx| -> Result<FlexGradientChunk, String> {
2435 let start = chunk_idx * chunk_size;
2436 let end = (start + chunk_size).min(n);
2437 let mc = input
2438 .design
2439 .try_row_chunk(start..end)
2440 .map_err(|e| e.to_string())?;
2441 let lc = design_logslope
2442 .try_row_chunk(start..end)
2443 .map_err(|e| e.to_string())?;
2444 let rows = end - start;
2445 let mut grad = Array2::<f64>::zeros((rows, theta.len()));
2446
2447 for li in 0..rows {
2448 let i = start + li;
2449 let mut row = grad.row_mut(li);
2450
2451 let a_q = a_q_vec[i];
2452 for j in 0..marginal_dim {
2453 row[j] = a_q * mc[[li, j]];
2454 }
2455
2456 let base_multiplier = link_c_obs.as_ref().map_or(1.0, |c| c[i]);
2457 let g_scale = base_multiplier * (a_b_vec[i] + z[i]) + score_dev_obs[i];
2458 for j in 0..logslope_dim {
2459 row[logslope_offset + j] = g_scale * lc[[li, j]];
2460 }
2461
2462 if let (Some(a_h_rows), Some(obs_design)) =
2463 (a_h_rows.as_ref(), score_warp_obs_design.as_ref())
2464 {
2465 let slope = logslope_eta[i];
2466 for j in 0..score_warp_dim {
2467 row[score_warp_offset + j] =
2468 base_multiplier * a_h_rows[[i, j]] + slope * obs_design[[i, j]];
2469 }
2470 }
2471
2472 if let Some(a_w_rows) = a_w_rows.as_ref() {
2473 for j in 0..link_dev_dim {
2474 row[link_dev_offset + j] = a_w_rows[[i, j]];
2475 }
2476 }
2477
2478 if let (Some(link_c), Some(link_basis)) =
2479 (link_c_obs.as_ref(), link_basis_obs.as_ref())
2480 {
2481 let c = link_c[i];
2482 for j in 0..marginal_dim {
2483 row[j] *= c;
2484 }
2485 for j in 0..link_dev_dim {
2486 row[link_dev_offset + j] =
2487 c * row[link_dev_offset + j] + link_basis[[i, j]];
2488 }
2489 }
2490 }
2491
2492 Ok(FlexGradientChunk { start, end, grad })
2493 })
2494 .collect::<Result<Vec<_>, String>>()
2495 .map_err(EstimationError::InvalidInput)?;
2496 let mut grad = Array2::<f64>::zeros((n, theta.len()));
2497 for chunk in grad_chunks {
2498 grad.slice_mut(ndarray::s![chunk.start..chunk.end, ..])
2499 .assign(&chunk.grad);
2500 }
2501 if scale != 1.0 {
2502 grad.mapv_inplace(|v| scale * v);
2503 }
2504 self.transform_internal_eta_to_base_scale(final_eta_internal, Some(grad))
2505 }
2506
2507 fn final_eta_from_theta(
2508 &self,
2509 input: &PredictInput,
2510 theta: &Array1<f64>,
2511 ) -> Result<Array1<f64>, EstimationError> {
2512 let (eta, _) = self.final_eta_and_gradient_from_theta(input, theta, false)?;
2513 Ok(eta)
2514 }
2515
2516 fn eta_standard_error_from_covariance(
2517 &self,
2518 input: &PredictInput,
2519 covariance: &Array2<f64>,
2520 ) -> Result<Array1<f64>, EstimationError> {
2521 let theta = self.theta();
2522 let backend = PredictionCovarianceBackend::from_dense(covariance.view());
2523 linear_predictor_se_from_backend(&backend, input.design.nrows(), |rows| {
2524 let chunk_input = slice_predict_input(input, rows).map_err(|e| e.to_string())?;
2525 let (_, grad) = self
2526 .final_eta_and_gradient_from_theta(&chunk_input, &theta, true)
2527 .map_err(|e| e.to_string())?;
2528 let grad = grad.ok_or_else(|| {
2529 "bernoulli marginal-slope analytic predictor gradient was not produced".to_string()
2530 })?;
2531 Ok(vec![grad])
2532 })
2533 }
2534
2535 fn eta_standard_error(
2536 &self,
2537 input: &PredictInput,
2538 fit: &UnifiedFitResult,
2539 ) -> Result<Array1<f64>, EstimationError> {
2540 let theta = self.theta();
2541 let backend = require_posterior_mean_backend(
2542 fit,
2543 self.covariance.as_ref(),
2544 theta.len(),
2545 "bernoulli marginal-slope posterior mean",
2546 )?;
2547 linear_predictor_se_from_backend(&backend, input.design.nrows(), |rows| {
2548 let chunk_input = slice_predict_input(input, rows).map_err(|e| e.to_string())?;
2549 let (_, grad) = self
2550 .final_eta_and_gradient_from_theta(&chunk_input, &theta, true)
2551 .map_err(|e| e.to_string())?;
2552 let grad = grad.ok_or_else(|| {
2553 "bernoulli marginal-slope analytic predictor gradient was not produced".to_string()
2554 })?;
2555 Ok(vec![grad])
2556 })
2557 }
2558
2559 pub fn predict_eta_and_q_chain(
2576 &self,
2577 input: &PredictInput,
2578 ) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
2579 let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
2580 EstimationError::InvalidInput(format!(
2581 "bernoulli marginal-slope prediction requires auxiliary z column '{}'",
2582 self.z_column
2583 ))
2584 })?;
2585 let z_normalized = self
2586 .latent_z_normalization
2587 .apply(z_raw, "bernoulli marginal-slope prediction")
2588 .map_err(EstimationError::InvalidInput)?;
2589 let z = self.apply_latent_z_calibration(&z_normalized);
2595 let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
2596 EstimationError::InvalidInput(
2597 "bernoulli marginal-slope prediction requires logslope design".to_string(),
2598 )
2599 })?;
2600 let n = z.len();
2601 if input.offset.len() != n {
2602 return Err(EstimationError::InvalidInput(format!(
2603 "bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
2604 input.offset.len()
2605 )));
2606 }
2607 let logslope_offset = input
2608 .offset_noise
2609 .as_ref()
2610 .map_or_else(|| Array1::zeros(n), Clone::clone);
2611 if logslope_offset.len() != n {
2612 return Err(EstimationError::InvalidInput(format!(
2613 "bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
2614 logslope_offset.len()
2615 )));
2616 }
2617 let marginal_eta = input
2618 .design
2619 .dot(&self.beta_marginal)
2620 .mapv(|v| v + self.baseline_marginal)
2621 + &input.offset;
2622 let logslope_eta = design_logslope
2623 .dot(&self.beta_logslope)
2624 .mapv(|v| v + self.baseline_logslope)
2625 + &logslope_offset;
2626 let scale = self.probit_frailty_scale();
2627 let flex_active =
2628 self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
2629
2630 if !flex_active {
2633 match &self.latent_measure {
2634 LatentMeasureKind::StandardNormal => {
2635 let sb = logslope_eta.mapv(|x| scale * x);
2638 let deta_dq = sb.mapv(|s| (1.0 + s * s).sqrt());
2639 let eta = &deta_dq * marginal_eta + &sb * z;
2640 return Ok((eta, deta_dq));
2641 }
2642 _ => {
2643 let mut eta = Array1::<f64>::zeros(n);
2644 let mut deta_dq = Array1::<f64>::zeros(n);
2645 for i in 0..n {
2646 let grid = self
2647 .empirical_grid_for_prediction_row(input, i)?
2648 .ok_or_else(|| {
2649 EstimationError::InvalidInput(
2650 "empirical latent prediction did not produce a row grid"
2651 .to_string(),
2652 )
2653 })?;
2654 let (intercept, a_marginal, _) = self
2655 .empirical_rigid_intercept_and_gradient(
2656 marginal_eta[i],
2657 logslope_eta[i],
2658 &grid.nodes,
2659 &grid.weights,
2660 )?;
2661 eta[i] = intercept + scale * logslope_eta[i] * z[i];
2662 deta_dq[i] = a_marginal;
2663 }
2664 return Ok((eta, deta_dq));
2665 }
2666 }
2667 }
2668
2669 let marginal_map = marginal_eta
2675 .iter()
2676 .map(|&eta_marg| {
2677 bernoulli_marginal_link_map(&self.base_link, eta_marg)
2678 .map_err(EstimationError::InvalidInput)
2679 })
2680 .collect::<Result<Vec<_>, _>>()?;
2681 let anchor_corrections = self.build_anchor_correction_matrices(input, design_logslope)?;
2684 use rayon::iter::{IntoParallelIterator, ParallelIterator};
2688 let pairs: Result<Vec<(f64, f64)>, EstimationError> = (0..n)
2689 .into_par_iter()
2690 .map_init(
2691 || Array1::<f64>::zeros(1),
2692 |warm_start_buf, i| {
2693 let q = marginal_eta[i];
2694 let slope = logslope_eta[i];
2695 let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
2696 let score_corr_row = anchor_corrections.score_warp_row(i);
2697 let link_corr_row = anchor_corrections.link_dev_row(i);
2698 let intercept = self.solve_intercept_scalar(
2699 q,
2700 slope,
2701 self.beta_link_dev.as_ref(),
2702 self.beta_score_warp.as_ref(),
2703 empirical_grid.as_ref(),
2704 warm_start_buf,
2705 score_corr_row,
2706 link_corr_row,
2707 )?;
2708 let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
2709 intercept,
2710 q,
2711 slope,
2712 self.beta_score_warp.as_ref(),
2713 self.beta_link_dev.as_ref(),
2714 empirical_grid.as_ref(),
2715 score_corr_row,
2716 link_corr_row,
2717 )?;
2718 let m_a = m_a_raw.max(1e-12);
2719 Ok((intercept, marginal_map[i].mu1 / m_a))
2720 },
2721 )
2722 .collect();
2723 let pairs = pairs?;
2724 let mut intercepts = Array1::<f64>::zeros(n);
2725 let mut a_q = Array1::<f64>::zeros(n);
2726 for (i, (intercept, a)) in pairs.into_iter().enumerate() {
2727 intercepts[i] = intercept;
2728 a_q[i] = a;
2729 }
2730
2731 let score_dev_obs = if let (Some(runtime), Some(beta)) = (
2732 self.score_warp_runtime.as_ref(),
2733 self.beta_score_warp.as_ref(),
2734 ) {
2735 let design = if runtime.anchor_residual_coefficients.is_some() {
2736 let anchor_rows = anchor_corrections.n_anchor_rows_view().ok_or_else(|| {
2737 EstimationError::InvalidInput(
2738 "bernoulli marginal-slope score-warp anchor residual present but \
2739 anchor_corrections bundle is missing the parametric anchor rows"
2740 .to_string(),
2741 )
2742 })?;
2743 runtime
2744 .design_with_anchor_rows(&z, anchor_rows)
2745 .map_err(EstimationError::InvalidInput)?
2746 } else {
2747 runtime.design(&z).map_err(EstimationError::InvalidInput)?
2748 };
2749 design.dot(beta)
2750 } else {
2751 Array1::zeros(n)
2752 };
2753 let eta_base = &intercepts + &(&logslope_eta * &z);
2754 let (link_dev_obs, link_c_obs) = if let (Some(runtime), Some(beta)) = (
2755 self.link_deviation_runtime.as_ref(),
2756 self.beta_link_dev.as_ref(),
2757 ) {
2758 let basis = if runtime.anchor_residual_coefficients.is_some() {
2759 let anchor_rows = anchor_corrections.n_anchor_rows_view().ok_or_else(|| {
2760 EstimationError::InvalidInput(
2761 "bernoulli marginal-slope link-deviation anchor residual present but \
2762 anchor_corrections bundle is missing the parametric anchor rows"
2763 .to_string(),
2764 )
2765 })?;
2766 runtime
2767 .design_with_anchor_rows(&eta_base, anchor_rows)
2768 .map_err(EstimationError::InvalidInput)?
2769 } else {
2770 runtime
2771 .design(&eta_base)
2772 .map_err(EstimationError::InvalidInput)?
2773 };
2774 let dev = basis.dot(beta);
2775 let d1 = runtime
2776 .first_derivative_design(&eta_base)
2777 .map_err(EstimationError::InvalidInput)?;
2778 let mut c_obs = d1.dot(beta);
2779 c_obs.mapv_inplace(|v| v + 1.0);
2780 (dev, c_obs)
2781 } else {
2782 (Array1::zeros(n), Array1::ones(n))
2783 };
2784 let final_eta_internal =
2785 (&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
2786 let deta_dq = (&link_c_obs * &a_q).mapv(|v| scale * v);
2787 Ok((final_eta_internal, deta_dq))
2788 }
2789}
2790
2791impl PredictableModel for BernoulliMarginalSlopePredictor {
2792 fn predict_plugin_response(
2793 &self,
2794 input: &PredictInput,
2795 ) -> Result<PredictResult, EstimationError> {
2796 let eta = self.final_eta_from_theta(input, &self.theta())?;
2797 let mean = self.mean_from_eta(&eta)?;
2798 Ok(PredictResult { eta, mean })
2799 }
2800
2801 fn predict_with_uncertainty(
2802 &self,
2803 input: &PredictInput,
2804 ) -> Result<PredictionWithSE, EstimationError> {
2805 let plugin = self.predict_plugin_response(input)?;
2806 let (eta_se, mean_se) = if let Some(covariance) = self.covariance.as_ref() {
2807 let theta = self.theta();
2808 if covariance.nrows() != theta.len() || covariance.ncols() != theta.len() {
2809 return Err(EstimationError::InvalidInput(format!(
2810 "bernoulli marginal-slope covariance dimension mismatch: expected {}x{}, got {}x{}",
2811 theta.len(),
2812 theta.len(),
2813 covariance.nrows(),
2814 covariance.ncols()
2815 )));
2816 }
2817 let eta_se = self.eta_standard_error_from_covariance(input, covariance)?;
2818 let mean_se = eta_se.clone() * self.mean_derivative_from_eta(&plugin.eta)?;
2819 (Some(eta_se), Some(mean_se))
2820 } else {
2821 (None, None)
2822 };
2823 Ok(PredictionWithSE {
2824 eta: plugin.eta,
2825 mean: plugin.mean,
2826 eta_se,
2827 mean_se,
2828 })
2829 }
2830
2831 fn predict_noise_scale(
2832 &self,
2833 _: &PredictInput,
2834 ) -> Result<Option<Array1<f64>>, EstimationError> {
2835 Ok(None)
2836 }
2837
2838 fn predict_full_uncertainty(
2839 &self,
2840 input: &PredictInput,
2841 fit: &UnifiedFitResult,
2842 options: &PredictUncertaintyOptions,
2843 ) -> Result<PredictUncertaintyResult, EstimationError> {
2844 let plugin = self.predict_plugin_response(input)?;
2845 let eta_se = self.eta_standard_error(input, fit)?;
2846 let zcrit = standard_normal_quantile(0.5 + options.confidence_level * 0.5)
2847 .map_err(EstimationError::InvalidInput)?;
2848 let eta_lower = &plugin.eta - &eta_se.mapv(|s| zcrit * s);
2849 let eta_upper = &plugin.eta + &eta_se.mapv(|s| zcrit * s);
2850 let mean_lower = self.mean_from_eta(&eta_lower)?;
2851 let mean_upper = self.mean_from_eta(&eta_upper)?;
2852 let mean_se = eta_se.clone() * self.mean_derivative_from_eta(&plugin.eta)?;
2853 Ok(PredictUncertaintyResult {
2854 eta: plugin.eta,
2855 mean: plugin.mean,
2856 eta_standard_error: eta_se.clone(),
2857 mean_standard_error: mean_se,
2858 eta_lower,
2859 eta_upper,
2860 mean_lower,
2861 mean_upper,
2862 observation_lower: None,
2863 observation_upper: None,
2864 covariance_mode_requested: options.covariance_mode,
2865 covariance_corrected_used: false,
2866 })
2867 }
2868
2869 fn predict_posterior_mean(
2870 &self,
2871 input: &PredictInput,
2872 fit: &UnifiedFitResult,
2873 confidence_level: Option<f64>,
2874 ) -> Result<PredictPosteriorMeanResult, EstimationError> {
2875 let plugin = self.predict_plugin_response(input)?;
2876 let eta_se = self.eta_standard_error(input, fit)?;
2877 let strategy = strategy_for_family(self.likelihood_family(), Some(&self.base_link));
2878 let quadctx = crate::quadrature::QuadratureContext::new();
2879 let mean = Array1::from_iter(
2880 plugin
2881 .eta
2882 .iter()
2883 .zip(eta_se.iter())
2884 .map(|(&eta, &se)| strategy.posterior_mean(&quadctx, eta, se))
2885 .collect::<Result<Vec<_>, _>>()?,
2886 );
2887 let (mean_lower, mean_upper) = if let Some(level) = confidence_level {
2888 let z = standard_normal_quantile(0.5 + 0.5 * level)
2889 .map_err(EstimationError::InvalidInput)?;
2890 let eta_lower = &plugin.eta - &eta_se.mapv(|s| z * s);
2891 let eta_upper = &plugin.eta + &eta_se.mapv(|s| z * s);
2892 (
2893 Some(self.mean_from_eta(&eta_lower)?),
2894 Some(self.mean_from_eta(&eta_upper)?),
2895 )
2896 } else {
2897 (None, None)
2898 };
2899 Ok(PredictPosteriorMeanResult {
2900 eta: plugin.eta,
2901 eta_standard_error: eta_se,
2902 mean,
2903 mean_lower,
2904 mean_upper,
2905 })
2906 }
2907
2908 fn n_blocks(&self) -> usize {
2909 2 + usize::from(self.beta_score_warp.is_some()) + usize::from(self.beta_link_dev.is_some())
2910 }
2911
2912 fn block_roles(&self) -> Vec<BlockRole> {
2913 let mut roles = vec![BlockRole::Location, BlockRole::Scale];
2914 if self.beta_score_warp.is_some() {
2915 roles.push(BlockRole::Mean);
2916 }
2917 if self.beta_link_dev.is_some() {
2918 roles.push(BlockRole::LinkWiggle);
2919 }
2920 roles
2921 }
2922}
2923
2924pub struct GaussianLocationScalePredictor {
2929 pub beta_mu: Array1<f64>,
2930 pub beta_noise: Array1<f64>,
2931 pub response_scale: f64,
2932 pub covariance: Option<Array2<f64>>,
2933 pub link_wiggle: Option<SavedLinkWiggleRuntime>,
2934}
2935
2936impl GaussianLocationScalePredictor {
2937 fn compute_sigma(
2945 &self,
2946 design_noise: &DesignMatrix,
2947 offset_noise: Option<&Array1<f64>>,
2948 ) -> Result<Array1<f64>, EstimationError> {
2949 let mut eta_noise = design_noise.dot(&self.beta_noise);
2950 if let Some(offset_noise) = offset_noise {
2951 if offset_noise.len() != eta_noise.len() {
2952 return Err(EstimationError::InvalidInput(format!(
2953 "gaussian location-scale noise offset length mismatch: expected {}, got {}",
2954 eta_noise.len(),
2955 offset_noise.len()
2956 )));
2957 }
2958 eta_noise += offset_noise;
2959 }
2960 let scale = self.response_scale;
2961 Ok(eta_noise
2962 .mapv(|eta| crate::families::sigma_link::logb_sigma_from_eta_scalar(eta) * scale))
2963 }
2964
2965 fn eta_standard_error(
2966 &self,
2967 input: &PredictInput,
2968 fit: &UnifiedFitResult,
2969 eta_len: usize,
2970 ) -> Result<Array1<f64>, EstimationError> {
2971 let backend = require_posterior_mean_backend(
2972 fit,
2973 self.covariance.as_ref(),
2974 self.beta_mu.len()
2975 + self.beta_noise.len()
2976 + self.link_wiggle.as_ref().map_or(0, |w| w.beta.len()),
2977 "gaussian location-scale posterior mean",
2978 )?;
2979 let p_mu = self.beta_mu.len();
2980 let p_sigma = self.beta_noise.len();
2981 let p_w = self.link_wiggle.as_ref().map_or(0, |w| w.beta.len());
2982 let p_total = p_mu + p_sigma + p_w;
2983 if backend.nrows() != p_total {
2984 return Err(EstimationError::InvalidInput(format!(
2985 "gaussian location-scale covariance mismatch: expected parameter dimension {}, got {}",
2986 p_total,
2987 backend.nrows()
2988 )));
2989 }
2990 self.eta_standard_error_from_backend(input, &backend, eta_len, p_mu, p_sigma, p_w)
2991 }
2992
2993 fn eta_standard_error_from_backend(
2994 &self,
2995 input: &PredictInput,
2996 backend: &PredictionCovarianceBackend<'_>,
2997 eta_len: usize,
2998 p_mu: usize,
2999 p_sigma: usize,
3000 p_w: usize,
3001 ) -> Result<Array1<f64>, EstimationError> {
3002 let p_total = p_mu + p_sigma + p_w;
3003 if backend.nrows() != p_total {
3004 return Err(EstimationError::InvalidInput(format!(
3005 "gaussian location-scale covariance mismatch: expected parameter dimension {}, got {}",
3006 p_total,
3007 backend.nrows()
3008 )));
3009 }
3010 if let Some(runtime) = self.link_wiggle.as_ref() {
3011 let eta_base = input.design.dot(&self.beta_mu) + &input.offset;
3012 linear_predictor_se_from_backend(&backend, eta_len, |rows| {
3013 let q0_chunk = eta_base.slice(ndarray::s![rows.clone()]).to_owned();
3014 let x_mu = design_row_chunk(&input.design, rows.clone())?;
3015 let wiggle_design = runtime.design(&q0_chunk)?;
3016 let dq_dq0 = runtime.derivative_q0(&q0_chunk)?;
3017 let rows_in_chunk = q0_chunk.len();
3018 let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
3019 for i in 0..rows_in_chunk {
3020 for j in 0..p_mu {
3021 grad[[i, j]] = dq_dq0[i] * x_mu[[i, j]];
3022 }
3023 }
3024 grad.slice_mut(ndarray::s![.., p_mu + p_sigma..p_total])
3025 .assign(&wiggle_design);
3026 Ok(vec![grad])
3027 })
3028 } else {
3029 padded_design_standard_errors_from_backend(
3030 &input.design,
3031 &backend,
3032 0,
3033 p_sigma + p_w,
3034 "gaussian location-scale posterior mean",
3035 )
3036 }
3037 }
3038}
3039
3040impl PredictableModel for GaussianLocationScalePredictor {
3041 fn predict_plugin_response(
3042 &self,
3043 input: &PredictInput,
3044 ) -> Result<PredictResult, EstimationError> {
3045 let eta_base = input.design.dot(&self.beta_mu) + &input.offset;
3046 let eta = if let Some(runtime) = self.link_wiggle.as_ref() {
3047 runtime
3048 .apply(&eta_base)
3049 .map_err(EstimationError::InvalidInput)?
3050 } else {
3051 eta_base
3052 };
3053 let mean = eta.clone();
3055 Ok(PredictResult { eta, mean })
3056 }
3057
3058 fn predict_with_uncertainty(
3059 &self,
3060 input: &PredictInput,
3061 ) -> Result<PredictionWithSE, EstimationError> {
3062 let result = self.predict_plugin_response(input)?;
3063 let (eta_se, mean_se) = if let Some(covariance) = self.covariance.as_ref() {
3064 let p_mu = self.beta_mu.len();
3065 let p_sigma = self.beta_noise.len();
3066 let p_w = self.link_wiggle.as_ref().map_or(0, |w| w.beta.len());
3067 let backend = PredictionCovarianceBackend::from_dense(covariance.view());
3068 let eta_se = self.eta_standard_error_from_backend(
3069 input,
3070 &backend,
3071 result.eta.len(),
3072 p_mu,
3073 p_sigma,
3074 p_w,
3075 )?;
3076 (Some(eta_se.clone()), Some(eta_se))
3077 } else {
3078 (None, None)
3079 };
3080 Ok(PredictionWithSE {
3081 eta: result.eta,
3082 mean: result.mean,
3083 eta_se,
3084 mean_se,
3085 })
3086 }
3087
3088 fn predict_noise_scale(
3089 &self,
3090 input: &PredictInput,
3091 ) -> Result<Option<Array1<f64>>, EstimationError> {
3092 let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3093 EstimationError::InvalidInput(
3094 "Gaussian location-scale prediction requires noise design matrix".to_string(),
3095 )
3096 })?;
3097 self.compute_sigma(design_noise, input.offset_noise.as_ref())
3098 .map(Some)
3099 }
3100
3101 fn predict_full_uncertainty(
3102 &self,
3103 input: &PredictInput,
3104 fit: &UnifiedFitResult,
3105 options: &PredictUncertaintyOptions,
3106 ) -> Result<PredictUncertaintyResult, EstimationError> {
3107 let pred = self.predict_plugin_response(input)?;
3108 let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3109 EstimationError::InvalidInput(
3110 "Gaussian location-scale prediction requires noise design matrix".to_string(),
3111 )
3112 })?;
3113 let sigma = self.compute_sigma(design_noise, input.offset_noise.as_ref())?;
3114 let eta_se = self.eta_standard_error(input, fit, pred.eta.len())?;
3115 let z = crate::probability::standard_normal_quantile(0.5 + options.confidence_level * 0.5)
3116 .map_err(|e| EstimationError::InvalidInput(e))?;
3117 let eta_lower = &pred.eta - &eta_se.mapv(|s| z * s);
3118 let eta_upper = &pred.eta + &eta_se.mapv(|s| z * s);
3119 Ok(PredictUncertaintyResult {
3120 eta: pred.eta.clone(),
3121 mean: pred.mean.clone(),
3122 eta_standard_error: eta_se.clone(),
3123 mean_standard_error: eta_se.clone(),
3124 eta_lower: eta_lower.clone(),
3125 eta_upper: eta_upper.clone(),
3126 mean_lower: eta_lower,
3127 mean_upper: eta_upper,
3128 observation_lower: options
3129 .includeobservation_interval
3130 .then(|| &pred.mean - &sigma.mapv(|s| z * s)),
3131 observation_upper: options
3132 .includeobservation_interval
3133 .then(|| &pred.mean + &sigma.mapv(|s| z * s)),
3134 covariance_mode_requested: options.covariance_mode,
3135 covariance_corrected_used: false,
3136 })
3137 }
3138
3139 fn predict_posterior_mean(
3140 &self,
3141 input: &PredictInput,
3142 fit: &UnifiedFitResult,
3143 confidence_level: Option<f64>,
3144 ) -> Result<PredictPosteriorMeanResult, EstimationError> {
3145 let result = self.predict_plugin_response(input)?;
3146 let eta_se = self.eta_standard_error(input, fit, result.eta.len())?;
3147 let (mean_lower, mean_upper) = if let Some(level) = confidence_level {
3149 let z = standard_normal_quantile(0.5 + 0.5 * level)
3150 .map_err(EstimationError::InvalidInput)?;
3151 (
3152 Some(&result.eta - &eta_se.mapv(|s| z * s)),
3153 Some(&result.eta + &eta_se.mapv(|s| z * s)),
3154 )
3155 } else {
3156 (None, None)
3157 };
3158 Ok(PredictPosteriorMeanResult {
3159 eta: result.eta,
3160 eta_standard_error: eta_se,
3161 mean: result.mean,
3162 mean_lower,
3163 mean_upper,
3164 })
3165 }
3166
3167 fn n_blocks(&self) -> usize {
3168 if self.link_wiggle.is_some() { 3 } else { 2 }
3169 }
3170
3171 fn block_roles(&self) -> Vec<BlockRole> {
3172 if self.link_wiggle.is_some() {
3173 vec![BlockRole::Location, BlockRole::Scale, BlockRole::LinkWiggle]
3174 } else {
3175 vec![BlockRole::Location, BlockRole::Scale]
3176 }
3177 }
3178}
3179
3180pub struct BinomialLocationScalePredictor {
3192 pub beta_threshold: Array1<f64>,
3193 pub beta_noise: Array1<f64>,
3194 pub covariance: Option<Array2<f64>>,
3195 pub inverse_link: InverseLink,
3196 pub link_wiggle: Option<SavedLinkWiggleRuntime>,
3197}
3198
3199impl BinomialLocationScalePredictor {
3200 fn compute_q0_and_sigma(
3205 &self,
3206 input: &PredictInput,
3207 ) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
3208 let eta_t = input.design.dot(&self.beta_threshold) + &input.offset;
3209 let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3210 EstimationError::InvalidInput(
3211 "Binomial location-scale prediction requires noise design matrix".to_string(),
3212 )
3213 })?;
3214 let offset_noise = input
3215 .offset_noise
3216 .as_ref()
3217 .map_or_else(|| Array1::zeros(design_noise.nrows()), |o| o.clone());
3218 let eta_s = design_noise.dot(&self.beta_noise) + &offset_noise;
3219 let sigma = eta_s.mapv(|v| v.exp().max(f64::MIN_POSITIVE));
3221 let q0 = Array1::from_shape_fn(eta_t.len(), |i| (-eta_t[i] / sigma[i]).clamp(-1e6, 1e6));
3222 Ok((q0, sigma, eta_t))
3223 }
3224
3225 fn apply_link(&self, q0: &Array1<f64>) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
3227 let eta = if let Some(runtime) = self.link_wiggle.as_ref() {
3228 runtime.apply(q0).map_err(EstimationError::InvalidInput)?
3229 } else {
3230 q0.clone()
3231 };
3232 use rayon::iter::{IntoParallelIterator, ParallelIterator};
3233 let n = eta.len();
3234 let prob_vec: Result<Vec<f64>, EstimationError> = (0..n)
3235 .into_par_iter()
3236 .map(|i| {
3237 let jet = crate::solver::mixture_link::inverse_link_jet_for_inverse_link(
3238 &self.inverse_link,
3239 eta[i],
3240 )?;
3241 Ok(jet.mu.clamp(0.0, 1.0))
3242 })
3243 .collect();
3244 let prob = Array1::from_vec(prob_vec?);
3245 Ok((eta, prob))
3246 }
3247}
3248
3249impl PredictableModel for BinomialLocationScalePredictor {
3250 fn predict_plugin_response(
3251 &self,
3252 input: &PredictInput,
3253 ) -> Result<PredictResult, EstimationError> {
3254 let (q0_base, _, _) = self.compute_q0_and_sigma(input)?;
3255 let (eta, prob) = self.apply_link(&q0_base)?;
3256 Ok(PredictResult { eta, mean: prob })
3257 }
3258
3259 fn predict_with_uncertainty(
3260 &self,
3261 input: &PredictInput,
3262 ) -> Result<PredictionWithSE, EstimationError> {
3263 let (q0_base, sigma, eta_t) = self.compute_q0_and_sigma(input)?;
3264 let (eta, prob) = self.apply_link(&q0_base)?;
3265
3266 let mean_se = if let Some(ref cov) = self.covariance {
3267 let n = eta_t.len();
3268 let p_t = self.beta_threshold.len();
3269 let p_s = self.beta_noise.len();
3270 let p_w = self.link_wiggle.as_ref().map_or(0, |w| w.beta.len());
3271 let p_total = p_t + p_s + p_w;
3272 let backend = PredictionCovarianceBackend::from_dense(cov.view());
3273 if backend.nrows() != p_total {
3274 return Err(EstimationError::InvalidInput(format!(
3275 "covariance dimension mismatch for binomial LS: expected parameter dimension {}, got {}",
3276 p_total,
3277 backend.nrows()
3278 )));
3279 }
3280
3281 let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3282 EstimationError::InvalidInput(
3283 "binomial location-scale uncertainty requires noise design matrix".to_string(),
3284 )
3285 })?;
3286 Some(linear_predictor_se_from_backend(&backend, n, |rows| {
3287 let x_t = design_row_chunk(&input.design, rows.clone())?;
3288 let x_s = design_row_chunk(design_noise, rows.clone())?;
3289 let eta_chunk = eta.slice(ndarray::s![rows.clone()]).to_owned();
3290 let q0_chunk = q0_base.slice(ndarray::s![rows.clone()]).to_owned();
3291 let sigma_chunk = sigma.slice(ndarray::s![rows.clone()]).to_owned();
3292 let eta_t_chunk = eta_t.slice(ndarray::s![rows.clone()]).to_owned();
3293 let wiggle_design = if let Some(runtime) = self.link_wiggle.as_ref() {
3294 Some(runtime.design(&q0_chunk)?)
3295 } else {
3296 None
3297 };
3298 let dq_dq0 = if let Some(runtime) = self.link_wiggle.as_ref() {
3299 runtime.derivative_q0(&q0_chunk)?
3300 } else {
3301 Array1::ones(q0_chunk.len())
3302 };
3303 let rows_in_chunk = q0_chunk.len();
3304 let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
3305 for i in 0..rows_in_chunk {
3306 let jet = crate::solver::mixture_link::inverse_link_jet_for_inverse_link(
3307 &self.inverse_link,
3308 eta_chunk[i],
3309 )
3310 .map_err(|e| e.to_string())?;
3311 let dphi = jet.d1;
3312 let scale = dq_dq0[i];
3313 let dprob_deta_t = dphi * scale * (-1.0 / sigma_chunk[i]);
3314 let dprob_deta_s = dphi * scale * (eta_t_chunk[i] / sigma_chunk[i]);
3316 for j in 0..p_t {
3317 grad[[i, j]] = dprob_deta_t * x_t[[i, j]];
3318 }
3319 for j in 0..p_s {
3320 grad[[i, p_t + j]] = dprob_deta_s * x_s[[i, j]];
3321 }
3322 if let Some(wd) = wiggle_design.as_ref() {
3323 for j in 0..p_w {
3324 grad[[i, p_t + p_s + j]] = dphi * wd[[i, j]];
3325 }
3326 }
3327 }
3328 Ok(vec![grad])
3329 })?)
3330 } else {
3331 None
3332 };
3333
3334 Ok(PredictionWithSE {
3335 eta,
3336 mean: prob,
3337 eta_se: None,
3338 mean_se,
3339 })
3340 }
3341
3342 fn predict_noise_scale(
3343 &self,
3344 _: &PredictInput,
3345 ) -> Result<Option<Array1<f64>>, EstimationError> {
3346 Ok(None)
3347 }
3348
3349 fn predict_full_uncertainty(
3350 &self,
3351 input: &PredictInput,
3352 _: &UnifiedFitResult,
3353 options: &PredictUncertaintyOptions,
3354 ) -> Result<PredictUncertaintyResult, EstimationError> {
3355 let pred = self.predict_with_uncertainty(input)?;
3356 let z = standard_normal_quantile(0.5 + options.confidence_level * 0.5)
3357 .map_err(EstimationError::InvalidInput)?;
3358
3359 let mean_se = pred
3360 .mean_se
3361 .as_ref()
3362 .cloned()
3363 .unwrap_or_else(|| Array1::zeros(pred.mean.len()));
3364
3365 let mut mean_lower = &pred.mean - &mean_se.mapv(|s| z * s);
3366 let mut mean_upper = &pred.mean + &mean_se.mapv(|s| z * s);
3367 mean_lower.mapv_inplace(|v| v.clamp(0.0, 1.0));
3369 mean_upper.mapv_inplace(|v| v.clamp(0.0, 1.0));
3370
3371 Ok(PredictUncertaintyResult {
3375 eta: pred.eta.clone(),
3376 mean: pred.mean.clone(),
3377 eta_standard_error: mean_se.clone(),
3378 mean_standard_error: mean_se,
3379 eta_lower: pred.eta.clone(),
3380 eta_upper: pred.eta,
3381 mean_lower,
3382 mean_upper,
3383 observation_lower: None,
3384 observation_upper: None,
3385 covariance_mode_requested: options.covariance_mode,
3386 covariance_corrected_used: false,
3387 })
3388 }
3389
3390 fn predict_posterior_mean(
3391 &self,
3392 input: &PredictInput,
3393 fit: &UnifiedFitResult,
3394 confidence_level: Option<f64>,
3395 ) -> Result<PredictPosteriorMeanResult, EstimationError> {
3396 let (q0_base, sigma, eta_t) = self.compute_q0_and_sigma(input)?;
3402 let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3403 EstimationError::InvalidInput(
3404 "Binomial location-scale posterior mean requires noise design matrix".to_string(),
3405 )
3406 })?;
3407 let offset_noise = input
3408 .offset_noise
3409 .as_ref()
3410 .map_or_else(|| Array1::zeros(design_noise.nrows()), |o| o.clone());
3411 let eta_s = design_noise.dot(&self.beta_noise) + &offset_noise;
3412 let (eta, _) = self.apply_link(&q0_base)?;
3413 let p_t = self.beta_threshold.len();
3414 let p_s = self.beta_noise.len();
3415 let p_w = self.link_wiggle.as_ref().map_or(0, |w| w.beta.len());
3416 let p_total = p_t + p_s + p_w;
3417 let backend = require_posterior_mean_backend(
3418 fit,
3419 self.covariance.as_ref(),
3420 p_total,
3421 "binomial location-scale posterior mean",
3422 )?;
3423
3424 let eta_se = linear_predictor_se_from_backend(&backend, eta_t.len(), |rows| {
3425 let x_t = design_row_chunk(&input.design, rows.clone())?;
3426 let x_s = design_row_chunk(design_noise, rows.clone())?;
3427 let eta_chunk = eta.slice(ndarray::s![rows.clone()]).to_owned();
3428 let q0_chunk = q0_base.slice(ndarray::s![rows.clone()]).to_owned();
3429 let sigma_chunk = sigma.slice(ndarray::s![rows.clone()]).to_owned();
3430 let eta_t_chunk = eta_t.slice(ndarray::s![rows.clone()]).to_owned();
3431 let wiggle_design = if let Some(runtime) = self.link_wiggle.as_ref() {
3432 Some(runtime.design(&q0_chunk)?)
3433 } else {
3434 None
3435 };
3436 let dq_dq0 = if let Some(runtime) = self.link_wiggle.as_ref() {
3437 runtime.derivative_q0(&q0_chunk)?
3438 } else {
3439 Array1::ones(q0_chunk.len())
3440 };
3441 let rows_in_chunk = q0_chunk.len();
3442 let row_gradients: Result<Vec<Vec<f64>>, String> = (0..rows_in_chunk)
3443 .into_par_iter()
3444 .map(|i| {
3445 let jet = crate::solver::mixture_link::inverse_link_jet_for_inverse_link(
3446 &self.inverse_link,
3447 eta_chunk[i],
3448 )
3449 .map_err(|e| e.to_string())?;
3450 let dphi = jet.d1;
3451 let scale = dq_dq0[i];
3452 let dprob_deta_t = dphi * scale * (-1.0 / sigma_chunk[i]);
3453 let dprob_deta_s = dphi * scale * (eta_t_chunk[i] / sigma_chunk[i]);
3454 let mut row = vec![0.0; p_total];
3455 for j in 0..p_t {
3456 row[j] = dprob_deta_t * x_t[[i, j]];
3457 }
3458 for j in 0..p_s {
3459 row[p_t + j] = dprob_deta_s * x_s[[i, j]];
3460 }
3461 if let Some(wd) = wiggle_design.as_ref() {
3462 for j in 0..p_w {
3463 row[p_t + p_s + j] = dphi * wd[[i, j]];
3464 }
3465 }
3466 Ok(row)
3467 })
3468 .collect();
3469 let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
3470 for (i, row) in row_gradients?.into_iter().enumerate() {
3471 for (j, value) in row.into_iter().enumerate() {
3472 grad[[i, j]] = value;
3473 }
3474 }
3475 Ok(vec![grad])
3476 })?;
3477
3478 let mean = if self.link_wiggle.is_none() {
3479 let (var_t, var_s, cov_ts) = project_two_block_linear_predictor_covariance(
3480 &input.design,
3481 design_noise,
3482 &backend,
3483 p_t,
3484 p_s,
3485 "binomial location-scale posterior mean",
3486 )?;
3487 let values: Result<Vec<_>, _> = (0..eta_t.len())
3488 .into_par_iter()
3489 .map(|i| {
3490 PREDICT_QUADRATURE_CONTEXT.with(|quadctx| {
3491 projected_bivariate_posterior_mean_result(
3492 quadctx,
3493 [eta_t[i], eta_s[i]],
3494 [
3495 [var_t[i].max(0.0), cov_ts[i]],
3496 [cov_ts[i], var_s[i].max(0.0)],
3497 ],
3498 |eta_threshold, eta_log_sigma| {
3499 let q0 = -eta_threshold * (-eta_log_sigma).exp();
3500 let jet =
3501 crate::solver::mixture_link::inverse_link_jet_for_inverse_link(
3502 &self.inverse_link,
3503 q0,
3504 )?;
3505 Ok(jet.mu.clamp(0.0, 1.0))
3506 },
3507 )
3508 })
3509 })
3510 .collect();
3511 Array1::from_vec(values?)
3512 } else {
3513 let runtime = self.link_wiggle.as_ref().expect("checked above");
3514 let betaw = Array1::from_vec(runtime.beta.clone());
3515 let mut wiggle_basis_rhs = Array2::<f64>::zeros((p_total, p_w));
3516 for j in 0..p_w {
3517 wiggle_basis_rhs[[p_t + p_s + j, j]] = 1.0;
3518 }
3519 let covww = backend
3520 .apply_columns(&wiggle_basis_rhs)
3521 .map_err(EstimationError::InvalidInput)?
3522 .slice(ndarray::s![p_t + p_s..p_total, ..])
3523 .to_owned();
3524 let mut out = Array1::<f64>::zeros(eta.len());
3525 let chunk_rows = prediction_chunk_rows(p_total, 2, eta.len());
3526 let mut start = 0usize;
3527 while start < eta.len() {
3528 let end = (start + chunk_rows).min(eta.len());
3529 let rows = start..end;
3530 let rows_in_chunk = end - start;
3531 let x_t = design_row_chunk(&input.design, rows.clone())
3532 .map_err(EstimationError::InvalidInput)?;
3533 let x_ls = design_row_chunk(design_noise, rows.clone())
3534 .map_err(EstimationError::InvalidInput)?;
3535 let mut rhs = Array2::<f64>::zeros((p_total, rows_in_chunk * 2));
3536 rhs.slice_mut(ndarray::s![0..p_t, 0..rows_in_chunk])
3537 .assign(&x_t.t());
3538 rhs.slice_mut(ndarray::s![
3539 p_t..p_t + p_s,
3540 rows_in_chunk..2 * rows_in_chunk
3541 ])
3542 .assign(&x_ls.t());
3543 let solved = backend
3544 .apply_columns(&rhs)
3545 .map_err(EstimationError::InvalidInput)?;
3546 let compute_chunk_row = |quadctx: &QuadratureContext, local_row: usize| {
3547 let i = start + local_row;
3548 let solved_t = solved.slice(ndarray::s![.., local_row]);
3549 let solved_ls = solved.slice(ndarray::s![.., rows_in_chunk + local_row]);
3550 let var_t = x_t
3551 .row(local_row)
3552 .dot(&solved_t.slice(ndarray::s![0..p_t]))
3553 .max(0.0);
3554 let var_ls = x_ls
3555 .row(local_row)
3556 .dot(&solved_ls.slice(ndarray::s![p_t..p_t + p_s]))
3557 .max(0.0);
3558 let cov_tls_t = x_t
3559 .row(local_row)
3560 .dot(&solved_ls.slice(ndarray::s![0..p_t]));
3561 let cov_tls_ls = x_ls
3562 .row(local_row)
3563 .dot(&solved_t.slice(ndarray::s![p_t..p_t + p_s]));
3564 let cov_tls = 0.5 * (cov_tls_t + cov_tls_ls);
3565 let suv_t = solved_t.slice(ndarray::s![p_t + p_s..p_total]);
3566 let suv_ls = solved_ls.slice(ndarray::s![p_t + p_s..p_total]);
3567 let det = (var_t * var_ls - cov_tls * cov_tls).max(1e-12);
3568 let inv_uu = [
3569 [var_ls / det, -cov_tls / det],
3570 [-cov_tls / det, var_t / det],
3571 ];
3572 let mut k0 = Array1::<f64>::zeros(p_w);
3573 let mut k1 = Array1::<f64>::zeros(p_w);
3574 for j in 0..p_w {
3575 k0[j] = suv_t[j] * inv_uu[0][0] + suv_ls[j] * inv_uu[1][0];
3576 k1[j] = suv_t[j] * inv_uu[0][1] + suv_ls[j] * inv_uu[1][1];
3577 }
3578 let mut covw_cond = covww.clone();
3579 for r in 0..p_w {
3580 for c in 0..p_w {
3581 covw_cond[[r, c]] -= k0[r] * suv_t[c] + k1[r] * suv_ls[c];
3582 }
3583 }
3584 crate::quadrature::normal_expectation_2d_adaptive_result(
3585 quadctx,
3586 [eta_t[i], eta_s[i]],
3587 [[var_t, cov_tls], [cov_tls, var_ls]],
3588 |t, ls| {
3589 let q0 = -t * (-ls).exp();
3590 let xw = runtime
3591 .basis_row_scalar(q0)
3592 .map_err(EstimationError::InvalidInput)?;
3593 let dt = t - eta_t[i];
3594 let dls = ls - eta_s[i];
3595 let meanw = q0 + xw.dot(&betaw) + dt * xw.dot(&k0) + dls * xw.dot(&k1);
3596 let mut varw = 0.0;
3597 for r in 0..p_w {
3598 let xr = xw[r];
3599 for c in 0..p_w {
3600 varw += xr * covw_cond[[r, c]] * xw[c];
3601 }
3602 }
3603 let jet = crate::quadrature::integrated_inverse_link_jetwith_state(
3604 quadctx,
3605 self.inverse_link.link_function(),
3606 meanw,
3607 varw.max(0.0).sqrt(),
3608 self.inverse_link.mixture_state(),
3609 self.inverse_link.sas_state(),
3610 )?;
3611 Ok::<f64, EstimationError>(jet.mean.clamp(0.0, 1.0))
3612 },
3613 )
3614 };
3615 let chunk_values: Result<Vec<f64>, EstimationError> = (0..rows_in_chunk)
3616 .into_par_iter()
3617 .map(|local_row| {
3618 PREDICT_QUADRATURE_CONTEXT
3619 .with(|quadctx| compute_chunk_row(quadctx, local_row))
3620 })
3621 .collect();
3622 for (local_row, value) in chunk_values?.into_iter().enumerate() {
3623 out[start + local_row] = value;
3624 }
3625 start = end;
3626 }
3627 out
3628 };
3629 let (mean_lower, mean_upper) = if let Some(level) = confidence_level {
3632 let z = standard_normal_quantile(0.5 + 0.5 * level)
3633 .map_err(EstimationError::InvalidInput)?;
3634 (
3635 Some((&mean - &eta_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0))),
3636 Some((&mean + &eta_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0))),
3637 )
3638 } else {
3639 (None, None)
3640 };
3641 Ok(PredictPosteriorMeanResult {
3642 eta,
3643 eta_standard_error: eta_se,
3644 mean,
3645 mean_lower,
3646 mean_upper,
3647 })
3648 }
3649
3650 fn n_blocks(&self) -> usize {
3651 if self.link_wiggle.is_some() { 3 } else { 2 }
3652 }
3653
3654 fn block_roles(&self) -> Vec<BlockRole> {
3655 if self.link_wiggle.is_some() {
3656 vec![BlockRole::Location, BlockRole::Scale, BlockRole::LinkWiggle]
3657 } else {
3658 vec![BlockRole::Location, BlockRole::Scale]
3659 }
3660 }
3661}
3662
3663const SURVIVAL_EXP_NEG_STABLE_MAX_ARG: f64 = 500.0;
3673
3674#[inline]
3675fn survival_inverse_sigma_from_eta_log_sigma(eta_log_sigma: f64) -> f64 {
3676 (-eta_log_sigma).min(SURVIVAL_EXP_NEG_STABLE_MAX_ARG).exp()
3677}
3678
3679#[inline]
3680fn survival_q0_and_inverse_sigma(eta_threshold: f64, eta_log_sigma: f64) -> (f64, f64) {
3681 let inv_sigma = survival_inverse_sigma_from_eta_log_sigma(eta_log_sigma);
3682 if eta_threshold == 0.0 {
3683 return (0.0, inv_sigma);
3684 }
3685 let log_abs = eta_threshold.abs().ln() + (-eta_log_sigma).min(SURVIVAL_EXP_NEG_STABLE_MAX_ARG);
3686 let q0 = if log_abs > SURVIVAL_EXP_NEG_STABLE_MAX_ARG {
3687 if eta_threshold > 0.0 {
3688 -f64::MAX
3689 } else {
3690 f64::MAX
3691 }
3692 } else {
3693 -eta_threshold * inv_sigma
3694 };
3695 (q0, inv_sigma)
3696}
3697
3698#[inline]
3699fn survival_tail_value_from_failure_jet(
3700 inverse_link: &InverseLink,
3701 eta: f64,
3702 failure_jet: &InverseLinkJet,
3703) -> f64 {
3704 match inverse_link {
3705 InverseLink::Standard(crate::types::LinkFunction::Probit) => {
3706 if eta.is_nan() {
3707 f64::NAN
3708 } else if eta == f64::INFINITY {
3709 0.0
3710 } else if eta == f64::NEG_INFINITY {
3711 1.0
3712 } else {
3713 0.5 * statrs::function::erf::erfc(eta / std::f64::consts::SQRT_2)
3714 }
3715 }
3716 InverseLink::Standard(crate::types::LinkFunction::Logit) => 1.0 / (1.0 + eta.exp()),
3717 InverseLink::Standard(crate::types::LinkFunction::CLogLog) => (-(eta.exp())).exp(),
3718 _ => (1.0 - failure_jet.mu).clamp(0.0, 1.0),
3719 }
3720}
3721
3722#[inline]
3723fn inverse_link_survival_tail_value_and_failure_density(
3724 inverse_link: &InverseLink,
3725 eta: f64,
3726) -> Result<(f64, f64), EstimationError> {
3727 let failure_jet =
3728 crate::solver::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta)?;
3729 Ok((
3730 survival_tail_value_from_failure_jet(inverse_link, eta, &failure_jet).clamp(0.0, 1.0),
3731 failure_jet.d1,
3732 ))
3733}
3734
3735pub struct SurvivalPredictor {
3736 pub beta_threshold: Array1<f64>,
3737 pub beta_log_sigma: Array1<f64>,
3738 pub covariance: Option<Array2<f64>>,
3739 pub inverse_link: InverseLink,
3740}
3741
3742impl SurvivalPredictor {
3743 pub(crate) fn from_unified(
3747 unified: &UnifiedFitResult,
3748 inverse_link: InverseLink,
3749 ) -> Result<Self, EstimationError> {
3750 let beta_threshold = unified
3751 .block_by_role(BlockRole::Threshold)
3752 .or_else(|| unified.block_by_role(BlockRole::Location))
3753 .or_else(|| unified.block_by_role(BlockRole::Mean))
3754 .map(|b| b.beta.clone())
3755 .ok_or_else(|| {
3756 EstimationError::InvalidInput("Survival model missing threshold block".to_string())
3757 })?;
3758 let beta_log_sigma = unified
3759 .block_by_role(BlockRole::Scale)
3760 .map(|b| b.beta.clone())
3761 .ok_or_else(|| {
3762 EstimationError::InvalidInput(
3763 "Survival model missing scale (log-sigma) block".to_string(),
3764 )
3765 })?;
3766 Ok(Self {
3767 beta_threshold,
3768 beta_log_sigma,
3769 covariance: unified.covariance_conditional.clone(),
3770 inverse_link,
3771 })
3772 }
3773
3774 fn compute_survival(
3776 &self,
3777 eta_threshold: &Array1<f64>,
3778 eta_log_sigma: &Array1<f64>,
3779 ) -> Result<Array1<f64>, EstimationError> {
3780 use rayon::iter::{IntoParallelIterator, ParallelIterator};
3781 let n = eta_threshold.len();
3782 let survival_prob: Result<Vec<f64>, EstimationError> = (0..n)
3783 .into_par_iter()
3784 .map(|i| {
3785 let (q0, _) = survival_q0_and_inverse_sigma(eta_threshold[i], eta_log_sigma[i]);
3786 let (survival, _) =
3787 inverse_link_survival_tail_value_and_failure_density(&self.inverse_link, q0)?;
3788 Ok(survival)
3789 })
3790 .collect();
3791 Ok(Array1::from_vec(survival_prob?))
3792 }
3793}
3794
3795impl PredictableModel for SurvivalPredictor {
3796 fn predict_plugin_response(
3797 &self,
3798 input: &PredictInput,
3799 ) -> Result<PredictResult, EstimationError> {
3800 let eta_threshold = input.design.dot(&self.beta_threshold) + &input.offset;
3801 let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3802 EstimationError::InvalidInput(
3803 "Survival prediction requires noise (log-sigma) design matrix".to_string(),
3804 )
3805 })?;
3806 let offset_noise = input.offset_noise.as_ref().ok_or_else(|| {
3807 EstimationError::InvalidInput(
3808 "Survival prediction requires noise (log-sigma) offset".to_string(),
3809 )
3810 })?;
3811 let eta_log_sigma = design_noise.dot(&self.beta_log_sigma) + offset_noise;
3812 let survival_prob = self.compute_survival(&eta_threshold, &eta_log_sigma)?;
3813 Ok(PredictResult {
3814 eta: eta_threshold,
3815 mean: survival_prob,
3816 })
3817 }
3818
3819 fn predict_with_uncertainty(
3820 &self,
3821 input: &PredictInput,
3822 ) -> Result<PredictionWithSE, EstimationError> {
3823 let eta_threshold = input.design.dot(&self.beta_threshold) + &input.offset;
3824 let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3825 EstimationError::InvalidInput(
3826 "Survival prediction requires noise (log-sigma) design matrix".to_string(),
3827 )
3828 })?;
3829 let offset_noise = input.offset_noise.as_ref().ok_or_else(|| {
3830 EstimationError::InvalidInput(
3831 "Survival prediction requires noise (log-sigma) offset".to_string(),
3832 )
3833 })?;
3834 let eta_log_sigma = design_noise.dot(&self.beta_log_sigma) + offset_noise;
3835 let survival_prob = self.compute_survival(&eta_threshold, &eta_log_sigma)?;
3836
3837 let (eta_se, mean_se) = if let Some(ref cov) = self.covariance {
3838 let n = eta_threshold.len();
3839 let p_t = self.beta_threshold.len();
3840 let p_s = self.beta_log_sigma.len();
3841 let backend = PredictionCovarianceBackend::from_dense(cov.view());
3842
3843 let eta_se = padded_design_standard_errors_from_backend(
3844 &input.design,
3845 &backend,
3846 0,
3847 p_s,
3848 "survival threshold uncertainty",
3849 )?;
3850
3851 let mean_se_vec = linear_predictor_se_from_backend(&backend, n, |rows| {
3853 let x_t = design_row_chunk(&input.design, rows.clone())?;
3854 let x_s = design_row_chunk(design_noise, rows.clone())?;
3855 let eta_t_chunk = eta_threshold.slice(ndarray::s![rows.clone()]).to_owned();
3856 let eta_ls_chunk = eta_log_sigma.slice(ndarray::s![rows.clone()]).to_owned();
3857 let rows_in_chunk = eta_t_chunk.len();
3858 let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_t + p_s));
3859 for i in 0..rows_in_chunk {
3860 let (q0, inv_sigma) =
3861 survival_q0_and_inverse_sigma(eta_t_chunk[i], eta_ls_chunk[i]);
3862 let (_, failure_density) =
3863 inverse_link_survival_tail_value_and_failure_density(
3864 &self.inverse_link,
3865 q0,
3866 )
3867 .map_err(|e| e.to_string())?;
3868 let dsurv_deta_t = failure_density * inv_sigma;
3869 let dsurv_deta_s = failure_density * q0;
3870 for j in 0..p_t {
3871 grad[[i, j]] = dsurv_deta_t * x_t[[i, j]];
3872 }
3873 for j in 0..p_s {
3874 grad[[i, p_t + j]] = dsurv_deta_s * x_s[[i, j]];
3875 }
3876 }
3877 Ok(vec![grad])
3878 })?;
3879 (Some(eta_se), Some(mean_se_vec))
3880 } else {
3881 (None, None)
3882 };
3883
3884 Ok(PredictionWithSE {
3885 eta: eta_threshold,
3886 mean: survival_prob,
3887 eta_se,
3888 mean_se,
3889 })
3890 }
3891
3892 fn predict_noise_scale(
3893 &self,
3894 _: &PredictInput,
3895 ) -> Result<Option<Array1<f64>>, EstimationError> {
3896 Ok(None)
3897 }
3898
3899 fn predict_full_uncertainty(
3900 &self,
3901 input: &PredictInput,
3902 _: &UnifiedFitResult,
3903 options: &PredictUncertaintyOptions,
3904 ) -> Result<PredictUncertaintyResult, EstimationError> {
3905 let pred = self.predict_with_uncertainty(input)?;
3906 let z = crate::probability::standard_normal_quantile(0.5 + options.confidence_level * 0.5)
3907 .map_err(|e| EstimationError::InvalidInput(e))?;
3908
3909 let eta_se = pred.eta_se.as_ref().ok_or_else(|| {
3910 EstimationError::InvalidInput(
3911 "Survival full uncertainty requires covariance (eta_se unavailable)".to_string(),
3912 )
3913 })?;
3914 let mean_se = pred.mean_se.as_ref().ok_or_else(|| {
3915 EstimationError::InvalidInput(
3916 "Survival full uncertainty requires covariance (mean_se unavailable)".to_string(),
3917 )
3918 })?;
3919
3920 let eta_lower = &pred.eta - &eta_se.mapv(|s| z * s);
3921 let eta_upper = &pred.eta + &eta_se.mapv(|s| z * s);
3922 let mut mean_lower = &pred.mean - &mean_se.mapv(|s| z * s);
3923 let mut mean_upper = &pred.mean + &mean_se.mapv(|s| z * s);
3924 mean_lower.mapv_inplace(|v| v.clamp(0.0, 1.0));
3926 mean_upper.mapv_inplace(|v| v.clamp(0.0, 1.0));
3927
3928 Ok(PredictUncertaintyResult {
3929 eta: pred.eta,
3930 mean: pred.mean,
3931 eta_standard_error: eta_se.clone(),
3932 mean_standard_error: mean_se.clone(),
3933 eta_lower,
3934 eta_upper,
3935 mean_lower,
3936 mean_upper,
3937 observation_lower: None,
3938 observation_upper: None,
3939 covariance_mode_requested: options.covariance_mode,
3940 covariance_corrected_used: false,
3941 })
3942 }
3943
3944 fn predict_posterior_mean(
3945 &self,
3946 input: &PredictInput,
3947 fit: &UnifiedFitResult,
3948 confidence_level: Option<f64>,
3949 ) -> Result<PredictPosteriorMeanResult, EstimationError> {
3950 let eta_threshold = input.design.dot(&self.beta_threshold) + &input.offset;
3960 let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3961 EstimationError::InvalidInput(
3962 "Survival posterior mean requires noise (log-sigma) design matrix".to_string(),
3963 )
3964 })?;
3965 let offset_noise = input.offset_noise.as_ref().ok_or_else(|| {
3966 EstimationError::InvalidInput(
3967 "Survival posterior mean requires noise (log-sigma) offset".to_string(),
3968 )
3969 })?;
3970 let eta_log_sigma = design_noise.dot(&self.beta_log_sigma) + offset_noise;
3971 let p_t = self.beta_threshold.len();
3972 let p_s = self.beta_log_sigma.len();
3973 let p_total = p_t + p_s;
3974 let backend = require_posterior_mean_backend(
3975 fit,
3976 self.covariance.as_ref(),
3977 p_total,
3978 "survival posterior mean",
3979 )?;
3980
3981 let eta_se = padded_design_standard_errors_from_backend(
3982 &input.design,
3983 &backend,
3984 0,
3985 p_s,
3986 "survival posterior mean",
3987 )?;
3988 let (var_t, var_s, cov_ts) = project_two_block_linear_predictor_covariance(
3989 &input.design,
3990 design_noise,
3991 &backend,
3992 p_t,
3993 p_s,
3994 "survival posterior mean",
3995 )?;
3996 let quadctx = crate::quadrature::QuadratureContext::new();
3997 let mean = Array1::from_vec(
3998 (0..eta_threshold.len())
3999 .map(|i| {
4000 projected_bivariate_posterior_mean_result(
4001 &quadctx,
4002 [eta_threshold[i], eta_log_sigma[i]],
4003 [
4004 [var_t[i].max(0.0), cov_ts[i]],
4005 [cov_ts[i], var_s[i].max(0.0)],
4006 ],
4007 |threshold, log_sigma| {
4008 let (q0, _) = survival_q0_and_inverse_sigma(threshold, log_sigma);
4009 let (survival, _) =
4010 inverse_link_survival_tail_value_and_failure_density(
4011 &self.inverse_link,
4012 q0,
4013 )?;
4014 Ok(survival)
4015 },
4016 )
4017 })
4018 .collect::<Result<Vec<_>, _>>()?,
4019 );
4020 let (mean_lower, mean_upper) = if let Some(level) = confidence_level {
4021 let z = crate::probability::standard_normal_quantile(0.5 + 0.5 * level).unwrap_or(1.96);
4022 let lo = (&mean - &eta_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0));
4023 let hi = (&mean + &eta_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0));
4024 (Some(lo), Some(hi))
4025 } else {
4026 (None, None)
4027 };
4028 Ok(PredictPosteriorMeanResult {
4029 eta: eta_threshold,
4030 eta_standard_error: eta_se,
4031 mean,
4032 mean_lower,
4033 mean_upper,
4034 })
4035 }
4036
4037 fn n_blocks(&self) -> usize {
4038 2
4039 }
4040
4041 fn block_roles(&self) -> Vec<BlockRole> {
4042 vec![BlockRole::Threshold, BlockRole::Scale]
4043 }
4044}
4045
4046pub struct TransformationNormalPredictor {
4052 pub covariance: Option<Array2<f64>>,
4053}
4054
4055impl PredictableModel for TransformationNormalPredictor {
4056 fn predict_plugin_response(
4057 &self,
4058 input: &PredictInput,
4059 ) -> Result<PredictResult, EstimationError> {
4060 let h = input.offset.clone();
4061 Ok(PredictResult {
4062 eta: h.clone(),
4063 mean: h,
4064 })
4065 }
4066
4067 fn predict_with_uncertainty(
4068 &self,
4069 input: &PredictInput,
4070 ) -> Result<PredictionWithSE, EstimationError> {
4071 let h = input.offset.clone();
4072 Ok(PredictionWithSE {
4073 eta: h.clone(),
4074 mean: h,
4075 eta_se: None,
4076 mean_se: None,
4077 })
4078 }
4079
4080 fn predict_noise_scale(
4081 &self,
4082 _: &PredictInput,
4083 ) -> Result<Option<Array1<f64>>, EstimationError> {
4084 Ok(None)
4085 }
4086
4087 fn predict_full_uncertainty(
4088 &self,
4089 input: &PredictInput,
4090 fit: &UnifiedFitResult,
4091 options: &PredictUncertaintyOptions,
4092 ) -> Result<PredictUncertaintyResult, EstimationError> {
4093 let h = input.offset.clone();
4094 let n = h.len();
4095 let zeros = Array1::zeros(n);
4096 Ok(PredictUncertaintyResult {
4097 eta: h.clone(),
4098 mean: h.clone(),
4099 eta_standard_error: zeros.clone(),
4100 mean_standard_error: zeros,
4101 eta_lower: h.clone(),
4102 eta_upper: h.clone(),
4103 mean_lower: h.clone(),
4104 mean_upper: h,
4105 observation_lower: None,
4106 observation_upper: None,
4107 covariance_mode_requested: options.covariance_mode,
4108 covariance_corrected_used: fit.covariance_corrected.is_some(),
4109 })
4110 }
4111
4112 fn predict_posterior_mean(
4113 &self,
4114 input: &PredictInput,
4115 fit: &UnifiedFitResult,
4116 confidence_level: Option<f64>,
4117 ) -> Result<PredictPosteriorMeanResult, EstimationError> {
4118 let h = input.offset.clone();
4119 let n = h.len();
4120 let has_fit_covariance =
4121 fit.covariance_corrected.is_some() || fit.covariance_conditional.is_some();
4122 let (mean_lower, mean_upper) = if confidence_level.is_some() && has_fit_covariance {
4123 (Some(h.clone()), Some(h.clone()))
4124 } else {
4125 (None, None)
4126 };
4127 Ok(PredictPosteriorMeanResult {
4128 eta: h.clone(),
4129 eta_standard_error: Array1::zeros(n),
4130 mean: h,
4131 mean_lower,
4132 mean_upper,
4133 })
4134 }
4135
4136 fn n_blocks(&self) -> usize {
4137 1
4138 }
4139 fn block_roles(&self) -> Vec<BlockRole> {
4140 vec![BlockRole::Mean]
4141 }
4142}
4143
4144fn eta_standard_errors_from_backend(
4146 x: &DesignMatrix,
4147 backend: &PredictionCovarianceBackend<'_>,
4148) -> Result<Array1<f64>, EstimationError> {
4149 let vars = linear_predictorvariance_from_backend(x, backend)?;
4150 Ok(vars.mapv(|v| v.max(0.0).sqrt()))
4151}
4152
4153fn delta_method_mean_se(
4155 eta: &Array1<f64>,
4156 eta_se: &Array1<f64>,
4157 strategy: &(dyn FamilyStrategy + Sync),
4158) -> Result<Array1<f64>, EstimationError> {
4159 use rayon::iter::{IntoParallelIterator, ParallelIterator};
4160 let n = eta.len();
4161 let values: Result<Vec<f64>, EstimationError> = (0..n)
4162 .into_par_iter()
4163 .map(|i| {
4164 let jet = strategy.inverse_link_jet(eta[i])?;
4165 Ok((jet.d1 * eta_se[i]).abs())
4166 })
4167 .collect();
4168 Ok(Array1::from_vec(values?))
4169}
4170
4171pub struct PredictPosteriorMeanResult {
4172 pub eta: Array1<f64>,
4173 pub eta_standard_error: Array1<f64>,
4174 pub mean: Array1<f64>,
4175 pub mean_lower: Option<Array1<f64>>,
4178 pub mean_upper: Option<Array1<f64>>,
4181}
4182
4183pub fn enrich_posterior_mean_bounds(
4193 result: &mut PredictPosteriorMeanResult,
4194 confidence_level: f64,
4195 family: crate::types::LikelihoodFamily,
4196 link_kind: Option<&InverseLink>,
4197) -> Result<(), EstimationError> {
4198 if !(confidence_level.is_finite() && confidence_level > 0.0 && confidence_level < 1.0) {
4199 return Err(EstimationError::InvalidInput(format!(
4200 "confidence_level must be in (0,1), got {confidence_level}"
4201 )));
4202 }
4203 let z = crate::probability::standard_normal_quantile(0.5 + 0.5 * confidence_level)
4204 .map_err(EstimationError::InvalidInput)?;
4205
4206 let eta_lower = &result.eta - &result.eta_standard_error.mapv(|s| z * s);
4207 let eta_upper = &result.eta + &result.eta_standard_error.mapv(|s| z * s);
4208
4209 let transformed_lower = apply_family_inverse_link(&eta_lower, family, link_kind)?;
4210 let transformed_upper = apply_family_inverse_link(&eta_upper, family, link_kind)?;
4211
4212 let mut mean_lower = Array1::from_iter(
4214 transformed_lower
4215 .iter()
4216 .zip(transformed_upper.iter())
4217 .map(|(&lo, &hi)| lo.min(hi)),
4218 );
4219 let mut mean_upper = Array1::from_iter(
4220 transformed_lower
4221 .iter()
4222 .zip(transformed_upper.iter())
4223 .map(|(&lo, &hi)| lo.max(hi)),
4224 );
4225
4226 if matches!(
4228 family,
4229 crate::types::LikelihoodFamily::BinomialLogit
4230 | crate::types::LikelihoodFamily::BinomialProbit
4231 | crate::types::LikelihoodFamily::BinomialCLogLog
4232 | crate::types::LikelihoodFamily::BinomialSas
4233 | crate::types::LikelihoodFamily::BinomialBetaLogistic
4234 | crate::types::LikelihoodFamily::BinomialMixture
4235 | crate::types::LikelihoodFamily::RoystonParmar
4236 ) {
4237 mean_lower.mapv_inplace(|v| v.clamp(0.0, 1.0));
4238 mean_upper.mapv_inplace(|v| v.clamp(0.0, 1.0));
4239 }
4240
4241 result.mean_lower = Some(mean_lower);
4242 result.mean_upper = Some(mean_upper);
4243 Ok(())
4244}
4245
4246#[derive(Clone, Copy, Debug, Eq, PartialEq)]
4247pub enum InferenceCovarianceMode {
4248 Conditional,
4251 ConditionalPlusSmoothingPreferred,
4255 ConditionalPlusSmoothingRequired,
4257}
4258
4259#[derive(Clone, Debug)]
4265pub struct TrainingSupport {
4266 pub axis_min: Array1<f64>,
4272 pub axis_max: Array1<f64>,
4274}
4275
4276impl TrainingSupport {
4277 pub fn from_training_rows(rows: ArrayView2<'_, f64>) -> Self {
4280 let d = rows.ncols();
4281 if rows.nrows() == 0 || d == 0 {
4282 return Self {
4283 axis_min: Array1::zeros(0),
4284 axis_max: Array1::zeros(0),
4285 };
4286 }
4287 let mut axis_min = Array1::from_elem(d, f64::INFINITY);
4288 let mut axis_max = Array1::from_elem(d, f64::NEG_INFINITY);
4289 for row in rows.outer_iter() {
4290 for k in 0..d {
4291 let v = row[k];
4292 if v < axis_min[k] {
4293 axis_min[k] = v;
4294 }
4295 if v > axis_max[k] {
4296 axis_max[k] = v;
4297 }
4298 }
4299 }
4300 Self { axis_min, axis_max }
4301 }
4302}
4303
4304pub struct PredictUncertaintyOptions {
4305 pub confidence_level: f64,
4307 pub covariance_mode: InferenceCovarianceMode,
4309 pub mean_interval_method: MeanIntervalMethod,
4311 pub includeobservation_interval: bool,
4314 pub apply_bias_correction: bool,
4321 pub edgeworth_one_sided: bool,
4330 pub boundary_correction: bool,
4337 pub ood_inflation: bool,
4342 pub multi_point_joint: bool,
4350 pub predictor_x_for_corrections: Option<Array2<f64>>,
4355 pub training_support: Option<TrainingSupport>,
4357 pub eta_skewness_for_corrections: Option<Array1<f64>>,
4361 pub joint_query_count: Option<usize>,
4364 pub boundary_alpha: f64,
4367 pub boundary_band_fraction: f64,
4371 pub ood_gamma: f64,
4374}
4375
4376impl Default for PredictUncertaintyOptions {
4377 fn default() -> Self {
4378 Self {
4379 confidence_level: 0.95,
4380 covariance_mode: InferenceCovarianceMode::ConditionalPlusSmoothingPreferred,
4381 mean_interval_method: MeanIntervalMethod::TransformEta,
4382 includeobservation_interval: true,
4383 apply_bias_correction: true,
4384 edgeworth_one_sided: true,
4385 boundary_correction: true,
4386 ood_inflation: false,
4387 multi_point_joint: false,
4388 predictor_x_for_corrections: None,
4389 training_support: None,
4390 eta_skewness_for_corrections: None,
4391 joint_query_count: None,
4392 boundary_alpha: 0.25,
4393 boundary_band_fraction: 0.05,
4394 ood_gamma: 1.0,
4395 }
4396 }
4397}
4398
4399#[derive(Clone, Copy, Debug)]
4403pub(crate) struct EdgeworthZ {
4404 pub z_lower: f64,
4405 pub z_upper: f64,
4406}
4407
4408pub(crate) fn edgeworth_one_sided_quantile(z: f64, skew_kappa3: f64) -> EdgeworthZ {
4420 let bump = (z * z - 1.0) * skew_kappa3 / 6.0;
4426 EdgeworthZ {
4427 z_lower: (z - bump).max(0.0),
4428 z_upper: (z + bump).max(0.0),
4429 }
4430}
4431
4432pub(crate) fn boundary_variance_inflation_factor(
4437 x_row: ArrayView1<'_, f64>,
4438 axis_min: ArrayView1<'_, f64>,
4439 axis_max: ArrayView1<'_, f64>,
4440 alpha: f64,
4441 band_fraction: f64,
4442) -> f64 {
4443 let d = x_row.len();
4444 if d == 0 || axis_min.len() != d || axis_max.len() != d || band_fraction <= 0.0 {
4445 return 1.0;
4446 }
4447 let mut excess = 0.0_f64;
4448 for k in 0..d {
4449 let lo = axis_min[k];
4450 let hi = axis_max[k];
4451 let range = hi - lo;
4452 if !(range > 0.0) {
4453 continue;
4454 }
4455 let x = x_row[k];
4456 let d_edge = (x - lo).min(hi - x);
4458 if !d_edge.is_finite() || d_edge >= band_fraction * range {
4459 continue;
4460 }
4461 if d_edge <= 0.0 {
4464 excess += 1.0;
4466 } else {
4467 let shortfall = 1.0 - d_edge / (band_fraction * range);
4468 excess += shortfall * shortfall;
4469 }
4470 }
4471 (1.0 + alpha * excess).max(1.0)
4472}
4473
4474pub(crate) fn ood_variance_inflation_factor(
4479 x_row: ArrayView1<'_, f64>,
4480 axis_min: ArrayView1<'_, f64>,
4481 axis_max: ArrayView1<'_, f64>,
4482 gamma: f64,
4483) -> f64 {
4484 let d = x_row.len();
4485 if d == 0 || axis_min.len() != d || axis_max.len() != d {
4486 return 1.0;
4487 }
4488 let mut sq_excess = 0.0_f64;
4489 for k in 0..d {
4490 let lo = axis_min[k];
4491 let hi = axis_max[k];
4492 let range = hi - lo;
4493 if !(range > 0.0) {
4494 continue;
4495 }
4496 let x = x_row[k];
4497 let excess = if x < lo {
4498 lo - x
4499 } else if x > hi {
4500 x - hi
4501 } else {
4502 0.0
4503 };
4504 let frac = excess / range;
4505 sq_excess += frac * frac;
4506 }
4507 (1.0 + gamma * sq_excess).max(1.0)
4508}
4509
4510pub(crate) fn multi_point_joint_z(level: f64, m: usize) -> Result<f64, String> {
4517 if m <= 1 || !(level.is_finite() && level > 0.0 && level < 1.0) {
4518 return standard_normal_quantile(0.5 + 0.5 * level);
4519 }
4520 let alpha = 1.0 - level;
4521 let per_row_alpha = alpha / (m as f64);
4522 let per_row_level = 1.0 - per_row_alpha;
4523 standard_normal_quantile(0.5 + 0.5 * per_row_level)
4524}
4525
4526#[derive(Clone, Copy, Debug, Eq, PartialEq)]
4527pub enum MeanIntervalMethod {
4528 Delta,
4530 TransformEta,
4533}
4534
4535pub struct PredictUncertaintyResult {
4536 pub eta: Array1<f64>,
4537 pub mean: Array1<f64>,
4538 pub eta_standard_error: Array1<f64>,
4539 pub mean_standard_error: Array1<f64>,
4540 pub eta_lower: Array1<f64>,
4541 pub eta_upper: Array1<f64>,
4542 pub mean_lower: Array1<f64>,
4543 pub mean_upper: Array1<f64>,
4544 pub observation_lower: Option<Array1<f64>>,
4546 pub observation_upper: Option<Array1<f64>>,
4547 pub covariance_mode_requested: InferenceCovarianceMode,
4549 pub covariance_corrected_used: bool,
4551}
4552
4553fn predict_gam_posterior_mean_from_backend(
4554 x: DesignMatrix,
4555 beta: ArrayView1<'_, f64>,
4556 offset: ArrayView1<'_, f64>,
4557 backend: &PredictionCovarianceBackend<'_>,
4558 strategy: &(dyn FamilyStrategy + Sync),
4559 label: &str,
4560) -> Result<PredictPosteriorMeanResult, EstimationError> {
4561 predict_gam_posterior_mean_from_backendwith_bc(x, beta, offset, backend, strategy, label, None)
4562}
4563
4564fn predict_gam_posterior_mean_from_backendwith_bc(
4565 x: DesignMatrix,
4566 beta: ArrayView1<'_, f64>,
4567 offset: ArrayView1<'_, f64>,
4568 backend: &PredictionCovarianceBackend<'_>,
4569 strategy: &(dyn FamilyStrategy + Sync),
4570 label: &str,
4571 bias_correction_beta: Option<ArrayView1<'_, f64>>,
4572) -> Result<PredictPosteriorMeanResult, EstimationError> {
4573 if x.ncols() != beta.len() {
4574 return Err(EstimationError::InvalidInput(format!(
4575 "{label} dimension mismatch: X has {} columns but beta has length {}",
4576 x.ncols(),
4577 beta.len()
4578 )));
4579 }
4580 if x.nrows() != offset.len() {
4581 return Err(EstimationError::InvalidInput(format!(
4582 "{label} dimension mismatch: X has {} rows but offset has length {}",
4583 x.nrows(),
4584 offset.len()
4585 )));
4586 }
4587 if backend.nrows() != beta.len() {
4588 return Err(EstimationError::InvalidInput(format!(
4589 "{label} covariance/backend dimension mismatch: expected parameter dimension {}, got {}",
4590 beta.len(),
4591 backend.nrows()
4592 )));
4593 }
4594
4595 let mut eta = x.matrixvectormultiply(&beta.to_owned());
4596 eta += &offset;
4597 if let Some(bc) = bias_correction_beta {
4598 if bc.len() != beta.len() {
4599 return Err(EstimationError::InvalidInput(format!(
4600 "{label} bias-correction dimension mismatch: beta has length {} but bias_correction_beta has length {}",
4601 beta.len(),
4602 bc.len()
4603 )));
4604 }
4605 let bc_owned = bc.to_owned();
4606 let delta = x.matrixvectormultiply(&bc_owned);
4607 eta += δ
4608 }
4609 let etavar = linear_predictorvariance_from_backend(&x, backend)?;
4610 let eta_standard_error = etavar.mapv(|v| v.max(0.0).sqrt());
4611 let quadctx = crate::quadrature::QuadratureContext::new();
4612 let means: Result<Vec<f64>, EstimationError> = (0..eta.len())
4613 .into_par_iter()
4614 .map(|i| strategy.posterior_mean(&quadctx, eta[i], eta_standard_error[i]))
4615 .collect();
4616
4617 Ok(PredictPosteriorMeanResult {
4618 eta,
4619 eta_standard_error,
4620 mean: Array1::from_vec(means?),
4621 mean_lower: None,
4622 mean_upper: None,
4623 })
4624}
4625
4626pub struct CoefficientUncertaintyResult {
4627 pub estimate: Array1<f64>,
4628 pub standard_error: Array1<f64>,
4629 pub lower: Array1<f64>,
4630 pub upper: Array1<f64>,
4631 pub corrected: bool,
4632 pub covariance_mode_requested: InferenceCovarianceMode,
4633}
4634
4635pub fn predict_gam<X>(
4642 x: X,
4643 beta: ArrayView1<'_, f64>,
4644 offset: ArrayView1<'_, f64>,
4645 family: crate::types::LikelihoodFamily,
4646) -> Result<PredictResult, EstimationError>
4647where
4648 X: Into<DesignMatrix>,
4649{
4650 let x = x.into();
4651 if let Some(message) =
4652 predict_gam_dimension_mismatch_message(x.nrows(), x.ncols(), beta.len(), offset.len())
4653 {
4654 return Err(EstimationError::InvalidInput(message));
4655 }
4656
4657 let mut eta = x.matrixvectormultiply(&beta.to_owned());
4658 eta += &offset;
4659
4660 let mean = apply_family_inverse_link(&eta, family, None)?;
4661
4662 Ok(PredictResult { eta, mean })
4663}
4664
4665pub fn predict_gam_posterior_mean<X>(
4670 x: X,
4671 beta: ArrayView1<'_, f64>,
4672 offset: ArrayView1<'_, f64>,
4673 family: crate::types::LikelihoodFamily,
4674 covariance: ArrayView2<'_, f64>,
4675) -> Result<PredictPosteriorMeanResult, EstimationError>
4676where
4677 X: Into<DesignMatrix>,
4678{
4679 let x = x.into();
4680 let backend = PredictionCovarianceBackend::from_dense(covariance.view());
4681 let strategy = strategy_for_family(family, None);
4682 predict_gam_posterior_mean_from_backend(
4683 x,
4684 beta,
4685 offset,
4686 &backend,
4687 &strategy,
4688 "predict_gam_posterior_mean",
4689 )
4690}
4691
4692pub fn predict_gam_posterior_meanwith_backend<X>(
4693 x: X,
4694 beta: ArrayView1<'_, f64>,
4695 offset: ArrayView1<'_, f64>,
4696 family: crate::types::LikelihoodFamily,
4697 backend: &PredictionCovarianceBackend<'_>,
4698) -> Result<PredictPosteriorMeanResult, EstimationError>
4699where
4700 X: Into<DesignMatrix>,
4701{
4702 let x = x.into();
4703 let strategy = strategy_for_family(family, None);
4704 predict_gam_posterior_mean_from_backend(
4705 x,
4706 beta,
4707 offset,
4708 backend,
4709 &strategy,
4710 "predict_gam_posterior_meanwith_backend",
4711 )
4712}
4713
4714pub fn predict_gam_posterior_meanwith_fit<X>(
4719 x: X,
4720 beta: ArrayView1<'_, f64>,
4721 offset: ArrayView1<'_, f64>,
4722 family: crate::types::LikelihoodFamily,
4723 covariance: ArrayView2<'_, f64>,
4724 fit: &UnifiedFitResult,
4725) -> Result<PredictPosteriorMeanResult, EstimationError>
4726where
4727 X: Into<DesignMatrix>,
4728{
4729 let x = x.into();
4730 let backend = PredictionCovarianceBackend::from_dense(covariance.view());
4731 let strategy = strategy_from_fit(family, fit)?;
4732 predict_gam_posterior_mean_from_backend(
4733 x,
4734 beta,
4735 offset,
4736 &backend,
4737 &strategy,
4738 "predict_gam_posterior_meanwith_fit",
4739 )
4740}
4741
4742pub fn predict_gamwith_uncertainty<X>(
4785 x: X,
4786 beta: ArrayView1<'_, f64>,
4787 offset: ArrayView1<'_, f64>,
4788 family: crate::types::LikelihoodFamily,
4789 fit: &UnifiedFitResult,
4790 options: &PredictUncertaintyOptions,
4791) -> Result<PredictUncertaintyResult, EstimationError>
4792where
4793 X: Into<DesignMatrix>,
4794{
4795 let x = x.into();
4796 if x.ncols() != beta.len() {
4797 return Err(EstimationError::InvalidInput(format!(
4798 "predict_gamwith_uncertainty dimension mismatch: X has {} columns but beta has length {}",
4799 x.ncols(),
4800 beta.len()
4801 )));
4802 }
4803 if x.nrows() != offset.len() {
4804 return Err(EstimationError::InvalidInput(format!(
4805 "predict_gamwith_uncertainty dimension mismatch: X has {} rows but offset has length {}",
4806 x.nrows(),
4807 offset.len()
4808 )));
4809 }
4810 if !(options.confidence_level.is_finite()
4811 && options.confidence_level > 0.0
4812 && options.confidence_level < 1.0)
4813 {
4814 return Err(EstimationError::InvalidInput(format!(
4815 "confidence_level must be in (0,1), got {}",
4816 options.confidence_level
4817 )));
4818 }
4819
4820 let requested_mode = options.covariance_mode;
4821 let (backend, covariance_corrected_used) = selected_uncertainty_backend(
4822 fit,
4823 beta.len(),
4824 requested_mode,
4825 "predict_gamwith_uncertainty",
4826 )?;
4827
4828 let mut eta = x.matrixvectormultiply(&beta.to_owned());
4829 eta += &offset;
4830 if options.apply_bias_correction
4831 && let Some(bc) = fit.bias_correction_beta()
4832 {
4833 if bc.len() == beta.len() {
4834 let delta = x.matrixvectormultiply(&bc.clone());
4835 eta += δ
4836 } else {
4837 log::warn!(
4838 "predict_gamwith_uncertainty: bias-correction dimension mismatch \
4839 (beta {}, bc {}); skipping bias correction",
4840 beta.len(),
4841 bc.len()
4842 );
4843 }
4844 }
4845 let fitted_link_state = fit.fitted_link_state(family).ok();
4846 let mixture_state = match fitted_link_state.as_ref() {
4847 Some(FittedLinkState::Mixture { state, .. }) => Some(state.clone()),
4848 _ => None,
4849 };
4850 let sas_state = match fitted_link_state.as_ref() {
4851 Some(FittedLinkState::Sas { state, .. })
4852 | Some(FittedLinkState::BetaLogistic { state, .. }) => Some(*state),
4853 _ => None,
4854 };
4855 let link_kind = match fitted_link_state.as_ref() {
4856 Some(FittedLinkState::Standard(Some(link))) => Some(InverseLink::Standard(*link)),
4857 Some(FittedLinkState::LatentCLogLog { state }) => Some(InverseLink::LatentCLogLog(*state)),
4858 Some(FittedLinkState::Sas { state, .. }) => Some(InverseLink::Sas(*state)),
4859 Some(FittedLinkState::BetaLogistic { state, .. }) => {
4860 Some(InverseLink::BetaLogistic(*state))
4861 }
4862 Some(FittedLinkState::Mixture { state, .. }) => Some(InverseLink::Mixture(state.clone())),
4863 Some(FittedLinkState::Standard(None)) | None => None,
4864 };
4865 let strategy = strategy_for_family(family, link_kind.as_ref());
4866 let mean = apply_family_inverse_link(&eta, family, link_kind.as_ref())?;
4867
4868 let etavar_raw = linear_predictorvariance_from_backend(&x, &backend)?;
4869 let n_rows = etavar_raw.len();
4870
4871 let mut variance_inflation = Array1::<f64>::ones(n_rows);
4876 if (options.boundary_correction || options.ood_inflation)
4877 && let (Some(predictor_x), Some(support)) = (
4878 options.predictor_x_for_corrections.as_ref(),
4879 options.training_support.as_ref(),
4880 )
4881 && predictor_x.nrows() == n_rows
4882 && predictor_x.ncols() == support.axis_min.len()
4883 && support.axis_min.len() == support.axis_max.len()
4884 {
4885 for i in 0..n_rows {
4886 let row = predictor_x.row(i);
4887 let mut factor = 1.0_f64;
4888 if options.boundary_correction {
4889 factor *= boundary_variance_inflation_factor(
4890 row,
4891 support.axis_min.view(),
4892 support.axis_max.view(),
4893 options.boundary_alpha,
4894 options.boundary_band_fraction,
4895 );
4896 }
4897 if options.ood_inflation {
4898 factor *= ood_variance_inflation_factor(
4899 row,
4900 support.axis_min.view(),
4901 support.axis_max.view(),
4902 options.ood_gamma,
4903 );
4904 }
4905 variance_inflation[i] = factor;
4906 }
4907 }
4908 let etavar = if variance_inflation.iter().all(|&f| f == 1.0) {
4909 etavar_raw.clone()
4910 } else {
4911 Array1::from_iter(
4912 etavar_raw
4913 .iter()
4914 .zip(variance_inflation.iter())
4915 .map(|(&v, &f)| v * f),
4916 )
4917 };
4918 let eta_standard_error = etavar.mapv(|v| v.max(0.0).sqrt());
4919
4920 let level = options.confidence_level;
4923 let z_central = if options.multi_point_joint {
4924 let m = options.joint_query_count.unwrap_or(n_rows).max(1);
4925 multi_point_joint_z(level, m).map_err(EstimationError::InvalidInput)?
4926 } else {
4927 standard_normal_quantile(0.5 + 0.5 * level).map_err(EstimationError::InvalidInput)?
4928 };
4929 let mut z_lower_per_row = Array1::<f64>::from_elem(n_rows, z_central);
4930 let mut z_upper_per_row = Array1::<f64>::from_elem(n_rows, z_central);
4931 if options.edgeworth_one_sided
4932 && let Some(skew) = options.eta_skewness_for_corrections.as_ref()
4933 && skew.len() == n_rows
4934 {
4935 for i in 0..n_rows {
4936 let adj = edgeworth_one_sided_quantile(z_central, skew[i]);
4937 z_lower_per_row[i] = adj.z_lower;
4938 z_upper_per_row[i] = adj.z_upper;
4939 }
4940 }
4941 let eta_lower = Array1::from_iter(
4942 eta.iter()
4943 .zip(eta_standard_error.iter())
4944 .zip(z_lower_per_row.iter())
4945 .map(|((&e, &s), &zl)| e - zl * s),
4946 );
4947 let eta_upper = Array1::from_iter(
4948 eta.iter()
4949 .zip(eta_standard_error.iter())
4950 .zip(z_upper_per_row.iter())
4951 .map(|((&e, &s), &zu)| e + zu * s),
4952 );
4953 let quadctx = crate::quadrature::QuadratureContext::new();
4954
4955 let mean_standard_error = Array1::from_vec(
4972 (0..eta.len())
4973 .into_par_iter()
4974 .map(|i| -> Result<f64, EstimationError> {
4975 let se_i = etavar[i].max(0.0).sqrt();
4976 let (_, mut meanvar) = strategy.posterior_meanvariance(&quadctx, eta[i], se_i)?;
4977 if matches!(family, crate::types::LikelihoodFamily::BinomialSas)
4978 && let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
4979 FittedLinkState::Sas { covariance, .. } => covariance.as_ref(),
4980 _ => None,
4981 })
4982 {
4983 let sas = sas_state.ok_or_else(|| {
4984 EstimationError::InvalidInput(
4985 "BinomialSas uncertainty requires fitted sas_epsilon/sas_log_delta"
4986 .to_string(),
4987 )
4988 })?;
4989 let jets =
4990 sas_inverse_link_jetwith_param_partials(eta[i], sas.epsilon, sas.log_delta);
4991 let g = [jets.djet_depsilon.mu, jets.djet_dlog_delta.mu];
4992 meanvar += quadratic_form(cov_theta, &g)?;
4993 }
4994 if matches!(family, crate::types::LikelihoodFamily::BinomialBetaLogistic)
4995 && let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
4996 FittedLinkState::BetaLogistic { covariance, .. } => covariance.as_ref(),
4997 _ => None,
4998 })
4999 {
5000 let sas = sas_state.ok_or_else(|| {
5001 EstimationError::InvalidInput(
5002 "BinomialBetaLogistic uncertainty requires fitted parameters"
5003 .to_string(),
5004 )
5005 })?;
5006 let jets = beta_logistic_inverse_link_jetwith_param_partials(
5007 eta[i],
5008 sas.log_delta,
5009 sas.epsilon,
5010 );
5011 let g = [jets.djet_depsilon.mu, jets.djet_dlog_delta.mu];
5012 meanvar += quadratic_form(cov_theta, &g)?;
5013 }
5014 if matches!(family, crate::types::LikelihoodFamily::BinomialMixture)
5015 && let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
5016 FittedLinkState::Mixture { covariance, .. } => covariance.as_ref(),
5017 _ => None,
5018 })
5019 && let Some(state) = mixture_state.as_ref()
5020 {
5021 let mut mix_partials = vec![
5022 InverseLinkJet {
5023 mu: 0.0,
5024 d1: 0.0,
5025 d2: 0.0,
5026 d3: 0.0,
5027 };
5028 state.rho.len()
5029 ];
5030 mixture_inverse_link_jetwith_rho_partials_into(
5031 state,
5032 eta[i],
5033 &mut mix_partials,
5034 );
5035 meanvar += quadratic_form_from_jetmu(cov_theta, &mix_partials)?;
5036 }
5037 Ok(meanvar.max(0.0).sqrt())
5038 })
5039 .collect::<Result<Vec<_>, _>>()?,
5040 );
5041
5042 let (mut mean_lower, mut mean_upper) = match options.mean_interval_method {
5043 MeanIntervalMethod::Delta => (
5044 Array1::from_iter(
5045 mean.iter()
5046 .zip(mean_standard_error.iter())
5047 .zip(z_lower_per_row.iter())
5048 .map(|((&m, &s), &zl)| m - zl * s),
5049 ),
5050 Array1::from_iter(
5051 mean.iter()
5052 .zip(mean_standard_error.iter())
5053 .zip(z_upper_per_row.iter())
5054 .map(|((&m, &s), &zu)| m + zu * s),
5055 ),
5056 ),
5057 MeanIntervalMethod::TransformEta => {
5058 let transformed_lower =
5059 apply_family_inverse_link(&eta_lower, family, link_kind.as_ref())?;
5060 let transformed_upper =
5061 apply_family_inverse_link(&eta_upper, family, link_kind.as_ref())?;
5062 (
5063 Array1::from_iter(
5064 transformed_lower
5065 .iter()
5066 .zip(transformed_upper.iter())
5067 .map(|(&lo, &hi)| lo.min(hi)),
5068 ),
5069 Array1::from_iter(
5070 transformed_lower
5071 .iter()
5072 .zip(transformed_upper.iter())
5073 .map(|(&lo, &hi)| lo.max(hi)),
5074 ),
5075 )
5076 }
5077 };
5078
5079 if matches!(
5080 family,
5081 crate::types::LikelihoodFamily::BinomialLogit
5082 | crate::types::LikelihoodFamily::BinomialProbit
5083 | crate::types::LikelihoodFamily::BinomialCLogLog
5084 | crate::types::LikelihoodFamily::BinomialSas
5085 | crate::types::LikelihoodFamily::BinomialBetaLogistic
5086 | crate::types::LikelihoodFamily::BinomialMixture
5087 | crate::types::LikelihoodFamily::RoystonParmar
5088 ) {
5089 mean_lower.mapv_inplace(|v| v.clamp(0.0, 1.0));
5090 mean_upper.mapv_inplace(|v| v.clamp(0.0, 1.0));
5091 }
5092
5093 let (observation_lower, observation_upper) = if options.includeobservation_interval
5094 && matches!(family, crate::types::LikelihoodFamily::GaussianIdentity)
5095 {
5096 let obsvar = fit.standard_deviation.max(0.0).powi(2);
5097 let obs_se = etavar.mapv(|v| (v + obsvar).max(0.0).sqrt());
5098 let lower = Array1::from_iter(
5099 eta.iter()
5100 .zip(obs_se.iter())
5101 .zip(z_lower_per_row.iter())
5102 .map(|((&e, &s), &zl)| e - zl * s),
5103 );
5104 let upper = Array1::from_iter(
5105 eta.iter()
5106 .zip(obs_se.iter())
5107 .zip(z_upper_per_row.iter())
5108 .map(|((&e, &s), &zu)| e + zu * s),
5109 );
5110 (Some(lower), Some(upper))
5111 } else {
5112 (None, None)
5113 };
5114
5115 Ok(PredictUncertaintyResult {
5116 eta,
5117 mean,
5118 eta_standard_error,
5119 mean_standard_error,
5120 eta_lower,
5121 eta_upper,
5122 mean_lower,
5123 mean_upper,
5124 observation_lower,
5125 observation_upper,
5126 covariance_mode_requested: requested_mode,
5127 covariance_corrected_used,
5128 })
5129}
5130
5131pub fn coefficient_uncertainty(
5133 fit: &UnifiedFitResult,
5134 confidence_level: f64,
5135 covariance_mode: InferenceCovarianceMode,
5136) -> Result<CoefficientUncertaintyResult, EstimationError> {
5137 coefficient_uncertaintywith_mode(fit, confidence_level, covariance_mode)
5138}
5139
5140pub fn coefficient_uncertaintywith_mode(
5142 fit: &UnifiedFitResult,
5143 confidence_level: f64,
5144 covariance_mode: InferenceCovarianceMode,
5145) -> Result<CoefficientUncertaintyResult, EstimationError> {
5146 if !(confidence_level.is_finite() && confidence_level > 0.0 && confidence_level < 1.0) {
5147 return Err(EstimationError::InvalidInput(format!(
5148 "confidence_level must be in (0,1), got {}",
5149 confidence_level
5150 )));
5151 }
5152 let (se, corrected) = match covariance_mode {
5156 InferenceCovarianceMode::Conditional => (
5157 fit.beta_standard_errors().cloned().ok_or_else(|| {
5158 EstimationError::InvalidInput(
5159 "fit result does not contain conditional coefficient standard errors"
5160 .to_string(),
5161 )
5162 })?,
5163 false,
5164 ),
5165 InferenceCovarianceMode::ConditionalPlusSmoothingPreferred => {
5166 if let Some(se_corr) = fit.beta_standard_errors_corrected() {
5167 (se_corr.clone(), true)
5168 } else if let Some(se_base) = fit.beta_standard_errors() {
5169 (se_base.clone(), false)
5170 } else {
5171 return Err(EstimationError::InvalidInput(
5172 "fit result does not contain coefficient standard errors".to_string(),
5173 ));
5174 }
5175 }
5176 InferenceCovarianceMode::ConditionalPlusSmoothingRequired => (
5177 fit.beta_standard_errors_corrected()
5178 .cloned()
5179 .ok_or_else(|| {
5180 EstimationError::InvalidInput(
5181 "fit result does not contain smoothing-corrected coefficient standard errors"
5182 .to_string(),
5183 )
5184 })?,
5185 true,
5186 ),
5187 };
5188
5189 if se.len() != fit.beta.len() {
5190 return Err(EstimationError::InvalidInput(format!(
5191 "standard error length mismatch: beta has {}, se has {}",
5192 fit.beta.len(),
5193 se.len()
5194 )));
5195 }
5196
5197 let z = standard_normal_quantile(0.5 + 0.5 * confidence_level)
5198 .map_err(EstimationError::InvalidInput)?;
5199 let lower = &fit.beta - &se.mapv(|s| z * s);
5200 let upper = &fit.beta + &se.mapv(|s| z * s);
5201 Ok(CoefficientUncertaintyResult {
5202 estimate: fit.beta.clone(),
5203 standard_error: se,
5204 lower,
5205 upper,
5206 corrected,
5207 covariance_mode_requested: covariance_mode,
5208 })
5209}
5210
5211#[cfg(test)]
5212mod tests {
5213 use super::*;
5214 use crate::estimate::{
5215 BlockRole, FitArtifacts, FittedBlock, FittedLinkState, UnifiedFitResult,
5216 UnifiedFitResultParts,
5217 };
5218 use crate::inference::model::SavedAnchoredDeviationRuntime;
5219 use crate::pirls::PirlsStatus;
5220 use crate::types::LinkFunction;
5221 use ndarray::{Array1, Array2, array};
5222
5223 fn saved_runtime_from_deviation_runtime(
5224 runtime: &crate::families::bernoulli_marginal_slope::DeviationRuntime,
5225 ) -> SavedAnchoredDeviationRuntime {
5226 SavedAnchoredDeviationRuntime {
5227 kernel:
5228 crate::families::bernoulli_marginal_slope::exact_kernel::ANCHORED_DEVIATION_KERNEL
5229 .to_string(),
5230 breakpoints: runtime.breakpoints().to_vec(),
5231 basis_dim: runtime.basis_dim(),
5232 span_c0: runtime
5233 .span_c0()
5234 .outer_iter()
5235 .map(|row| row.to_vec())
5236 .collect(),
5237 span_c1: runtime
5238 .span_c1()
5239 .outer_iter()
5240 .map(|row| row.to_vec())
5241 .collect(),
5242 span_c2: runtime
5243 .span_c2()
5244 .outer_iter()
5245 .map(|row| row.to_vec())
5246 .collect(),
5247 span_c3: runtime
5248 .span_c3()
5249 .outer_iter()
5250 .map(|row| row.to_vec())
5251 .collect(),
5252 anchor_residual_coefficients: None,
5253 anchor_residual_components: Vec::new(),
5254 anchor_residual_rotation: None,
5255 }
5256 }
5257
5258 fn test_fit_with_covariance(beta: Array1<f64>, covariance: Array2<f64>) -> UnifiedFitResult {
5259 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5260 blocks: vec![FittedBlock {
5261 beta: beta.clone(),
5262 role: BlockRole::Mean,
5263 edf: 0.0,
5264 lambdas: Array1::zeros(0),
5265 }],
5266 log_lambdas: Array1::zeros(0),
5267 lambdas: Array1::zeros(0),
5268 likelihood_family: Some(crate::types::LikelihoodFamily::GaussianIdentity),
5269 likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
5270 log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
5271 log_likelihood: 0.0,
5272 deviance: 0.0,
5273 reml_score: 0.0,
5274 stable_penalty_term: 0.0,
5275 penalized_objective: 0.0,
5276 outer_iterations: 0,
5277 outer_converged: true,
5278 outer_gradient_norm: 0.0,
5279 standard_deviation: 1.0,
5280 covariance_conditional: Some(covariance),
5281 covariance_corrected: None,
5282 inference: None,
5283 fitted_link: FittedLinkState::Standard(None),
5284 geometry: None,
5285 block_states: Vec::new(),
5286 pirls_status: PirlsStatus::Converged,
5287 max_abs_eta: 0.0,
5288 constraint_kkt: None,
5289 artifacts: FitArtifacts {
5290 pirls: None,
5291 ..Default::default()
5292 },
5293 inner_cycles: 0,
5294 })
5295 .expect("test fit")
5296 }
5297
5298 fn gaussian_location_scale_fit_with_covariance(
5299 beta_mu: Array1<f64>,
5300 beta_noise: Array1<f64>,
5301 covariance: Array2<f64>,
5302 ) -> UnifiedFitResult {
5303 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5304 blocks: vec![
5305 FittedBlock {
5306 beta: beta_mu,
5307 role: BlockRole::Location,
5308 edf: 0.0,
5309 lambdas: Array1::zeros(0),
5310 },
5311 FittedBlock {
5312 beta: beta_noise,
5313 role: BlockRole::Scale,
5314 edf: 0.0,
5315 lambdas: Array1::zeros(0),
5316 },
5317 ],
5318 log_lambdas: Array1::zeros(0),
5319 lambdas: Array1::zeros(0),
5320 likelihood_family: Some(crate::types::LikelihoodFamily::GaussianIdentity),
5321 likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
5322 log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
5323 log_likelihood: 0.0,
5324 deviance: 0.0,
5325 reml_score: 0.0,
5326 stable_penalty_term: 0.0,
5327 penalized_objective: 0.0,
5328 outer_iterations: 0,
5329 outer_converged: true,
5330 outer_gradient_norm: 0.0,
5331 standard_deviation: 1.0,
5332 covariance_conditional: Some(covariance),
5333 covariance_corrected: None,
5334 inference: None,
5335 fitted_link: FittedLinkState::Standard(None),
5336 geometry: None,
5337 block_states: Vec::new(),
5338 pirls_status: PirlsStatus::Converged,
5339 max_abs_eta: 0.0,
5340 constraint_kkt: None,
5341 artifacts: FitArtifacts {
5342 pirls: None,
5343 ..Default::default()
5344 },
5345 inner_cycles: 0,
5346 })
5347 .expect("gaussian location-scale fit")
5348 }
5349
5350 fn survival_fit_with_covariance(
5351 beta_threshold: Array1<f64>,
5352 beta_log_sigma: Array1<f64>,
5353 covariance: Array2<f64>,
5354 ) -> UnifiedFitResult {
5355 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5356 blocks: vec![
5357 FittedBlock {
5358 beta: beta_threshold,
5359 role: BlockRole::Threshold,
5360 edf: 0.0,
5361 lambdas: Array1::zeros(0),
5362 },
5363 FittedBlock {
5364 beta: beta_log_sigma,
5365 role: BlockRole::Scale,
5366 edf: 0.0,
5367 lambdas: Array1::zeros(0),
5368 },
5369 ],
5370 log_lambdas: Array1::zeros(0),
5371 lambdas: Array1::zeros(0),
5372 likelihood_family: Some(crate::types::LikelihoodFamily::RoystonParmar),
5373 likelihood_scale: crate::types::LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
5374 log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
5375 log_likelihood: 0.0,
5376 deviance: 0.0,
5377 reml_score: 0.0,
5378 stable_penalty_term: 0.0,
5379 penalized_objective: 0.0,
5380 outer_iterations: 0,
5381 outer_converged: true,
5382 outer_gradient_norm: 0.0,
5383 standard_deviation: 1.0,
5384 covariance_conditional: Some(covariance),
5385 covariance_corrected: None,
5386 inference: None,
5387 fitted_link: FittedLinkState::Standard(None),
5388 geometry: None,
5389 block_states: Vec::new(),
5390 pirls_status: PirlsStatus::Converged,
5391 max_abs_eta: 0.0,
5392 constraint_kkt: None,
5393 artifacts: FitArtifacts {
5394 pirls: None,
5395 ..Default::default()
5396 },
5397 inner_cycles: 0,
5398 })
5399 .expect("survival fit")
5400 }
5401
5402 #[test]
5403 fn predict_posterior_mean_probit_matches_closed_form_reference() {
5404 let x = array![[1.0], [1.0]];
5405 let beta = array![0.7];
5406 let offset = array![0.0, 0.0];
5407 let covariance = Array2::from_diag(&array![0.25]);
5408 let out = predict_gam_posterior_mean(
5409 x,
5410 beta.view(),
5411 offset.view(),
5412 crate::types::LikelihoodFamily::BinomialProbit,
5413 covariance.view(),
5414 )
5415 .expect("predict posterior mean");
5416 let expected = crate::quadrature::probit_posterior_meanwith_deriv_exact(0.7, 0.5).mean;
5417 assert!((out.mean[0] - expected).abs() <= 1e-12);
5418 assert!((out.mean[1] - expected).abs() <= 1e-12);
5419 }
5420
5421 #[test]
5422 fn predict_posterior_mean_logit_uses_integrated_dispatch() {
5423 let x = array![[1.0], [1.0]];
5424 let beta = array![0.4];
5425 let offset = array![0.0, 0.0];
5426 let covariance = Array2::from_diag(&array![0.16]);
5427 let out = predict_gam_posterior_mean(
5428 x,
5429 beta.view(),
5430 offset.view(),
5431 crate::types::LikelihoodFamily::BinomialLogit,
5432 covariance.view(),
5433 )
5434 .expect("predict posterior mean");
5435 let quadctx = crate::quadrature::QuadratureContext::new();
5436 let expected = crate::quadrature::integrated_inverse_link_mean_and_derivative(
5437 &quadctx,
5438 LinkFunction::Logit,
5439 0.4,
5440 0.4,
5441 )
5442 .expect("logit integrated inverse-link moments should evaluate")
5443 .mean;
5444 assert!((out.mean[0] - expected).abs() <= 1e-12);
5445 assert!((out.mean[1] - expected).abs() <= 1e-12);
5446 }
5447
5448 #[test]
5449 fn bernoulli_marginal_slope_predictor_rejects_structurally_invalid_or_unknown_runtime_kernel() {
5450 let seed = array![-1.5, -0.2, 0.6, 1.4];
5451 let prepared =
5452 crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5453 &seed,
5454 &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5455 degree: 3,
5456 num_internal_knots: 3,
5457 ..Default::default()
5458 },
5459 )
5460 .expect("production score-warp runtime");
5461 let production_runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
5462 let score_only = BernoulliMarginalSlopePredictor {
5463 beta_marginal: array![0.8],
5464 beta_logslope: array![1.6],
5465 beta_score_warp: Some(array![0.7, -0.4]),
5466 beta_link_dev: None,
5467 base_link: InverseLink::Standard(crate::types::LinkFunction::Probit),
5468 z_column: "z".to_string(),
5469 latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
5470 latent_measure: LatentMeasureKind::StandardNormal,
5471 baseline_marginal: 0.0,
5472 baseline_logslope: 0.0,
5473 covariance: None,
5474 score_warp_runtime: Some(SavedAnchoredDeviationRuntime {
5475 kernel: "OldQuadrature".to_string(),
5476 ..production_runtime.clone()
5477 }),
5478 link_deviation_runtime: None,
5480 gaussian_frailty_sd: None,
5481 latent_z_calibration: None,
5482 };
5483 let err = score_only
5484 .score_warp_runtime
5485 .as_ref()
5486 .unwrap()
5487 .design(&array![0.0])
5488 .unwrap_err();
5489 assert!(err.contains("DenestedCubicTransport"));
5490
5491 let err =
5492 crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5493 &seed,
5494 &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5495 degree: 2,
5496 num_internal_knots: 3,
5497 ..Default::default()
5498 },
5499 )
5500 .expect_err("non-cubic deviation runtimes should be rejected");
5501 assert!(err.contains("degree must be 3"));
5502
5503 let mut structurally_invalid = production_runtime.clone();
5504 structurally_invalid.span_c0[0].pop();
5505 let err = structurally_invalid.design(&array![0.0]).unwrap_err();
5506 assert!(err.contains("c0 row 0 has width"));
5507
5508 let cubic = production_runtime;
5509 assert!(cubic.design(&array![0.0]).is_ok());
5510 }
5511
5512 #[test]
5513 fn saved_anchored_deviation_runtime_local_cubic_reconstructs_values() {
5514 let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
5515 let prepared =
5516 crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5517 &seed,
5518 &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5519 num_internal_knots: 4,
5520 ..Default::default()
5521 },
5522 )
5523 .expect("build saved anchored deviation runtime");
5524 let runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
5525 let beta = Array1::from_iter(
5526 (0..runtime.basis_dim)
5527 .map(|idx| 0.02 * (idx as f64 + 1.0) * (-1.0_f64).powi(idx as i32)),
5528 );
5529 let n_spans = runtime.span_count().expect("span count");
5530 assert!(n_spans >= 2);
5531 for span_idx in 0..n_spans {
5532 let cubic = runtime
5533 .local_cubic_on_span(&beta, span_idx)
5534 .expect("local cubic");
5535 let x_eval = array![cubic.left, 0.5 * (cubic.left + cubic.right), cubic.right];
5536 let expected = runtime.design(&x_eval).expect("design").dot(&beta);
5537 let expected_d1 = runtime
5538 .first_derivative_design(&x_eval)
5539 .expect("d1 design")
5540 .dot(&beta);
5541 for i in 0..x_eval.len() {
5542 let x = x_eval[i];
5543 assert!((cubic.evaluate(x) - expected[i]).abs() < 1e-10);
5544 assert!((cubic.first_derivative(x) - expected_d1[i]).abs() < 1e-10);
5545 let selected = runtime.local_cubic_at(&beta, x).expect("local cubic at x");
5546 let expected_span_idx = if i == 0 && span_idx > 0 {
5547 span_idx - 1
5548 } else {
5549 span_idx
5550 };
5551 let expected_cubic = runtime
5552 .local_cubic_on_span(&beta, expected_span_idx)
5553 .expect("expected local cubic on span");
5554 assert_eq!(selected.left, expected_cubic.left);
5555 assert_eq!(selected.right, expected_cubic.right);
5556 }
5557 }
5558 }
5559
5560 #[test]
5561 fn saved_anchored_deviation_runtime_design_with_anchor_rows_applies_residual() {
5562 use crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock;
5563 use crate::inference::model::{SavedAnchorComponent, SavedAnchorKind};
5564
5565 let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
5566 let prepared =
5567 crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5568 &seed,
5569 &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5570 num_internal_knots: 4,
5571 ..Default::default()
5572 },
5573 )
5574 .expect("build saved anchored deviation runtime");
5575 let mut runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
5576
5577 let d = 3usize;
5580 let m: Vec<Vec<f64>> = (0..d)
5581 .map(|i| {
5582 (0..runtime.basis_dim)
5583 .map(|j| 0.1 * (i as f64 + 1.0) - 0.05 * (j as f64 + 1.0))
5584 .collect()
5585 })
5586 .collect();
5587 runtime.anchor_residual_coefficients = Some(m.clone());
5588 runtime.anchor_residual_components = vec![SavedAnchorComponent {
5589 kind: SavedAnchorKind::Parametric {
5590 block: ParametricAnchorBlock::Marginal,
5591 ncols: d,
5592 },
5593 }];
5594 runtime.anchor_residual_rotation = None;
5595
5596 let values = array![-1.0, 0.0, 0.5, 2.0];
5597 let n = values.len();
5598 let anchor_rows = Array2::from_shape_fn((n, d), |(i, j)| {
5599 0.3 * (i as f64 + 1.0) - 0.1 * (j as f64 + 1.0)
5600 });
5601
5602 let raw = runtime
5603 .design_uncorrected(&values)
5604 .expect("uncorrected design");
5605 let corrected = runtime
5606 .design_with_anchor_rows(&values, anchor_rows.view())
5607 .expect("design with anchor rows");
5608
5609 let mut m_dense = Array2::<f64>::zeros((d, runtime.basis_dim));
5611 for (i, row) in m.iter().enumerate() {
5612 for (j, &v) in row.iter().enumerate() {
5613 m_dense[[i, j]] = v;
5614 }
5615 }
5616 let expected = &raw - &anchor_rows.dot(&m_dense);
5617
5618 for i in 0..n {
5619 for j in 0..runtime.basis_dim {
5620 assert!(
5621 (corrected[[i, j]] - expected[[i, j]]).abs() < 1e-12,
5622 "residual-corrected design mismatch at ({i}, {j}): \
5623 got {got}, expected {exp}",
5624 got = corrected[[i, j]],
5625 exp = expected[[i, j]],
5626 );
5627 }
5628 }
5629
5630 let correction = runtime
5633 .anchor_correction_matrix(anchor_rows.view())
5634 .expect("anchor correction matrix")
5635 .expect("Some correction when residual is present");
5636 for i in 0..n {
5637 for j in 0..runtime.basis_dim {
5638 assert!((raw[[i, j]] - correction[[i, j]] - corrected[[i, j]]).abs() < 1e-12,);
5639 }
5640 }
5641 }
5642
5643 #[test]
5644 fn bernoulli_marginal_slope_rigid_gaussian_frailty_uses_scaled_closed_form() {
5645 let predictor = BernoulliMarginalSlopePredictor {
5646 beta_marginal: array![0.7],
5647 beta_logslope: array![-0.4],
5648 beta_score_warp: None,
5649 beta_link_dev: None,
5650 base_link: InverseLink::Standard(crate::types::LinkFunction::Probit),
5651 z_column: "z".to_string(),
5652 latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
5653 latent_measure: LatentMeasureKind::StandardNormal,
5654 baseline_marginal: 0.1,
5655 baseline_logslope: -0.2,
5656 covariance: None,
5657 score_warp_runtime: None,
5658 link_deviation_runtime: None,
5659 gaussian_frailty_sd: Some(0.8),
5660 latent_z_calibration: None,
5661 };
5662 let theta = predictor.theta();
5663 let input = PredictInput {
5664 design: DesignMatrix::from(array![[1.0], [1.0]]),
5665 offset: array![0.0, 0.05],
5666 design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
5667 offset_noise: Some(array![0.0, -0.1]),
5668 auxiliary_scalar: Some(array![-0.3, 1.2]),
5669 auxiliary_matrix: None,
5670 };
5671
5672 let (eta, grad) = predictor
5673 .final_eta_and_gradient_from_theta(&input, &theta, true)
5674 .expect("rigid frailty path should evaluate");
5675
5676 let scale = predictor.probit_frailty_scale();
5677 let marginal_eta = array![0.8, 0.85];
5678 let logslope_eta = array![-0.6, -0.7];
5679 let z = array![-0.3, 1.2];
5680 for i in 0..eta.len() {
5681 let sb = scale * logslope_eta[i];
5682 let c = (1.0 + sb * sb).sqrt();
5683 let expected_eta = marginal_eta[i] * c + sb * z[i];
5684 assert!((eta[i] - expected_eta).abs() <= 1e-12);
5685 let expected_d_marginal = c;
5686 let expected_d_logslope =
5687 marginal_eta[i] * scale * scale * logslope_eta[i] / c + scale * z[i];
5688 let grad = grad.as_ref().expect("gradient should be returned");
5689 assert!((grad[[i, 0]] - expected_d_marginal).abs() <= 1e-12);
5690 assert!((grad[[i, 1]] - expected_d_logslope).abs() <= 1e-12);
5691 }
5692 }
5693
5694 #[test]
5695 fn bernoulli_marginal_slope_predictor_uses_local_empirical_latent_law() {
5696 let grids = vec![
5697 EmpiricalZGrid {
5698 nodes: vec![-1.2, -0.2, 0.7],
5699 weights: vec![0.45, 0.35, 0.20],
5700 },
5701 EmpiricalZGrid {
5702 nodes: vec![-0.4, 0.6, 2.4],
5703 weights: vec![0.20, 0.35, 0.45],
5704 },
5705 ];
5706 let predictor = BernoulliMarginalSlopePredictor {
5707 beta_marginal: array![0.2],
5708 beta_logslope: array![0.9],
5709 beta_score_warp: None,
5710 beta_link_dev: None,
5711 base_link: InverseLink::Standard(crate::types::LinkFunction::Probit),
5712 z_column: "z".to_string(),
5713 latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
5714 latent_measure: LatentMeasureKind::LocalEmpirical {
5715 feature_cols: vec![0],
5716 input_scales: None,
5717 centers: vec![vec![-1.0], vec![1.0]],
5718 grids: grids.clone(),
5719 top_k: 1,
5720 bandwidth: 0.25,
5721 train_row_mixtures: std::sync::Arc::new(Vec::new()),
5722 },
5723 baseline_marginal: 0.0,
5724 baseline_logslope: 0.0,
5725 covariance: None,
5726 score_warp_runtime: None,
5727 link_deviation_runtime: None,
5728 gaussian_frailty_sd: None,
5729 latent_z_calibration: None,
5730 };
5731 let input = PredictInput {
5732 design: DesignMatrix::from(array![[1.0], [1.0]]),
5733 offset: array![0.0, 0.0],
5734 design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
5735 offset_noise: Some(array![0.0, 0.0]),
5736 auxiliary_scalar: Some(array![0.0, 0.0]),
5737 auxiliary_matrix: Some(array![[-1.0], [1.0]]),
5738 };
5739
5740 let (eta, _) = predictor
5741 .final_eta_and_gradient_from_theta(&input, &predictor.theta(), true)
5742 .expect("local empirical prediction");
5743 let (chain_eta, deta_dq) = predictor
5744 .predict_eta_and_q_chain(&input)
5745 .expect("local empirical q chain");
5746
5747 for (row, grid) in grids.iter().enumerate() {
5748 let expected_intercept = empirical_intercept_from_marginal(
5749 normal_cdf(0.2),
5750 0.2,
5751 0.9,
5752 1.0,
5753 &grid.nodes,
5754 &grid.weights,
5755 None,
5756 )
5757 .expect("expected empirical intercept");
5758 assert!((eta[row] - expected_intercept).abs() <= 1e-10);
5759 assert!((chain_eta[row] - eta[row]).abs() <= 1e-12);
5760 assert!(deta_dq[row].is_finite() && deta_dq[row] > 0.0);
5761 }
5762 }
5763
5764 #[test]
5765 fn bernoulli_marginal_slope_predictor_rejects_nonprobit_base_link_scale() {
5766 let predictor = BernoulliMarginalSlopePredictor {
5767 beta_marginal: array![0.7],
5768 beta_logslope: array![-0.4],
5769 beta_score_warp: None,
5770 beta_link_dev: None,
5771 base_link: InverseLink::Standard(crate::types::LinkFunction::Logit),
5772 z_column: "z".to_string(),
5773 latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
5774 latent_measure: LatentMeasureKind::StandardNormal,
5775 baseline_marginal: 0.1,
5776 baseline_logslope: -0.2,
5777 covariance: None,
5778 score_warp_runtime: None,
5779 link_deviation_runtime: None,
5780 gaussian_frailty_sd: Some(0.8),
5781 latent_z_calibration: None,
5782 };
5783 let theta = predictor.theta();
5784 let input = PredictInput {
5785 design: DesignMatrix::from(array![[1.0], [1.0]]),
5786 offset: array![0.0, 0.05],
5787 design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
5788 offset_noise: Some(array![0.0, -0.1]),
5789 auxiliary_scalar: Some(array![-0.3, 1.2]),
5790 auxiliary_matrix: None,
5791 };
5792
5793 let err = predictor
5794 .final_eta_and_gradient_from_theta(&input, &theta, true)
5795 .expect_err("non-probit marginal-slope prediction should be rejected");
5796 assert!(err.to_string().contains("requires link(type=probit)"));
5797 }
5798
5799 #[test]
5800 fn saved_anchored_deviation_runtime_basis_cubic_matches_basis_column() {
5801 let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
5802 let prepared =
5803 crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5804 &seed,
5805 &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5806 num_internal_knots: 4,
5807 ..Default::default()
5808 },
5809 )
5810 .expect("build saved anchored deviation runtime");
5811 let runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
5812 let cubic = runtime.basis_span_cubic(0, 1).expect("basis span cubic");
5813 let x_eval = array![cubic.left, 0.5 * (cubic.left + cubic.right), cubic.right];
5814 let design = runtime.design(&x_eval).expect("basis design");
5815 let d1 = runtime
5816 .first_derivative_design(&x_eval)
5817 .expect("basis d1 design");
5818 for i in 0..x_eval.len() {
5819 let x = x_eval[i];
5820 assert!((cubic.evaluate(x) - design[[i, 1]]).abs() < 1e-10);
5821 assert!((cubic.first_derivative(x) - d1[[i, 1]]).abs() < 1e-10);
5822 let selected = runtime.basis_cubic_at(1, x).expect("basis cubic at x");
5823 let expected_span_idx = 0;
5824 let expected_cubic = runtime
5825 .basis_span_cubic(expected_span_idx, 1)
5826 .expect("expected basis span cubic");
5827 assert_eq!(selected.left, expected_cubic.left);
5828 assert_eq!(selected.right, expected_cubic.right);
5829 }
5830 }
5831
5832 #[test]
5833 fn predict_royston_parmar_point_prediction_returns_survival_probability() {
5834 let x = array![[1.0], [1.0]];
5835 let beta = array![0.4];
5836 let offset = array![0.0, 0.8];
5837 let out = predict_gam(
5838 x,
5839 beta.view(),
5840 offset.view(),
5841 crate::types::LikelihoodFamily::RoystonParmar,
5842 )
5843 .expect("royston-parmar point prediction");
5844 let expected_eta = array![0.4, 1.2];
5845 let expected_mean = expected_eta.mapv(|eta: f64| (-(eta.exp())).exp().clamp(0.0, 1.0));
5846 for i in 0..out.eta.len() {
5848 assert!(
5849 (out.eta[i] - expected_eta[i]).abs() <= 1e-14,
5850 "eta[{i}] mismatch"
5851 );
5852 }
5853 for i in 0..out.mean.len() {
5854 assert!((out.mean[i] - expected_mean[i]).abs() <= 1e-12);
5855 }
5856 }
5857
5858 #[test]
5859 fn predict_royston_parmar_posterior_mean_matches_quadrature_and_fit_path() {
5860 let x = array![[1.0], [1.0]];
5861 let beta = array![0.35];
5862 let offset = array![0.0, 0.0];
5863 let covariance = Array2::from_diag(&array![0.09]);
5864 let fit = test_fit_with_covariance(beta.clone(), covariance.clone());
5865
5866 let out = predict_gam_posterior_mean(
5867 x.clone(),
5868 beta.view(),
5869 offset.view(),
5870 crate::types::LikelihoodFamily::RoystonParmar,
5871 covariance.view(),
5872 )
5873 .expect("royston-parmar posterior mean");
5874 let out_with_fit = predict_gam_posterior_meanwith_fit(
5875 x,
5876 beta.view(),
5877 offset.view(),
5878 crate::types::LikelihoodFamily::RoystonParmar,
5879 covariance.view(),
5880 &fit,
5881 )
5882 .expect("royston-parmar posterior mean with fit");
5883
5884 let quadctx = crate::quadrature::QuadratureContext::new();
5885 let expected = crate::quadrature::survival_posterior_mean(&quadctx, 0.35, 0.3);
5886 for i in 0..out.mean.len() {
5887 assert!((out.mean[i] - expected).abs() <= 1e-12);
5888 assert!((out_with_fit.mean[i] - expected).abs() <= 1e-12);
5889 assert!((out_with_fit.mean[i] - out.mean[i]).abs() <= 1e-12);
5890 assert!(
5891 (out_with_fit.eta_standard_error[i] - out.eta_standard_error[i]).abs() <= 1e-12
5892 );
5893 }
5894 }
5895
5896 #[test]
5897 fn predict_royston_parmar_uncertainty_clamps_and_orders_intervals() {
5898 let x = array![[1.0]];
5899 let beta = array![0.6];
5900 let offset = array![0.0];
5901 let covariance = Array2::from_diag(&array![0.25]);
5902 let fit = test_fit_with_covariance(beta.clone(), covariance);
5903 let options = PredictUncertaintyOptions {
5904 confidence_level: 0.95,
5905 covariance_mode: InferenceCovarianceMode::Conditional,
5906 mean_interval_method: MeanIntervalMethod::TransformEta,
5907 includeobservation_interval: false,
5908 apply_bias_correction: false,
5909 edgeworth_one_sided: false,
5912 boundary_correction: false,
5913 ood_inflation: false,
5914 multi_point_joint: false,
5915 ..PredictUncertaintyOptions::default()
5916 };
5917
5918 let out = predict_gamwith_uncertainty(
5919 x,
5920 beta.view(),
5921 offset.view(),
5922 crate::types::LikelihoodFamily::RoystonParmar,
5923 &fit,
5924 &options,
5925 )
5926 .expect("royston-parmar uncertainty");
5927
5928 let quadctx = crate::quadrature::QuadratureContext::new();
5929 let (_, variance) = crate::quadrature::survival_posterior_meanvariance(&quadctx, 0.6, 0.5);
5930 assert!((out.mean[0] - (-(0.6_f64.exp())).exp()).abs() <= 1e-12);
5931 assert!((out.eta_standard_error[0] - 0.5).abs() <= 1e-12);
5932 assert!((out.mean_standard_error[0] - variance.sqrt()).abs() <= 1e-12);
5933 assert!(out.mean_lower[0] <= out.mean_upper[0]);
5934 assert!((0.0..=1.0).contains(&out.mean_lower[0]));
5935 assert!((0.0..=1.0).contains(&out.mean_upper[0]));
5936 }
5937
5938 #[test]
5939 fn gaussian_location_scale_sigma_includes_noise_offset() {
5940 let predictor = GaussianLocationScalePredictor {
5941 beta_mu: array![0.0],
5942 beta_noise: array![0.0],
5943 response_scale: 2.0,
5944 covariance: None,
5945 link_wiggle: None,
5946 };
5947 let input = PredictInput {
5948 design: DesignMatrix::from(array![[1.0], [1.0]]),
5949 offset: array![0.0, 0.0],
5950 design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
5951 offset_noise: Some(array![(3.0f64).ln(), (5.0f64).ln()]),
5952 auxiliary_scalar: None,
5953 auxiliary_matrix: None,
5954 };
5955
5956 let sigma = predictor
5957 .predict_noise_scale(&input)
5958 .expect("gaussian location-scale sigma")
5959 .expect("sigma should be returned");
5960 assert!((sigma[0] - 6.02).abs() <= 1e-12);
5962 assert!((sigma[1] - 10.02).abs() <= 1e-12);
5963 let out = predictor
5964 .predict_with_uncertainty(&input)
5965 .expect("gaussian location-scale uncertainty");
5966 assert!(out.eta_se.is_none());
5967 assert!(out.mean_se.is_none());
5968 }
5969
5970 #[test]
5971 fn gaussian_location_scale_eta_se_pads_scale_block_without_wiggle() {
5972 let predictor = GaussianLocationScalePredictor {
5973 beta_mu: array![0.5],
5974 beta_noise: array![0.1],
5975 response_scale: 1.0,
5976 covariance: Some(array![[4.0, 0.0], [0.0, 9.0]]),
5977 link_wiggle: None,
5978 };
5979 let fit = gaussian_location_scale_fit_with_covariance(
5980 array![0.5],
5981 array![0.1],
5982 array![[4.0, 0.0], [0.0, 9.0]],
5983 );
5984 let input = PredictInput {
5985 design: DesignMatrix::from(array![[1.0]]),
5986 offset: array![0.0],
5987 design_noise: Some(DesignMatrix::from(array![[1.0]])),
5988 offset_noise: None,
5989 auxiliary_scalar: None,
5990 auxiliary_matrix: None,
5991 };
5992
5993 let out = predictor
5994 .predict_posterior_mean(&input, &fit, None)
5995 .expect("gaussian location-scale posterior mean");
5996 assert!((out.eta_standard_error[0] - 2.0).abs() <= 1e-12);
5997 }
5998
5999 #[test]
6000 fn survival_eta_se_pads_log_sigma_block() {
6001 let predictor = SurvivalPredictor {
6002 beta_threshold: array![0.5],
6003 beta_log_sigma: array![0.0],
6004 inverse_link: InverseLink::Standard(LinkFunction::Probit),
6005 covariance: Some(array![[9.0, 0.0], [0.0, 16.0]]),
6006 };
6007 let input = PredictInput {
6008 design: DesignMatrix::from(array![[1.0]]),
6009 offset: array![0.0],
6010 design_noise: Some(DesignMatrix::from(array![[1.0]])),
6011 offset_noise: Some(array![0.0]),
6012 auxiliary_scalar: None,
6013 auxiliary_matrix: None,
6014 };
6015
6016 let out = predictor
6017 .predict_with_uncertainty(&input)
6018 .expect("survival uncertainty");
6019 let eta_se = out.eta_se.expect("eta_se should be present");
6020 assert!((eta_se[0] - 3.0).abs() <= 1e-12);
6021 }
6022
6023 #[test]
6024 fn survival_predictor_cloglog_point_and_se_use_upper_tail_at_q0() {
6025 let predictor = SurvivalPredictor {
6026 beta_threshold: array![-1.0],
6027 beta_log_sigma: array![0.0],
6028 inverse_link: InverseLink::Standard(LinkFunction::CLogLog),
6029 covariance: Some(array![[4.0, 0.0], [0.0, 0.0]]),
6030 };
6031 let input = PredictInput {
6032 design: DesignMatrix::from(array![[1.0]]),
6033 offset: array![0.0],
6034 design_noise: Some(DesignMatrix::from(array![[1.0]])),
6035 offset_noise: Some(array![0.0]),
6036 auxiliary_scalar: None,
6037 auxiliary_matrix: None,
6038 };
6039
6040 let out = predictor
6041 .predict_with_uncertainty(&input)
6042 .expect("cloglog survival prediction");
6043 let q0 = 1.0_f64;
6044 let expected_survival = (-(q0.exp())).exp();
6045 let expected_mean_se = 2.0 * (q0 - q0.exp()).exp();
6046
6047 assert!((out.mean[0] - expected_survival).abs() <= 1e-12);
6048 assert!(
6049 (out.mean_se.expect("mean_se should be present")[0] - expected_mean_se).abs() <= 1e-12
6050 );
6051 }
6052
6053 #[test]
6054 fn survival_predictor_cloglog_posterior_mean_zero_covariance_matches_point_prediction() {
6055 let predictor = SurvivalPredictor {
6056 beta_threshold: array![-1.0],
6057 beta_log_sigma: array![0.0],
6058 inverse_link: InverseLink::Standard(LinkFunction::CLogLog),
6059 covariance: Some(Array2::zeros((2, 2))),
6060 };
6061 let fit = survival_fit_with_covariance(array![-1.0], array![0.0], Array2::zeros((2, 2)));
6062 let input = PredictInput {
6063 design: DesignMatrix::from(array![[1.0]]),
6064 offset: array![0.0],
6065 design_noise: Some(DesignMatrix::from(array![[1.0]])),
6066 offset_noise: Some(array![0.0]),
6067 auxiliary_scalar: None,
6068 auxiliary_matrix: None,
6069 };
6070
6071 let point = predictor
6072 .predict_plugin_response(&input)
6073 .expect("cloglog survival point prediction");
6074 let posterior = predictor
6075 .predict_posterior_mean(&input, &fit, None)
6076 .expect("cloglog survival posterior mean");
6077
6078 assert!((posterior.mean[0] - point.mean[0]).abs() <= 1e-12);
6079 }
6080
6081 #[test]
6082 fn survival_predictor_zero_threshold_with_tiny_sigma_stays_finite() {
6083 let predictor = SurvivalPredictor {
6084 beta_threshold: array![0.0],
6085 beta_log_sigma: array![0.0],
6086 inverse_link: InverseLink::Standard(LinkFunction::CLogLog),
6087 covariance: None,
6088 };
6089 let input = PredictInput {
6090 design: DesignMatrix::from(array![[1.0]]),
6091 offset: array![0.0],
6092 design_noise: Some(DesignMatrix::from(array![[1.0]])),
6093 offset_noise: Some(array![-1000.0]),
6094 auxiliary_scalar: None,
6095 auxiliary_matrix: None,
6096 };
6097
6098 let point = predictor
6099 .predict_plugin_response(&input)
6100 .expect("cloglog survival point prediction");
6101 let expected = (-1.0_f64).exp();
6102
6103 assert!(point.mean[0].is_finite());
6104 assert!((point.mean[0] - expected).abs() <= 1e-12);
6105 }
6106
6107 fn test_fit_with_bias_correction(
6110 beta: Array1<f64>,
6111 covariance: Array2<f64>,
6112 bias_correction_beta: Option<Array1<f64>>,
6113 ) -> UnifiedFitResult {
6114 use crate::estimate::FitInference;
6115 let p = beta.len();
6116 let inf = FitInference {
6117 edf_by_block: vec![],
6120 edf_total: p as f64,
6121 smoothing_correction: None,
6122 penalized_hessian: Array2::<f64>::eye(p),
6123 working_weights: Array1::zeros(0),
6124 working_response: Array1::zeros(0),
6125 reparam_qs: None,
6126 beta_covariance: Some(covariance.clone()),
6127 beta_standard_errors: None,
6128 beta_covariance_corrected: None,
6129 beta_standard_errors_corrected: None,
6130 bias_correction_beta,
6131 };
6132 UnifiedFitResult::new_for_test_unchecked(UnifiedFitResultParts {
6133 blocks: vec![FittedBlock {
6134 beta: beta.clone(),
6135 role: BlockRole::Mean,
6136 edf: p as f64,
6137 lambdas: Array1::zeros(0),
6138 }],
6139 log_lambdas: Array1::zeros(0),
6140 lambdas: Array1::zeros(0),
6141 likelihood_family: Some(crate::types::LikelihoodFamily::GaussianIdentity),
6142 likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
6143 log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
6144 log_likelihood: 0.0,
6145 deviance: 0.0,
6146 reml_score: 0.0,
6147 stable_penalty_term: 0.0,
6148 penalized_objective: 0.0,
6149 outer_iterations: 0,
6150 outer_converged: true,
6151 outer_gradient_norm: 0.0,
6152 standard_deviation: 1.0,
6153 covariance_conditional: Some(covariance),
6154 covariance_corrected: None,
6155 inference: Some(inf),
6156 fitted_link: FittedLinkState::Standard(Some(LinkFunction::Identity)),
6157 geometry: None,
6158 block_states: Vec::new(),
6159 pirls_status: PirlsStatus::Converged,
6160 max_abs_eta: 0.0,
6161 constraint_kkt: None,
6162 artifacts: FitArtifacts {
6163 pirls: None,
6164 ..Default::default()
6165 },
6166 inner_cycles: 0,
6167 })
6168 }
6169
6170 fn bc_options(apply: bool) -> PredictUncertaintyOptions {
6171 PredictUncertaintyOptions {
6172 confidence_level: 0.95,
6173 covariance_mode: InferenceCovarianceMode::Conditional,
6174 mean_interval_method: MeanIntervalMethod::TransformEta,
6175 includeobservation_interval: false,
6176 apply_bias_correction: apply,
6177 edgeworth_one_sided: false,
6178 boundary_correction: false,
6179 ood_inflation: false,
6180 multi_point_joint: false,
6181 ..PredictUncertaintyOptions::default()
6182 }
6183 }
6184
6185 #[test]
6186 fn test_bias_correction_idempotent_with_flag() {
6187 let x = array![[1.0, 0.5]];
6190 let beta = array![1.0, 2.0];
6191 let bc = array![0.1, -0.05];
6192 let cov = Array2::<f64>::eye(2);
6193 let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc.clone()));
6194 let offset = array![0.0];
6195
6196 let pred_off = predict_gamwith_uncertainty(
6198 x.clone(),
6199 beta.view(),
6200 offset.view(),
6201 crate::types::LikelihoodFamily::GaussianIdentity,
6202 &fit,
6203 &bc_options(false),
6204 )
6205 .expect("predict no-bc");
6206 let pred_on = predict_gamwith_uncertainty(
6207 x.clone(),
6208 beta.view(),
6209 offset.view(),
6210 crate::types::LikelihoodFamily::GaussianIdentity,
6211 &fit,
6212 &bc_options(true),
6213 )
6214 .expect("predict bc");
6215 assert!((pred_off.eta[0] - 2.0).abs() < 1e-12);
6216 let expected_delta = 1.0 * 0.1 + 0.5 * (-0.05);
6217 assert!((pred_on.eta[0] - (2.0 + expected_delta)).abs() < 1e-12);
6218 assert!(
6220 (pred_off.eta_standard_error[0] - pred_on.eta_standard_error[0]).abs() < 1e-14,
6221 "bias correction must not affect eta standard error"
6222 );
6223 }
6224
6225 #[test]
6226 fn test_bias_correction_zero_when_unset() {
6227 let x = array![[1.0, 0.5]];
6230 let beta = array![1.0, 2.0];
6231 let cov = Array2::<f64>::eye(2);
6232 let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
6233 let offset = array![0.0];
6234
6235 let pred = predict_gamwith_uncertainty(
6236 x,
6237 beta.view(),
6238 offset.view(),
6239 crate::types::LikelihoodFamily::GaussianIdentity,
6240 &fit,
6241 &bc_options(true),
6242 )
6243 .expect("predict");
6244 assert!((pred.eta[0] - 2.0).abs() < 1e-12);
6245 }
6246
6247 #[test]
6248 fn test_bias_correction_does_not_affect_posterior_se() {
6249 let x = array![[1.0, 0.5], [0.7, -0.3]];
6251 let beta = array![0.4, 0.9];
6252 let bc = array![0.2, -0.1];
6253 let cov = array![[1.0, 0.1], [0.1, 0.5]];
6254 let fit_with = test_fit_with_bias_correction(beta.clone(), cov.clone(), Some(bc));
6255 let fit_without = test_fit_with_bias_correction(beta.clone(), cov, None);
6256 let offset = array![0.0, 0.0];
6257
6258 let pred_with = predict_gamwith_uncertainty(
6259 x.clone(),
6260 beta.view(),
6261 offset.view(),
6262 crate::types::LikelihoodFamily::GaussianIdentity,
6263 &fit_with,
6264 &bc_options(true),
6265 )
6266 .expect("predict with bc");
6267 let pred_without = predict_gamwith_uncertainty(
6268 x,
6269 beta.view(),
6270 offset.view(),
6271 crate::types::LikelihoodFamily::GaussianIdentity,
6272 &fit_without,
6273 &bc_options(true),
6274 )
6275 .expect("predict without bc");
6276 for i in 0..2 {
6277 assert!(
6278 (pred_with.eta_standard_error[i] - pred_without.eta_standard_error[i]).abs()
6279 < 1e-14,
6280 "BC must not perturb eta SE at index {i}"
6281 );
6282 }
6283 }
6284
6285 #[test]
6286 fn test_bias_correction_accessor_propagates() {
6287 let beta = array![1.0, 2.0];
6289 let bc = array![0.3, -0.2];
6290 let cov = Array2::<f64>::eye(2);
6291 let fit = test_fit_with_bias_correction(beta, cov, Some(bc.clone()));
6292 let recovered = fit
6293 .bias_correction_beta()
6294 .expect("bias correction should be present");
6295 assert_eq!(recovered.len(), bc.len());
6296 for i in 0..bc.len() {
6297 assert!((recovered[i] - bc[i]).abs() < 1e-15);
6298 }
6299 }
6300
6301 fn solve_3x3_spd(h: &Array2<f64>, r: &Array1<f64>) -> Array1<f64> {
6307 assert_eq!(h.nrows(), 3);
6308 assert_eq!(h.ncols(), 3);
6309 let m = |i: usize, j: usize| h[[i, j]];
6310 let det = m(0, 0) * (m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1))
6311 - m(0, 1) * (m(1, 0) * m(2, 2) - m(1, 2) * m(2, 0))
6312 + m(0, 2) * (m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0));
6313 assert!(det.abs() > 1e-12, "singular matrix in solve_3x3_spd");
6314 let cof = array![
6316 [
6317 m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1),
6318 -(m(1, 0) * m(2, 2) - m(1, 2) * m(2, 0)),
6319 m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0)
6320 ],
6321 [
6322 -(m(0, 1) * m(2, 2) - m(0, 2) * m(2, 1)),
6323 m(0, 0) * m(2, 2) - m(0, 2) * m(2, 0),
6324 -(m(0, 0) * m(2, 1) - m(0, 1) * m(2, 0))
6325 ],
6326 [
6327 m(0, 1) * m(1, 2) - m(0, 2) * m(1, 1),
6328 -(m(0, 0) * m(1, 2) - m(0, 2) * m(1, 0)),
6329 m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0)
6330 ]
6331 ];
6332 let mut y = Array1::<f64>::zeros(3);
6334 for i in 0..3 {
6335 let mut acc = 0.0;
6336 for j in 0..3 {
6337 acc += cof[[j, i]] * r[j];
6338 }
6339 y[i] = acc / det;
6340 }
6341 y
6342 }
6343
6344 struct Lcg(u64);
6346 impl Lcg {
6347 fn new(seed: u64) -> Self {
6348 Self(
6349 seed.wrapping_mul(6364136223846793005)
6350 .wrapping_add(1442695040888963407),
6351 )
6352 }
6353 fn next_u64(&mut self) -> u64 {
6354 self.0 = self
6355 .0
6356 .wrapping_mul(6364136223846793005)
6357 .wrapping_add(1442695040888963407);
6358 self.0
6359 }
6360 fn unif(&mut self) -> f64 {
6361 ((self.next_u64() >> 11) as f64) / ((1u64 << 53) as f64)
6363 }
6364 fn normal(&mut self) -> f64 {
6366 let u1 = self.unif().max(1e-300);
6367 let u2 = self.unif();
6368 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
6369 }
6370 }
6371
6372 #[test]
6375 fn test_bias_correction_matches_explicit_formula() {
6376 let h = array![[4.0_f64, 0.5, 0.2], [0.5, 3.0, 0.1], [0.2, 0.1, 2.0]];
6378 let s_pen = array![[1.0_f64, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 2.0]];
6379 let beta = array![0.7_f64, -1.3, 0.4];
6380 let s_beta = s_pen.dot(&beta);
6381 let b_hat = solve_3x3_spd(&h, &s_beta);
6382
6383 let cov = Array2::<f64>::eye(3);
6385 let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(b_hat.clone()));
6386
6387 let x = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
6389 let offset = array![0.0, 0.0, 0.0];
6390
6391 let pred_raw = predict_gamwith_uncertainty(
6392 x.clone(),
6393 beta.view(),
6394 offset.view(),
6395 crate::types::LikelihoodFamily::GaussianIdentity,
6396 &fit,
6397 &bc_options(false),
6398 )
6399 .expect("raw predict");
6400 let pred_bc = predict_gamwith_uncertainty(
6401 x,
6402 beta.view(),
6403 offset.view(),
6404 crate::types::LikelihoodFamily::GaussianIdentity,
6405 &fit,
6406 &bc_options(true),
6407 )
6408 .expect("bc predict");
6409
6410 for i in 0..3 {
6411 assert!(
6412 (pred_raw.eta[i] - beta[i]).abs() < 1e-12,
6413 "raw eta[{i}] = {} expected {}",
6414 pred_raw.eta[i],
6415 beta[i]
6416 );
6417 let expected = beta[i] + b_hat[i];
6418 assert!(
6419 (pred_bc.eta[i] - expected).abs() < 1e-12,
6420 "BC eta[{i}] = {} expected β+b̂ = {} (b̂[{i}] = {})",
6421 pred_bc.eta[i],
6422 expected,
6423 b_hat[i]
6424 );
6425 }
6426 }
6427
6428 #[test]
6430 fn test_bias_correction_zero_for_zero_penalty() {
6431 let beta = array![0.5_f64, -0.4, 1.7];
6434 let bc_zero = Array1::<f64>::zeros(3);
6435 let cov = Array2::<f64>::eye(3);
6436 let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc_zero));
6437
6438 let x = array![[1.0, 2.0, -0.5], [0.3, -0.7, 1.2], [2.0, 0.1, 0.0]];
6439 let offset = array![0.0, 0.0, 0.0];
6440
6441 let pred_raw = predict_gamwith_uncertainty(
6442 x.clone(),
6443 beta.view(),
6444 offset.view(),
6445 crate::types::LikelihoodFamily::GaussianIdentity,
6446 &fit,
6447 &bc_options(false),
6448 )
6449 .expect("raw predict");
6450 let pred_bc = predict_gamwith_uncertainty(
6451 x,
6452 beta.view(),
6453 offset.view(),
6454 crate::types::LikelihoodFamily::GaussianIdentity,
6455 &fit,
6456 &bc_options(true),
6457 )
6458 .expect("bc predict");
6459
6460 for i in 0..3 {
6461 assert!(
6462 (pred_bc.eta[i] - pred_raw.eta[i]).abs() < 1e-15,
6463 "S=0 ⇒ BC must be a no-op; got Δ={} at i={i}",
6464 pred_bc.eta[i] - pred_raw.eta[i]
6465 );
6466 }
6467 }
6468
6469 #[test]
6473 fn test_bias_correction_increases_with_penalty_strength() {
6474 let h_base = array![[3.0_f64, 0.4, 0.1], [0.4, 2.5, 0.2], [0.1, 0.2, 4.0]];
6476 let beta = array![1.2_f64, -0.8, 0.5];
6477 let x = array![[1.0, 0.5, -0.2], [0.3, -0.4, 0.9], [0.7, 0.7, 0.7]];
6478 let offset = array![0.0, 0.0, 0.0];
6479
6480 let lambdas = [0.1_f64, 1.0, 10.0];
6481 let mut deltas = Vec::with_capacity(lambdas.len());
6482 for &lam in &lambdas {
6483 let mut h = h_base.clone();
6485 for k in 0..3 {
6486 h[[k, k]] += lam;
6487 }
6488 let s_beta = beta.mapv(|v| lam * v);
6489 let b_hat = solve_3x3_spd(&h, &s_beta);
6490
6491 let cov = Array2::<f64>::eye(3);
6492 let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(b_hat));
6493
6494 let pred_raw = predict_gamwith_uncertainty(
6495 x.clone(),
6496 beta.view(),
6497 offset.view(),
6498 crate::types::LikelihoodFamily::GaussianIdentity,
6499 &fit,
6500 &bc_options(false),
6501 )
6502 .expect("raw predict");
6503 let pred_bc = predict_gamwith_uncertainty(
6504 x.clone(),
6505 beta.view(),
6506 offset.view(),
6507 crate::types::LikelihoodFamily::GaussianIdentity,
6508 &fit,
6509 &bc_options(true),
6510 )
6511 .expect("bc predict");
6512
6513 let mut sumsq = 0.0;
6514 for i in 0..3 {
6515 let d = pred_bc.eta[i] - pred_raw.eta[i];
6516 sumsq += d * d;
6517 }
6518 deltas.push(sumsq.sqrt());
6519 }
6520
6521 assert!(
6522 deltas[0] < deltas[1],
6523 "‖η_BC − η_raw‖ must grow with λ: λ={} gave {}, λ={} gave {}",
6524 lambdas[0],
6525 deltas[0],
6526 lambdas[1],
6527 deltas[1]
6528 );
6529 assert!(
6530 deltas[1] < deltas[2],
6531 "‖η_BC − η_raw‖ must grow with λ: λ={} gave {}, λ={} gave {}",
6532 lambdas[1],
6533 deltas[1],
6534 lambdas[2],
6535 deltas[2]
6536 );
6537 assert!(
6539 deltas[2] > 10.0 * deltas[0],
6540 "expected order-of-magnitude growth in BC magnitude across λ ∈ {{0.1,1,10}}; got {:?}",
6541 deltas
6542 );
6543 }
6544
6545 #[test]
6553 fn test_bias_correction_recovers_unpenalized_in_simulation() {
6554 let n = 200usize;
6555 let p = 5usize;
6556 let mut rng = Lcg::new(0xC0FFEE_u64);
6557
6558 let mut x_data = vec![0.0_f64; n * p];
6560 for i in 0..n {
6561 x_data[i * p] = 1.0;
6562 for j in 1..p {
6563 x_data[i * p + j] = rng.normal();
6564 }
6565 }
6566 let x = Array2::from_shape_vec((n, p), x_data).expect("X shape");
6567
6568 let beta_true = array![0.5_f64, 1.0, -0.7, 0.3, 0.8];
6570 let mut y = Array1::<f64>::zeros(n);
6571 for i in 0..n {
6572 let mut eta = 0.0;
6573 for j in 0..p {
6574 eta += x[[i, j]] * beta_true[j];
6575 }
6576 y[i] = eta + 0.3 * rng.normal();
6577 }
6578 let xtx = x.t().dot(&x);
6582 let xty = x.t().dot(&y);
6583 let beta_ols = solve_dense_spd(&xtx, &xty);
6584
6585 let shrink = 0.6_f64;
6587 let beta_hat = beta_ols.mapv(|v| shrink * v);
6588
6589 let lambda = 100.0_f64;
6592 let mut h = xtx.clone();
6593 for k in 0..p {
6594 h[[k, k]] += lambda;
6595 }
6596 let s_beta = beta_hat.mapv(|v| lambda * v);
6597 let b_hat = solve_dense_spd(&h, &s_beta);
6598
6599 let cov = Array2::<f64>::eye(p);
6600 let fit = test_fit_with_bias_correction(beta_hat.clone(), cov, Some(b_hat.clone()));
6601
6602 let m = 50usize;
6604 let mut xt_data = vec![0.0_f64; m * p];
6605 for i in 0..m {
6606 xt_data[i * p] = 1.0;
6607 for j in 1..p {
6608 xt_data[i * p + j] = rng.normal();
6609 }
6610 }
6611 let xt = Array2::from_shape_vec((m, p), xt_data).expect("Xtest shape");
6612 let offset = Array1::<f64>::zeros(m);
6613
6614 let pred_raw = predict_gamwith_uncertainty(
6615 xt.clone(),
6616 beta_hat.view(),
6617 offset.view(),
6618 crate::types::LikelihoodFamily::GaussianIdentity,
6619 &fit,
6620 &bc_options(false),
6621 )
6622 .expect("raw predict");
6623 let pred_bc = predict_gamwith_uncertainty(
6624 xt.clone(),
6625 beta_hat.view(),
6626 offset.view(),
6627 crate::types::LikelihoodFamily::GaussianIdentity,
6628 &fit,
6629 &bc_options(true),
6630 )
6631 .expect("bc predict");
6632 let eta_ols = xt.dot(&beta_ols);
6633
6634 let mut closer = 0usize;
6635 for i in 0..m {
6636 let raw_gap = (eta_ols[i] - pred_raw.eta[i]).abs();
6637 let bc_gap = (eta_ols[i] - pred_bc.eta[i]).abs();
6638 if bc_gap < raw_gap {
6639 closer += 1;
6640 }
6641 }
6642 let frac = closer as f64 / m as f64;
6643 assert!(
6644 frac >= 0.9,
6645 "BC must close the OLS gap at ≥90% of test points; got {}/{} = {:.2}",
6646 closer,
6647 m,
6648 frac
6649 );
6650 }
6651
6652 #[test]
6679 fn test_bias_correction_bias_drops_with_n_simulation() {
6680 let p = 4usize;
6681 let beta_true = array![0.4_f64, 0.9, -0.5, 0.6];
6682 let lambda = 5.0_f64;
6683 let ns = [200usize, 1000, 5000];
6684
6685 let m = 32usize;
6687 let mut probe_rng = Lcg::new(424242);
6688 let mut xt_data = vec![0.0_f64; m * p];
6689 for i in 0..m {
6690 xt_data[i * p] = 1.0;
6691 for j in 1..p {
6692 xt_data[i * p + j] = probe_rng.normal();
6693 }
6694 }
6695 let xt = Array2::from_shape_vec((m, p), xt_data).expect("Xtest shape");
6696 let eta_true = xt.dot(&beta_true);
6697 let offset = Array1::<f64>::zeros(m);
6698
6699 let mut mean_abs_raw_bias = [0.0_f64; 3];
6700 let mut mean_abs_bc_bias = [0.0_f64; 3];
6701
6702 let bias_by_n: Vec<(usize, f64, f64)> = (0..ns.len())
6712 .into_par_iter()
6713 .map(|kn| {
6714 let n = ns[kn];
6715 let mut rng = Lcg::new(0xBEEFu64);
6716 let mut x_data = vec![0.0_f64; n * p];
6717 for i in 0..n {
6718 x_data[i * p] = 1.0;
6719 for j in 1..p {
6720 x_data[i * p + j] = rng.normal();
6721 }
6722 }
6723 let x = Array2::from_shape_vec((n, p), x_data).expect("X shape");
6724 let xtx = x.t().dot(&x);
6725 let mut h = xtx.clone();
6726 for k in 0..p {
6727 h[[k, k]] += lambda;
6728 }
6729
6730 let xtx_beta = xtx.dot(&beta_true);
6732 let beta_mean = solve_dense_spd(&h, &xtx_beta);
6733 let s_beta_mean = beta_mean.mapv(|v| lambda * v);
6735 let b_hat = solve_dense_spd(&h, &s_beta_mean);
6736
6737 let cov = Array2::<f64>::eye(p);
6738 let fit = test_fit_with_bias_correction(beta_mean.clone(), cov, Some(b_hat));
6739
6740 let pred_raw = predict_gamwith_uncertainty(
6741 xt.clone(),
6742 beta_mean.view(),
6743 offset.view(),
6744 crate::types::LikelihoodFamily::GaussianIdentity,
6745 &fit,
6746 &bc_options(false),
6747 )
6748 .expect("raw predict");
6749 let pred_bc = predict_gamwith_uncertainty(
6750 xt.clone(),
6751 beta_mean.view(),
6752 offset.view(),
6753 crate::types::LikelihoodFamily::GaussianIdentity,
6754 &fit,
6755 &bc_options(true),
6756 )
6757 .expect("bc predict");
6758
6759 let mut acc_raw = 0.0;
6760 let mut acc_bc = 0.0;
6761 for i in 0..m {
6762 acc_raw += (pred_raw.eta[i] - eta_true[i]).abs();
6763 acc_bc += (pred_bc.eta[i] - eta_true[i]).abs();
6764 }
6765 (kn, acc_raw / m as f64, acc_bc / m as f64)
6766 })
6767 .collect();
6768 for (kn, raw, bc) in bias_by_n {
6769 mean_abs_raw_bias[kn] = raw;
6770 mean_abs_bc_bias[kn] = bc;
6771 }
6772
6773 assert!(
6776 mean_abs_raw_bias[2] < mean_abs_raw_bias[0],
6777 "raw penalized conditional bias should shrink with n: got {:?}",
6778 mean_abs_raw_bias
6779 );
6780 let ratio_large = mean_abs_bc_bias[2] / mean_abs_raw_bias[2].max(1e-300);
6786 assert!(
6787 ratio_large < 0.5,
6788 "BC must reduce conditional bias by >2× at n={}; raw={}, bc={}, ratio={}",
6789 ns[2],
6790 mean_abs_raw_bias[2],
6791 mean_abs_bc_bias[2],
6792 ratio_large
6793 );
6794 let ratio_small = mean_abs_bc_bias[0] / mean_abs_raw_bias[0].max(1e-300);
6796 assert!(
6797 ratio_large <= ratio_small + 1e-6,
6798 "BC/raw ratio should not grow with n: small-n ratio={}, large-n ratio={}",
6799 ratio_small,
6800 ratio_large
6801 );
6802 }
6803
6804 #[test]
6810 fn test_bias_correction_identity_in_basis_change() {
6811 let h = array![[4.0_f64, 0.5, 0.2], [0.5, 3.0, 0.1], [0.2, 0.1, 2.5]];
6813 let s_pen = array![[0.7_f64, 0.1, 0.0], [0.1, 0.5, 0.05], [0.0, 0.05, 1.2]];
6814 let beta = array![0.6_f64, -0.4, 1.1];
6815 let s_beta = s_pen.dot(&beta);
6816 let b_hat = solve_3x3_spd(&h, &s_beta);
6817
6818 let q = array![[1.0_f64, 0.3, -0.2], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]];
6820 let qinv = invert_upper_triangular_3(&q);
6822 let theta = qinv.dot(&beta);
6823 let b_tilde = qinv.dot(&b_hat);
6825
6826 let x_row = array![[0.4_f64, -0.7, 0.9]];
6829 let mut x_tilde = Array2::<f64>::zeros((1, 3));
6830 for j in 0..3 {
6831 let mut acc = 0.0;
6832 for i in 0..3 {
6833 acc += q[[i, j]] * x_row[[0, i]];
6834 }
6835 x_tilde[[0, j]] = acc;
6836 }
6837 let offset = array![0.0_f64];
6838
6839 let cov = Array2::<f64>::eye(3);
6840 let fit_orig = test_fit_with_bias_correction(beta.clone(), cov.clone(), Some(b_hat));
6841 let fit_repar = test_fit_with_bias_correction(theta.clone(), cov, Some(b_tilde));
6842
6843 let pred_orig = predict_gamwith_uncertainty(
6844 x_row,
6845 beta.view(),
6846 offset.view(),
6847 crate::types::LikelihoodFamily::GaussianIdentity,
6848 &fit_orig,
6849 &bc_options(true),
6850 )
6851 .expect("orig predict");
6852 let pred_repar = predict_gamwith_uncertainty(
6853 x_tilde,
6854 theta.view(),
6855 offset.view(),
6856 crate::types::LikelihoodFamily::GaussianIdentity,
6857 &fit_repar,
6858 &bc_options(true),
6859 )
6860 .expect("repar predict");
6861
6862 assert!(
6863 (pred_orig.eta[0] - pred_repar.eta[0]).abs() < 1e-12,
6864 "BC must be invariant under reparameterization: orig η={} repar η={} Δ={}",
6865 pred_orig.eta[0],
6866 pred_repar.eta[0],
6867 (pred_orig.eta[0] - pred_repar.eta[0]).abs()
6868 );
6869 }
6870
6871 #[test]
6876 fn test_bias_correction_does_not_inflate_se() {
6877 let p = 4usize;
6878 let beta = array![0.5_f64, -0.7, 1.1, 0.3];
6879 let cov = array![
6881 [2.0_f64, 0.3, 0.1, 0.0],
6882 [0.3, 1.5, 0.2, 0.05],
6883 [0.1, 0.2, 1.8, 0.1],
6884 [0.0, 0.05, 0.1, 2.2]
6885 ];
6886 let bc = array![0.2_f64, -0.15, 0.05, 0.1];
6887 let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc));
6888
6889 let m = 100usize;
6890 let mut rng = Lcg::new(0xBEEFCAFE_u64);
6891 let mut x_data = vec![0.0_f64; m * p];
6892 for i in 0..m {
6893 for j in 0..p {
6894 x_data[i * p + j] = rng.normal();
6895 }
6896 }
6897 let x = Array2::from_shape_vec((m, p), x_data).expect("X shape");
6898 let offset = Array1::<f64>::zeros(m);
6899
6900 let pred_off = predict_gamwith_uncertainty(
6901 x.clone(),
6902 beta.view(),
6903 offset.view(),
6904 crate::types::LikelihoodFamily::GaussianIdentity,
6905 &fit,
6906 &bc_options(false),
6907 )
6908 .expect("predict no-bc");
6909 let pred_on = predict_gamwith_uncertainty(
6910 x,
6911 beta.view(),
6912 offset.view(),
6913 crate::types::LikelihoodFamily::GaussianIdentity,
6914 &fit,
6915 &bc_options(true),
6916 )
6917 .expect("predict bc");
6918
6919 for i in 0..m {
6920 let a = pred_off.eta_standard_error[i];
6921 let b = pred_on.eta_standard_error[i];
6922 let rel = (a - b).abs() / a.abs().max(b.abs()).max(1e-300);
6923 assert!(
6924 rel < 1e-14,
6925 "SE leakage detected at i={}: off={}, on={}, relΔ={}",
6926 i,
6927 a,
6928 b,
6929 rel
6930 );
6931 }
6932 }
6933
6934 #[test]
6937 fn test_bias_correction_finite_for_pathological_inputs() {
6938 let beta = array![1.0_f64, f64::NAN, 0.5];
6939 let bc = array![0.1_f64, 0.2, f64::INFINITY];
6940 let cov = Array2::<f64>::eye(3);
6941 let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc));
6942
6943 let x = array![[1.0_f64, 1.0, 1.0]];
6944 let offset = array![0.0_f64];
6945 let pred = predict_gamwith_uncertainty(
6946 x,
6947 beta.view(),
6948 offset.view(),
6949 crate::types::LikelihoodFamily::GaussianIdentity,
6950 &fit,
6951 &bc_options(true),
6952 )
6953 .expect("pathological predict should not error, only propagate NaN/Inf");
6954 assert!(
6955 !pred.eta[0].is_finite(),
6956 "expected non-finite η to propagate; got η = {}",
6957 pred.eta[0]
6958 );
6959 }
6960
6961 #[test]
6964 fn test_bias_correction_disabled_via_options_returns_raw() {
6965 let beta = array![1.5_f64, -0.7];
6966 let bc = array![0.4_f64, -0.3];
6967 let cov = Array2::<f64>::eye(2);
6968 let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc.clone()));
6969
6970 let x = array![[1.0_f64, 0.5], [0.7, -0.3]];
6971 let offset = array![0.0_f64, 0.0];
6972 let pred = predict_gamwith_uncertainty(
6973 x.clone(),
6974 beta.view(),
6975 offset.view(),
6976 crate::types::LikelihoodFamily::GaussianIdentity,
6977 &fit,
6978 &bc_options(false),
6979 )
6980 .expect("predict no-bc");
6981
6982 let expected = x.dot(&beta);
6984 for i in 0..2 {
6985 let d = (pred.eta[i] - expected[i]).abs();
6986 assert!(
6987 d < 1e-15,
6988 "apply_bias_correction=false must return raw plug-in: η[{i}]={} expected={} Δ={}",
6989 pred.eta[i],
6990 expected[i],
6991 d
6992 );
6993 }
6994 }
6995
6996 #[test]
7003 fn test_bias_correction_with_nonidentity_covariance_uses_correct_h() {
7004 let h_true = array![[5.0_f64, 0.7, 0.2], [0.7, 4.0, 0.3], [0.2, 0.3, 3.5]];
7006 let s_pen = array![[0.8_f64, 0.0, 0.0], [0.0, 1.2, 0.0], [0.0, 0.0, 0.6]];
7007 let beta = array![0.9_f64, -1.1, 0.4];
7008 let s_beta = s_pen.dot(&beta);
7009 let b_hat_correct = solve_3x3_spd(&h_true, &s_beta);
7010
7011 let cov_wrong = array![[2.0_f64, 0.4, 0.0], [0.4, 1.5, 0.3], [0.0, 0.3, 1.8]];
7015 let h_inv = invert_3x3_spd(&h_true);
7017 let mut diff = 0.0;
7018 for i in 0..3 {
7019 for j in 0..3 {
7020 diff += (h_inv[[i, j]] - cov_wrong[[i, j]]).abs();
7021 }
7022 }
7023 assert!(
7024 diff > 0.5,
7025 "test setup error: cov_wrong should be far from H_true⁻¹ (diff={})",
7026 diff
7027 );
7028
7029 let fit =
7033 test_fit_with_bias_correction(beta.clone(), cov_wrong, Some(b_hat_correct.clone()));
7034
7035 let x = array![[1.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
7036 let offset = array![0.0_f64, 0.0, 0.0];
7037 let pred = predict_gamwith_uncertainty(
7038 x,
7039 beta.view(),
7040 offset.view(),
7041 crate::types::LikelihoodFamily::GaussianIdentity,
7042 &fit,
7043 &bc_options(true),
7044 )
7045 .expect("predict bc");
7046
7047 for i in 0..3 {
7048 let expected = beta[i] + b_hat_correct[i];
7049 assert!(
7050 (pred.eta[i] - expected).abs() < 1e-12,
7051 "prediction must use the supplied bias_correction_beta verbatim: \
7052 η[{i}]={} expected={} (β+b̂_correct[{i}]={})",
7053 pred.eta[i],
7054 expected,
7055 b_hat_correct[i]
7056 );
7057 }
7058 }
7059
7060 #[test]
7063 fn test_bias_correction_propagates_through_unified_fit_result() {
7064 let beta = array![0.7_f64, -0.4, 1.2];
7065 let bc = array![0.123456789_f64, -0.987654321, 0.5];
7066 let cov = Array2::<f64>::eye(3);
7067 let fit = test_fit_with_bias_correction(beta, cov, Some(bc.clone()));
7068
7069 let json = serde_json::to_string(&fit).expect("serialize unified fit");
7070 let decoded: UnifiedFitResult =
7071 serde_json::from_str(&json).expect("deserialize unified fit");
7072 let recovered = decoded
7073 .bias_correction_beta()
7074 .expect("bias_correction_beta must survive JSON round-trip");
7075 assert_eq!(
7076 recovered.len(),
7077 bc.len(),
7078 "bc length changed across round-trip"
7079 );
7080 for i in 0..bc.len() {
7081 assert!(
7082 (recovered[i] - bc[i]).abs() < 1e-15,
7083 "bc[{i}] drifted across JSON round-trip: in={}, out={}",
7084 bc[i],
7085 recovered[i]
7086 );
7087 }
7088 }
7089
7090 fn solve_dense_spd(h: &Array2<f64>, r: &Array1<f64>) -> Array1<f64> {
7096 let n = h.nrows();
7097 assert_eq!(h.ncols(), n);
7098 assert_eq!(r.len(), n);
7099 let mut a = Array2::<f64>::zeros((n, n + 1));
7100 for i in 0..n {
7101 for j in 0..n {
7102 a[[i, j]] = h[[i, j]];
7103 }
7104 a[[i, n]] = r[i];
7105 }
7106 for k in 0..n {
7107 let mut piv = k;
7109 let mut best = a[[k, k]].abs();
7110 for i in (k + 1)..n {
7111 if a[[i, k]].abs() > best {
7112 best = a[[i, k]].abs();
7113 piv = i;
7114 }
7115 }
7116 assert!(best > 1e-14, "near-singular system in solve_dense_spd");
7117 if piv != k {
7118 for j in 0..=n {
7119 let tmp = a[[k, j]];
7120 a[[k, j]] = a[[piv, j]];
7121 a[[piv, j]] = tmp;
7122 }
7123 }
7124 for i in (k + 1)..n {
7125 let factor = a[[i, k]] / a[[k, k]];
7126 for j in k..=n {
7127 a[[i, j]] -= factor * a[[k, j]];
7128 }
7129 }
7130 }
7131 let mut y = Array1::<f64>::zeros(n);
7132 for i in (0..n).rev() {
7133 let mut acc = a[[i, n]];
7134 for j in (i + 1)..n {
7135 acc -= a[[i, j]] * y[j];
7136 }
7137 y[i] = acc / a[[i, i]];
7138 }
7139 y
7140 }
7141
7142 fn invert_3x3_spd(h: &Array2<f64>) -> Array2<f64> {
7144 let mut out = Array2::<f64>::zeros((3, 3));
7145 for col in 0..3 {
7146 let mut e = Array1::<f64>::zeros(3);
7147 e[col] = 1.0;
7148 let v = solve_3x3_spd(h, &e);
7149 for row in 0..3 {
7150 out[[row, col]] = v[row];
7151 }
7152 }
7153 out
7154 }
7155
7156 fn invert_upper_triangular_3(q: &Array2<f64>) -> Array2<f64> {
7158 let a = q[[0, 1]];
7164 let b = q[[0, 2]];
7165 let c = q[[1, 2]];
7166 array![[1.0, -a, a * c - b], [0.0, 1.0, -c], [0.0, 0.0, 1.0]]
7167 }
7168
7169 fn coverage_correction_fixture() -> (UnifiedFitResult, Array2<f64>, Array1<f64>, Array1<f64>) {
7175 let beta = array![1.0];
7176 let cov = array![[0.25_f64]];
7177 let fit = test_fit_with_bias_correction(beta.clone(), cov.clone(), None);
7178 let x = array![[1.0_f64]];
7180 let offset = array![0.0_f64];
7181 (fit, x, beta, offset)
7182 }
7183
7184 fn corrections_baseline_options() -> PredictUncertaintyOptions {
7185 PredictUncertaintyOptions {
7186 confidence_level: 0.95,
7187 covariance_mode: InferenceCovarianceMode::Conditional,
7188 mean_interval_method: MeanIntervalMethod::TransformEta,
7189 includeobservation_interval: false,
7190 apply_bias_correction: false,
7191 edgeworth_one_sided: false,
7193 boundary_correction: false,
7194 ood_inflation: false,
7195 multi_point_joint: false,
7196 ..PredictUncertaintyOptions::default()
7197 }
7198 }
7199
7200 #[test]
7201 fn coverage_corrections_all_off_matches_legacy() {
7202 let (fit, x, beta, offset) = coverage_correction_fixture();
7206 let opts = corrections_baseline_options();
7207 let pred = predict_gamwith_uncertainty(
7208 x.view(),
7209 beta.view(),
7210 offset.view(),
7211 crate::types::LikelihoodFamily::GaussianIdentity,
7212 &fit,
7213 &opts,
7214 )
7215 .expect("prediction baseline");
7216
7217 let z = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
7218 let expected_se = (0.25_f64).sqrt();
7219 assert!((pred.eta_standard_error[0] - expected_se).abs() <= 1e-12);
7220 let expected_lower = 1.0 - z * expected_se;
7221 let expected_upper = 1.0 + z * expected_se;
7222 assert!(
7223 (pred.eta_lower[0] - expected_lower).abs() <= 1e-12,
7224 "baseline lower drifted: got {}, expected {}",
7225 pred.eta_lower[0],
7226 expected_lower
7227 );
7228 assert!(
7229 (pred.eta_upper[0] - expected_upper).abs() <= 1e-12,
7230 "baseline upper drifted: got {}, expected {}",
7231 pred.eta_upper[0],
7232 expected_upper
7233 );
7234 }
7235
7236 #[test]
7237 fn edgeworth_one_sided_makes_interval_asymmetric_with_positive_skew() {
7238 let (fit, x, beta, offset) = coverage_correction_fixture();
7239 let mut opts = corrections_baseline_options();
7240 opts.edgeworth_one_sided = true;
7241 opts.eta_skewness_for_corrections = Some(array![0.6_f64]);
7242
7243 let pred = predict_gamwith_uncertainty(
7244 x.view(),
7245 beta.view(),
7246 offset.view(),
7247 crate::types::LikelihoodFamily::GaussianIdentity,
7248 &fit,
7249 &opts,
7250 )
7251 .expect("edgeworth prediction");
7252
7253 let dist_upper = pred.eta_upper[0] - 1.0;
7258 let dist_lower = 1.0 - pred.eta_lower[0];
7259 assert!(
7260 dist_upper > dist_lower + 1e-9,
7261 "positive skew should push upper tail further than lower: \
7262 upper-dist={dist_upper}, lower-dist={dist_lower}"
7263 );
7264 opts.eta_skewness_for_corrections = Some(array![0.0_f64]);
7266 let pred_sym = predict_gamwith_uncertainty(
7267 x.view(),
7268 beta.view(),
7269 offset.view(),
7270 crate::types::LikelihoodFamily::GaussianIdentity,
7271 &fit,
7272 &opts,
7273 )
7274 .expect("edgeworth zero-skew prediction");
7275 let sym_upper = pred_sym.eta_upper[0] - 1.0;
7276 let sym_lower = 1.0 - pred_sym.eta_lower[0];
7277 assert!((sym_upper - sym_lower).abs() <= 1e-12);
7278 }
7279
7280 #[test]
7281 fn boundary_correction_widens_interval_near_edge() {
7282 let beta = array![1.0_f64];
7288 let cov = array![[0.25_f64]];
7289 let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
7290 let x = array![[1.0_f64], [1.0_f64]];
7291 let offset = array![0.0_f64, 0.0_f64];
7292
7293 let mut opts = corrections_baseline_options();
7294 opts.boundary_correction = true;
7295 opts.predictor_x_for_corrections = Some(array![[5.0_f64], [9.9_f64]]);
7296 opts.training_support = Some(TrainingSupport {
7297 axis_min: array![0.0_f64],
7298 axis_max: array![10.0_f64],
7299 });
7300
7301 let pred = predict_gamwith_uncertainty(
7302 x.view(),
7303 beta.view(),
7304 offset.view(),
7305 crate::types::LikelihoodFamily::GaussianIdentity,
7306 &fit,
7307 &opts,
7308 )
7309 .expect("boundary-corrected prediction");
7310
7311 let baseline_se = (0.25_f64).sqrt();
7312 assert!(
7314 (pred.eta_standard_error[0] - baseline_se).abs() <= 1e-12,
7315 "interior row must not be inflated: {} vs {}",
7316 pred.eta_standard_error[0],
7317 baseline_se
7318 );
7319 assert!(
7321 pred.eta_standard_error[1] > baseline_se + 1e-9,
7322 "near-edge row must be inflated: got {}, baseline {}",
7323 pred.eta_standard_error[1],
7324 baseline_se
7325 );
7326 let width0 = pred.eta_upper[0] - pred.eta_lower[0];
7328 let width1 = pred.eta_upper[1] - pred.eta_lower[1];
7329 assert!(
7330 width1 > width0 + 1e-9,
7331 "near-edge interval not wider: width0={width0}, width1={width1}"
7332 );
7333 }
7334
7335 #[test]
7336 fn ood_inflation_widens_interval_outside_support() {
7337 let beta = array![1.0_f64];
7338 let cov = array![[0.25_f64]];
7339 let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
7340 let x = array![[1.0_f64], [1.0_f64]];
7341 let offset = array![0.0_f64, 0.0_f64];
7342
7343 let mut opts = corrections_baseline_options();
7346 opts.ood_inflation = true;
7347 opts.predictor_x_for_corrections = Some(array![[5.0_f64], [15.0_f64]]);
7348 opts.training_support = Some(TrainingSupport {
7349 axis_min: array![0.0_f64],
7350 axis_max: array![10.0_f64],
7351 });
7352
7353 let pred = predict_gamwith_uncertainty(
7354 x.view(),
7355 beta.view(),
7356 offset.view(),
7357 crate::types::LikelihoodFamily::GaussianIdentity,
7358 &fit,
7359 &opts,
7360 )
7361 .expect("ood-inflated prediction");
7362
7363 let baseline_se = (0.25_f64).sqrt();
7364 assert!((pred.eta_standard_error[0] - baseline_se).abs() <= 1e-12);
7365 let expected = (0.25_f64 * 1.25).sqrt();
7368 assert!(
7369 (pred.eta_standard_error[1] - expected).abs() <= 1e-12,
7370 "ood inflation factor wrong: got {}, expected {}",
7371 pred.eta_standard_error[1],
7372 expected
7373 );
7374 assert!(pred.eta_standard_error[1] > baseline_se);
7375 }
7376
7377 #[test]
7378 fn multi_point_joint_widens_interval_relative_to_per_row() {
7379 let beta = array![1.0_f64];
7380 let cov = array![[0.25_f64]];
7381 let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
7382 let x = Array2::<f64>::from_elem((5, 1), 1.0_f64);
7385 let offset = Array1::zeros(5);
7386 let mut opts = corrections_baseline_options();
7387 opts.multi_point_joint = true;
7388 let pred = predict_gamwith_uncertainty(
7391 x.view(),
7392 beta.view(),
7393 offset.view(),
7394 crate::types::LikelihoodFamily::GaussianIdentity,
7395 &fit,
7396 &opts,
7397 )
7398 .expect("joint-adjusted prediction");
7399
7400 let z_per_row = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
7401 let z_joint = standard_normal_quantile(0.5 + 0.5 * (1.0 - 0.05_f64 / 5.0)).unwrap();
7402 assert!(
7403 z_joint > z_per_row + 1e-6,
7404 "Bonferroni z must exceed per-row z: joint={z_joint}, per-row={z_per_row}"
7405 );
7406 let baseline_se = (0.25_f64).sqrt();
7407 for i in 0..5 {
7409 let width = pred.eta_upper[i] - pred.eta_lower[i];
7410 let expected = 2.0 * z_joint * baseline_se;
7411 assert!(
7412 (width - expected).abs() <= 1e-12,
7413 "joint row {i} width mismatch: got {width}, expected {expected}"
7414 );
7415 }
7416 }
7417
7418 #[test]
7419 fn edgeworth_helper_zero_skew_returns_central_z() {
7420 let z = 1.96_f64;
7421 let adj = edgeworth_one_sided_quantile(z, 0.0);
7422 assert!((adj.z_lower - z).abs() <= 1e-12);
7423 assert!((adj.z_upper - z).abs() <= 1e-12);
7424 }
7425
7426 #[test]
7427 fn boundary_helper_returns_one_in_interior() {
7428 let f = boundary_variance_inflation_factor(
7429 array![5.0_f64].view(),
7430 array![0.0_f64].view(),
7431 array![10.0_f64].view(),
7432 0.25,
7433 0.05,
7434 );
7435 assert!((f - 1.0).abs() <= 1e-12);
7436 }
7437
7438 #[test]
7439 fn ood_helper_returns_one_inside_box() {
7440 let f = ood_variance_inflation_factor(
7441 array![5.0_f64].view(),
7442 array![0.0_f64].view(),
7443 array![10.0_f64].view(),
7444 1.0,
7445 );
7446 assert!((f - 1.0).abs() <= 1e-12);
7447 }
7448
7449 #[test]
7450 fn multi_point_joint_z_passthrough_at_m_one() {
7451 let z1 = multi_point_joint_z(0.95, 1).unwrap();
7452 let z_baseline = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
7453 assert!((z1 - z_baseline).abs() <= 1e-12);
7454 }
7455}