use crate::custom_family::{
BatchedOuterGradientTerms, BlockEffectiveJacobian, BlockWorkingSet, BlockwiseFitOptions,
CustomFamily, CustomFamilyWarmStart, ExactNewtonJointGradientEvaluation,
ExactNewtonJointHessianWorkspace, ExactNewtonJointPsiSecondOrderTerms,
ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace, FamilyEvaluation,
FamilyLinearizationState, ParameterBlockSpec, ParameterBlockState,
build_block_spatial_psi_derivatives, custom_family_outer_derivatives,
evaluate_custom_family_joint_hyper_efs_shared, evaluate_custom_family_joint_hyper_shared,
fit_custom_family, joint_hyper_options_for_outer_tolerance,
};
use crate::estimate::UnifiedFitResult;
use crate::estimate::reml::unified::{DenseSpectralOperator, HessianOperator, HyperOperator};
use crate::families::gamlss::{ParameterBlockInput, initialize_monotone_wiggle_knots_from_seed};
use crate::families::jet_partitions::MultiDirJet;
use crate::families::lognormal_kernel::FrailtySpec;
use crate::families::marginal_slope_shared::{
CoeffSupport, ObservedDenestedCellPartials, SparsePrimaryCoeffJetView, WeightedOuterRow,
add_optional_matrix, add_optional_vector, add_two_surface_psi_outer,
build_denested_partition_cells as shared_denested_partition_cells, chunked_row_reduction,
eval_coeff4_at, is_sigma_aux_index as shared_is_sigma_aux_index,
observed_denested_cell_partials as shared_observed_denested_cell_partials, outer_row_indices,
outer_weighted_rows, parameter_block_specs_match_rows, probit_frailty_scale,
probit_frailty_scale_multi_dir_jet, psi_derivative_location, scale_coeff4,
};
use crate::families::row_kernel::{
RowKernel, RowKernelHessianWorkspace, build_row_kernel_cache, row_kernel_gradient,
row_kernel_hessian_dense, row_kernel_log_likelihood,
};
use crate::matrix::{DesignMatrix, SymmetricMatrix};
use crate::pirls::LinearInequalityConstraints;
use crate::probability::{
normal_cdf, normal_logcdf, normal_pdf, signed_probit_logcdf_and_mills_ratio,
standard_normal_quantile,
};
use crate::smooth::{
ExactJointHyperSetup, SpatialLengthScaleOptimizationOptions, SpatialLogKappaCoords,
TermCollectionDesign, TermCollectionSpec, apply_spatial_anisotropy_pilot_initializer,
build_term_collection_designs_and_freeze_joint, optimize_spatial_length_scale_exact_joint,
spatial_length_scale_term_indices,
};
use crate::types::{InverseLink, StandardLink, WigglePenaltyConfig};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, s};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
pub mod deviation_runtime;
pub(crate) mod exact_kernel;
pub use deviation_runtime::DeviationRuntime;
pub use deviation_runtime::ParametricAnchorBlock;
const BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD: usize = 50_000;
#[derive(Clone, Debug)]
pub struct DeviationBlockConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_order: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
pub monotonicity_eps: f64,
}
impl Default for DeviationBlockConfig {
fn default() -> Self {
WigglePenaltyConfig::cubic_triple_operator_default().into()
}
}
impl DeviationBlockConfig {
pub fn triple_penalty_default() -> Self {
Self::default()
}
}
impl From<WigglePenaltyConfig> for DeviationBlockConfig {
fn from(cfg: WigglePenaltyConfig) -> Self {
let penalty_order = *cfg.penalty_orders.iter().max().unwrap_or(&2);
Self {
degree: cfg.degree,
num_internal_knots: cfg.num_internal_knots,
penalty_order,
penalty_orders: cfg.penalty_orders,
double_penalty: cfg.double_penalty,
monotonicity_eps: cfg.monotonicity_eps,
}
}
}
#[derive(Clone)]
pub(crate) struct DeviationPrepared {
pub(crate) block: ParameterBlockInput,
pub(crate) runtime: DeviationRuntime,
}
impl std::fmt::Debug for DeviationPrepared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeviationPrepared").finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct BernoulliMarginalSlopeTermSpec {
pub y: Array1<f64>,
pub weights: Array1<f64>,
pub z: Array1<f64>,
pub base_link: InverseLink,
pub marginalspec: TermCollectionSpec,
pub logslopespec: TermCollectionSpec,
pub marginal_offset: Array1<f64>,
pub logslope_offset: Array1<f64>,
pub frailty: FrailtySpec,
pub score_warp: Option<DeviationBlockConfig>,
pub link_dev: Option<DeviationBlockConfig>,
pub latent_z_policy: LatentZPolicy,
}
pub struct BernoulliMarginalSlopeFitResult {
pub fit: UnifiedFitResult,
pub marginalspec_resolved: TermCollectionSpec,
pub logslopespec_resolved: TermCollectionSpec,
pub marginal_design: TermCollectionDesign,
pub logslope_design: TermCollectionDesign,
pub baseline_marginal: f64,
pub baseline_logslope: f64,
pub z_normalization: LatentZNormalization,
pub latent_measure: LatentMeasureKind,
pub score_warp_runtime: Option<DeviationRuntime>,
pub link_dev_runtime: Option<DeviationRuntime>,
pub gaussian_frailty_sd: Option<f64>,
pub cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning>,
pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
}
#[derive(Clone, Debug)]
pub enum LatentZCheckMode {
Strict,
WarnOnly,
Off,
}
#[derive(Clone, Debug)]
pub enum LatentZNormalizationMode {
None,
FitWeighted,
Frozen { mean: f64, sd: f64 },
}
pub const DEFAULT_EMPIRICAL_LATENT_GRID_SIZE: usize = 65;
const AUTO_Z_NORMAL_SKEW_TOL: f64 = 0.10;
const AUTO_Z_NORMAL_KURT_TOL: f64 = 0.25;
const AUTO_Z_NORMAL_KS_TOL: f64 = 0.025;
const AUTO_Z_NORMAL_MAX_ABS: f64 = 8.0;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LatentMeasureSpec {
Auto { grid_size: usize },
StandardNormal,
GlobalEmpirical { grid_size: usize },
}
impl LatentMeasureSpec {
pub fn auto_default() -> Self {
Self::Auto {
grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
}
}
}
impl Default for LatentMeasureSpec {
fn default() -> Self {
Self::auto_default()
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EmpiricalZGrid {
pub nodes: Vec<f64>,
pub weights: Vec<f64>,
}
impl EmpiricalZGrid {
pub fn new(nodes: Vec<f64>, weights: Vec<f64>, context: &str) -> Result<Self, String> {
validate_empirical_z_grid(&nodes, &weights, context)?;
Ok(Self { nodes, weights })
}
#[inline]
pub fn pairs(&self) -> impl Iterator<Item = (f64, f64)> + '_ {
self.nodes.iter().copied().zip(self.weights.iter().copied())
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "kebab-case")]
#[derive(Default)]
pub enum LatentMeasureKind {
#[default]
StandardNormal,
GlobalEmpirical {
grid: EmpiricalZGrid,
},
LocalEmpirical {
feature_cols: Vec<usize>,
#[serde(default)]
input_scales: Option<Vec<f64>>,
centers: Vec<Vec<f64>>,
grids: Vec<EmpiricalZGrid>,
top_k: usize,
bandwidth: f64,
#[serde(skip)]
train_row_mixtures: Arc<Vec<Vec<(usize, f64)>>>,
},
}
impl LatentMeasureKind {
pub fn validate(&self, context: &str) -> Result<(), String> {
match self {
Self::StandardNormal => Ok(()),
Self::GlobalEmpirical { grid } => {
validate_empirical_z_grid(&grid.nodes, &grid.weights, context)
}
Self::LocalEmpirical {
feature_cols,
input_scales,
centers,
grids,
top_k,
bandwidth,
..
} => {
if feature_cols.is_empty() {
return Err(format!(
"{context} local empirical latent measure needs feature columns"
));
}
if centers.is_empty() {
return Err(format!(
"{context} local empirical latent measure needs centers"
));
}
if centers.len() != grids.len() {
return Err(format!(
"{context} local empirical latent measure center/grid length mismatch: centers={}, grids={}",
centers.len(),
grids.len()
));
}
if *top_k == 0 || *top_k > centers.len() {
return Err(format!(
"{context} local empirical latent measure top_k must be in 1..={}, got {top_k}",
centers.len()
));
}
if !(*bandwidth).is_finite() || *bandwidth <= 0.0 {
return Err(format!(
"{context} local empirical latent measure bandwidth must be finite and positive, got {bandwidth}"
));
}
if let Some(scales) = input_scales.as_ref() {
if scales.len() != feature_cols.len() {
return Err(format!(
"{context} local empirical latent measure input scale dimension mismatch: scales={}, features={}",
scales.len(),
feature_cols.len()
));
}
for (scale_idx, scale) in scales.iter().enumerate() {
if !(scale.is_finite() && *scale > 0.0) {
return Err(format!(
"{context} local empirical latent measure input scale {scale_idx} must be finite and positive, got {scale}"
));
}
}
}
for (center_idx, center) in centers.iter().enumerate() {
if center.len() != feature_cols.len() {
return Err(format!(
"{context} local empirical latent center {center_idx} dimension mismatch: got {}, expected {}",
center.len(),
feature_cols.len()
));
}
if center.iter().any(|value| !value.is_finite()) {
return Err(format!(
"{context} local empirical latent center {center_idx} has non-finite coordinates"
));
}
}
for (grid_idx, grid) in grids.iter().enumerate() {
validate_empirical_z_grid(
&grid.nodes,
&grid.weights,
&format!("{context} local empirical grid {grid_idx}"),
)?;
}
Ok(())
}
}
}
fn is_empirical(&self) -> bool {
matches!(
self,
Self::GlobalEmpirical { .. } | Self::LocalEmpirical { .. }
)
}
fn empirical_grid_for_training_row(
&self,
row: usize,
) -> Result<Option<EmpiricalZGrid>, String> {
match self {
Self::StandardNormal => Ok(None),
Self::GlobalEmpirical { grid } => Ok(Some(grid.clone())),
Self::LocalEmpirical {
grids,
train_row_mixtures,
..
} => {
let mixture = train_row_mixtures.get(row).ok_or_else(|| {
format!(
"local empirical latent measure is missing training mixture for row {row}"
)
})?;
Ok(Some(combine_empirical_grids(grids, mixture)?))
}
}
}
}
fn validate_empirical_z_grid(nodes: &[f64], weights: &[f64], context: &str) -> Result<(), String> {
if nodes.len() != weights.len() {
return Err(format!(
"{context} empirical latent measure node/weight length mismatch: nodes={}, weights={}",
nodes.len(),
weights.len()
));
}
if nodes.len() < 2 {
return Err(format!(
"{context} empirical latent measure requires at least two nodes"
));
}
let mut total = 0.0;
for (idx, (&node, &weight)) in nodes.iter().zip(weights.iter()).enumerate() {
if !node.is_finite() {
return Err(format!(
"{context} empirical latent measure node {idx} is non-finite ({node})"
));
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"{context} empirical latent measure weight {idx} must be finite and positive, got {weight}"
));
}
total += weight;
}
if !(total.is_finite() && (total - 1.0).abs() <= 1e-8) {
return Err(format!(
"{context} empirical latent measure weights must sum to 1, got {total}"
));
}
Ok(())
}
fn combine_empirical_grids(
grids: &[EmpiricalZGrid],
mixture: &[(usize, f64)],
) -> Result<EmpiricalZGrid, String> {
if mixture.is_empty() {
return Err("local empirical latent measure row mixture is empty".to_string());
}
let mut nodes = Vec::new();
let mut weights = Vec::new();
for &(grid_idx, grid_weight) in mixture {
if !grid_weight.is_finite() || grid_weight <= 0.0 {
return Err(format!(
"local empirical latent mixture weight must be finite and positive, got {grid_weight}"
));
}
let grid = grids.get(grid_idx).ok_or_else(|| {
format!("local empirical latent mixture references missing grid {grid_idx}")
})?;
for (node, weight) in grid.pairs() {
nodes.push(node);
weights.push(grid_weight * weight);
}
}
let total = weights.iter().copied().sum::<f64>();
if !(total.is_finite() && total > 0.0) {
return Err(
"local empirical latent combined grid has non-positive total weight".to_string(),
);
}
for weight in &mut weights {
*weight /= total;
}
EmpiricalZGrid::new(nodes, weights, "local empirical latent combined grid")
}
#[derive(Clone, Debug)]
pub struct LatentZPolicy {
pub check_mode: LatentZCheckMode,
pub normalization: LatentZNormalizationMode,
pub latent_measure: LatentMeasureSpec,
pub mean_tol_multiplier: f64,
pub sd_tol_multiplier: f64,
pub max_abs_skew: f64,
pub max_abs_excess_kurtosis: f64,
}
impl LatentZPolicy {
pub fn frozen_transformation_normal() -> Self {
Self {
check_mode: LatentZCheckMode::WarnOnly,
normalization: LatentZNormalizationMode::Frozen { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureSpec::auto_default(),
mean_tol_multiplier: 4.0,
sd_tol_multiplier: 4.0,
max_abs_skew: 4.0,
max_abs_excess_kurtosis: 20.0,
}
}
pub fn exploratory_fit_weighted() -> Self {
Self {
check_mode: LatentZCheckMode::WarnOnly,
normalization: LatentZNormalizationMode::FitWeighted,
latent_measure: LatentMeasureSpec::auto_default(),
mean_tol_multiplier: 8.0,
sd_tol_multiplier: 8.0,
max_abs_skew: 4.0,
max_abs_excess_kurtosis: 20.0,
}
}
}
impl Default for LatentZPolicy {
fn default() -> Self {
Self::frozen_transformation_normal()
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct LatentZNormalization {
pub mean: f64,
pub sd: f64,
}
impl LatentZNormalization {
pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, String> {
if !(self.mean.is_finite() && self.sd.is_finite() && self.sd > BMS_VARIANCE_FLOOR) {
return Err(format!(
"{context} requires finite latent z normalization with sd > 1e-12; got mean={} sd={}",
self.mean, self.sd
));
}
if z.iter().any(|value| !value.is_finite()) {
return Err(format!("{context} requires finite z values"));
}
Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct LatentZRankIntCalibration {
pub sorted_z: Vec<f64>,
pub weighted_cdf: Vec<f64>,
pub post_mean: f64,
pub post_sd: f64,
}
impl LatentZRankIntCalibration {
pub fn fit(z: &Array1<f64>, weights: &Array1<f64>) -> Result<Self, String> {
if z.len() != weights.len() {
return Err(format!(
"rank-INT calibration: z length {} != weights length {}",
z.len(),
weights.len()
));
}
if z.is_empty() {
return Err("rank-INT calibration requires at least one observation".to_string());
}
let w_total = weights.iter().copied().sum::<f64>();
if !(w_total.is_finite() && w_total > 0.0) {
return Err(format!(
"rank-INT calibration requires positive finite total weight, got {w_total}"
));
}
for (idx, value) in z.iter().enumerate() {
if !value.is_finite() {
return Err(format!(
"rank-INT calibration: z[{idx}] = {value} not finite"
));
}
}
for (idx, weight) in weights.iter().enumerate() {
if !(weight.is_finite() && *weight >= 0.0) {
return Err(format!(
"rank-INT calibration: weight[{idx}] = {weight} not finite/non-negative"
));
}
}
let mut order: Vec<usize> = (0..z.len()).collect();
order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap_or(std::cmp::Ordering::Equal));
let mut sorted_z: Vec<f64> = Vec::with_capacity(z.len());
let mut weighted_cdf: Vec<f64> = Vec::with_capacity(z.len());
let denom = w_total + 0.25;
let eps = 0.5 / w_total.max(1.0);
let mut cum_w = 0.0_f64;
let mut last_z: Option<f64> = None;
for &idx in &order {
cum_w += weights[idx];
let zi = z[idx];
if let Some(prev) = last_z
&& zi == prev
{
if let Some(slot) = weighted_cdf.last_mut() {
let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
*slot = p;
}
continue;
}
let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
sorted_z.push(zi);
weighted_cdf.push(p);
last_z = Some(zi);
}
let mut sum_wz = 0.0_f64;
let mut sum_w = 0.0_f64;
for &idx in &order {
let zi = z[idx];
let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
sum_wz += weights[idx] * calibrated;
sum_w += weights[idx];
}
let post_mean = if sum_w > 0.0 { sum_wz / sum_w } else { 0.0 };
let mut sum_w_dev = 0.0_f64;
for &idx in &order {
let zi = z[idx];
let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
let d = calibrated - post_mean;
sum_w_dev += weights[idx] * d * d;
}
let post_sd = if sum_w > 0.0 {
(sum_w_dev / sum_w).sqrt()
} else {
1.0
};
Ok(Self {
sorted_z,
weighted_cdf,
post_mean,
post_sd,
})
}
pub fn apply_to_training(&self, z: &Array1<f64>) -> Result<Array1<f64>, String> {
if self.sorted_z.is_empty() {
return Err("rank-INT calibration has no knots".to_string());
}
let mut out = Array1::<f64>::zeros(z.len());
for (idx, &zi) in z.iter().enumerate() {
if !zi.is_finite() {
return Err(format!(
"rank-INT calibration apply: z[{idx}] = {zi} not finite"
));
}
out[idx] = self.apply_at_predict(zi);
}
Ok(out)
}
pub fn apply_at_predict(&self, z: f64) -> f64 {
Self::apply_with_knots(z, &self.sorted_z, &self.weighted_cdf)
}
fn apply_with_knots(z: f64, sorted_z: &[f64], weighted_cdf: &[f64]) -> f64 {
assert_eq!(sorted_z.len(), weighted_cdf.len());
assert!(!sorted_z.is_empty());
let n = sorted_z.len();
let p = if z <= sorted_z[0] {
weighted_cdf[0]
} else if z >= sorted_z[n - 1] {
weighted_cdf[n - 1]
} else {
let mut lo = 0usize;
let mut hi = n - 1;
while hi - lo > 1 {
let mid = (lo + hi) / 2;
if sorted_z[mid] <= z {
lo = mid;
} else {
hi = mid;
}
}
let z_lo = sorted_z[lo];
let z_hi = sorted_z[hi];
let p_lo = weighted_cdf[lo];
let p_hi = weighted_cdf[hi];
if z_hi == z_lo {
p_hi
} else {
let t = (z - z_lo) / (z_hi - z_lo);
p_lo + t * (p_hi - p_lo)
}
};
standard_normal_quantile(p).unwrap_or_else(|_| if p < 0.5 { -8.0 } else { 8.0 })
}
}
#[derive(Clone, Debug)]
pub enum LatentMeasureCalibration {
None,
RankInverseNormal(LatentZRankIntCalibration),
}
fn build_latent_measure_with_geometry(
z: &Array1<f64>,
weights: &Array1<f64>,
policy: &LatentZPolicy,
) -> Result<(LatentMeasureKind, LatentMeasureCalibration), String> {
match policy.latent_measure {
LatentMeasureSpec::Auto { grid_size: _ } => {
if latent_z_is_standard_normal_enough(z, weights, policy)? {
Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::None,
))
} else {
let calibration = LatentZRankIntCalibration::fit(z, weights)?;
log::info!(
"[BMS latent-z] rank-INT calibrated: post_mean={:.3e} post_sd={:.3e} knots={}",
calibration.post_mean,
calibration.post_sd,
calibration.sorted_z.len(),
);
Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::RankInverseNormal(calibration),
))
}
}
LatentMeasureSpec::StandardNormal => Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::None,
)),
LatentMeasureSpec::GlobalEmpirical { grid_size } => {
let kind = build_global_empirical_latent_measure(z, weights, grid_size)?;
Ok((kind, LatentMeasureCalibration::None))
}
}
}
fn latent_z_is_standard_normal_enough(
z: &Array1<f64>,
weights: &Array1<f64>,
policy: &LatentZPolicy,
) -> Result<bool, String> {
if z.len() != weights.len() {
return Err(format!(
"latent-measure auto-detection length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
if !(weight_sum.is_finite()
&& weight_sum > 0.0
&& weight_sq_sum.is_finite()
&& weight_sq_sum > 0.0)
{
return Err("latent-measure auto-detection requires positive finite weights".to_string());
}
let effective_n = weight_sum * weight_sum / weight_sq_sum;
if !(effective_n.is_finite() && effective_n > 1.0) {
return Err(
"latent-measure auto-detection requires at least two effective observations"
.to_string(),
);
}
let mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ weight_sum;
let var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
.sum::<f64>()
/ weight_sum;
let sd = var.sqrt();
if !(mean.is_finite() && sd.is_finite() && sd > 0.0) {
return Ok(false);
}
let skew = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| {
let centered = (zi - mean) / sd;
wi * centered.powi(3)
})
.sum::<f64>()
/ weight_sum;
let excess_kurtosis = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| {
let centered = (zi - mean) / sd;
wi * centered.powi(4)
})
.sum::<f64>()
/ weight_sum
- 3.0;
let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
let ks_to_normal = weighted_ks_to_standard_normal(z, weights, weight_sum)?;
let tail_mass_4 = weighted_tail_mass(z, weights, weight_sum, 4.0);
let tail_mass_6 = weighted_tail_mass(z, weights, weight_sum, 6.0);
let max_abs_z = z.iter().fold(0.0_f64, |acc, &zi| acc.max(zi.abs()));
let normal_tail_4 = 2.0 * (1.0 - normal_cdf(4.0));
let normal_tail_6 = 2.0 * (1.0 - normal_cdf(6.0));
Ok(mean.abs() <= mean_tol
&& (sd - 1.0).abs() <= sd_tol
&& skew.is_finite()
&& skew.abs() <= policy.max_abs_skew.min(AUTO_Z_NORMAL_SKEW_TOL)
&& excess_kurtosis.is_finite()
&& excess_kurtosis.abs() <= policy.max_abs_excess_kurtosis.min(AUTO_Z_NORMAL_KURT_TOL)
&& ks_to_normal.is_finite()
&& ks_to_normal <= AUTO_Z_NORMAL_KS_TOL
&& tail_mass_4 <= 2.0 * normal_tail_4 + 1e-5
&& tail_mass_6 <= 2.0 * normal_tail_6 + 1e-8
&& max_abs_z < AUTO_Z_NORMAL_MAX_ABS)
}
fn build_global_empirical_latent_measure(
z: &Array1<f64>,
weights: &Array1<f64>,
grid_size: usize,
) -> Result<LatentMeasureKind, String> {
let grid = build_empirical_z_grid(z, weights, grid_size, "empirical latent measure")?;
let measure = LatentMeasureKind::GlobalEmpirical { grid };
measure.validate("empirical latent measure")?;
Ok(measure)
}
fn weighted_ks_to_standard_normal(
z: &Array1<f64>,
weights: &Array1<f64>,
total_weight: f64,
) -> Result<f64, String> {
let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
for (&zi, &wi) in z.iter().zip(weights.iter()) {
if !zi.is_finite() || !wi.is_finite() || wi < 0.0 {
return Err(
"latent-measure KS diagnostic requires finite z and non-negative finite weights"
.to_string(),
);
}
if wi > 0.0 {
pairs.push((zi, wi));
}
}
pairs.sort_by(|left, right| {
left.0
.partial_cmp(&right.0)
.expect("validated latent z values are finite")
});
let mut prev = 0.0;
let mut ks = 0.0_f64;
for (zi, wi) in pairs {
let cdf = normal_cdf(zi);
let next = prev + wi / total_weight;
ks = ks.max((cdf - prev).abs()).max((cdf - next).abs());
prev = next;
}
Ok(ks)
}
fn weighted_tail_mass(
z: &Array1<f64>,
weights: &Array1<f64>,
total_weight: f64,
cutoff: f64,
) -> f64 {
z.iter()
.zip(weights.iter())
.filter(|&(&zi, _)| zi.abs() > cutoff)
.map(|(_, &wi)| wi)
.sum::<f64>()
/ total_weight
}
fn build_empirical_z_grid(
z: &Array1<f64>,
weights: &Array1<f64>,
grid_size: usize,
context: &str,
) -> Result<EmpiricalZGrid, String> {
if grid_size < 3 {
return Err(format!(
"empirical latent measure grid_size must be at least 3, got {grid_size}"
));
}
if z.len() != weights.len() {
return Err(format!(
"{context} length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
for (idx, (&zi, &wi)) in z.iter().zip(weights.iter()).enumerate() {
if !zi.is_finite() {
return Err(format!(
"{context} z value at row {idx} is non-finite ({zi})"
));
}
if !wi.is_finite() || wi < 0.0 {
return Err(format!(
"{context} weight at row {idx} must be finite and non-negative, got {wi}"
));
}
if wi > 0.0 {
pairs.push((zi, wi));
}
}
if pairs.len() < 2 {
return Err(format!(
"{context} requires at least two positive-weight rows"
));
}
pairs.sort_by(|left, right| {
left.0
.partial_cmp(&right.0)
.expect("validated empirical latent z values are finite")
});
let total_weight = pairs.iter().map(|(_, weight)| *weight).sum::<f64>();
if !(total_weight.is_finite() && total_weight > 0.0) {
return Err(format!("{context} requires positive finite total weight"));
}
let m = grid_size.min(pairs.len());
let mut nodes = Vec::with_capacity(m);
let mut out_weights = Vec::with_capacity(m);
let bin_weight_target = total_weight / (m as f64);
let mut cursor = 0usize;
let mut remaining = pairs[0].1;
for _ in 0..m {
let mut need = bin_weight_target;
let mut bin_weight = 0.0;
let mut bin_sum = 0.0;
while need > 1e-14 * bin_weight_target && cursor < pairs.len() {
let take = remaining.min(need);
bin_sum += take * pairs[cursor].0;
bin_weight += take;
need -= take;
remaining -= take;
if remaining <= 1e-14 * pairs[cursor].1 {
cursor += 1;
if cursor < pairs.len() {
remaining = pairs[cursor].1;
}
}
}
if bin_weight > 0.0 {
nodes.push(bin_sum / bin_weight);
out_weights.push(bin_weight / total_weight);
}
}
if nodes.len() < 2 {
return Err(format!(
"{context} compression produced fewer than two nodes"
));
}
recenter_rescale_empirical_grid(&mut nodes, &out_weights);
let total = out_weights.iter().sum::<f64>();
if total.is_finite() && total > 0.0 {
for weight in &mut out_weights {
*weight /= total;
}
}
validate_empirical_z_grid(&nodes, &out_weights, context)?;
Ok(EmpiricalZGrid {
nodes,
weights: out_weights,
})
}
fn recenter_rescale_empirical_grid(nodes: &mut [f64], weights: &[f64]) {
let total = weights.iter().sum::<f64>();
if !(total.is_finite() && total > 0.0) {
return;
}
let mean = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * node)
.sum::<f64>()
/ total;
let var = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * (node - mean).powi(2))
.sum::<f64>()
/ total;
let sd = var.sqrt();
if sd.is_finite() && sd > BMS_VARIANCE_FLOOR {
for node in nodes {
*node = (*node - mean) / sd;
}
}
}
pub(super) const BMS_AUTO_SUBSAMPLE_PHASE1_BUDGET: usize = 12;
pub(super) const BERNOULLI_LINK_PROBABILITY_EPS: f64 = 1e-12;
pub(super) const BMS_VARIANCE_FLOOR: f64 = 1e-12;
pub(super) const BMS_DERIV_TOL: f64 = 1e-8;
pub(super) const ROW_CHUNK_SIZE: usize = 1024;
pub(super) const EXACT_WORK_LOG_MIN_ROWS: usize = 50_000;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_EXPECTED_REUSE_PASSES: usize = 3;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_MIN_REUSE_PASSES: usize = 2;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_NUM: u64 = 1;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_DEN: u64 = 4;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_NUM: u64 = 1;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_DEN: u64 = 2;
pub(super) const BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS: usize = 10_000;
pub mod audit_jacobian;
pub(crate) mod block_specs;
pub(crate) mod exact_eval_cache;
pub(crate) mod family;
pub(crate) mod gradient_paths;
pub(crate) mod hessian_paths;
pub(crate) mod install_flex;
pub(crate) mod row_kernel;
#[cfg(test)]
mod tests_inline;
pub(crate) mod workspace;
pub use block_specs::fit_bernoulli_marginal_slope_terms;
pub use gradient_paths::{
MarginalSlopeCovariance, MarginalSlopeCovarianceShape, marginal_slope_covariance_from_scores,
marginal_slope_preserving_scale, marginal_slope_probit_eta, padded_deviation_seed,
};
pub use install_flex::CrossBlockIdentifiabilityWarning;
pub(crate) use install_flex::FlexCompileOutcome;
pub(crate) use block_specs::push_deviation_aux_blockspecs;
pub use block_specs::{BmsFamilyScalars, BmsLogslopeJacobian, BmsMarginalJacobian};
pub(crate) use family::{
BernoulliMarginalLinkMap, bernoulli_marginal_link_map,
build_link_deviation_block_from_knots_design_seed_and_weights,
build_score_warp_deviation_block_from_seed,
};
pub(crate) use gradient_paths::standardize_latent_z_with_policy;
pub(crate) use gradient_paths::{
empirical_intercept_from_marginal, signed_probit_neglog_derivatives_up_to_fourth,
unary_derivatives_log, unary_derivatives_log_normal_pdf, unary_derivatives_neglog_phi,
unary_derivatives_sqrt,
};
pub(crate) use install_flex::{
install_compiled_flex_block_into_runtime, project_monotone_feasible_beta,
validate_monotone_structural_feasible,
};