1use crate::custom_family::{
2 BatchedOuterGradientTerms, BlockEffectiveJacobian, BlockWorkingSet, BlockwiseFitOptions,
3 CustomFamily, CustomFamilyWarmStart, ExactNewtonJointGradientEvaluation,
4 ExactNewtonJointHessianWorkspace, ExactNewtonJointPsiSecondOrderTerms,
5 ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace, FamilyEvaluation,
6 FamilyLinearizationState, ParameterBlockSpec, ParameterBlockState, PenaltyMatrix,
7 custom_family_outer_derivatives, evaluate_custom_family_joint_hyper_efs_shared,
8 evaluate_custom_family_joint_hyper_shared, fit_custom_family,
9 joint_hyper_options_for_outer_tolerance,
10};
11use gam_solve::estimate::reml::reml_outer_engine::{DenseSpectralOperator, HessianOperator};
12use crate::cubic_cell_kernel as exact_kernel;
13use crate::marginal_slope_shared::{
14 CoeffSupport, DirectionalScaleJets, ObservedDenestedCellPartials, SparsePrimaryCoeffJetView,
15 add_optional_matrix, add_optional_vector, add_two_surface_psi_outer,
16 build_denested_partition_cells as shared_denested_partition_cells, chunked_row_reduction,
17 directional_obj_grad_hess, eval_coeff4_at, is_sigma_aux_index as shared_is_sigma_aux_index,
18 observed_denested_cell_partials as shared_observed_denested_cell_partials, outer_row_indices,
19 outer_weighted_rows, parameter_block_specs_match_rows, probit_frailty_scale,
20 probit_frailty_scale_multi_dir_jet, psi_derivative_location, scale_coeff4,
21};
22use crate::parameter_block::ParameterBlockInput;
23use crate::row_kernel::{
24 RowKernel, RowKernelHessianWorkspace, build_row_kernel_cache, row_kernel_gradient,
25 row_kernel_hessian_dense, row_kernel_log_likelihood,
26};
27use crate::spatial_psi_bridge::build_block_spatial_psi_derivatives;
28use crate::survival::lognormal_kernel::FrailtySpec;
29use crate::wiggle::initializewiggle_knots_from_seed;
30use gam_linalg::matrix::{DesignMatrix, SymmetricMatrix};
31use crate::model_types::UnifiedFitResult;
32use crate::outer_subsample::WeightedOuterRow;
33use gam_solve::pirls::LinearInequalityConstraints;
34use crate::probability::{
35 normal_cdf, normal_logcdf, normal_pdf, signed_probit_logcdf_and_mills_ratio,
36 standard_normal_quantile,
37};
38use gam_terms::smooth::{
39 SpatialLengthScaleOptimizationOptions, SpatialLogKappaCoords, TermCollectionDesign,
40 TermCollectionSpec,
41};
42use crate::fit_orchestration::drivers::{
43 ExactJointHyperSetup, apply_spatial_anisotropy_pilot_initializer,
44 build_term_collection_designs_and_freeze_joint, optimize_spatial_length_scale_exact_joint,
45 spatial_length_scale_term_indices,
46};
47use gam_problem::{InverseLink, StandardLink, WigglePenaltyConfig};
48use gam_math::jet_partitions::MultiDirJet;
49use gam_problem::HyperOperator;
50use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, s};
51use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
52use serde::{Deserialize, Serialize};
53use std::cell::RefCell;
54use std::collections::HashMap;
55use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
56use std::sync::{Arc, Mutex, OnceLock};
57
58pub mod deviation_runtime;
59pub mod gpu;
60pub use deviation_runtime::DeviationRuntime;
61pub use deviation_runtime::ParametricAnchorBlock;
62
63pub(crate) const BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD: usize = 50_000;
68
69#[derive(Clone, Debug)]
70pub struct DeviationBlockConfig {
71 pub degree: usize,
72 pub num_internal_knots: usize,
73 pub penalty_order: usize,
74 pub penalty_orders: Vec<usize>,
75 pub double_penalty: bool,
76 pub monotonicity_eps: f64,
77}
78
79impl Default for DeviationBlockConfig {
80 fn default() -> Self {
81 WigglePenaltyConfig::cubic_triple_operator_default().into()
82 }
83}
84
85impl DeviationBlockConfig {
86 pub fn triple_penalty_default() -> Self {
87 Self::default()
88 }
89}
90
91impl From<WigglePenaltyConfig> for DeviationBlockConfig {
92 fn from(cfg: WigglePenaltyConfig) -> Self {
93 let penalty_order = *cfg.penalty_orders.iter().max().unwrap_or(&2);
94 Self {
95 degree: cfg.degree,
96 num_internal_knots: cfg.num_internal_knots,
97 penalty_order,
98 penalty_orders: cfg.penalty_orders,
99 double_penalty: cfg.double_penalty,
100 monotonicity_eps: cfg.monotonicity_eps,
101 }
102 }
103}
104
105#[derive(Clone)]
106pub(crate) struct DeviationPrepared {
107 pub(crate) block: ParameterBlockInput,
108 pub(crate) runtime: DeviationRuntime,
109}
110
111impl std::fmt::Debug for DeviationPrepared {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("DeviationPrepared").finish_non_exhaustive()
114 }
115}
116
117#[derive(Clone)]
118pub struct BernoulliMarginalSlopeTermSpec {
119 pub y: Array1<f64>,
120 pub weights: Array1<f64>,
121 pub z: Array1<f64>,
122 pub base_link: InverseLink,
123 pub marginalspec: TermCollectionSpec,
124 pub logslopespec: TermCollectionSpec,
125 pub marginal_offset: Array1<f64>,
126 pub logslope_offset: Array1<f64>,
127 pub frailty: FrailtySpec,
139 pub score_warp: Option<DeviationBlockConfig>,
140 pub link_dev: Option<DeviationBlockConfig>,
141 pub latent_z_policy: LatentZPolicy,
142 pub score_influence_jacobian: Option<Array2<f64>>,
151}
152
153pub struct BernoulliMarginalSlopeFitResult {
154 pub fit: UnifiedFitResult,
155 pub marginalspec_resolved: TermCollectionSpec,
156 pub logslopespec_resolved: TermCollectionSpec,
157 pub marginal_design: TermCollectionDesign,
158 pub logslope_design: TermCollectionDesign,
159 pub baseline_marginal: f64,
160 pub baseline_logslope: f64,
161 pub z_normalization: LatentZNormalization,
162 pub latent_measure: LatentMeasureKind,
163 pub score_warp_runtime: Option<DeviationRuntime>,
164 pub link_dev_runtime: Option<DeviationRuntime>,
165 pub gaussian_frailty_sd: Option<f64>,
167 pub cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning>,
174 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
191 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
202}
203
204#[derive(Clone, Debug)]
205pub enum LatentZCheckMode {
206 Strict,
207 WarnOnly,
208 Off,
209}
210
211#[derive(Clone, Debug)]
212pub enum LatentZNormalizationMode {
213 None,
214 FitWeighted,
215 Frozen { mean: f64, sd: f64 },
216}
217
218pub const DEFAULT_EMPIRICAL_LATENT_GRID_SIZE: usize = 65;
219pub(crate) const AUTO_Z_NORMAL_SKEW_TOL: f64 = 0.10;
220pub(crate) const AUTO_Z_NORMAL_KURT_TOL: f64 = 0.25;
221pub(crate) const AUTO_Z_NORMAL_KS_TOL: f64 = 0.025;
222pub(crate) const AUTO_Z_NORMAL_MAX_ABS: f64 = 8.0;
223pub(crate) const AUTO_Z_NORMAL_TAIL_SIGMA_INNER: f64 = 4.0;
228pub(crate) const AUTO_Z_NORMAL_TAIL_SIGMA_OUTER: f64 = 6.0;
231pub(crate) const AUTO_Z_NORMAL_TAIL_MASS_SLACK: f64 = 2.0;
235pub(crate) const AUTO_Z_NORMAL_TAIL_FLOOR_INNER: f64 = 1e-5;
238pub(crate) const AUTO_Z_NORMAL_TAIL_FLOOR_OUTER: f64 = 1e-8;
241pub(crate) const AUTO_Z_CONDITIONAL_RAO_ALPHA: f64 = 1.0e-3;
250pub(crate) const AUTO_Z_CONDITIONAL_RIDGE_REL: f64 = 1.0e-8;
256pub(crate) const AUTO_Z_CONDITIONAL_VAR_FLOOR_FRAC: f64 = 1.0e-3;
261
262#[derive(Clone, Copy, Debug, PartialEq, Eq)]
263pub enum LatentMeasureSpec {
264 Auto { grid_size: usize },
265 StandardNormal,
266 GlobalEmpirical { grid_size: usize },
267}
268
269impl LatentMeasureSpec {
270 pub fn auto_default() -> Self {
271 Self::Auto {
272 grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
273 }
274 }
275}
276
277impl Default for LatentMeasureSpec {
278 fn default() -> Self {
279 Self::auto_default()
280 }
281}
282
283#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
284pub struct EmpiricalZGrid {
285 pub nodes: Vec<f64>,
286 pub weights: Vec<f64>,
287}
288
289impl EmpiricalZGrid {
290 pub fn new(nodes: Vec<f64>, weights: Vec<f64>, context: &str) -> Result<Self, String> {
296 validate_empirical_z_grid(&nodes, &weights, context)?;
297 Ok(Self { nodes, weights })
298 }
299
300 #[inline]
304 pub fn pairs(&self) -> impl Iterator<Item = (f64, f64)> + '_ {
305 self.nodes.iter().copied().zip(self.weights.iter().copied())
306 }
307}
308
309#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
310#[serde(tag = "kind", rename_all = "kebab-case")]
311#[derive(Default)]
312pub enum LatentMeasureKind {
313 #[default]
314 StandardNormal,
315 GlobalEmpirical {
316 grid: EmpiricalZGrid,
317 },
318 LocalEmpirical {
319 feature_cols: Vec<usize>,
320 #[serde(default)]
321 input_scales: Option<Vec<f64>>,
322 centers: Vec<Vec<f64>>,
323 grids: Vec<EmpiricalZGrid>,
324 top_k: usize,
325 bandwidth: f64,
326 #[serde(skip)]
327 train_row_mixtures: Arc<Vec<Vec<(usize, f64)>>>,
328 },
329}
330
331impl LatentMeasureKind {
332 pub fn validate(&self, context: &str) -> Result<(), String> {
333 match self {
334 Self::StandardNormal => Ok(()),
335 Self::GlobalEmpirical { grid } => {
336 validate_empirical_z_grid(&grid.nodes, &grid.weights, context)
337 }
338 Self::LocalEmpirical {
339 feature_cols,
340 input_scales,
341 centers,
342 grids,
343 top_k,
344 bandwidth,
345 ..
346 } => {
347 if feature_cols.is_empty() {
348 return Err(format!(
349 "{context} local empirical latent measure needs feature columns"
350 ));
351 }
352 if centers.is_empty() {
353 return Err(format!(
354 "{context} local empirical latent measure needs centers"
355 ));
356 }
357 if centers.len() != grids.len() {
358 return Err(format!(
359 "{context} local empirical latent measure center/grid length mismatch: centers={}, grids={}",
360 centers.len(),
361 grids.len()
362 ));
363 }
364 if *top_k == 0 || *top_k > centers.len() {
365 return Err(format!(
366 "{context} local empirical latent measure top_k must be in 1..={}, got {top_k}",
367 centers.len()
368 ));
369 }
370 if !(*bandwidth).is_finite() || *bandwidth <= 0.0 {
371 return Err(format!(
372 "{context} local empirical latent measure bandwidth must be finite and positive, got {bandwidth}"
373 ));
374 }
375 if let Some(scales) = input_scales.as_ref() {
376 if scales.len() != feature_cols.len() {
377 return Err(format!(
378 "{context} local empirical latent measure input scale dimension mismatch: scales={}, features={}",
379 scales.len(),
380 feature_cols.len()
381 ));
382 }
383 for (scale_idx, scale) in scales.iter().enumerate() {
384 if !(scale.is_finite() && *scale > 0.0) {
385 return Err(format!(
386 "{context} local empirical latent measure input scale {scale_idx} must be finite and positive, got {scale}"
387 ));
388 }
389 }
390 }
391 for (center_idx, center) in centers.iter().enumerate() {
392 if center.len() != feature_cols.len() {
393 return Err(format!(
394 "{context} local empirical latent center {center_idx} dimension mismatch: got {}, expected {}",
395 center.len(),
396 feature_cols.len()
397 ));
398 }
399 if center.iter().any(|value| !value.is_finite()) {
400 return Err(format!(
401 "{context} local empirical latent center {center_idx} has non-finite coordinates"
402 ));
403 }
404 }
405 for (grid_idx, grid) in grids.iter().enumerate() {
406 validate_empirical_z_grid(
407 &grid.nodes,
408 &grid.weights,
409 &format!("{context} local empirical grid {grid_idx}"),
410 )?;
411 }
412 Ok(())
413 }
414 }
415 }
416
417 pub(crate) fn is_empirical(&self) -> bool {
418 matches!(
419 self,
420 Self::GlobalEmpirical { .. } | Self::LocalEmpirical { .. }
421 )
422 }
423
424 pub(crate) fn empirical_grid_for_training_row(
432 &self,
433 row: usize,
434 ) -> Result<Option<std::borrow::Cow<'_, EmpiricalZGrid>>, String> {
435 match self {
436 Self::StandardNormal => Ok(None),
437 Self::GlobalEmpirical { grid } => Ok(Some(std::borrow::Cow::Borrowed(grid))),
438 Self::LocalEmpirical {
439 grids,
440 train_row_mixtures,
441 ..
442 } => {
443 let mixture = train_row_mixtures.get(row).ok_or_else(|| {
444 format!(
445 "local empirical latent measure is missing training mixture for row {row}"
446 )
447 })?;
448 Ok(Some(std::borrow::Cow::Owned(combine_empirical_grids(
449 grids, mixture,
450 )?)))
451 }
452 }
453 }
454}
455
456pub(crate) fn validate_empirical_z_grid(
457 nodes: &[f64],
458 weights: &[f64],
459 context: &str,
460) -> Result<(), String> {
461 if nodes.len() != weights.len() {
462 return Err(format!(
463 "{context} empirical latent measure node/weight length mismatch: nodes={}, weights={}",
464 nodes.len(),
465 weights.len()
466 ));
467 }
468 if nodes.len() < 2 {
469 return Err(format!(
470 "{context} empirical latent measure requires at least two nodes"
471 ));
472 }
473 let mut total = 0.0;
474 for (idx, (&node, &weight)) in nodes.iter().zip(weights.iter()).enumerate() {
475 if !node.is_finite() {
476 return Err(format!(
477 "{context} empirical latent measure node {idx} is non-finite ({node})"
478 ));
479 }
480 if !(weight.is_finite() && weight > 0.0) {
481 return Err(format!(
482 "{context} empirical latent measure weight {idx} must be finite and positive, got {weight}"
483 ));
484 }
485 total += weight;
486 }
487 if !(total.is_finite() && (total - 1.0).abs() <= 1e-8) {
488 return Err(format!(
489 "{context} empirical latent measure weights must sum to 1, got {total}"
490 ));
491 }
492 Ok(())
493}
494
495pub(crate) fn combine_empirical_grids(
496 grids: &[EmpiricalZGrid],
497 mixture: &[(usize, f64)],
498) -> Result<EmpiricalZGrid, String> {
499 if mixture.is_empty() {
500 return Err("local empirical latent measure row mixture is empty".to_string());
501 }
502 let mut nodes = Vec::new();
503 let mut weights = Vec::new();
504 for &(grid_idx, grid_weight) in mixture {
505 if !grid_weight.is_finite() || grid_weight <= 0.0 {
506 return Err(format!(
507 "local empirical latent mixture weight must be finite and positive, got {grid_weight}"
508 ));
509 }
510 let grid = grids.get(grid_idx).ok_or_else(|| {
511 format!("local empirical latent mixture references missing grid {grid_idx}")
512 })?;
513 for (node, weight) in grid.pairs() {
514 nodes.push(node);
515 weights.push(grid_weight * weight);
516 }
517 }
518 let total = weights.iter().copied().sum::<f64>();
519 if !(total.is_finite() && total > 0.0) {
520 return Err(
521 "local empirical latent combined grid has non-positive total weight".to_string(),
522 );
523 }
524 for weight in &mut weights {
525 *weight /= total;
526 }
527 EmpiricalZGrid::new(nodes, weights, "local empirical latent combined grid")
528}
529
530#[derive(Clone, Debug)]
531pub struct LatentZPolicy {
532 pub check_mode: LatentZCheckMode,
533 pub normalization: LatentZNormalizationMode,
534 pub latent_measure: LatentMeasureSpec,
535 pub mean_tol_multiplier: f64,
536 pub sd_tol_multiplier: f64,
537 pub max_abs_skew: f64,
538 pub max_abs_excess_kurtosis: f64,
539}
540
541impl LatentZPolicy {
542 pub fn frozen_transformation_normal() -> Self {
543 Self {
556 check_mode: LatentZCheckMode::WarnOnly,
557 normalization: LatentZNormalizationMode::Frozen { mean: 0.0, sd: 1.0 },
558 latent_measure: LatentMeasureSpec::auto_default(),
559 mean_tol_multiplier: 4.0,
560 sd_tol_multiplier: 4.0,
561 max_abs_skew: 4.0,
562 max_abs_excess_kurtosis: 20.0,
563 }
564 }
565
566 pub fn exploratory_fit_weighted() -> Self {
567 Self {
568 check_mode: LatentZCheckMode::WarnOnly,
569 normalization: LatentZNormalizationMode::FitWeighted,
570 latent_measure: LatentMeasureSpec::auto_default(),
571 mean_tol_multiplier: 8.0,
572 sd_tol_multiplier: 8.0,
573 max_abs_skew: 4.0,
574 max_abs_excess_kurtosis: 20.0,
575 }
576 }
577}
578
579impl Default for LatentZPolicy {
580 fn default() -> Self {
581 Self::frozen_transformation_normal()
582 }
583}
584
585#[derive(Clone, Copy, Debug, PartialEq)]
586pub struct LatentZNormalization {
587 pub mean: f64,
588 pub sd: f64,
589}
590
591impl LatentZNormalization {
592 pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, String> {
593 if !(self.mean.is_finite() && self.sd.is_finite() && self.sd > BMS_VARIANCE_FLOOR) {
594 return Err(format!(
595 "{context} requires finite latent z normalization with sd > {BMS_VARIANCE_FLOOR:e}; got mean={} sd={}",
596 self.mean, self.sd
597 ));
598 }
599 if z.iter().any(|value| !value.is_finite()) {
600 return Err(format!("{context} requires finite z values"));
601 }
602 Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
603 }
604}
605
606#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
634pub struct LatentZRankIntCalibration {
635 pub sorted_z: Vec<f64>,
638 pub weighted_cdf: Vec<f64>,
642 pub post_mean: f64,
645 pub post_sd: f64,
648}
649
650impl LatentZRankIntCalibration {
651 pub fn fit(z: &Array1<f64>, weights: &Array1<f64>) -> Result<Self, String> {
664 if z.len() != weights.len() {
665 return Err(format!(
666 "rank-INT calibration: z length {} != weights length {}",
667 z.len(),
668 weights.len()
669 ));
670 }
671 if z.is_empty() {
672 return Err("rank-INT calibration requires at least one observation".to_string());
673 }
674 let w_total = weights.iter().copied().sum::<f64>();
675 if !(w_total.is_finite() && w_total > 0.0) {
676 return Err(format!(
677 "rank-INT calibration requires positive finite total weight, got {w_total}"
678 ));
679 }
680 for (idx, value) in z.iter().enumerate() {
681 if !value.is_finite() {
682 return Err(format!(
683 "rank-INT calibration: z[{idx}] = {value} not finite"
684 ));
685 }
686 }
687 for (idx, weight) in weights.iter().enumerate() {
688 if !(weight.is_finite() && *weight >= 0.0) {
689 return Err(format!(
690 "rank-INT calibration: weight[{idx}] = {weight} not finite/non-negative"
691 ));
692 }
693 }
694 let mut order: Vec<usize> = (0..z.len()).collect();
695 order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap_or(std::cmp::Ordering::Equal));
696
697 let mut sorted_z: Vec<f64> = Vec::with_capacity(z.len());
698 let mut weighted_cdf: Vec<f64> = Vec::with_capacity(z.len());
699 let denom = w_total + 0.25;
700 let eps = 0.5 / w_total.max(1.0);
701 let mut cum_w = 0.0_f64;
702 let mut last_z: Option<f64> = None;
703 for &idx in &order {
704 cum_w += weights[idx];
705 let zi = z[idx];
706 if let Some(prev) = last_z
708 && zi == prev
709 {
710 if let Some(slot) = weighted_cdf.last_mut() {
711 let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
712 *slot = p;
713 }
714 continue;
715 }
716 let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
717 sorted_z.push(zi);
718 weighted_cdf.push(p);
719 last_z = Some(zi);
720 }
721
722 let mut sum_wz = 0.0_f64;
725 let mut sum_w = 0.0_f64;
726 for &idx in &order {
727 let zi = z[idx];
728 let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
729 sum_wz += weights[idx] * calibrated;
730 sum_w += weights[idx];
731 }
732 let post_mean = if sum_w > 0.0 { sum_wz / sum_w } else { 0.0 };
733 let mut sum_w_dev = 0.0_f64;
734 for &idx in &order {
735 let zi = z[idx];
736 let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
737 let d = calibrated - post_mean;
738 sum_w_dev += weights[idx] * d * d;
739 }
740 let post_sd = if sum_w > 0.0 {
741 (sum_w_dev / sum_w).sqrt()
742 } else {
743 1.0
744 };
745
746 Ok(Self {
747 sorted_z,
748 weighted_cdf,
749 post_mean,
750 post_sd,
751 })
752 }
753
754 pub fn apply_to_training(&self, z: &Array1<f64>) -> Result<Array1<f64>, String> {
758 if self.sorted_z.is_empty() {
759 return Err("rank-INT calibration has no knots".to_string());
760 }
761 let mut out = Array1::<f64>::zeros(z.len());
762 for (idx, &zi) in z.iter().enumerate() {
763 if !zi.is_finite() {
764 return Err(format!(
765 "rank-INT calibration apply: z[{idx}] = {zi} not finite"
766 ));
767 }
768 out[idx] = self.apply_at_predict(zi);
769 }
770 Ok(out)
771 }
772
773 pub fn apply_at_predict(&self, z: f64) -> f64 {
781 Self::apply_with_knots(z, &self.sorted_z, &self.weighted_cdf)
782 }
783
784 pub(crate) fn apply_with_knots(z: f64, sorted_z: &[f64], weighted_cdf: &[f64]) -> f64 {
785 assert_eq!(sorted_z.len(), weighted_cdf.len());
786 assert!(!sorted_z.is_empty());
787 let n = sorted_z.len();
788 let p = if z <= sorted_z[0] {
789 weighted_cdf[0]
790 } else if z >= sorted_z[n - 1] {
791 weighted_cdf[n - 1]
792 } else {
793 let mut lo = 0usize;
795 let mut hi = n - 1;
796 while hi - lo > 1 {
797 let mid = (lo + hi) / 2;
798 if sorted_z[mid] <= z {
799 lo = mid;
800 } else {
801 hi = mid;
802 }
803 }
804 let z_lo = sorted_z[lo];
805 let z_hi = sorted_z[hi];
806 let p_lo = weighted_cdf[lo];
807 let p_hi = weighted_cdf[hi];
808 if z_hi == z_lo {
809 p_hi
810 } else {
811 let t = (z - z_lo) / (z_hi - z_lo);
812 p_lo + t * (p_hi - p_lo)
813 }
814 };
815 standard_normal_quantile(p).unwrap_or_else(|_| if p < 0.5 { -8.0 } else { 8.0 })
817 }
818}
819
820#[derive(Clone, Debug)]
825pub enum LatentMeasureCalibration {
826 None,
827 RankInverseNormal(LatentZRankIntCalibration),
828 ConditionalLocationScale(LatentZConditionalCalibration),
829}
830
831#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
853pub struct LatentZConditionalCalibration {
854 pub mean_coeffs: Vec<f64>,
858 pub var_coeffs: Vec<f64>,
863 pub basis_ncols: usize,
867 pub var_floor: f64,
870 pub global_var: f64,
873 pub post_mean: f64,
875 pub post_sd: f64,
877 pub mean_cov: Array2<f64>,
885 pub var_cov: Array2<f64>,
890}
891
892impl LatentZConditionalCalibration {
893 #[inline]
894 pub(crate) fn affine(coeffs: &[f64], a_row: ArrayView1<'_, f64>) -> f64 {
895 let mut acc = coeffs[0];
896 for (c, &x) in coeffs[1..].iter().zip(a_row.iter()) {
897 acc += c * x;
898 }
899 acc
900 }
901
902 pub(crate) fn conditional_mean(&self, a_row: ArrayView1<'_, f64>) -> f64 {
903 Self::affine(&self.mean_coeffs, a_row)
904 }
905
906 pub(crate) fn conditional_var(&self, a_row: ArrayView1<'_, f64>) -> f64 {
907 if self.var_coeffs.is_empty() {
908 self.global_var.max(self.var_floor)
909 } else {
910 Self::affine(&self.var_coeffs, a_row).max(self.var_floor)
911 }
912 }
913
914 pub fn apply(
918 &self,
919 z: ArrayView1<'_, f64>,
920 a_block: ArrayView2<'_, f64>,
921 ) -> Result<Array1<f64>, String> {
922 if a_block.ncols() != self.basis_ncols {
923 return Err(format!(
924 "conditional latent calibration expects {} basis columns, got {}",
925 self.basis_ncols,
926 a_block.ncols()
927 ));
928 }
929 if a_block.nrows() != z.len() {
930 return Err(format!(
931 "conditional latent calibration row mismatch: z={}, basis rows={}",
932 z.len(),
933 a_block.nrows()
934 ));
935 }
936 if self.mean_coeffs.len() != self.basis_ncols + 1 {
937 return Err(format!(
938 "conditional latent calibration mean coefficient length {} != basis_ncols+1 ({})",
939 self.mean_coeffs.len(),
940 self.basis_ncols + 1
941 ));
942 }
943 let mut out = Array1::<f64>::zeros(z.len());
944 for i in 0..z.len() {
945 let a_row = a_block.row(i);
946 if !z[i].is_finite() {
947 return Err(format!(
948 "conditional latent calibration: z[{i}] = {} not finite",
949 z[i]
950 ));
951 }
952 let m = self.conditional_mean(a_row);
953 let v = self.conditional_var(a_row);
954 if !(v.is_finite() && v > 0.0) {
955 return Err(format!(
956 "conditional latent calibration produced non-positive variance {v} at row {i}"
957 ));
958 }
959 let zeta = (z[i] - m) / v.sqrt();
960 if !zeta.is_finite() {
961 return Err(format!(
962 "conditional latent calibration produced non-finite zeta at row {i}"
963 ));
964 }
965 out[i] = zeta;
966 }
967 Ok(out)
968 }
969
970 pub fn theta1_dim(&self) -> usize {
975 self.mean_coeffs.len() + self.var_coeffs.len()
976 }
977
978 pub fn zeta_theta1_jacobian_row(&self, z: f64, a_row: ArrayView1<'_, f64>) -> Vec<f64> {
991 let m = self.conditional_mean(a_row);
992 let v = self.conditional_var(a_row);
993 let inv_sqrt_v = 1.0 / v.sqrt();
994 let mut out = Vec::with_capacity(self.theta1_dim());
996 let dzeta_dm = -inv_sqrt_v;
997 out.push(dzeta_dm); for &x in a_row.iter() {
999 out.push(dzeta_dm * x);
1000 }
1001 if !self.var_coeffs.is_empty() {
1002 let raw_v = Self::affine(&self.var_coeffs, a_row);
1005 let dzeta_dv = if raw_v > self.var_floor {
1006 let zeta = (z - m) * inv_sqrt_v;
1007 -zeta / (2.0 * v)
1008 } else {
1009 0.0
1010 };
1011 out.push(dzeta_dv);
1012 for &x in a_row.iter() {
1013 out.push(dzeta_dv * x);
1014 }
1015 }
1016 out
1017 }
1018
1019 pub fn theta1_covariance(&self) -> Array2<f64> {
1027 let dm = self.mean_coeffs.len();
1028 let dv = self.var_coeffs.len();
1029 let mut v1 = Array2::<f64>::zeros((dm + dv, dm + dv));
1030 v1.slice_mut(s![..dm, ..dm]).assign(&self.mean_cov);
1031 if dv > 0 {
1032 v1.slice_mut(s![dm.., dm..]).assign(&self.var_cov);
1033 }
1034 v1
1035 }
1036
1037 pub fn generated_regressor_term(&self, hbeta_inv_g: ArrayView2<'_, f64>) -> Array2<f64> {
1051 let v1 = self.theta1_covariance();
1052 hbeta_inv_g.dot(&v1).dot(&hbeta_inv_g.t())
1053 }
1054
1055 pub fn generated_regressor_correction(
1087 &self,
1088 score_zeta_sensitivity: ArrayView2<'_, f64>,
1089 z: ArrayView1<'_, f64>,
1090 a_block: ArrayView2<'_, f64>,
1091 vb: ArrayView2<'_, f64>,
1092 ) -> Result<Array2<f64>, String> {
1093 let n = score_zeta_sensitivity.nrows();
1094 let p_beta = score_zeta_sensitivity.ncols();
1095 if z.len() != n || a_block.nrows() != n {
1096 return Err(format!(
1097 "generated_regressor_correction row mismatch: score_zeta_sensitivity rows={n}, \
1098 z={}, a_block rows={}",
1099 z.len(),
1100 a_block.nrows()
1101 ));
1102 }
1103 if a_block.ncols() != self.basis_ncols {
1104 return Err(format!(
1105 "generated_regressor_correction expects {} basis columns, got {}",
1106 self.basis_ncols,
1107 a_block.ncols()
1108 ));
1109 }
1110 if vb.nrows() != p_beta || vb.ncols() != p_beta {
1111 return Err(format!(
1112 "generated_regressor_correction: vb must be {p_beta}×{p_beta}, got {}×{}",
1113 vb.nrows(),
1114 vb.ncols()
1115 ));
1116 }
1117 let j_mat = self.build_zeta_theta1_jacobian(z, a_block);
1128 let vb_g = self.beta_theta1_sensitivity(score_zeta_sensitivity, j_mat.view(), vb)?;
1129 Ok(self.generated_regressor_term(vb_g.view()))
1130 }
1131
1132 fn build_zeta_theta1_jacobian(
1137 &self,
1138 z: ArrayView1<'_, f64>,
1139 a_block: ArrayView2<'_, f64>,
1140 ) -> Array2<f64> {
1141 let n = a_block.nrows();
1142 let dim_theta1 = self.theta1_dim();
1143 let mut j_mat = Array2::<f64>::zeros((n, dim_theta1));
1144 for i in 0..n {
1145 let j_zeta_row = self.zeta_theta1_jacobian_row(z[i], a_block.row(i));
1146 assert_eq!(
1147 j_zeta_row.len(),
1148 dim_theta1,
1149 "J_zeta row width must match the first-stage hyperparameter dimension"
1150 );
1151 let mut dst = j_mat.row_mut(i);
1152 for (slot, jz) in dst.iter_mut().zip(j_zeta_row.into_iter()) {
1153 *slot = jz;
1154 }
1155 }
1156 j_mat
1157 }
1158
1159 fn beta_theta1_sensitivity(
1174 &self,
1175 score_zeta_sensitivity: ArrayView2<'_, f64>,
1176 j_zeta: ArrayView2<'_, f64>,
1177 vb: ArrayView2<'_, f64>,
1178 ) -> Result<Array2<f64>, String> {
1179 let g = gam_linalg::faer_ndarray::fast_atb(&score_zeta_sensitivity, &j_zeta);
1181 Ok(vb.dot(&g))
1184 }
1185}
1186
1187pub(crate) fn weighted_ridge_sandwich_cov(
1206 basis: ArrayView2<'_, f64>,
1207 residuals: &[f64],
1208 weights: ArrayView1<'_, f64>,
1209 normal_matrix: &Array2<f64>,
1210) -> Result<Array2<f64>, String> {
1211 let n = basis.nrows();
1212 let p = basis.ncols();
1213 if residuals.len() != n || weights.len() != n {
1214 return Err(format!(
1215 "weighted ridge sandwich length mismatch: rows={n}, residuals={}, weights={}",
1216 residuals.len(),
1217 weights.len()
1218 ));
1219 }
1220 if normal_matrix.nrows() != p || normal_matrix.ncols() != p {
1221 return Err(format!(
1222 "weighted ridge sandwich normal-matrix shape mismatch: basis cols={p}, normal {}x{}",
1223 normal_matrix.nrows(),
1224 normal_matrix.ncols()
1225 ));
1226 }
1227 let mut b = basis.to_owned();
1234 for i in 0..n {
1235 let wi = weights[i];
1236 let ri = residuals[i];
1237 let scale = wi * ri;
1238 if scale == 0.0 {
1239 b.row_mut(i).fill(0.0);
1240 continue;
1241 }
1242 b.row_mut(i).iter_mut().for_each(|value| *value *= scale);
1243 }
1244 let meat = gam_linalg::faer_ndarray::fast_ata(&b);
1245 let mut m_sym = normal_matrix.clone();
1249 gam_linalg::matrix::symmetrize_in_place(&mut m_sym);
1250 let scale: Vec<f64> = (0..p)
1273 .map(|j| 1.0 / m_sym[[j, j]].max(f64::MIN_POSITIVE).sqrt())
1274 .collect();
1275 let mut m_scaled = m_sym;
1276 let mut meat_scaled = meat;
1277 for i in 0..p {
1278 for j in 0..p {
1279 let s = scale[i] * scale[j];
1280 m_scaled[[i, j]] *= s;
1281 meat_scaled[[i, j]] *= s;
1282 }
1283 }
1284 let (_rank, m_pinv) =
1285 gam_linalg::utils::block_penalty_rank_and_pinv(&m_scaled).map_err(|e| {
1286 format!("conditional latent calibration sandwich pseudo-inverse failed: {e}")
1287 })?;
1288 let mut cov = m_pinv.dot(&meat_scaled).dot(&m_pinv);
1289 for i in 0..p {
1291 for j in 0..p {
1292 cov[[i, j]] *= scale[i] * scale[j];
1293 }
1294 }
1295 if cov.iter().any(|v| !v.is_finite()) {
1296 return Err("conditional latent calibration sandwich covariance is non-finite".to_string());
1297 }
1298 Ok(cov)
1299}
1300
1301pub(crate) fn weighted_mean(
1303 values: &[f64],
1304 weights: ArrayView1<'_, f64>,
1305 total_weight: f64,
1306) -> f64 {
1307 values
1308 .iter()
1309 .zip(weights.iter())
1310 .map(|(&v, &w)| w * v)
1311 .sum::<f64>()
1312 / total_weight
1313}
1314
1315pub(crate) fn robust_conditional_score_pvalue(
1326 a_centered: ArrayView2<'_, f64>,
1327 u: &[f64],
1328 weights: ArrayView1<'_, f64>,
1329) -> Result<Option<f64>, String> {
1330 let n = a_centered.nrows();
1331 let r = a_centered.ncols();
1332 if r == 0 || n == 0 {
1333 return Ok(None);
1334 }
1335 if u.len() != n || weights.len() != n {
1336 return Err(format!(
1337 "conditional score test length mismatch: rows={n}, u={}, weights={}",
1338 u.len(),
1339 weights.len()
1340 ));
1341 }
1342 let mut b = a_centered.to_owned();
1353 for i in 0..n {
1354 let wi = weights[i];
1355 let scale = if wi > 0.0 { wi * u[i] } else { 0.0 };
1356 if scale == 0.0 {
1357 b.row_mut(i).fill(0.0);
1358 continue;
1359 }
1360 b.row_mut(i).iter_mut().for_each(|value| *value *= scale);
1361 }
1362 let s = b.sum_axis(ndarray::Axis(0));
1363 let omega = gam_linalg::faer_ndarray::fast_ata(&b);
1364 if !s.iter().all(|v| v.is_finite()) || !omega.iter().all(|v| v.is_finite()) {
1365 return Ok(None);
1366 }
1367 let (rank, omega_pinv) = gam_linalg::utils::block_penalty_rank_and_pinv(&omega)
1368 .map_err(|e| format!("conditional score test pseudo-inverse failed: {e}"))?;
1369 if rank == 0 {
1370 return Ok(None);
1371 }
1372 let d_stat = s.dot(&omega_pinv.dot(&s));
1373 if !(d_stat.is_finite() && d_stat >= 0.0) {
1374 return Ok(None);
1375 }
1376 let p_lower = statrs::function::gamma::gamma_lr(rank as f64 / 2.0, d_stat / 2.0);
1378 let p_value = (1.0 - p_lower).clamp(0.0, 1.0);
1379 Ok(Some(p_value))
1380}
1381
1382pub(crate) fn fit_conditional_latent_calibration_if_needed(
1389 z: &Array1<f64>,
1390 weights: &Array1<f64>,
1391 a_block: ArrayView2<'_, f64>,
1392) -> Result<Option<LatentZConditionalCalibration>, String> {
1393 let n = z.len();
1394 let p = a_block.ncols();
1395 if n != weights.len() {
1396 return Err(format!(
1397 "conditional latent gate length mismatch: z={n}, weights={}",
1398 weights.len()
1399 ));
1400 }
1401 if a_block.nrows() != n {
1402 return Err(format!(
1403 "conditional latent gate row mismatch: z={n}, basis rows={}",
1404 a_block.nrows()
1405 ));
1406 }
1407 if p == 0 {
1408 return Ok(None);
1409 }
1410 let total_weight = weights.iter().copied().sum::<f64>();
1411 if !(total_weight.is_finite() && total_weight > 0.0) {
1412 return Ok(None);
1413 }
1414 if z.iter().any(|v| !v.is_finite()) || a_block.iter().any(|v| !v.is_finite()) {
1415 return Ok(None);
1416 }
1417
1418 let z_mean = z
1419 .iter()
1420 .zip(weights.iter())
1421 .map(|(&zi, &wi)| wi * zi)
1422 .sum::<f64>()
1423 / total_weight;
1424 let global_var = z
1425 .iter()
1426 .zip(weights.iter())
1427 .map(|(&zi, &wi)| wi * (zi - z_mean) * (zi - z_mean))
1428 .sum::<f64>()
1429 / total_weight;
1430 if !(global_var.is_finite() && global_var > 0.0) {
1431 return Ok(None);
1432 }
1433
1434 let mut a_centered = a_block.to_owned();
1439 for j in 0..p {
1440 let col = a_block.column(j);
1441 let col_mean = col
1442 .iter()
1443 .zip(weights.iter())
1444 .map(|(&v, &w)| w * v)
1445 .sum::<f64>()
1446 / total_weight;
1447 a_centered.column_mut(j).mapv_inplace(|v| v - col_mean);
1448 }
1449
1450 let u_mean: Vec<f64> = z.iter().map(|&zi| zi - z_mean).collect();
1452 let p_mean = robust_conditional_score_pvalue(a_centered.view(), &u_mean, weights.view())?;
1453 let u_var: Vec<f64> = u_mean.iter().map(|&e| e * e - global_var).collect();
1455 let p_var = robust_conditional_score_pvalue(a_centered.view(), &u_var, weights.view())?;
1456
1457 let mean_fires = p_mean.is_some_and(|p| p < AUTO_Z_CONDITIONAL_RAO_ALPHA);
1458 let var_fires = p_var.is_some_and(|p| p < AUTO_Z_CONDITIONAL_RAO_ALPHA);
1459 if !mean_fires && !var_fires {
1460 return Ok(None);
1461 }
1462
1463 let basis = build_intercept_basis(a_block);
1470 let mut penalty = Array2::<f64>::zeros((basis.ncols(), basis.ncols()));
1477 for j in 0..basis.ncols() {
1478 let diag_jj = basis
1479 .column(j)
1480 .iter()
1481 .zip(weights.iter())
1482 .map(|(&x, &w)| w * x * x)
1483 .sum::<f64>()
1484 .max(f64::MIN_POSITIVE);
1485 penalty[[j, j]] = diag_jj;
1486 }
1487 let z_col = z.view().insert_axis(ndarray::Axis(1));
1488 let (mean_coeffs_mat, mean_fitted) = gam_linalg::utils::gaussian_weighted_ridge(
1489 basis.view(),
1490 z_col,
1491 penalty.view(),
1492 weights.view(),
1493 AUTO_Z_CONDITIONAL_RIDGE_REL,
1494 )?;
1495 let mean_coeffs: Vec<f64> = mean_coeffs_mat.column(0).to_vec();
1496
1497 let normal_matrix = {
1503 let mut wa = basis.to_owned();
1504 for i in 0..wa.nrows() {
1505 let wi = weights[i];
1506 wa.row_mut(i).iter_mut().for_each(|value| *value *= wi);
1507 }
1508 let mut m = basis.t().dot(&wa);
1509 m += &(penalty.to_owned() * AUTO_Z_CONDITIONAL_RIDGE_REL);
1510 m
1511 };
1512 let mean_residuals: Vec<f64> = z
1513 .iter()
1514 .zip(mean_fitted.column(0).iter())
1515 .map(|(&zi, &mi)| zi - mi)
1516 .collect();
1517 let mean_cov = weighted_ridge_sandwich_cov(
1518 basis.view(),
1519 &mean_residuals,
1520 weights.view(),
1521 &normal_matrix,
1522 )?;
1523
1524 let var_floor = (AUTO_Z_CONDITIONAL_VAR_FLOOR_FRAC * global_var).max(f64::MIN_POSITIVE);
1525 let (var_coeffs, var_cov): (Vec<f64>, Array2<f64>) = if var_fires {
1526 let resid_sq: Array1<f64> = mean_residuals.iter().map(|&e| e * e).collect();
1529 let resid_col = resid_sq.view().insert_axis(ndarray::Axis(1));
1530 let (var_coeffs_mat, var_fitted) = gam_linalg::utils::gaussian_weighted_ridge(
1531 basis.view(),
1532 resid_col,
1533 penalty.view(),
1534 weights.view(),
1535 AUTO_Z_CONDITIONAL_RIDGE_REL,
1536 )?;
1537 let var_residuals: Vec<f64> = resid_sq
1542 .iter()
1543 .zip(var_fitted.column(0).iter())
1544 .map(|(&si, &vi)| si - vi)
1545 .collect();
1546 let cov = weighted_ridge_sandwich_cov(
1547 basis.view(),
1548 &var_residuals,
1549 weights.view(),
1550 &normal_matrix,
1551 )?;
1552 (var_coeffs_mat.column(0).to_vec(), cov)
1553 } else {
1554 (Vec::new(), Array2::<f64>::zeros((0, 0)))
1555 };
1556
1557 let mut calibration = LatentZConditionalCalibration {
1558 mean_coeffs,
1559 var_coeffs,
1560 basis_ncols: p,
1561 var_floor,
1562 global_var,
1563 post_mean: 0.0,
1564 post_sd: 1.0,
1565 mean_cov,
1566 var_cov,
1567 };
1568
1569 let calibrated = calibration.apply(z.view(), a_block)?;
1571 let post_mean = weighted_mean(calibrated.as_slice().unwrap(), weights.view(), total_weight);
1572 let post_var = calibrated
1573 .iter()
1574 .zip(weights.iter())
1575 .map(|(&zi, &wi)| wi * (zi - post_mean) * (zi - post_mean))
1576 .sum::<f64>()
1577 / total_weight;
1578 calibration.post_mean = post_mean;
1579 calibration.post_sd = post_var.max(0.0).sqrt();
1580
1581 Ok(Some(calibration))
1582}
1583
1584pub(crate) fn build_intercept_basis(a_block: ArrayView2<'_, f64>) -> Array2<f64> {
1587 let n = a_block.nrows();
1588 let p = a_block.ncols();
1589 let mut basis = Array2::<f64>::ones((n, p + 1));
1590 basis.slice_mut(s![.., 1..]).assign(&a_block);
1591 basis
1592}
1593
1594pub(crate) fn build_latent_measure_with_geometry(
1595 z: &Array1<f64>,
1596 weights: &Array1<f64>,
1597 policy: &LatentZPolicy,
1598 conditioning: Option<ArrayView2<'_, f64>>,
1599) -> Result<(LatentMeasureKind, LatentMeasureCalibration), String> {
1600 match policy.latent_measure {
1601 LatentMeasureSpec::Auto { grid_size: _ } => {
1602 if let Some(a_block) = conditioning
1609 && let Some(cal) =
1610 fit_conditional_latent_calibration_if_needed(z, weights, a_block)?
1611 {
1612 log::info!(
1613 "[BMS latent-z] conditional location-scale calibrated: basis_ncols={} var_active={} post_mean={:.3e} post_sd={:.3e} (E[z|C]/Var(z|C) Rao gate fired)",
1614 cal.basis_ncols,
1615 !cal.var_coeffs.is_empty(),
1616 cal.post_mean,
1617 cal.post_sd,
1618 );
1619 return Ok((
1620 LatentMeasureKind::StandardNormal,
1621 LatentMeasureCalibration::ConditionalLocationScale(cal),
1622 ));
1623 }
1624 if latent_z_is_standard_normal_enough(z, weights, policy)? {
1625 Ok((
1626 LatentMeasureKind::StandardNormal,
1627 LatentMeasureCalibration::None,
1628 ))
1629 } else {
1630 let calibration = LatentZRankIntCalibration::fit(z, weights)?;
1639 log::info!(
1640 "[BMS latent-z] rank-INT calibrated: post_mean={:.3e} post_sd={:.3e} knots={}",
1641 calibration.post_mean,
1642 calibration.post_sd,
1643 calibration.sorted_z.len(),
1644 );
1645 Ok((
1646 LatentMeasureKind::StandardNormal,
1647 LatentMeasureCalibration::RankInverseNormal(calibration),
1648 ))
1649 }
1650 }
1651 LatentMeasureSpec::StandardNormal => Ok((
1652 LatentMeasureKind::StandardNormal,
1653 LatentMeasureCalibration::None,
1654 )),
1655 LatentMeasureSpec::GlobalEmpirical { grid_size } => {
1656 let kind = build_global_empirical_latent_measure(z, weights, grid_size)?;
1657 Ok((kind, LatentMeasureCalibration::None))
1658 }
1659 }
1660}
1661
1662pub(crate) fn latent_z_is_standard_normal_enough(
1663 z: &Array1<f64>,
1664 weights: &Array1<f64>,
1665 policy: &LatentZPolicy,
1666) -> Result<bool, String> {
1667 if z.len() != weights.len() {
1668 return Err(format!(
1669 "latent-measure auto-detection length mismatch: z={}, weights={}",
1670 z.len(),
1671 weights.len()
1672 ));
1673 }
1674 let weight_sum = weights.iter().copied().sum::<f64>();
1675 let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
1676 if !(weight_sum.is_finite()
1677 && weight_sum > 0.0
1678 && weight_sq_sum.is_finite()
1679 && weight_sq_sum > 0.0)
1680 {
1681 return Err("latent-measure auto-detection requires positive finite weights".to_string());
1682 }
1683 let effective_n = weight_sum * weight_sum / weight_sq_sum;
1684 if !(effective_n.is_finite() && effective_n > 1.0) {
1685 return Err(
1686 "latent-measure auto-detection requires at least two effective observations"
1687 .to_string(),
1688 );
1689 }
1690 let mean = z
1691 .iter()
1692 .zip(weights.iter())
1693 .map(|(&zi, &wi)| wi * zi)
1694 .sum::<f64>()
1695 / weight_sum;
1696 let var = z
1697 .iter()
1698 .zip(weights.iter())
1699 .map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
1700 .sum::<f64>()
1701 / weight_sum;
1702 let sd = var.sqrt();
1703 if !(mean.is_finite() && sd.is_finite() && sd > 0.0) {
1704 return Ok(false);
1705 }
1706 let skew = z
1707 .iter()
1708 .zip(weights.iter())
1709 .map(|(&zi, &wi)| {
1710 let centered = (zi - mean) / sd;
1711 wi * centered.powi(3)
1712 })
1713 .sum::<f64>()
1714 / weight_sum;
1715 let excess_kurtosis = z
1716 .iter()
1717 .zip(weights.iter())
1718 .map(|(&zi, &wi)| {
1719 let centered = (zi - mean) / sd;
1720 wi * centered.powi(4)
1721 })
1722 .sum::<f64>()
1723 / weight_sum
1724 - 3.0;
1725 let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
1726 let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
1727 let ks_to_normal = weighted_ks_to_standard_normal(z, weights, weight_sum)?;
1728 let tail_mass_4 = weighted_tail_mass(z, weights, weight_sum, AUTO_Z_NORMAL_TAIL_SIGMA_INNER);
1729 let tail_mass_6 = weighted_tail_mass(z, weights, weight_sum, AUTO_Z_NORMAL_TAIL_SIGMA_OUTER);
1730 let max_abs_z = z.iter().fold(0.0_f64, |acc, &zi| acc.max(zi.abs()));
1731 let normal_tail_4 = 2.0 * (1.0 - normal_cdf(AUTO_Z_NORMAL_TAIL_SIGMA_INNER));
1732 let normal_tail_6 = 2.0 * (1.0 - normal_cdf(AUTO_Z_NORMAL_TAIL_SIGMA_OUTER));
1733 Ok(mean.abs() <= mean_tol
1734 && (sd - 1.0).abs() <= sd_tol
1735 && skew.is_finite()
1736 && skew.abs() <= policy.max_abs_skew.min(AUTO_Z_NORMAL_SKEW_TOL)
1737 && excess_kurtosis.is_finite()
1738 && excess_kurtosis.abs() <= policy.max_abs_excess_kurtosis.min(AUTO_Z_NORMAL_KURT_TOL)
1739 && ks_to_normal.is_finite()
1740 && ks_to_normal <= AUTO_Z_NORMAL_KS_TOL
1741 && tail_mass_4
1742 <= AUTO_Z_NORMAL_TAIL_MASS_SLACK * normal_tail_4 + AUTO_Z_NORMAL_TAIL_FLOOR_INNER
1743 && tail_mass_6
1744 <= AUTO_Z_NORMAL_TAIL_MASS_SLACK * normal_tail_6 + AUTO_Z_NORMAL_TAIL_FLOOR_OUTER
1745 && max_abs_z < AUTO_Z_NORMAL_MAX_ABS)
1746}
1747
1748pub(crate) fn build_global_empirical_latent_measure(
1749 z: &Array1<f64>,
1750 weights: &Array1<f64>,
1751 grid_size: usize,
1752) -> Result<LatentMeasureKind, String> {
1753 let grid = build_empirical_z_grid(z, weights, grid_size, "empirical latent measure")?;
1754 let measure = LatentMeasureKind::GlobalEmpirical { grid };
1755 measure.validate("empirical latent measure")?;
1756 Ok(measure)
1757}
1758
1759pub(crate) fn weighted_ks_to_standard_normal(
1760 z: &Array1<f64>,
1761 weights: &Array1<f64>,
1762 total_weight: f64,
1763) -> Result<f64, String> {
1764 let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
1765 for (&zi, &wi) in z.iter().zip(weights.iter()) {
1766 if !zi.is_finite() || !wi.is_finite() || wi < 0.0 {
1767 return Err(
1768 "latent-measure KS diagnostic requires finite z and non-negative finite weights"
1769 .to_string(),
1770 );
1771 }
1772 if wi > 0.0 {
1773 pairs.push((zi, wi));
1774 }
1775 }
1776 pairs.sort_by(|left, right| {
1777 left.0
1778 .partial_cmp(&right.0)
1779 .expect("validated latent z values are finite")
1780 });
1781 let mut prev = 0.0;
1782 let mut ks = 0.0_f64;
1783 for (zi, wi) in pairs {
1784 let cdf = normal_cdf(zi);
1785 let next = prev + wi / total_weight;
1786 ks = ks.max((cdf - prev).abs()).max((cdf - next).abs());
1787 prev = next;
1788 }
1789 Ok(ks)
1790}
1791
1792pub(crate) fn weighted_tail_mass(
1793 z: &Array1<f64>,
1794 weights: &Array1<f64>,
1795 total_weight: f64,
1796 cutoff: f64,
1797) -> f64 {
1798 z.iter()
1799 .zip(weights.iter())
1800 .filter(|&(&zi, _)| zi.abs() > cutoff)
1801 .map(|(_, &wi)| wi)
1802 .sum::<f64>()
1803 / total_weight
1804}
1805
1806pub(crate) fn build_empirical_z_grid(
1807 z: &Array1<f64>,
1808 weights: &Array1<f64>,
1809 grid_size: usize,
1810 context: &str,
1811) -> Result<EmpiricalZGrid, String> {
1812 if grid_size < 3 {
1813 return Err(format!(
1814 "empirical latent measure grid_size must be at least 3, got {grid_size}"
1815 ));
1816 }
1817 if z.len() != weights.len() {
1818 return Err(format!(
1819 "{context} length mismatch: z={}, weights={}",
1820 z.len(),
1821 weights.len()
1822 ));
1823 }
1824 let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
1825 for (idx, (&zi, &wi)) in z.iter().zip(weights.iter()).enumerate() {
1826 if !zi.is_finite() {
1827 return Err(format!(
1828 "{context} z value at row {idx} is non-finite ({zi})"
1829 ));
1830 }
1831 if !wi.is_finite() || wi < 0.0 {
1832 return Err(format!(
1833 "{context} weight at row {idx} must be finite and non-negative, got {wi}"
1834 ));
1835 }
1836 if wi > 0.0 {
1837 pairs.push((zi, wi));
1838 }
1839 }
1840 if pairs.len() < 2 {
1841 return Err(format!(
1842 "{context} requires at least two positive-weight rows"
1843 ));
1844 }
1845 pairs.sort_by(|left, right| {
1846 left.0
1847 .partial_cmp(&right.0)
1848 .expect("validated empirical latent z values are finite")
1849 });
1850 let total_weight = pairs.iter().map(|(_, weight)| *weight).sum::<f64>();
1851 if !(total_weight.is_finite() && total_weight > 0.0) {
1852 return Err(format!("{context} requires positive finite total weight"));
1853 }
1854
1855 let m = grid_size.min(pairs.len());
1856 let mut nodes = Vec::with_capacity(m);
1857 let mut out_weights = Vec::with_capacity(m);
1858 let bin_weight_target = total_weight / (m as f64);
1859 let mut cursor = 0usize;
1860 let mut remaining = pairs[0].1;
1861 for _ in 0..m {
1862 let mut need = bin_weight_target;
1863 let mut bin_weight = 0.0;
1864 let mut bin_sum = 0.0;
1865 while need > EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL * bin_weight_target
1866 && cursor < pairs.len()
1867 {
1868 let take = remaining.min(need);
1869 bin_sum += take * pairs[cursor].0;
1870 bin_weight += take;
1871 need -= take;
1872 remaining -= take;
1873 if remaining <= EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL * pairs[cursor].1 {
1874 cursor += 1;
1875 if cursor < pairs.len() {
1876 remaining = pairs[cursor].1;
1877 }
1878 }
1879 }
1880 if bin_weight > 0.0 {
1881 nodes.push(bin_sum / bin_weight);
1882 out_weights.push(bin_weight / total_weight);
1883 }
1884 }
1885 if nodes.len() < 2 {
1886 return Err(format!(
1887 "{context} compression produced fewer than two nodes"
1888 ));
1889 }
1890 recenter_rescale_empirical_grid(&mut nodes, &out_weights);
1891 let total = out_weights.iter().sum::<f64>();
1892 if total.is_finite() && total > 0.0 {
1893 for weight in &mut out_weights {
1894 *weight /= total;
1895 }
1896 }
1897 validate_empirical_z_grid(&nodes, &out_weights, context)?;
1898 Ok(EmpiricalZGrid {
1899 nodes,
1900 weights: out_weights,
1901 })
1902}
1903
1904pub(crate) fn recenter_rescale_empirical_grid(nodes: &mut [f64], weights: &[f64]) {
1905 let total = weights.iter().sum::<f64>();
1906 if !(total.is_finite() && total > 0.0) {
1907 return;
1908 }
1909 let mean = nodes
1910 .iter()
1911 .zip(weights.iter())
1912 .map(|(&node, &weight)| weight * node)
1913 .sum::<f64>()
1914 / total;
1915 let var = nodes
1916 .iter()
1917 .zip(weights.iter())
1918 .map(|(&node, &weight)| weight * (node - mean).powi(2))
1919 .sum::<f64>()
1920 / total;
1921 let sd = var.sqrt();
1922 if sd.is_finite() && sd > BMS_VARIANCE_FLOOR {
1923 for node in nodes {
1924 *node = (*node - mean) / sd;
1925 }
1926 }
1927}
1928
1929pub(super) const BMS_AUTO_SUBSAMPLE_PHASE1_BUDGET: usize = 12;
1934pub(super) const BERNOULLI_LINK_PROBABILITY_EPS: f64 = 1e-12;
1935pub(super) const BMS_VARIANCE_FLOOR: f64 = 1e-12;
1936pub(super) const BMS_DERIV_TOL: f64 = 1e-8;
1937pub(super) const EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL: f64 = 1e-14;
1943pub(super) const ROW_CHUNK_SIZE: usize = 1024;
1952pub(super) const ROW_CHUNK_MIN: usize = 64;
1956pub(super) const ROW_CHUNKS_PER_WORKER: usize = 4;
1964
1965#[inline]
1990pub(super) fn bms_row_chunk_size(n: usize) -> usize {
1991 if n == 0 {
1992 return ROW_CHUNK_SIZE;
1993 }
1994 let workers = rayon::current_num_threads().max(1);
1995 let target_chunks = workers.saturating_mul(ROW_CHUNKS_PER_WORKER).max(1);
1996 n.div_ceil(target_chunks)
1999 .clamp(ROW_CHUNK_MIN, ROW_CHUNK_SIZE)
2000}
2001pub(super) const EXACT_WORK_LOG_MIN_ROWS: usize = 50_000;
2002pub(super) const BMS_ROW_PRIMARY_HESSIAN_EXPECTED_REUSE_PASSES: usize = 3;
2003pub(super) const BMS_ROW_PRIMARY_HESSIAN_MIN_REUSE_PASSES: usize = 2;
2004pub(super) const BMS_ROW_PRIMARY_HESSIAN_TILE_ROWS: usize = 8192;
2005pub(super) const BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_NUM: u64 = 1;
2006pub(super) const BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_DEN: u64 = 4;
2007pub(super) const BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_NUM: u64 = 1;
2008pub(super) const BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_DEN: u64 = 2;
2009pub(super) const BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS: usize = 10_000;
2010
2011pub(crate) mod block_specs;
2015pub(crate) mod exact_eval_cache;
2016pub(crate) mod family;
2017pub(crate) mod gradient_paths;
2018pub(crate) mod hessian_paths;
2019pub(crate) mod install_flex;
2020pub(crate) mod row_kernel;
2021#[cfg(test)]
2022mod tests {
2023 include!("../../../../tests/src_modules/misc/families_bms_identifiability_rigid_tests.rs");
2024 include!(
2025 "../../../../tests/src_modules/optimization/families_bms_joint_hessian_hvp_correction_tests.rs"
2026 );
2027}
2028pub(crate) mod axis_direction_search;
2029pub(crate) mod cell_moment_assembly;
2030pub(crate) mod custom_family_impl;
2031pub(crate) mod row_primary_hessian;
2032
2033pub use block_specs::fit_bernoulli_marginal_slope_terms;
2034pub use gradient_paths::{
2035 MarginalSlopeCovariance, MarginalSlopeCovarianceShape, marginal_slope_covariance_from_scores,
2036 marginal_slope_preserving_scale, marginal_slope_probit_eta, padded_deviation_seed,
2037};
2038pub use install_flex::CrossBlockIdentifiabilityWarning;
2039pub(crate) use install_flex::FlexCompileOutcome;
2040
2041pub(crate) use block_specs::push_deviation_aux_blockspecs;
2043pub use block_specs::{BmsLogslopeJacobian, BmsMarginalJacobian};
2044pub(crate) use family::{
2045 BernoulliMarginalLinkMap, bernoulli_marginal_link_map,
2046 build_link_deviation_block_from_knots_design_seed_and_weights,
2047 build_score_warp_deviation_block_from_seed,
2048};
2049pub(crate) use gradient_paths::standardize_latent_z_with_policy;
2050pub(crate) use gradient_paths::{
2051 empirical_intercept_from_marginal, signed_probit_neglog_derivatives_up_to_fourth,
2052 unary_derivatives_log, unary_derivatives_log_normal_pdf, unary_derivatives_neglog_phi,
2053 unary_derivatives_sqrt,
2054};
2055pub(crate) use install_flex::{
2056 install_compiled_flex_block_into_runtime, project_monotone_feasible_beta,
2057 validate_monotone_structural_feasible,
2058};