1use crate::bms::{
2 EmpiricalZGrid, LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
3 bernoulli_marginal_link_map, empirical_intercept_from_marginal,
4};
5use crate::marginal_slope_shared::{
6 ObservedDenestedCellPartials, eval_coeff4_at,
7 probit_frailty_scale as marginal_slope_probit_frailty_scale, scale_coeff4,
8};
9use crate::survival::lognormal_kernel::FrailtySpec;
10use crate::inference::model::{SavedCompiledFlexBlock, SavedLatentZNormalization};
11use gam_linalg::matrix::DesignMatrix;
12use gam_math::probability::{normal_cdf, normal_pdf};
13use gam_solve::estimate::{EstimationError, UnifiedFitResult};
14use gam_problem::types::{InverseLink, LikelihoodSpec};
15use ndarray::{Array1, Array2, ArrayView1};
16use rayon::iter::{IntoParallelIterator, ParallelIterator};
17
18pub struct PredictResult {
19 pub eta: Array1<f64>,
20 pub mean: Array1<f64>,
21}
22
23pub struct PredictInput {
26 pub design: DesignMatrix,
28 pub offset: Array1<f64>,
30 pub design_noise: Option<DesignMatrix>,
32 pub offset_noise: Option<Array1<f64>>,
34 pub auxiliary_scalar: Option<Array1<f64>>,
36 pub auxiliary_matrix: Option<Array2<f64>>,
38}
39
40pub struct BernoulliMarginalSlopePredictor {
41 pub beta_marginal: Array1<f64>,
42 pub beta_logslope: Array1<f64>,
43 pub beta_score_warp: Option<Array1<f64>>,
44 pub beta_link_dev: Option<Array1<f64>>,
45 pub base_link: InverseLink,
46 pub z_column: String,
47 pub latent_z_normalization: SavedLatentZNormalization,
48 pub latent_measure: LatentMeasureKind,
49 pub baseline_marginal: f64,
50 pub baseline_logslope: f64,
51 pub covariance: Option<Array2<f64>>,
52 pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
53 pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
54 pub gaussian_frailty_sd: Option<f64>,
55 pub latent_z_calibration: Option<LatentZRankIntCalibration>,
56 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
57}
58
59fn prediction_chunk_rows(parameter_dim: usize, local_dim: usize, n_rows: usize) -> usize {
60 const PREDICTION_TARGET_WORK_BYTES: usize = 2 * 1024 * 1024;
61 const PREDICTION_MIN_CHUNK_ROWS: usize = 16;
62 const PREDICTION_MAX_CHUNK_ROWS: usize = 4096;
63 if n_rows == 0 {
64 return 1;
65 }
66 let bytes_per_row = parameter_dim
67 .max(1)
68 .saturating_mul(local_dim.max(1))
69 .saturating_mul(std::mem::size_of::<f64>())
70 .saturating_mul(4);
71 let target_rows = if bytes_per_row == 0 {
72 n_rows
73 } else {
74 PREDICTION_TARGET_WORK_BYTES / bytes_per_row
75 };
76 target_rows
77 .max(PREDICTION_MIN_CHUNK_ROWS)
78 .min(PREDICTION_MAX_CHUNK_ROWS)
79 .min(n_rows.max(1))
80}
81
82#[derive(Default)]
101struct BmsAnchorCorrections {
102 score_warp_anchor_rows: Option<Array2<f64>>,
105 link_dev_anchor_rows: Option<Array2<f64>>,
111 score_warp: Option<Array2<f64>>,
112 link_dev: Option<Array2<f64>>,
113}
114
115impl BmsAnchorCorrections {
116 fn score_warp_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
117 self.score_warp.as_ref().map(|m| m.row(row))
118 }
119
120 fn link_dev_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
121 self.link_dev.as_ref().map(|m| m.row(row))
122 }
123
124 fn score_warp_anchor_rows_view(&self) -> Option<ndarray::ArrayView2<'_, f64>> {
125 self.score_warp_anchor_rows.as_ref().map(|m| m.view())
126 }
127
128 fn link_dev_anchor_rows_view(&self) -> Option<ndarray::ArrayView2<'_, f64>> {
129 self.link_dev_anchor_rows.as_ref().map(|m| m.view())
130 }
131}
132
133impl BernoulliMarginalSlopePredictor {
134 fn build_anchor_correction_matrices(
143 &self,
144 input: &PredictInput,
145 design_logslope: &DesignMatrix,
146 z: &Array1<f64>,
147 ) -> Result<BmsAnchorCorrections, EstimationError> {
148 use crate::inference::model::SavedAnchorKind;
149 let needs_score = self
150 .score_warp_runtime
151 .as_ref()
152 .is_some_and(|r| r.anchor_correction.is_some());
153 let needs_link = self
154 .link_deviation_runtime
155 .as_ref()
156 .is_some_and(|r| r.anchor_correction.is_some());
157 if !needs_score && !needs_link {
158 return Ok(BmsAnchorCorrections::default());
159 }
160 let marginal_dense = input
165 .design
166 .try_to_dense_arc(
167 "bernoulli marginal-slope predict-time marginal anchor materialisation",
168 )
169 .map_err(EstimationError::InvalidInput)?;
170 let logslope_dense = design_logslope
171 .try_to_dense_arc(
172 "bernoulli marginal-slope predict-time logslope anchor materialisation",
173 )
174 .map_err(EstimationError::InvalidInput)?;
175 let n_rows = marginal_dense.nrows();
176 if logslope_dense.nrows() != n_rows {
177 return Err(EstimationError::InvalidInput(format!(
178 "bernoulli marginal-slope predict anchor materialisation row mismatch: marginal {} vs logslope {}",
179 n_rows,
180 logslope_dense.nrows()
181 )));
182 }
183 if z.len() != n_rows {
184 return Err(EstimationError::InvalidInput(format!(
185 "bernoulli marginal-slope predict anchor materialisation: z has {} entries, expected {}",
186 z.len(),
187 n_rows
188 )));
189 }
190 let p_marginal = marginal_dense.ncols();
191 let p_logslope = logslope_dense.ncols();
192 let d_parametric = p_marginal + p_logslope;
193 let mut parametric_rows = Array2::<f64>::zeros((n_rows, d_parametric));
194 parametric_rows
195 .slice_mut(ndarray::s![.., 0..p_marginal])
196 .assign(&marginal_dense.view());
197 parametric_rows
198 .slice_mut(ndarray::s![.., p_marginal..d_parametric])
199 .assign(&logslope_dense.view());
200
201 let score_warp = if needs_score {
204 let runtime = self.score_warp_runtime.as_ref().unwrap();
205 self.validate_runtime_anchor_layout_parametric_only(runtime, "score_warp")?;
206 runtime
207 .anchor_correction_matrix(parametric_rows.view())
208 .map_err(EstimationError::from)?
209 } else {
210 None
211 };
212
213 let (link_dev_anchor_rows, link_dev) = if needs_link {
218 let runtime = self.link_deviation_runtime.as_ref().unwrap();
219 let mut saw_flex_tail = false;
224 let mut flex_tail_ncols: usize = 0;
225 for (idx, component) in runtime.anchor_components.iter().enumerate() {
226 match &component.kind {
227 SavedAnchorKind::Parametric { .. } => {
228 if saw_flex_tail {
229 return Err(EstimationError::InvalidInput(format!(
230 "bernoulli marginal-slope link-deviation saved anchor components \
231 are out of order: parametric component at index {idx} follows \
232 a FlexEvaluation tail",
233 )));
234 }
235 }
236 SavedAnchorKind::FlexEvaluation { ncols } => {
237 if saw_flex_tail {
238 return Err(EstimationError::InvalidInput(
239 "bernoulli marginal-slope link-deviation saved anchor components \
240 carry more than one FlexEvaluation tail; fit-time stacking emits \
241 at most one (score-warp)"
242 .to_string(),
243 ));
244 }
245 saw_flex_tail = true;
246 flex_tail_ncols = *ncols;
247 }
248 }
249 }
250 let rows = if saw_flex_tail {
251 let score_runtime = self.score_warp_runtime.as_ref().ok_or_else(|| {
252 EstimationError::InvalidInput(
253 "bernoulli marginal-slope link-deviation saved anchor includes a \
254 FlexEvaluation tail but the saved score-warp runtime is missing"
255 .to_string(),
256 )
257 })?;
258 let score_basis = if score_runtime.anchor_correction.is_some() {
264 score_runtime
265 .design_with_anchor_rows(z, parametric_rows.view())
266 .map_err(EstimationError::from)?
267 } else {
268 score_runtime.design(z).map_err(EstimationError::from)?
269 };
270 if score_basis.ncols() != flex_tail_ncols {
271 return Err(EstimationError::InvalidInput(format!(
272 "bernoulli marginal-slope link-deviation FlexEvaluation tail expects \
273 {} score-warp basis columns at predict rows, got {}",
274 flex_tail_ncols,
275 score_basis.ncols()
276 )));
277 }
278 let mut combined = Array2::<f64>::zeros((n_rows, d_parametric + flex_tail_ncols));
279 combined
280 .slice_mut(ndarray::s![.., 0..d_parametric])
281 .assign(¶metric_rows.view());
282 combined
283 .slice_mut(ndarray::s![.., d_parametric..])
284 .assign(&score_basis.view());
285 combined
286 } else {
287 parametric_rows.clone()
288 };
289 let corr = runtime
290 .anchor_correction_matrix(rows.view())
291 .map_err(EstimationError::from)?;
292 (Some(rows), corr)
293 } else {
294 (None, None)
295 };
296
297 Ok(BmsAnchorCorrections {
298 score_warp_anchor_rows: Some(parametric_rows),
299 link_dev_anchor_rows,
300 score_warp,
301 link_dev,
302 })
303 }
304
305 fn validate_runtime_anchor_layout_parametric_only(
309 &self,
310 runtime: &SavedCompiledFlexBlock,
311 runtime_label: &str,
312 ) -> Result<(), EstimationError> {
313 use crate::inference::model::SavedAnchorKind;
314 for (idx, component) in runtime.anchor_components.iter().enumerate() {
315 match &component.kind {
316 SavedAnchorKind::Parametric { .. } => {}
317 SavedAnchorKind::FlexEvaluation { .. } => {
318 return Err(EstimationError::InvalidInput(format!(
319 "bernoulli marginal-slope {runtime_label} saved anchor component at \
320 index {idx} is FlexEvaluation; only Parametric components are \
321 expected for this runtime",
322 )));
323 }
324 }
325 }
326 Ok(())
327 }
328
329 pub fn likelihood_family(&self) -> LikelihoodSpec {
330 LikelihoodSpec::binomial_probit()
331 }
332
333 pub fn mean_from_eta(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
334 Ok(eta.mapv(normal_cdf))
335 }
336
337 pub fn mean_derivative_from_eta(
338 &self,
339 eta: &Array1<f64>,
340 ) -> Result<Array1<f64>, EstimationError> {
341 Ok(eta.mapv(normal_pdf))
342 }
343
344 pub(crate) fn probit_frailty_scale(&self) -> f64 {
345 marginal_slope_probit_frailty_scale(self.gaussian_frailty_sd)
346 }
347
348 fn apply_latent_z_calibration(&self, z: &Array1<f64>) -> Array1<f64> {
366 match &self.latent_z_calibration {
367 Some(cal) => Array1::from_iter(z.iter().map(|&zi| cal.apply_at_predict(zi))),
368 None => z.clone(),
369 }
370 }
371
372 fn apply_latent_z_conditional_calibration(
383 &self,
384 z: &Array1<f64>,
385 input: &PredictInput,
386 ) -> Result<Array1<f64>, EstimationError> {
387 let Some(cal) = self.latent_z_conditional_calibration.as_ref() else {
388 return Ok(z.clone());
389 };
390 let a_block = input.design.to_dense();
391 cal.apply(z.view(), a_block.view())
392 .map_err(EstimationError::InvalidInput)
393 }
394
395 fn rigid_intercept_from_marginal(&self, marginal_eta: f64, slope: f64) -> f64 {
396 let probit_scale = self.probit_frailty_scale();
397 marginal_eta * (1.0 + (probit_scale * slope).powi(2)).sqrt() / probit_scale
398 }
399
400 fn empirical_rigid_intercept_and_gradient(
401 &self,
402 marginal_eta: f64,
403 slope: f64,
404 nodes: &[f64],
405 weights: &[f64],
406 ) -> Result<(f64, f64, f64), EstimationError> {
407 let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
408 .map_err(EstimationError::InvalidInput)?;
409 let scale = self.probit_frailty_scale();
410 let intercept = empirical_intercept_from_marginal(
411 marginal.mu,
412 marginal.q,
413 slope,
414 scale,
415 nodes,
416 weights,
417 None,
418 )
419 .map_err(EstimationError::InvalidInput)?;
420 let observed_slope = scale * slope;
421 let mut f_a = 0.0;
422 let mut f_b = 0.0;
423 for (&node, &weight) in nodes.iter().zip(weights.iter()) {
424 let eta = intercept + observed_slope * node;
425 let pdf = normal_pdf(eta);
426 f_a += weight * pdf;
427 f_b += weight * pdf * scale * node;
428 }
429 if !(f_a.is_finite() && f_a > 0.0 && f_b.is_finite()) {
430 return Err(EstimationError::InvalidInput(format!(
431 "empirical latent prediction calibration derivative is invalid: F_a={f_a}, F_b={f_b}"
432 )));
433 }
434 let a_marginal_eta = marginal.mu1 / f_a;
435 let a_slope = -f_b / f_a;
436 Ok((intercept, a_marginal_eta, a_slope))
437 }
438
439 fn local_empirical_mixture_for_point(
440 point: &[f64],
441 centers: &[Vec<f64>],
442 top_k: usize,
443 bandwidth: f64,
444 ) -> Result<Vec<(usize, f64)>, EstimationError> {
445 if centers.is_empty() {
446 return Err(EstimationError::InvalidInput(
447 "local empirical latent prediction has no centers".to_string(),
448 ));
449 }
450 if top_k == 0 {
451 return Err(EstimationError::InvalidInput(
452 "local empirical latent prediction top_k must be positive".to_string(),
453 ));
454 }
455 if !(bandwidth.is_finite() && bandwidth > 0.0) {
456 return Err(EstimationError::InvalidInput(format!(
457 "local empirical latent prediction bandwidth must be finite and positive, got {bandwidth}"
458 )));
459 }
460 let bw2 = bandwidth * bandwidth;
461 let mut distances = Vec::<(usize, f64)>::with_capacity(centers.len());
462 for (idx, center) in centers.iter().enumerate() {
463 if center.len() != point.len() {
464 return Err(EstimationError::InvalidInput(format!(
465 "local empirical latent prediction center {idx} dimension mismatch: center={}, point={}",
466 center.len(),
467 point.len()
468 )));
469 }
470 let d2 = center
471 .iter()
472 .zip(point.iter())
473 .map(|(&c, &x)| {
474 let delta = x - c;
475 delta * delta
476 })
477 .sum::<f64>();
478 if !d2.is_finite() {
479 return Err(EstimationError::InvalidInput(
480 "local empirical latent prediction distance is non-finite".to_string(),
481 ));
482 }
483 distances.push((idx, d2));
484 }
485 distances.sort_by(|left, right| {
486 left.1
487 .partial_cmp(&right.1)
488 .expect("validated local empirical distances are finite")
489 });
490 let k = top_k.min(distances.len());
491 let mut mixture = Vec::with_capacity(k);
492 let mut total = 0.0;
493 for &(idx, d2) in distances.iter().take(k) {
494 let weight = (-0.5 * d2 / bw2).exp().max(1e-300);
495 mixture.push((idx, weight));
496 total += weight;
497 }
498 if !(total.is_finite() && total > 0.0) {
499 return Err(EstimationError::InvalidInput(
500 "local empirical latent prediction mixture has non-positive total weight"
501 .to_string(),
502 ));
503 }
504 for (_, weight) in &mut mixture {
505 *weight /= total;
506 }
507 Ok(mixture)
508 }
509
510 fn combine_empirical_grids(
511 grids: &[EmpiricalZGrid],
512 mixture: &[(usize, f64)],
513 ) -> Result<EmpiricalZGrid, EstimationError> {
514 let total_len = mixture
515 .iter()
516 .map(|&(idx, _)| grids.get(idx).map_or(0, |grid| grid.nodes.len()))
517 .sum::<usize>();
518 let mut nodes = Vec::with_capacity(total_len);
519 let mut weights = Vec::with_capacity(total_len);
520 let mut total_weight = 0.0;
521 for &(grid_idx, grid_weight) in mixture {
522 if !(grid_weight.is_finite() && grid_weight >= 0.0) {
523 return Err(EstimationError::InvalidInput(format!(
524 "local empirical latent prediction mixture weight must be finite and non-negative, got {grid_weight}"
525 )));
526 }
527 let grid = grids.get(grid_idx).ok_or_else(|| {
528 EstimationError::InvalidInput(format!(
529 "local empirical latent prediction grid index {grid_idx} is out of bounds for {} grids",
530 grids.len()
531 ))
532 })?;
533 if grid.nodes.len() != grid.weights.len() || grid.nodes.is_empty() {
534 return Err(EstimationError::InvalidInput(format!(
535 "local empirical latent prediction grid {grid_idx} is invalid: nodes={}, weights={}",
536 grid.nodes.len(),
537 grid.weights.len()
538 )));
539 }
540 for (node, weight) in grid.pairs() {
541 let combined_weight = grid_weight * weight;
542 if !(node.is_finite() && combined_weight.is_finite() && combined_weight >= 0.0) {
543 return Err(EstimationError::InvalidInput(
544 "local empirical latent prediction grid contains invalid node/weight"
545 .to_string(),
546 ));
547 }
548 nodes.push(node);
549 weights.push(combined_weight);
550 total_weight += combined_weight;
551 }
552 }
553 if !(total_weight.is_finite() && total_weight > 0.0) {
554 return Err(EstimationError::InvalidInput(
555 "local empirical latent prediction combined grid has non-positive total weight"
556 .to_string(),
557 ));
558 }
559 for weight in &mut weights {
560 *weight /= total_weight;
561 }
562 Ok(EmpiricalZGrid { nodes, weights })
563 }
564
565 fn empirical_grid_for_prediction_row(
566 &self,
567 input: &PredictInput,
568 row: usize,
569 ) -> Result<Option<EmpiricalZGrid>, EstimationError> {
570 match &self.latent_measure {
571 LatentMeasureKind::StandardNormal => Ok(None),
572 LatentMeasureKind::GlobalEmpirical { grid } => Ok(Some(grid.clone())),
573 LatentMeasureKind::LocalEmpirical {
574 centers,
575 grids,
576 top_k,
577 bandwidth,
578 ..
579 } => {
580 let conditioning = input.auxiliary_matrix.as_ref().ok_or_else(|| {
581 EstimationError::InvalidInput(
582 "bernoulli marginal-slope local empirical prediction requires auxiliary conditioning matrix"
583 .to_string(),
584 )
585 })?;
586 if row >= conditioning.nrows() {
587 return Err(EstimationError::InvalidInput(format!(
588 "local empirical latent prediction row {row} is out of bounds for {} conditioning rows",
589 conditioning.nrows()
590 )));
591 }
592 let expected_dim = centers.first().map_or(0, Vec::len);
593 if conditioning.ncols() != expected_dim {
594 return Err(EstimationError::InvalidInput(format!(
595 "local empirical latent prediction conditioning dimension mismatch: got {}, expected {expected_dim}",
596 conditioning.ncols()
597 )));
598 }
599 let point = conditioning.row(row).to_vec();
600 let mixture =
601 Self::local_empirical_mixture_for_point(&point, centers, *top_k, *bandwidth)?;
602 Self::combine_empirical_grids(grids, &mixture).map(Some)
603 }
604 }
605 }
606
607 fn transform_internal_eta_to_base_scale(
608 &self,
609 internal_eta: Array1<f64>,
610 internal_grad: Option<Array2<f64>>,
611 ) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
612 Ok((internal_eta, internal_grad))
613 }
614
615 fn link_terms_value_d1(
616 &self,
617 eta0: &Array1<f64>,
618 beta_link_dev: Option<&Array1<f64>>,
619 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
620 ) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
621 if let (Some(runtime), Some(beta)) = (&self.link_deviation_runtime, beta_link_dev) {
622 let basis = runtime
630 .design_uncorrected(eta0)
631 .map_err(EstimationError::from)?;
632 let mut value = &basis.dot(beta) + eta0;
633 if let Some(corr) = link_dev_correction_for_row {
634 let offset = corr.dot(beta);
635 for v in value.iter_mut() {
636 *v -= offset;
637 }
638 } else if runtime.anchor_correction.is_some() {
639 return Err(EstimationError::InvalidInput(
640 "bernoulli marginal-slope link-deviation runtime has an anchor residual but \
641 no per-row correction was supplied to link_terms_value_d1"
642 .to_string(),
643 ));
644 }
645 let d1 = runtime
646 .first_derivative_design(eta0)
647 .map_err(EstimationError::from)?;
648 Ok((value, d1.dot(beta) + 1.0))
649 } else {
650 Ok((eta0.clone(), Array1::ones(eta0.len())))
651 }
652 }
653
654 fn denested_partition_cells(
655 &self,
656 a: f64,
657 b: f64,
658 beta_score_warp: Option<&Array1<f64>>,
659 beta_link_dev: Option<&Array1<f64>>,
660 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
661 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
662 ) -> Result<Vec<crate::cubic_cell_kernel::DenestedPartitionCell>, EstimationError> {
663 let score_breaks = if let Some(runtime) = self.score_warp_runtime.as_ref() {
664 runtime.breakpoints().map_err(EstimationError::from)?
665 } else {
666 Vec::new()
667 };
668 let link_breaks = if let Some(runtime) = self.link_deviation_runtime.as_ref() {
669 runtime.breakpoints().map_err(EstimationError::from)?
670 } else {
671 Vec::new()
672 };
673 let mut cells =
674 crate::cubic_cell_kernel::build_denested_partition_cells_with_tails(
675 a,
676 b,
677 &score_breaks,
678 &link_breaks,
679 |z| {
680 if let (Some(runtime), Some(beta)) =
681 (self.score_warp_runtime.as_ref(), beta_score_warp)
682 {
683 let mut span = runtime.local_cubic_at(beta, z)?;
684 if let Some(corr) = score_warp_correction_for_row {
691 span.c0 -= corr.dot(beta);
692 }
693 Ok(span)
694 } else {
695 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
696 left: 0.0,
697 right: 1.0,
698 c0: 0.0,
699 c1: 0.0,
700 c2: 0.0,
701 c3: 0.0,
702 })
703 }
704 },
705 |u| {
706 if let (Some(runtime), Some(beta)) =
707 (self.link_deviation_runtime.as_ref(), beta_link_dev)
708 {
709 let mut span = runtime.local_cubic_at(beta, u)?;
710 if let Some(corr) = link_dev_correction_for_row {
711 span.c0 -= corr.dot(beta);
712 }
713 Ok(span)
714 } else {
715 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
716 left: 0.0,
717 right: 1.0,
718 c0: 0.0,
719 c1: 0.0,
720 c2: 0.0,
721 c3: 0.0,
722 })
723 }
724 },
725 )
726 .map_err(EstimationError::InvalidInput)?;
727 let scale = self.probit_frailty_scale();
728 if scale != 1.0 {
729 for partition_cell in &mut cells {
730 partition_cell.cell.c0 *= scale;
731 partition_cell.cell.c1 *= scale;
732 partition_cell.cell.c2 *= scale;
733 partition_cell.cell.c3 *= scale;
734 }
735 }
736 Ok(cells)
737 }
738
739 fn evaluate_denested_calibration(
740 &self,
741 a: f64,
742 marginal_eta: f64,
743 slope: f64,
744 beta_score_warp: Option<&Array1<f64>>,
745 beta_link_dev: Option<&Array1<f64>>,
746 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
747 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
748 ) -> Result<(f64, f64, f64), EstimationError> {
749 let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
750 .map_err(EstimationError::InvalidInput)?;
751 let cells = self.denested_partition_cells(
752 a,
753 slope,
754 beta_score_warp,
755 beta_link_dev,
756 score_warp_correction_for_row,
757 link_dev_correction_for_row,
758 )?;
759 let scale = self.probit_frailty_scale();
760 let mut f = -marginal.mu;
761 let mut f_a = 0.0;
762 let mut f_aa = 0.0;
763 for partition_cell in cells {
764 let cell = partition_cell.cell;
765 let (dc_da_raw, _) =
766 crate::cubic_cell_kernel::denested_cell_coefficient_partials(
767 partition_cell.score_span,
768 partition_cell.link_span,
769 a,
770 slope,
771 );
772 let (d2c_da2_raw, _, _) =
773 crate::cubic_cell_kernel::denested_cell_second_partials(
774 partition_cell.score_span,
775 partition_cell.link_span,
776 a,
777 slope,
778 );
779 let dc_da = scale_coeff4(dc_da_raw, scale);
780 let d2c_da2 = scale_coeff4(d2c_da2_raw, scale);
781 let max_degree =
787 crate::cubic_cell_kernel::cell_second_derivative_required_max_degree(
788 &dc_da, &dc_da, &d2c_da2,
789 );
790 let state = crate::cubic_cell_kernel::evaluate_cell_moments(cell, max_degree)
791 .map_err(EstimationError::InvalidInput)?;
792 f += state.value;
793 f_a += crate::cubic_cell_kernel::cell_first_derivative_from_moments(
794 &dc_da,
795 &state.moments,
796 )
797 .map_err(EstimationError::InvalidInput)?;
798 f_aa += crate::cubic_cell_kernel::cell_second_derivative_from_moments(
799 cell,
800 &dc_da,
801 &dc_da,
802 &d2c_da2,
803 &state.moments,
804 )
805 .map_err(EstimationError::InvalidInput)?;
806 }
807 Ok((f, f_a, f_aa))
808 }
809
810 fn observed_denested_cell_partials_at_z(
811 &self,
812 z_value: f64,
813 a: f64,
814 b: f64,
815 beta_score_warp: Option<&Array1<f64>>,
816 beta_link_dev: Option<&Array1<f64>>,
817 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
818 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
819 ) -> Result<ObservedDenestedCellPartials, EstimationError> {
820 use crate::cubic_cell_kernel as exact;
821
822 let zero_span = exact::LocalSpanCubic {
823 left: 0.0,
824 right: 1.0,
825 c0: 0.0,
826 c1: 0.0,
827 c2: 0.0,
828 c3: 0.0,
829 };
830 let u_value = a + b * z_value;
831 let score_span = if let (Some(runtime), Some(beta)) =
832 (self.score_warp_runtime.as_ref(), beta_score_warp)
833 {
834 let mut span = runtime
835 .local_cubic_at(beta, z_value)
836 .map_err(EstimationError::from)?;
837 if let Some(corr) = score_warp_correction_for_row {
838 span.c0 -= corr.dot(beta);
839 }
840 span
841 } else {
842 zero_span
843 };
844 let link_span = if let (Some(runtime), Some(beta)) =
845 (self.link_deviation_runtime.as_ref(), beta_link_dev)
846 {
847 let mut span = runtime
848 .local_cubic_at(beta, u_value)
849 .map_err(EstimationError::from)?;
850 if let Some(corr) = link_dev_correction_for_row {
851 span.c0 -= corr.dot(beta);
852 }
853 span
854 } else {
855 zero_span
856 };
857 let scale = self.probit_frailty_scale();
858 let coeff = scale_coeff4(
859 exact::denested_cell_coefficients(score_span, link_span, a, b),
860 scale,
861 );
862 let (dc_da_raw, dc_db_raw) =
863 exact::denested_cell_coefficient_partials(score_span, link_span, a, b);
864 let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) =
865 exact::denested_cell_second_partials(score_span, link_span, a, b);
866 let (dc_daaa, dc_daab, dc_dabb, dc_dbbb) = exact::denested_cell_third_partials(link_span);
867 Ok(ObservedDenestedCellPartials {
868 coeff,
869 dc_da: scale_coeff4(dc_da_raw, scale),
870 dc_db: scale_coeff4(dc_db_raw, scale),
871 dc_daa: scale_coeff4(dc_daa_raw, scale),
872 dc_dab: scale_coeff4(dc_dab_raw, scale),
873 dc_dbb: scale_coeff4(dc_dbb_raw, scale),
874 dc_daaa: scale_coeff4(dc_daaa, scale),
875 dc_daab: scale_coeff4(dc_daab, scale),
876 dc_dabb: scale_coeff4(dc_dabb, scale),
877 dc_dbbb: scale_coeff4(dc_dbbb, scale),
878 })
879 }
880
881 fn evaluate_empirical_denested_calibration(
882 &self,
883 a: f64,
884 marginal_eta: f64,
885 slope: f64,
886 beta_score_warp: Option<&Array1<f64>>,
887 beta_link_dev: Option<&Array1<f64>>,
888 grid: &EmpiricalZGrid,
889 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
890 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
891 ) -> Result<(f64, f64, f64), EstimationError> {
892 let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
893 .map_err(EstimationError::InvalidInput)?;
894 let mut f = -marginal.mu;
895 let mut f_a = 0.0;
896 let mut f_aa = 0.0;
897 for (node, weight) in grid.pairs() {
898 let obs = self.observed_denested_cell_partials_at_z(
899 node,
900 a,
901 slope,
902 beta_score_warp,
903 beta_link_dev,
904 score_warp_correction_for_row,
905 link_dev_correction_for_row,
906 )?;
907 let eta = eval_coeff4_at(&obs.coeff, node);
908 let eta_a = eval_coeff4_at(&obs.dc_da, node);
909 let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
910 let pdf = normal_pdf(eta);
911 f += weight * normal_cdf(eta);
912 f_a += weight * pdf * eta_a;
913 f_aa += weight * pdf * (eta_aa - eta * eta_a * eta_a);
914 }
915 Ok((f, f_a, f_aa))
916 }
917
918 fn evaluate_prediction_calibration(
919 &self,
920 a: f64,
921 marginal_eta: f64,
922 slope: f64,
923 beta_score_warp: Option<&Array1<f64>>,
924 beta_link_dev: Option<&Array1<f64>>,
925 empirical_grid: Option<&EmpiricalZGrid>,
926 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
927 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
928 ) -> Result<(f64, f64, f64), EstimationError> {
929 if let Some(grid) = empirical_grid {
930 self.evaluate_empirical_denested_calibration(
931 a,
932 marginal_eta,
933 slope,
934 beta_score_warp,
935 beta_link_dev,
936 grid,
937 score_warp_correction_for_row,
938 link_dev_correction_for_row,
939 )
940 } else {
941 self.evaluate_denested_calibration(
942 a,
943 marginal_eta,
944 slope,
945 beta_score_warp,
946 beta_link_dev,
947 score_warp_correction_for_row,
948 link_dev_correction_for_row,
949 )
950 }
951 }
952
953 pub fn from_unified(
954 unified: &UnifiedFitResult,
955 z_column: String,
956 latent_z_normalization: SavedLatentZNormalization,
957 latent_measure: LatentMeasureKind,
958 baseline_marginal: f64,
959 baseline_logslope: f64,
960 base_link: InverseLink,
961 frailty: FrailtySpec,
962 score_warp_runtime: Option<SavedCompiledFlexBlock>,
963 link_deviation_runtime: Option<SavedCompiledFlexBlock>,
964 latent_z_calibration: Option<crate::bms::LatentZRankIntCalibration>,
965 latent_z_conditional_calibration: Option<crate::bms::LatentZConditionalCalibration>,
966 ) -> Result<Self, String> {
967 let gaussian_frailty_sd = match frailty {
968 FrailtySpec::None => None,
969 FrailtySpec::GaussianShift {
970 sigma_fixed: Some(sigma),
971 } => Some(sigma),
972 FrailtySpec::GaussianShift { sigma_fixed: None } => {
973 return Err(
974 "bernoulli marginal-slope predictor requires a fixed GaussianShift sigma"
975 .to_string(),
976 );
977 }
978 FrailtySpec::HazardMultiplier { .. } => {
979 return Err(
980 "bernoulli marginal-slope predictor does not support HazardMultiplier frailty"
981 .to_string(),
982 );
983 }
984 };
985 if !matches!(
986 base_link,
987 InverseLink::Standard(gam_problem::types::StandardLink::Probit)
988 ) {
989 return Err(
990 "bernoulli marginal-slope predictor requires link(type=probit); saved non-probit marginal-slope models must be refit"
991 .to_string(),
992 );
993 }
994 if let Some(runtime) = score_warp_runtime.as_ref() {
995 runtime.validate_exact_replay_contract().map_err(|e| {
996 format!("bernoulli marginal-slope score-warp runtime is invalid: {e}")
997 })?;
998 }
999 if let Some(runtime) = link_deviation_runtime.as_ref() {
1000 runtime.validate_exact_replay_contract().map_err(|e| {
1001 format!("bernoulli marginal-slope link-deviation runtime is invalid: {e}")
1002 })?;
1003 }
1004 latent_z_normalization
1008 .validate("bernoulli marginal-slope predictor")
1009 .map_err(|e| {
1010 format!("bernoulli marginal-slope predictor latent z normalization is invalid: {e}")
1011 })?;
1012 latent_measure
1013 .validate("bernoulli marginal-slope predictor latent measure")
1014 .map_err(|e| {
1015 format!("bernoulli marginal-slope predictor latent measure is invalid: {e}")
1016 })?;
1017 let blocks = &unified.blocks;
1018 let expected_blocks = 2
1019 + usize::from(score_warp_runtime.is_some())
1020 + usize::from(link_deviation_runtime.is_some());
1021 if blocks.len() != expected_blocks {
1022 return Err(format!(
1023 "bernoulli marginal-slope predictor requires exactly {expected_blocks} coefficient blocks under the current exact de-nested semantics, got {}",
1024 blocks.len()
1025 ));
1026 }
1027 let mut cursor = 2usize;
1028 let beta_score_warp = if score_warp_runtime.is_some() {
1029 let beta = blocks
1030 .get(cursor)
1031 .ok_or_else(|| "missing score-warp coefficient block".to_string())?
1032 .beta
1033 .clone();
1034 cursor += 1;
1035 Some(beta)
1036 } else {
1037 None
1038 };
1039 let beta_link_dev = if link_deviation_runtime.is_some() {
1040 Some(
1041 blocks
1042 .get(cursor)
1043 .ok_or_else(|| "missing link-deviation coefficient block".to_string())?
1044 .beta
1045 .clone(),
1046 )
1047 } else {
1048 None
1049 };
1050 Ok(Self {
1051 beta_marginal: blocks[0].beta.clone(),
1052 beta_logslope: blocks[1].beta.clone(),
1053 beta_score_warp,
1054 beta_link_dev,
1055 base_link,
1056 z_column,
1057 latent_z_normalization,
1058 latent_measure,
1059 baseline_marginal,
1060 baseline_logslope,
1061 covariance: unified.beta_covariance().cloned(),
1062 score_warp_runtime,
1063 link_deviation_runtime,
1064 gaussian_frailty_sd,
1065 latent_z_calibration,
1066 latent_z_conditional_calibration,
1067 })
1068 }
1069
1070 pub fn theta(&self) -> Array1<f64> {
1071 let total = self.beta_marginal.len()
1072 + self.beta_logslope.len()
1073 + self.beta_score_warp.as_ref().map_or(0, |b| b.len())
1074 + self.beta_link_dev.as_ref().map_or(0, |b| b.len());
1075 let mut theta = Array1::<f64>::zeros(total);
1076 let mut cursor = 0usize;
1077 theta
1078 .slice_mut(ndarray::s![cursor..cursor + self.beta_marginal.len()])
1079 .assign(&self.beta_marginal);
1080 cursor += self.beta_marginal.len();
1081 theta
1082 .slice_mut(ndarray::s![cursor..cursor + self.beta_logslope.len()])
1083 .assign(&self.beta_logslope);
1084 cursor += self.beta_logslope.len();
1085 if let Some(beta) = self.beta_score_warp.as_ref() {
1086 theta
1087 .slice_mut(ndarray::s![cursor..cursor + beta.len()])
1088 .assign(beta);
1089 cursor += beta.len();
1090 }
1091 if let Some(beta) = self.beta_link_dev.as_ref() {
1092 theta
1093 .slice_mut(ndarray::s![cursor..cursor + beta.len()])
1094 .assign(beta);
1095 }
1096 theta
1097 }
1098
1099 fn split_theta<'a>(
1100 &'a self,
1101 theta: &'a Array1<f64>,
1102 ) -> Result<
1103 (
1104 ArrayView1<'a, f64>,
1105 ArrayView1<'a, f64>,
1106 Option<ArrayView1<'a, f64>>,
1107 Option<ArrayView1<'a, f64>>,
1108 ),
1109 EstimationError,
1110 > {
1111 let expected = self.theta().len();
1112 if theta.len() != expected {
1113 return Err(EstimationError::InvalidInput(format!(
1114 "bernoulli marginal-slope theta length mismatch: expected {expected}, got {}",
1115 theta.len()
1116 )));
1117 }
1118 let mut cursor = 0usize;
1119 let marginal = theta.slice(ndarray::s![cursor..cursor + self.beta_marginal.len()]);
1120 cursor += self.beta_marginal.len();
1121 let logslope = theta.slice(ndarray::s![cursor..cursor + self.beta_logslope.len()]);
1122 cursor += self.beta_logslope.len();
1123 let score_warp = self.beta_score_warp.as_ref().map(|beta| {
1124 let view = theta.slice(ndarray::s![cursor..cursor + beta.len()]);
1125 cursor += beta.len();
1126 view
1127 });
1128 let link_dev = self
1129 .beta_link_dev
1130 .as_ref()
1131 .map(|beta| theta.slice(ndarray::s![cursor..cursor + beta.len()]));
1132 Ok((marginal, logslope, score_warp, link_dev))
1133 }
1134
1135 fn solve_intercept_scalar(
1139 &self,
1140 marginal_eta: f64,
1141 slope: f64,
1142 link_dev_beta: Option<&Array1<f64>>,
1143 score_warp_beta: Option<&Array1<f64>>,
1144 empirical_grid: Option<&EmpiricalZGrid>,
1145 warm_start_buf: &mut Array1<f64>,
1146 score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1147 link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1148 ) -> Result<f64, EstimationError> {
1149 let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1150 .map_err(EstimationError::InvalidInput)?;
1151 let eval = |a: f64| -> Result<(f64, f64, f64), String> {
1152 self.evaluate_prediction_calibration(
1153 a,
1154 marginal_eta,
1155 slope,
1156 score_warp_beta,
1157 link_dev_beta,
1158 empirical_grid,
1159 score_warp_correction_for_row,
1160 link_dev_correction_for_row,
1161 )
1162 .map_err(|err| err.to_string())
1163 };
1164
1165 let probit_scale = self.probit_frailty_scale();
1166 let a_rigid = self.rigid_intercept_from_marginal(marginal.q, slope);
1167 let mut intercept = a_rigid;
1168 if let (Some(_), Some(beta)) = (self.link_deviation_runtime.as_ref(), link_dev_beta) {
1169 warm_start_buf[0] = a_rigid;
1170 let one_pt = warm_start_buf.slice(ndarray::s![0..1]).to_owned();
1171 let (l_val, l_d1) =
1172 self.link_terms_value_d1(&one_pt, Some(beta), link_dev_correction_for_row)?;
1173 let ell1 = l_d1[0];
1174 if ell1 > 1e-8 {
1175 let ell0 = l_val[0] - ell1 * a_rigid;
1176 let observed_logslope = probit_scale * ell1 * slope;
1177 intercept = (marginal.q * (1.0 + observed_logslope * observed_logslope).sqrt()
1178 / probit_scale
1179 - ell0)
1180 / ell1;
1181 }
1182 }
1183
1184 let target = marginal.mu;
1187 let abs_tol = 1e-8_f64.max(1e-4 * target.abs());
1188
1189 let (root, _, f_best) = crate::monotone_root::solve_monotone_root(
1190 eval,
1191 intercept,
1192 "saved bernoulli intercept",
1193 abs_tol,
1194 64,
1195 48,
1196 )?;
1197
1198 if f_best.abs() > abs_tol {
1199 return Err(EstimationError::InvalidInput(format!(
1200 "saved bernoulli marginal-slope intercept solve failed: residual={f_best:.3e} at a={root:.6}, target mu={target:.6}"
1201 )));
1202 }
1203 Ok(root)
1204 }
1205
1206 pub fn final_eta_and_gradient_from_theta(
1207 &self,
1208 input: &PredictInput,
1209 theta: &Array1<f64>,
1210 need_gradient: bool,
1211 ) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
1212 let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
1213 EstimationError::InvalidInput(format!(
1214 "bernoulli marginal-slope prediction requires auxiliary z column '{}'",
1215 self.z_column
1216 ))
1217 })?;
1218 let z_normalized = self
1219 .latent_z_normalization
1220 .apply(z_raw, "bernoulli marginal-slope prediction")
1221 .map_err(EstimationError::from)?;
1222 let z = self.apply_latent_z_calibration(&z_normalized);
1232 let z = self.apply_latent_z_conditional_calibration(&z, input)?;
1236 let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
1237 EstimationError::InvalidInput(
1238 "bernoulli marginal-slope prediction requires logslope design".to_string(),
1239 )
1240 })?;
1241 let (beta_marginal, beta_logslope, beta_score_warp, beta_link_dev) =
1242 self.split_theta(theta)?;
1243 if self.score_warp_runtime.is_some() != beta_score_warp.is_some() {
1244 return Err(EstimationError::InvalidInput(
1245 "bernoulli marginal-slope saved score-warp runtime/coefficients are inconsistent"
1246 .to_string(),
1247 ));
1248 }
1249 if self.link_deviation_runtime.is_some() != beta_link_dev.is_some() {
1250 return Err(EstimationError::InvalidInput(
1251 "bernoulli marginal-slope saved link-deviation runtime/coefficients are inconsistent"
1252 .to_string(),
1253 ));
1254 }
1255 let n = z.len();
1256 if input.offset.len() != n {
1257 return Err(EstimationError::InvalidInput(format!(
1258 "bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
1259 input.offset.len()
1260 )));
1261 }
1262 let logslope_offset = input
1263 .offset_noise
1264 .as_ref()
1265 .map_or_else(|| Array1::zeros(n), Clone::clone);
1266 if logslope_offset.len() != n {
1267 return Err(EstimationError::InvalidInput(format!(
1268 "bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
1269 logslope_offset.len()
1270 )));
1271 }
1272 let marginal_eta = input
1273 .design
1274 .dot(&beta_marginal.to_owned())
1275 .mapv(|v| v + self.baseline_marginal)
1276 + &input.offset;
1277 let logslope_eta = design_logslope
1278 .dot(&beta_logslope.to_owned())
1279 .mapv(|v| v + self.baseline_logslope)
1280 + &logslope_offset;
1281 let flex_active =
1282 self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
1283 let marginal_dim = self.beta_marginal.len();
1284 let logslope_dim = self.beta_logslope.len();
1285 let score_warp_dim = self.beta_score_warp.as_ref().map_or(0, Array1::len);
1286 let link_dev_dim = self.beta_link_dev.as_ref().map_or(0, Array1::len);
1287 let logslope_offset = marginal_dim;
1288 let score_warp_offset = logslope_offset + logslope_dim;
1289 let link_dev_offset = score_warp_offset + score_warp_dim;
1290 let chunk_size = prediction_chunk_rows(theta.len(), 1, n);
1291 let num_chunks = n.div_ceil(chunk_size);
1292 let scale = self.probit_frailty_scale();
1293 let anchor_corrections =
1300 self.build_anchor_correction_matrices(input, design_logslope, &z)?;
1301 let marginal_map = marginal_eta
1302 .iter()
1303 .map(|&eta| {
1304 bernoulli_marginal_link_map(&self.base_link, eta)
1305 .map_err(EstimationError::InvalidInput)
1306 })
1307 .collect::<Result<Vec<_>, _>>()?;
1308
1309 if !flex_active {
1310 let (final_eta_internal, marginal_scales, logslope_scales) = match &self.latent_measure
1311 {
1312 LatentMeasureKind::StandardNormal => {
1313 let sb_vec = logslope_eta.mapv(|b| scale * b);
1314 let c_vec = sb_vec.mapv(|sb| (1.0 + sb * sb).sqrt());
1315 let final_eta_internal = Array1::from_iter(
1316 (0..n).map(|i| c_vec[i] * marginal_eta[i] + sb_vec[i] * z[i]),
1317 );
1318 let marginal_scales = c_vec;
1319 let logslope_scales = Array1::from_iter((0..n).map(|i| {
1320 marginal_eta[i] * (scale * scale) * logslope_eta[i] / marginal_scales[i]
1321 + scale * z[i]
1322 }));
1323 (final_eta_internal, marginal_scales, logslope_scales)
1324 }
1325 LatentMeasureKind::GlobalEmpirical { grid } => {
1326 let mut final_eta = Array1::<f64>::zeros(n);
1327 let mut marginal_scales = Array1::<f64>::zeros(n);
1328 let mut logslope_scales = Array1::<f64>::zeros(n);
1329 for i in 0..n {
1330 let (intercept, a_marginal, a_slope) = self
1331 .empirical_rigid_intercept_and_gradient(
1332 marginal_eta[i],
1333 logslope_eta[i],
1334 &grid.nodes,
1335 &grid.weights,
1336 )?;
1337 final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
1338 marginal_scales[i] = a_marginal;
1339 logslope_scales[i] = a_slope + scale * z[i];
1340 }
1341 (final_eta, marginal_scales, logslope_scales)
1342 }
1343 LatentMeasureKind::LocalEmpirical { .. } => {
1344 let mut final_eta = Array1::<f64>::zeros(n);
1345 let mut marginal_scales = Array1::<f64>::zeros(n);
1346 let mut logslope_scales = Array1::<f64>::zeros(n);
1347 for i in 0..n {
1348 let grid = self
1349 .empirical_grid_for_prediction_row(input, i)?
1350 .ok_or_else(|| {
1351 EstimationError::InvalidInput(
1352 "local empirical latent prediction did not produce a row grid"
1353 .to_string(),
1354 )
1355 })?;
1356 let (intercept, a_marginal, a_slope) = self
1357 .empirical_rigid_intercept_and_gradient(
1358 marginal_eta[i],
1359 logslope_eta[i],
1360 &grid.nodes,
1361 &grid.weights,
1362 )?;
1363 final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
1364 marginal_scales[i] = a_marginal;
1365 logslope_scales[i] = a_slope + scale * z[i];
1366 }
1367 (final_eta, marginal_scales, logslope_scales)
1368 }
1369 };
1370
1371 if !need_gradient {
1372 return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
1373 }
1374
1375 let mut grad_internal = Array2::<f64>::zeros((n, theta.len()));
1377 let mut start = 0usize;
1378 while start < n {
1379 let end = (start + chunk_size).min(n);
1380 let mc = input
1381 .design
1382 .try_row_chunk(start..end)
1383 .map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
1384 let lc = design_logslope
1385 .try_row_chunk(start..end)
1386 .map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
1387
1388 for li in 0..(end - start) {
1389 let i = start + li;
1390 let c = marginal_scales[i];
1391 let g_scale = logslope_scales[i];
1392 let mut row = grad_internal.row_mut(i);
1393 for j in 0..marginal_dim {
1394 row[j] = c * mc[[li, j]];
1395 }
1396 for j in 0..logslope_dim {
1397 row[logslope_offset + j] = g_scale * lc[[li, j]];
1398 }
1399 }
1400
1401 start = end;
1402 }
1403 return self
1404 .transform_internal_eta_to_base_scale(final_eta_internal, Some(grad_internal));
1405 }
1406
1407 let score_warp_obs_design = self
1409 .score_warp_runtime
1410 .as_ref()
1411 .map(|runtime| {
1412 if runtime.anchor_correction.is_some() {
1413 let anchor_rows = anchor_corrections
1414 .score_warp_anchor_rows_view()
1415 .ok_or_else(|| {
1416 EstimationError::InvalidInput(
1417 "bernoulli marginal-slope score-warp anchor residual present but \
1418 anchor_corrections bundle is missing the parametric anchor rows"
1419 .to_string(),
1420 )
1421 })?;
1422 runtime
1423 .design_with_anchor_rows(&z, anchor_rows)
1424 .map_err(EstimationError::from)
1425 } else {
1426 runtime.design(&z).map_err(EstimationError::from)
1427 }
1428 })
1429 .transpose()?;
1430 let score_dev_obs =
1431 if let (Some(design), Some(beta)) = (score_warp_obs_design.as_ref(), beta_score_warp) {
1432 design.dot(&beta.to_owned())
1433 } else {
1434 Array1::zeros(n)
1435 };
1436
1437 let score_warp_beta_owned = beta_score_warp.as_ref().map(|v| v.to_owned());
1442 let link_dev_beta_owned = beta_link_dev.as_ref().map(|v| v.to_owned());
1443 let mut intercepts = Array1::<f64>::zeros(n);
1444 let mut a_q_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
1445 let mut a_b_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
1446 let mut a_h_rows = if need_gradient && score_warp_dim > 0 {
1447 Some(Array2::<f64>::zeros((n, score_warp_dim)))
1448 } else {
1449 None
1450 };
1451 let mut a_w_rows = if need_gradient && link_dev_dim > 0 {
1452 Some(Array2::<f64>::zeros((n, link_dev_dim)))
1453 } else {
1454 None
1455 };
1456 let solve_result: Result<(), EstimationError> = {
1457 use ndarray::Axis;
1458 use rayon::iter::IndexedParallelIterator;
1459 let intercepts_chunks: Vec<ndarray::ArrayViewMut1<f64>> = intercepts
1460 .axis_chunks_iter_mut(Axis(0), chunk_size)
1461 .collect();
1462 let a_q_chunks: Option<Vec<ndarray::ArrayViewMut1<f64>>> = a_q_vec
1463 .as_mut()
1464 .map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
1465 let a_b_chunks: Option<Vec<ndarray::ArrayViewMut1<f64>>> = a_b_vec
1466 .as_mut()
1467 .map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
1468 let a_h_chunks: Option<Vec<ndarray::ArrayViewMut2<f64>>> = a_h_rows
1469 .as_mut()
1470 .map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
1471 let a_w_chunks: Option<Vec<ndarray::ArrayViewMut2<f64>>> = a_w_rows
1472 .as_mut()
1473 .map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
1474
1475 struct FlexSolveSink<'a> {
1478 intercepts: ndarray::ArrayViewMut1<'a, f64>,
1479 a_q: Option<ndarray::ArrayViewMut1<'a, f64>>,
1480 a_b: Option<ndarray::ArrayViewMut1<'a, f64>>,
1481 a_h: Option<ndarray::ArrayViewMut2<'a, f64>>,
1482 a_w: Option<ndarray::ArrayViewMut2<'a, f64>>,
1483 }
1484 let mut sinks: Vec<FlexSolveSink<'_>> = Vec::with_capacity(num_chunks);
1485 let mut intercepts_iter = intercepts_chunks.into_iter();
1487 let mut a_q_iter = a_q_chunks.map(|v| v.into_iter());
1488 let mut a_b_iter = a_b_chunks.map(|v| v.into_iter());
1489 let mut a_h_iter = a_h_chunks.map(|v| v.into_iter());
1490 let mut a_w_iter = a_w_chunks.map(|v| v.into_iter());
1491 for _ in 0..num_chunks {
1492 sinks.push(FlexSolveSink {
1493 intercepts: intercepts_iter.next().expect("chunk count matches"),
1494 a_q: a_q_iter
1495 .as_mut()
1496 .map(|it| it.next().expect("chunk count matches")),
1497 a_b: a_b_iter
1498 .as_mut()
1499 .map(|it| it.next().expect("chunk count matches")),
1500 a_h: a_h_iter
1501 .as_mut()
1502 .map(|it| it.next().expect("chunk count matches")),
1503 a_w: a_w_iter
1504 .as_mut()
1505 .map(|it| it.next().expect("chunk count matches")),
1506 });
1507 }
1508
1509 let global_score_basis_table: Option<
1520 Vec<Vec<crate::cubic_cell_kernel::LocalSpanCubic>>,
1521 > = if let (LatentMeasureKind::GlobalEmpirical { grid }, Some(runtime)) =
1522 (&self.latent_measure, self.score_warp_runtime.as_ref())
1523 {
1524 let mut table = Vec::with_capacity(score_warp_dim);
1525 for j in 0..score_warp_dim {
1526 let mut row = Vec::with_capacity(grid.nodes.len());
1527 for &node in &grid.nodes {
1528 row.push(
1529 runtime
1530 .basis_cubic_at(j, node)
1531 .map_err(EstimationError::from)?,
1532 );
1533 }
1534 table.push(row);
1535 }
1536 Some(table)
1537 } else {
1538 None
1539 };
1540 let global_score_basis_table = global_score_basis_table.as_ref();
1541
1542 sinks
1543 .into_par_iter()
1544 .enumerate()
1545 .try_for_each(|(chunk_idx, mut sink)| -> Result<(), EstimationError> {
1546 let start = chunk_idx * chunk_size;
1547 let end = (start + chunk_size).min(n);
1548 let rows = end - start;
1549 let intercepts_view = &mut sink.intercepts;
1553 let mut a_q = sink.a_q.as_mut();
1554 let mut a_b = sink.a_b.as_mut();
1555 let mut a_h = sink.a_h.as_mut();
1556 let mut a_w = sink.a_w.as_mut();
1557 let mut warm_start_buf = Array1::<f64>::zeros(1);
1558 let mut f_h_row = vec![0.0; score_warp_dim];
1559 let mut f_w_row = vec![0.0; link_dev_dim];
1560
1561 for local_row in 0..rows {
1562 let i = start + local_row;
1563 let slope = logslope_eta[i];
1564 let q = marginal_eta[i];
1565 let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
1566 let score_corr_row = anchor_corrections.score_warp_row(i);
1567 let link_corr_row = anchor_corrections.link_dev_row(i);
1568 intercepts_view[local_row] = self.solve_intercept_scalar(
1569 q,
1570 slope,
1571 link_dev_beta_owned.as_ref(),
1572 score_warp_beta_owned.as_ref(),
1573 empirical_grid.as_ref(),
1574 &mut warm_start_buf,
1575 score_corr_row,
1576 link_corr_row,
1577 )?;
1578
1579 if !need_gradient {
1580 continue;
1581 }
1582
1583 let intercept = intercepts_view[local_row];
1584 let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
1585 intercept,
1586 q,
1587 slope,
1588 score_warp_beta_owned.as_ref(),
1589 link_dev_beta_owned.as_ref(),
1590 empirical_grid.as_ref(),
1591 score_corr_row,
1592 link_corr_row,
1593 )?;
1594 let m_a = m_a_raw.max(1e-12);
1595 a_q.as_mut().expect("a_q allocated when need_gradient")[local_row] =
1596 marginal_map[i].mu1 / m_a;
1597 let mut f_b = 0.0;
1598 f_h_row.fill(0.0);
1599 f_w_row.fill(0.0);
1600 if let Some(grid) = empirical_grid.as_ref() {
1601 for (node_idx, (node, weight)) in grid.pairs().enumerate() {
1602 let obs = self.observed_denested_cell_partials_at_z(
1603 node,
1604 intercept,
1605 slope,
1606 score_warp_beta_owned.as_ref(),
1607 link_dev_beta_owned.as_ref(),
1608 score_corr_row,
1609 link_corr_row,
1610 )?;
1611 let eta = eval_coeff4_at(&obs.coeff, node);
1612 let pdf = normal_pdf(eta);
1613 f_b += weight * pdf * eval_coeff4_at(&obs.dc_db, node);
1614
1615 if let Some(runtime) = self.score_warp_runtime.as_ref() {
1616 for j in 0..score_warp_dim {
1617 let mut basis_span = if let Some(table) =
1625 global_score_basis_table
1626 {
1627 table[j][node_idx]
1628 } else {
1629 runtime
1630 .basis_cubic_at(j, node)
1631 .map_err(EstimationError::from)?
1632 };
1633 if let Some(corr) = score_corr_row {
1640 basis_span.c0 -= corr[j];
1641 }
1642 let coeffs = crate::cubic_cell_kernel::score_basis_cell_coefficients(
1643 basis_span,
1644 slope,
1645 );
1646 let coeffs = scale_coeff4(coeffs, scale);
1647 f_h_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
1648 }
1649 }
1650
1651 if let Some(runtime) = self.link_deviation_runtime.as_ref() {
1652 for j in 0..link_dev_dim {
1653 let mut basis_span = runtime
1654 .basis_cubic_at(j, intercept + slope * node)
1655 .map_err(EstimationError::from)?;
1656 if let Some(corr) = link_corr_row {
1657 basis_span.c0 -= corr[j];
1658 }
1659 let coeffs = crate::cubic_cell_kernel::link_basis_cell_coefficients(
1660 basis_span,
1661 intercept,
1662 slope,
1663 );
1664 let coeffs = scale_coeff4(coeffs, scale);
1665 f_w_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
1666 }
1667 }
1668 }
1669 } else {
1670 let cells = self.denested_partition_cells(
1671 intercept,
1672 slope,
1673 score_warp_beta_owned.as_ref(),
1674 link_dev_beta_owned.as_ref(),
1675 score_corr_row,
1676 link_corr_row,
1677 )?;
1678 for partition_cell in cells {
1679 let cell = partition_cell.cell;
1680 let state =
1681 crate::cubic_cell_kernel::evaluate_cell_moments(
1682 cell, 9,
1683 )
1684 .map_err(EstimationError::InvalidInput)?;
1685 let (_, dc_db_raw) = crate::cubic_cell_kernel::denested_cell_coefficient_partials(
1686 partition_cell.score_span,
1687 partition_cell.link_span,
1688 intercept,
1689 slope,
1690 );
1691 let dc_db = scale_coeff4(dc_db_raw, scale);
1695 f_b += crate::cubic_cell_kernel::cell_first_derivative_from_moments(
1696 &dc_db,
1697 &state.moments,
1698 )
1699 .map_err(EstimationError::InvalidInput)?;
1700
1701 let mid = 0.5 * (cell.left + cell.right);
1702 if let Some(runtime) = self.score_warp_runtime.as_ref() {
1703 for j in 0..score_warp_dim {
1704 let mut basis_span = runtime
1705 .basis_cubic_at(j, mid)
1706 .map_err(EstimationError::from)?;
1707 if let Some(corr) = score_corr_row {
1708 basis_span.c0 -= corr[j];
1709 }
1710 let coeffs = crate::cubic_cell_kernel::score_basis_cell_coefficients(
1711 basis_span, slope,
1712 );
1713 let coeffs = scale_coeff4(coeffs, scale);
1714 f_h_row[j] += crate::cubic_cell_kernel::cell_first_derivative_from_moments(
1715 &coeffs,
1716 &state.moments,
1717 )
1718 .map_err(EstimationError::InvalidInput)?;
1719 }
1720 }
1721
1722 if let Some(runtime) = self.link_deviation_runtime.as_ref() {
1723 for j in 0..link_dev_dim {
1724 let mut basis_span = runtime
1725 .basis_cubic_at(j, intercept + slope * mid)
1726 .map_err(EstimationError::from)?;
1727 if let Some(corr) = link_corr_row {
1728 basis_span.c0 -= corr[j];
1729 }
1730 let coeffs = crate::cubic_cell_kernel::link_basis_cell_coefficients(
1731 basis_span,
1732 intercept,
1733 slope,
1734 );
1735 let coeffs = scale_coeff4(coeffs, scale);
1736 f_w_row[j] += crate::cubic_cell_kernel::cell_first_derivative_from_moments(
1737 &coeffs,
1738 &state.moments,
1739 )
1740 .map_err(EstimationError::InvalidInput)?;
1741 }
1742 }
1743 }
1744 }
1745 if let Some(a_h_view) = a_h.as_mut() {
1746 let factor = -1.0 / m_a;
1747 for j in 0..score_warp_dim {
1748 a_h_view[[local_row, j]] = factor * f_h_row[j];
1749 }
1750 }
1751 if let Some(a_w_view) = a_w.as_mut() {
1752 let factor = -1.0 / m_a;
1753 for j in 0..link_dev_dim {
1754 a_w_view[[local_row, j]] = factor * f_w_row[j];
1755 }
1756 }
1757 a_b.as_mut().expect("a_b allocated when need_gradient")[local_row] =
1758 -f_b / m_a;
1759 }
1760 Ok(())
1761 })
1762 };
1763 solve_result?;
1764
1765 let eta_base = &intercepts + &(&logslope_eta * &z);
1766
1767 let mut link_c_obs: Option<Array1<f64>> = None;
1768 let mut link_basis_obs: Option<Array2<f64>> = None;
1769 let link_dev_obs = if let (Some(runtime), Some(beta_owned)) = (
1770 self.link_deviation_runtime.as_ref(),
1771 link_dev_beta_owned.as_ref(),
1772 ) {
1773 let basis = if runtime.anchor_correction.is_some() {
1774 let anchor_rows =
1775 anchor_corrections
1776 .link_dev_anchor_rows_view()
1777 .ok_or_else(|| {
1778 EstimationError::InvalidInput(
1779 "bernoulli marginal-slope link-deviation anchor residual present but \
1780 anchor_corrections bundle is missing the parametric anchor rows"
1781 .to_string(),
1782 )
1783 })?;
1784 runtime
1785 .design_with_anchor_rows(&eta_base, anchor_rows)
1786 .map_err(EstimationError::from)?
1787 } else {
1788 runtime.design(&eta_base).map_err(EstimationError::from)?
1789 };
1790 let dev = basis.dot(beta_owned);
1791 if need_gradient {
1792 let d1 = runtime
1793 .first_derivative_design(&eta_base)
1794 .map_err(EstimationError::from)?;
1795 let mut c_obs = d1.dot(beta_owned);
1796 c_obs.mapv_inplace(|v| v + 1.0);
1797 link_c_obs = Some(c_obs);
1798 link_basis_obs = Some(basis);
1799 }
1800 dev
1801 } else {
1802 Array1::zeros(n)
1803 };
1804 let final_eta_internal =
1805 (&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
1806
1807 if !need_gradient {
1808 return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
1809 }
1810
1811 let a_q_vec = a_q_vec.unwrap();
1812 let a_b_vec = a_b_vec.unwrap();
1813
1814 let mut grad = Array2::<f64>::zeros((n, theta.len()));
1818 {
1819 use ndarray::Axis;
1820 use rayon::iter::IndexedParallelIterator;
1821 let grad_result: Result<(), String> = grad
1822 .axis_chunks_iter_mut(Axis(0), chunk_size)
1823 .into_par_iter()
1824 .enumerate()
1825 .try_for_each(|(chunk_idx, mut grad_chunk)| -> Result<(), String> {
1826 let start = chunk_idx * chunk_size;
1827 let end = (start + chunk_size).min(n);
1828 let mc = input
1829 .design
1830 .try_row_chunk(start..end)
1831 .map_err(|e| e.to_string())?;
1832 let lc = design_logslope
1833 .try_row_chunk(start..end)
1834 .map_err(|e| e.to_string())?;
1835 let rows = end - start;
1836
1837 for li in 0..rows {
1838 let i = start + li;
1839 let mut row = grad_chunk.row_mut(li);
1840
1841 let a_q = a_q_vec[i];
1842 for j in 0..marginal_dim {
1843 row[j] = a_q * mc[[li, j]];
1844 }
1845
1846 let base_multiplier = link_c_obs.as_ref().map_or(1.0, |c| c[i]);
1847 let g_scale = base_multiplier * (a_b_vec[i] + z[i]) + score_dev_obs[i];
1848 for j in 0..logslope_dim {
1849 row[logslope_offset + j] = g_scale * lc[[li, j]];
1850 }
1851
1852 if let (Some(a_h_rows), Some(obs_design)) =
1853 (a_h_rows.as_ref(), score_warp_obs_design.as_ref())
1854 {
1855 let slope = logslope_eta[i];
1856 for j in 0..score_warp_dim {
1857 row[score_warp_offset + j] =
1858 base_multiplier * a_h_rows[[i, j]] + slope * obs_design[[i, j]];
1859 }
1860 }
1861
1862 if let Some(a_w_rows) = a_w_rows.as_ref() {
1863 for j in 0..link_dev_dim {
1864 row[link_dev_offset + j] = a_w_rows[[i, j]];
1865 }
1866 }
1867
1868 if let (Some(link_c), Some(link_basis)) =
1869 (link_c_obs.as_ref(), link_basis_obs.as_ref())
1870 {
1871 let c = link_c[i];
1872 for j in 0..marginal_dim {
1873 row[j] *= c;
1874 }
1875 for j in 0..link_dev_dim {
1876 row[link_dev_offset + j] =
1877 c * row[link_dev_offset + j] + link_basis[[i, j]];
1878 }
1879 }
1880 }
1881 Ok(())
1882 });
1883 grad_result.map_err(EstimationError::InvalidInput)?;
1884 }
1885 if scale != 1.0 {
1886 grad.mapv_inplace(|v| scale * v);
1887 }
1888 self.transform_internal_eta_to_base_scale(final_eta_internal, Some(grad))
1889 }
1890
1891 pub fn final_eta_from_theta(
1901 &self,
1902 input: &PredictInput,
1903 theta: &Array1<f64>,
1904 ) -> Result<Array1<f64>, EstimationError> {
1905 let (eta, _) = self.final_eta_and_gradient_from_theta(input, theta, false)?;
1906 Ok(eta)
1907 }
1908
1909 pub fn theta_len(&self) -> usize {
1914 self.beta_marginal.len()
1915 + self.beta_logslope.len()
1916 + self.beta_score_warp.as_ref().map_or(0, Array1::len)
1917 + self.beta_link_dev.as_ref().map_or(0, Array1::len)
1918 }
1919
1920 pub fn predict_eta_and_q_chain(
1937 &self,
1938 input: &PredictInput,
1939 ) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
1940 let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
1941 EstimationError::InvalidInput(format!(
1942 "bernoulli marginal-slope prediction requires auxiliary z column '{}'",
1943 self.z_column
1944 ))
1945 })?;
1946 let z_normalized = self
1947 .latent_z_normalization
1948 .apply(z_raw, "bernoulli marginal-slope prediction")
1949 .map_err(EstimationError::from)?;
1950 let z = self.apply_latent_z_calibration(&z_normalized);
1956 let z = self.apply_latent_z_conditional_calibration(&z, input)?;
1960 let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
1961 EstimationError::InvalidInput(
1962 "bernoulli marginal-slope prediction requires logslope design".to_string(),
1963 )
1964 })?;
1965 let n = z.len();
1966 if input.offset.len() != n {
1967 return Err(EstimationError::InvalidInput(format!(
1968 "bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
1969 input.offset.len()
1970 )));
1971 }
1972 let logslope_offset = input
1973 .offset_noise
1974 .as_ref()
1975 .map_or_else(|| Array1::zeros(n), Clone::clone);
1976 if logslope_offset.len() != n {
1977 return Err(EstimationError::InvalidInput(format!(
1978 "bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
1979 logslope_offset.len()
1980 )));
1981 }
1982 let marginal_eta = input
1983 .design
1984 .dot(&self.beta_marginal)
1985 .mapv(|v| v + self.baseline_marginal)
1986 + &input.offset;
1987 let logslope_eta = design_logslope
1988 .dot(&self.beta_logslope)
1989 .mapv(|v| v + self.baseline_logslope)
1990 + &logslope_offset;
1991 let scale = self.probit_frailty_scale();
1992 let flex_active =
1993 self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
1994
1995 if !flex_active {
1998 match &self.latent_measure {
1999 LatentMeasureKind::StandardNormal => {
2000 let sb = logslope_eta.mapv(|x| scale * x);
2003 let deta_dq = sb.mapv(|s| (1.0 + s * s).sqrt());
2004 let eta = &deta_dq * marginal_eta + &sb * z;
2005 return Ok((eta, deta_dq));
2006 }
2007 _ => {
2008 let mut eta = Array1::<f64>::zeros(n);
2009 let mut deta_dq = Array1::<f64>::zeros(n);
2010 for i in 0..n {
2011 let grid = self
2012 .empirical_grid_for_prediction_row(input, i)?
2013 .ok_or_else(|| {
2014 EstimationError::InvalidInput(
2015 "empirical latent prediction did not produce a row grid"
2016 .to_string(),
2017 )
2018 })?;
2019 let (intercept, a_marginal, _) = self
2020 .empirical_rigid_intercept_and_gradient(
2021 marginal_eta[i],
2022 logslope_eta[i],
2023 &grid.nodes,
2024 &grid.weights,
2025 )?;
2026 eta[i] = intercept + scale * logslope_eta[i] * z[i];
2027 deta_dq[i] = a_marginal;
2028 }
2029 return Ok((eta, deta_dq));
2030 }
2031 }
2032 }
2033
2034 let marginal_map = marginal_eta
2040 .iter()
2041 .map(|&eta_marg| {
2042 bernoulli_marginal_link_map(&self.base_link, eta_marg)
2043 .map_err(EstimationError::InvalidInput)
2044 })
2045 .collect::<Result<Vec<_>, _>>()?;
2046 let anchor_corrections =
2049 self.build_anchor_correction_matrices(input, design_logslope, &z)?;
2050 use rayon::iter::{IntoParallelIterator, ParallelIterator};
2054 let pairs: Result<Vec<(f64, f64)>, EstimationError> = (0..n)
2055 .into_par_iter()
2056 .map_init(
2057 || Array1::<f64>::zeros(1),
2058 |warm_start_buf, i| {
2059 let q = marginal_eta[i];
2060 let slope = logslope_eta[i];
2061 let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
2062 let score_corr_row = anchor_corrections.score_warp_row(i);
2063 let link_corr_row = anchor_corrections.link_dev_row(i);
2064 let intercept = self.solve_intercept_scalar(
2065 q,
2066 slope,
2067 self.beta_link_dev.as_ref(),
2068 self.beta_score_warp.as_ref(),
2069 empirical_grid.as_ref(),
2070 warm_start_buf,
2071 score_corr_row,
2072 link_corr_row,
2073 )?;
2074 let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
2075 intercept,
2076 q,
2077 slope,
2078 self.beta_score_warp.as_ref(),
2079 self.beta_link_dev.as_ref(),
2080 empirical_grid.as_ref(),
2081 score_corr_row,
2082 link_corr_row,
2083 )?;
2084 let m_a = m_a_raw.max(1e-12);
2085 Ok((intercept, marginal_map[i].mu1 / m_a))
2086 },
2087 )
2088 .collect();
2089 let pairs = pairs?;
2090 let mut intercepts = Array1::<f64>::zeros(n);
2091 let mut a_q = Array1::<f64>::zeros(n);
2092 for (i, (intercept, a)) in pairs.into_iter().enumerate() {
2093 intercepts[i] = intercept;
2094 a_q[i] = a;
2095 }
2096
2097 let score_dev_obs = if let (Some(runtime), Some(beta)) = (
2098 self.score_warp_runtime.as_ref(),
2099 self.beta_score_warp.as_ref(),
2100 ) {
2101 let design = if runtime.anchor_correction.is_some() {
2102 let anchor_rows = anchor_corrections
2103 .score_warp_anchor_rows_view()
2104 .ok_or_else(|| {
2105 EstimationError::InvalidInput(
2106 "bernoulli marginal-slope score-warp anchor residual present but \
2107 anchor_corrections bundle is missing the parametric anchor rows"
2108 .to_string(),
2109 )
2110 })?;
2111 runtime
2112 .design_with_anchor_rows(&z, anchor_rows)
2113 .map_err(EstimationError::from)?
2114 } else {
2115 runtime.design(&z).map_err(EstimationError::from)?
2116 };
2117 design.dot(beta)
2118 } else {
2119 Array1::zeros(n)
2120 };
2121 let eta_base = &intercepts + &(&logslope_eta * &z);
2122 let (link_dev_obs, link_c_obs) = if let (Some(runtime), Some(beta)) = (
2123 self.link_deviation_runtime.as_ref(),
2124 self.beta_link_dev.as_ref(),
2125 ) {
2126 let basis = if runtime.anchor_correction.is_some() {
2127 let anchor_rows =
2128 anchor_corrections
2129 .link_dev_anchor_rows_view()
2130 .ok_or_else(|| {
2131 EstimationError::InvalidInput(
2132 "bernoulli marginal-slope link-deviation anchor residual present but \
2133 anchor_corrections bundle is missing the parametric anchor rows"
2134 .to_string(),
2135 )
2136 })?;
2137 runtime
2138 .design_with_anchor_rows(&eta_base, anchor_rows)
2139 .map_err(EstimationError::from)?
2140 } else {
2141 runtime.design(&eta_base).map_err(EstimationError::from)?
2142 };
2143 let dev = basis.dot(beta);
2144 let d1 = runtime
2145 .first_derivative_design(&eta_base)
2146 .map_err(EstimationError::from)?;
2147 let mut c_obs = d1.dot(beta);
2148 c_obs.mapv_inplace(|v| v + 1.0);
2149 (dev, c_obs)
2150 } else {
2151 (Array1::zeros(n), Array1::ones(n))
2152 };
2153 let final_eta_internal =
2154 (&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
2155 let deta_dq = (&link_c_obs * &a_q).mapv(|v| scale * v);
2156 Ok((final_eta_internal, deta_dq))
2157 }
2158}