use crate::custom_family::{
BatchedOuterGradientTerms, BlockEffectiveJacobian, BlockWorkingSet, BlockwiseFitOptions,
CustomFamily, CustomFamilyWarmStart, ExactNewtonJointGradientEvaluation,
ExactNewtonJointHessianWorkspace, ExactNewtonJointPsiSecondOrderTerms,
ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace, FamilyEvaluation,
FamilyLinearizationState, ParameterBlockSpec, ParameterBlockState, PenaltyMatrix,
custom_family_outer_derivatives, evaluate_custom_family_joint_hyper_efs_shared,
evaluate_custom_family_joint_hyper_shared, fit_custom_family,
joint_hyper_options_for_outer_tolerance,
};
use crate::estimate::UnifiedFitResult;
use crate::estimate::reml::unified::{DenseSpectralOperator, HessianOperator, HyperOperator};
use crate::families::cubic_cell_kernel as exact_kernel;
use crate::families::jet_partitions::MultiDirJet;
use crate::families::lognormal_kernel::FrailtySpec;
use crate::families::marginal_slope_shared::{
CoeffSupport, DirectionalScaleJets, ObservedDenestedCellPartials, SparsePrimaryCoeffJetView,
WeightedOuterRow, add_optional_matrix, add_optional_vector, add_two_surface_psi_outer,
build_denested_partition_cells as shared_denested_partition_cells, chunked_row_reduction,
directional_obj_grad_hess, eval_coeff4_at, is_sigma_aux_index as shared_is_sigma_aux_index,
observed_denested_cell_partials as shared_observed_denested_cell_partials, outer_row_indices,
outer_weighted_rows, parameter_block_specs_match_rows, probit_frailty_scale,
probit_frailty_scale_multi_dir_jet, psi_derivative_location, scale_coeff4,
};
use crate::families::parameter_block::ParameterBlockInput;
use crate::families::row_kernel::{
RowKernel, RowKernelHessianWorkspace, build_row_kernel_cache, row_kernel_gradient,
row_kernel_hessian_dense, row_kernel_log_likelihood,
};
use crate::families::spatial_psi_bridge::build_block_spatial_psi_derivatives;
use crate::families::wiggle::initializewiggle_knots_from_seed;
use crate::matrix::{DesignMatrix, SymmetricMatrix};
use crate::pirls::LinearInequalityConstraints;
use crate::probability::{
normal_cdf, normal_logcdf, normal_pdf, signed_probit_logcdf_and_mills_ratio,
standard_normal_quantile,
};
use crate::smooth::{
ExactJointHyperSetup, SpatialLengthScaleOptimizationOptions, SpatialLogKappaCoords,
TermCollectionDesign, TermCollectionSpec, apply_spatial_anisotropy_pilot_initializer,
build_term_collection_designs_and_freeze_joint, optimize_spatial_length_scale_exact_joint,
spatial_length_scale_term_indices,
};
use crate::types::{InverseLink, StandardLink, WigglePenaltyConfig};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, s};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
pub mod deviation_runtime;
pub mod gpu;
pub use deviation_runtime::DeviationRuntime;
pub use deviation_runtime::ParametricAnchorBlock;
const BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD: usize = 50_000;
#[derive(Clone, Debug)]
pub struct DeviationBlockConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_order: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
pub monotonicity_eps: f64,
}
impl Default for DeviationBlockConfig {
fn default() -> Self {
WigglePenaltyConfig::cubic_triple_operator_default().into()
}
}
impl DeviationBlockConfig {
pub fn triple_penalty_default() -> Self {
Self::default()
}
}
impl From<WigglePenaltyConfig> for DeviationBlockConfig {
fn from(cfg: WigglePenaltyConfig) -> Self {
let penalty_order = *cfg.penalty_orders.iter().max().unwrap_or(&2);
Self {
degree: cfg.degree,
num_internal_knots: cfg.num_internal_knots,
penalty_order,
penalty_orders: cfg.penalty_orders,
double_penalty: cfg.double_penalty,
monotonicity_eps: cfg.monotonicity_eps,
}
}
}
#[derive(Clone)]
pub(crate) struct DeviationPrepared {
pub(crate) block: ParameterBlockInput,
pub(crate) runtime: DeviationRuntime,
}
impl std::fmt::Debug for DeviationPrepared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeviationPrepared").finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct BernoulliMarginalSlopeTermSpec {
pub y: Array1<f64>,
pub weights: Array1<f64>,
pub z: Array1<f64>,
pub base_link: InverseLink,
pub marginalspec: TermCollectionSpec,
pub logslopespec: TermCollectionSpec,
pub marginal_offset: Array1<f64>,
pub logslope_offset: Array1<f64>,
pub frailty: FrailtySpec,
pub score_warp: Option<DeviationBlockConfig>,
pub link_dev: Option<DeviationBlockConfig>,
pub latent_z_policy: LatentZPolicy,
pub score_influence_jacobian: Option<Array2<f64>>,
}
pub struct BernoulliMarginalSlopeFitResult {
pub fit: UnifiedFitResult,
pub marginalspec_resolved: TermCollectionSpec,
pub logslopespec_resolved: TermCollectionSpec,
pub marginal_design: TermCollectionDesign,
pub logslope_design: TermCollectionDesign,
pub baseline_marginal: f64,
pub baseline_logslope: f64,
pub z_normalization: LatentZNormalization,
pub latent_measure: LatentMeasureKind,
pub score_warp_runtime: Option<DeviationRuntime>,
pub link_dev_runtime: Option<DeviationRuntime>,
pub gaussian_frailty_sd: Option<f64>,
pub cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning>,
pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
}
#[derive(Clone, Debug)]
pub enum LatentZCheckMode {
Strict,
WarnOnly,
Off,
}
#[derive(Clone, Debug)]
pub enum LatentZNormalizationMode {
None,
FitWeighted,
Frozen { mean: f64, sd: f64 },
}
pub const DEFAULT_EMPIRICAL_LATENT_GRID_SIZE: usize = 65;
const AUTO_Z_NORMAL_SKEW_TOL: f64 = 0.10;
const AUTO_Z_NORMAL_KURT_TOL: f64 = 0.25;
const AUTO_Z_NORMAL_KS_TOL: f64 = 0.025;
const AUTO_Z_NORMAL_MAX_ABS: f64 = 8.0;
const AUTO_Z_NORMAL_TAIL_SIGMA_INNER: f64 = 4.0;
const AUTO_Z_NORMAL_TAIL_SIGMA_OUTER: f64 = 6.0;
const AUTO_Z_NORMAL_TAIL_MASS_SLACK: f64 = 2.0;
const AUTO_Z_NORMAL_TAIL_FLOOR_INNER: f64 = 1e-5;
const AUTO_Z_NORMAL_TAIL_FLOOR_OUTER: f64 = 1e-8;
const AUTO_Z_CONDITIONAL_RAO_ALPHA: f64 = 1.0e-3;
const AUTO_Z_CONDITIONAL_RIDGE_REL: f64 = 1.0e-8;
const AUTO_Z_CONDITIONAL_VAR_FLOOR_FRAC: f64 = 1.0e-3;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LatentMeasureSpec {
Auto { grid_size: usize },
StandardNormal,
GlobalEmpirical { grid_size: usize },
}
impl LatentMeasureSpec {
pub fn auto_default() -> Self {
Self::Auto {
grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
}
}
}
impl Default for LatentMeasureSpec {
fn default() -> Self {
Self::auto_default()
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EmpiricalZGrid {
pub nodes: Vec<f64>,
pub weights: Vec<f64>,
}
impl EmpiricalZGrid {
pub fn new(nodes: Vec<f64>, weights: Vec<f64>, context: &str) -> Result<Self, String> {
validate_empirical_z_grid(&nodes, &weights, context)?;
Ok(Self { nodes, weights })
}
#[inline]
pub fn pairs(&self) -> impl Iterator<Item = (f64, f64)> + '_ {
self.nodes.iter().copied().zip(self.weights.iter().copied())
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "kebab-case")]
#[derive(Default)]
pub enum LatentMeasureKind {
#[default]
StandardNormal,
GlobalEmpirical {
grid: EmpiricalZGrid,
},
LocalEmpirical {
feature_cols: Vec<usize>,
#[serde(default)]
input_scales: Option<Vec<f64>>,
centers: Vec<Vec<f64>>,
grids: Vec<EmpiricalZGrid>,
top_k: usize,
bandwidth: f64,
#[serde(skip)]
train_row_mixtures: Arc<Vec<Vec<(usize, f64)>>>,
},
}
impl LatentMeasureKind {
pub fn validate(&self, context: &str) -> Result<(), String> {
match self {
Self::StandardNormal => Ok(()),
Self::GlobalEmpirical { grid } => {
validate_empirical_z_grid(&grid.nodes, &grid.weights, context)
}
Self::LocalEmpirical {
feature_cols,
input_scales,
centers,
grids,
top_k,
bandwidth,
..
} => {
if feature_cols.is_empty() {
return Err(format!(
"{context} local empirical latent measure needs feature columns"
));
}
if centers.is_empty() {
return Err(format!(
"{context} local empirical latent measure needs centers"
));
}
if centers.len() != grids.len() {
return Err(format!(
"{context} local empirical latent measure center/grid length mismatch: centers={}, grids={}",
centers.len(),
grids.len()
));
}
if *top_k == 0 || *top_k > centers.len() {
return Err(format!(
"{context} local empirical latent measure top_k must be in 1..={}, got {top_k}",
centers.len()
));
}
if !(*bandwidth).is_finite() || *bandwidth <= 0.0 {
return Err(format!(
"{context} local empirical latent measure bandwidth must be finite and positive, got {bandwidth}"
));
}
if let Some(scales) = input_scales.as_ref() {
if scales.len() != feature_cols.len() {
return Err(format!(
"{context} local empirical latent measure input scale dimension mismatch: scales={}, features={}",
scales.len(),
feature_cols.len()
));
}
for (scale_idx, scale) in scales.iter().enumerate() {
if !(scale.is_finite() && *scale > 0.0) {
return Err(format!(
"{context} local empirical latent measure input scale {scale_idx} must be finite and positive, got {scale}"
));
}
}
}
for (center_idx, center) in centers.iter().enumerate() {
if center.len() != feature_cols.len() {
return Err(format!(
"{context} local empirical latent center {center_idx} dimension mismatch: got {}, expected {}",
center.len(),
feature_cols.len()
));
}
if center.iter().any(|value| !value.is_finite()) {
return Err(format!(
"{context} local empirical latent center {center_idx} has non-finite coordinates"
));
}
}
for (grid_idx, grid) in grids.iter().enumerate() {
validate_empirical_z_grid(
&grid.nodes,
&grid.weights,
&format!("{context} local empirical grid {grid_idx}"),
)?;
}
Ok(())
}
}
}
fn is_empirical(&self) -> bool {
matches!(
self,
Self::GlobalEmpirical { .. } | Self::LocalEmpirical { .. }
)
}
fn empirical_grid_for_training_row(
&self,
row: usize,
) -> Result<Option<std::borrow::Cow<'_, EmpiricalZGrid>>, String> {
match self {
Self::StandardNormal => Ok(None),
Self::GlobalEmpirical { grid } => Ok(Some(std::borrow::Cow::Borrowed(grid))),
Self::LocalEmpirical {
grids,
train_row_mixtures,
..
} => {
let mixture = train_row_mixtures.get(row).ok_or_else(|| {
format!(
"local empirical latent measure is missing training mixture for row {row}"
)
})?;
Ok(Some(std::borrow::Cow::Owned(combine_empirical_grids(
grids, mixture,
)?)))
}
}
}
}
fn validate_empirical_z_grid(nodes: &[f64], weights: &[f64], context: &str) -> Result<(), String> {
if nodes.len() != weights.len() {
return Err(format!(
"{context} empirical latent measure node/weight length mismatch: nodes={}, weights={}",
nodes.len(),
weights.len()
));
}
if nodes.len() < 2 {
return Err(format!(
"{context} empirical latent measure requires at least two nodes"
));
}
let mut total = 0.0;
for (idx, (&node, &weight)) in nodes.iter().zip(weights.iter()).enumerate() {
if !node.is_finite() {
return Err(format!(
"{context} empirical latent measure node {idx} is non-finite ({node})"
));
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"{context} empirical latent measure weight {idx} must be finite and positive, got {weight}"
));
}
total += weight;
}
if !(total.is_finite() && (total - 1.0).abs() <= 1e-8) {
return Err(format!(
"{context} empirical latent measure weights must sum to 1, got {total}"
));
}
Ok(())
}
fn combine_empirical_grids(
grids: &[EmpiricalZGrid],
mixture: &[(usize, f64)],
) -> Result<EmpiricalZGrid, String> {
if mixture.is_empty() {
return Err("local empirical latent measure row mixture is empty".to_string());
}
let mut nodes = Vec::new();
let mut weights = Vec::new();
for &(grid_idx, grid_weight) in mixture {
if !grid_weight.is_finite() || grid_weight <= 0.0 {
return Err(format!(
"local empirical latent mixture weight must be finite and positive, got {grid_weight}"
));
}
let grid = grids.get(grid_idx).ok_or_else(|| {
format!("local empirical latent mixture references missing grid {grid_idx}")
})?;
for (node, weight) in grid.pairs() {
nodes.push(node);
weights.push(grid_weight * weight);
}
}
let total = weights.iter().copied().sum::<f64>();
if !(total.is_finite() && total > 0.0) {
return Err(
"local empirical latent combined grid has non-positive total weight".to_string(),
);
}
for weight in &mut weights {
*weight /= total;
}
EmpiricalZGrid::new(nodes, weights, "local empirical latent combined grid")
}
#[derive(Clone, Debug)]
pub struct LatentZPolicy {
pub check_mode: LatentZCheckMode,
pub normalization: LatentZNormalizationMode,
pub latent_measure: LatentMeasureSpec,
pub mean_tol_multiplier: f64,
pub sd_tol_multiplier: f64,
pub max_abs_skew: f64,
pub max_abs_excess_kurtosis: f64,
}
impl LatentZPolicy {
pub fn frozen_transformation_normal() -> Self {
Self {
check_mode: LatentZCheckMode::WarnOnly,
normalization: LatentZNormalizationMode::Frozen { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureSpec::auto_default(),
mean_tol_multiplier: 4.0,
sd_tol_multiplier: 4.0,
max_abs_skew: 4.0,
max_abs_excess_kurtosis: 20.0,
}
}
pub fn exploratory_fit_weighted() -> Self {
Self {
check_mode: LatentZCheckMode::WarnOnly,
normalization: LatentZNormalizationMode::FitWeighted,
latent_measure: LatentMeasureSpec::auto_default(),
mean_tol_multiplier: 8.0,
sd_tol_multiplier: 8.0,
max_abs_skew: 4.0,
max_abs_excess_kurtosis: 20.0,
}
}
}
impl Default for LatentZPolicy {
fn default() -> Self {
Self::frozen_transformation_normal()
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct LatentZNormalization {
pub mean: f64,
pub sd: f64,
}
impl LatentZNormalization {
pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, String> {
if !(self.mean.is_finite() && self.sd.is_finite() && self.sd > BMS_VARIANCE_FLOOR) {
return Err(format!(
"{context} requires finite latent z normalization with sd > {BMS_VARIANCE_FLOOR:e}; got mean={} sd={}",
self.mean, self.sd
));
}
if z.iter().any(|value| !value.is_finite()) {
return Err(format!("{context} requires finite z values"));
}
Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct LatentZRankIntCalibration {
pub sorted_z: Vec<f64>,
pub weighted_cdf: Vec<f64>,
pub post_mean: f64,
pub post_sd: f64,
}
impl LatentZRankIntCalibration {
pub fn fit(z: &Array1<f64>, weights: &Array1<f64>) -> Result<Self, String> {
if z.len() != weights.len() {
return Err(format!(
"rank-INT calibration: z length {} != weights length {}",
z.len(),
weights.len()
));
}
if z.is_empty() {
return Err("rank-INT calibration requires at least one observation".to_string());
}
let w_total = weights.iter().copied().sum::<f64>();
if !(w_total.is_finite() && w_total > 0.0) {
return Err(format!(
"rank-INT calibration requires positive finite total weight, got {w_total}"
));
}
for (idx, value) in z.iter().enumerate() {
if !value.is_finite() {
return Err(format!(
"rank-INT calibration: z[{idx}] = {value} not finite"
));
}
}
for (idx, weight) in weights.iter().enumerate() {
if !(weight.is_finite() && *weight >= 0.0) {
return Err(format!(
"rank-INT calibration: weight[{idx}] = {weight} not finite/non-negative"
));
}
}
let mut order: Vec<usize> = (0..z.len()).collect();
order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap_or(std::cmp::Ordering::Equal));
let mut sorted_z: Vec<f64> = Vec::with_capacity(z.len());
let mut weighted_cdf: Vec<f64> = Vec::with_capacity(z.len());
let denom = w_total + 0.25;
let eps = 0.5 / w_total.max(1.0);
let mut cum_w = 0.0_f64;
let mut last_z: Option<f64> = None;
for &idx in &order {
cum_w += weights[idx];
let zi = z[idx];
if let Some(prev) = last_z
&& zi == prev
{
if let Some(slot) = weighted_cdf.last_mut() {
let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
*slot = p;
}
continue;
}
let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
sorted_z.push(zi);
weighted_cdf.push(p);
last_z = Some(zi);
}
let mut sum_wz = 0.0_f64;
let mut sum_w = 0.0_f64;
for &idx in &order {
let zi = z[idx];
let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
sum_wz += weights[idx] * calibrated;
sum_w += weights[idx];
}
let post_mean = if sum_w > 0.0 { sum_wz / sum_w } else { 0.0 };
let mut sum_w_dev = 0.0_f64;
for &idx in &order {
let zi = z[idx];
let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
let d = calibrated - post_mean;
sum_w_dev += weights[idx] * d * d;
}
let post_sd = if sum_w > 0.0 {
(sum_w_dev / sum_w).sqrt()
} else {
1.0
};
Ok(Self {
sorted_z,
weighted_cdf,
post_mean,
post_sd,
})
}
pub fn apply_to_training(&self, z: &Array1<f64>) -> Result<Array1<f64>, String> {
if self.sorted_z.is_empty() {
return Err("rank-INT calibration has no knots".to_string());
}
let mut out = Array1::<f64>::zeros(z.len());
for (idx, &zi) in z.iter().enumerate() {
if !zi.is_finite() {
return Err(format!(
"rank-INT calibration apply: z[{idx}] = {zi} not finite"
));
}
out[idx] = self.apply_at_predict(zi);
}
Ok(out)
}
pub fn apply_at_predict(&self, z: f64) -> f64 {
Self::apply_with_knots(z, &self.sorted_z, &self.weighted_cdf)
}
fn apply_with_knots(z: f64, sorted_z: &[f64], weighted_cdf: &[f64]) -> f64 {
assert_eq!(sorted_z.len(), weighted_cdf.len());
assert!(!sorted_z.is_empty());
let n = sorted_z.len();
let p = if z <= sorted_z[0] {
weighted_cdf[0]
} else if z >= sorted_z[n - 1] {
weighted_cdf[n - 1]
} else {
let mut lo = 0usize;
let mut hi = n - 1;
while hi - lo > 1 {
let mid = (lo + hi) / 2;
if sorted_z[mid] <= z {
lo = mid;
} else {
hi = mid;
}
}
let z_lo = sorted_z[lo];
let z_hi = sorted_z[hi];
let p_lo = weighted_cdf[lo];
let p_hi = weighted_cdf[hi];
if z_hi == z_lo {
p_hi
} else {
let t = (z - z_lo) / (z_hi - z_lo);
p_lo + t * (p_hi - p_lo)
}
};
standard_normal_quantile(p).unwrap_or_else(|_| if p < 0.5 { -8.0 } else { 8.0 })
}
}
#[derive(Clone, Debug)]
pub enum LatentMeasureCalibration {
None,
RankInverseNormal(LatentZRankIntCalibration),
ConditionalLocationScale(LatentZConditionalCalibration),
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct LatentZConditionalCalibration {
pub mean_coeffs: Vec<f64>,
pub var_coeffs: Vec<f64>,
pub basis_ncols: usize,
pub var_floor: f64,
pub global_var: f64,
pub post_mean: f64,
pub post_sd: f64,
pub mean_cov: Array2<f64>,
pub var_cov: Array2<f64>,
}
impl LatentZConditionalCalibration {
#[inline]
fn affine(coeffs: &[f64], a_row: ArrayView1<'_, f64>) -> f64 {
let mut acc = coeffs[0];
for (c, &x) in coeffs[1..].iter().zip(a_row.iter()) {
acc += c * x;
}
acc
}
fn conditional_mean(&self, a_row: ArrayView1<'_, f64>) -> f64 {
Self::affine(&self.mean_coeffs, a_row)
}
fn conditional_var(&self, a_row: ArrayView1<'_, f64>) -> f64 {
if self.var_coeffs.is_empty() {
self.global_var.max(self.var_floor)
} else {
Self::affine(&self.var_coeffs, a_row).max(self.var_floor)
}
}
pub fn apply(
&self,
z: ArrayView1<'_, f64>,
a_block: ArrayView2<'_, f64>,
) -> Result<Array1<f64>, String> {
if a_block.ncols() != self.basis_ncols {
return Err(format!(
"conditional latent calibration expects {} basis columns, got {}",
self.basis_ncols,
a_block.ncols()
));
}
if a_block.nrows() != z.len() {
return Err(format!(
"conditional latent calibration row mismatch: z={}, basis rows={}",
z.len(),
a_block.nrows()
));
}
if self.mean_coeffs.len() != self.basis_ncols + 1 {
return Err(format!(
"conditional latent calibration mean coefficient length {} != basis_ncols+1 ({})",
self.mean_coeffs.len(),
self.basis_ncols + 1
));
}
let mut out = Array1::<f64>::zeros(z.len());
for i in 0..z.len() {
let a_row = a_block.row(i);
if !z[i].is_finite() {
return Err(format!(
"conditional latent calibration: z[{i}] = {} not finite",
z[i]
));
}
let m = self.conditional_mean(a_row);
let v = self.conditional_var(a_row);
if !(v.is_finite() && v > 0.0) {
return Err(format!(
"conditional latent calibration produced non-positive variance {v} at row {i}"
));
}
let zeta = (z[i] - m) / v.sqrt();
if !zeta.is_finite() {
return Err(format!(
"conditional latent calibration produced non-finite zeta at row {i}"
));
}
out[i] = zeta;
}
Ok(out)
}
pub fn theta1_dim(&self) -> usize {
self.mean_coeffs.len() + self.var_coeffs.len()
}
pub fn zeta_theta1_jacobian_row(&self, z: f64, a_row: ArrayView1<'_, f64>) -> Vec<f64> {
let m = self.conditional_mean(a_row);
let v = self.conditional_var(a_row);
let inv_sqrt_v = 1.0 / v.sqrt();
let mut out = Vec::with_capacity(self.theta1_dim());
let dzeta_dm = -inv_sqrt_v;
out.push(dzeta_dm); for &x in a_row.iter() {
out.push(dzeta_dm * x);
}
if !self.var_coeffs.is_empty() {
let raw_v = Self::affine(&self.var_coeffs, a_row);
let dzeta_dv = if raw_v > self.var_floor {
let zeta = (z - m) * inv_sqrt_v;
-zeta / (2.0 * v)
} else {
0.0
};
out.push(dzeta_dv);
for &x in a_row.iter() {
out.push(dzeta_dv * x);
}
}
out
}
pub fn theta1_covariance(&self) -> Array2<f64> {
let dm = self.mean_coeffs.len();
let dv = self.var_coeffs.len();
let mut v1 = Array2::<f64>::zeros((dm + dv, dm + dv));
v1.slice_mut(s![..dm, ..dm]).assign(&self.mean_cov);
if dv > 0 {
v1.slice_mut(s![dm.., dm..]).assign(&self.var_cov);
}
v1
}
pub fn generated_regressor_term(&self, hbeta_inv_g: ArrayView2<'_, f64>) -> Array2<f64> {
let v1 = self.theta1_covariance();
hbeta_inv_g.dot(&v1).dot(&hbeta_inv_g.t())
}
pub fn generated_regressor_correction(
&self,
score_zeta_sensitivity: ArrayView2<'_, f64>,
z: ArrayView1<'_, f64>,
a_block: ArrayView2<'_, f64>,
vb: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
let n = score_zeta_sensitivity.nrows();
let p_beta = score_zeta_sensitivity.ncols();
let dim_theta1 = self.theta1_dim();
if z.len() != n || a_block.nrows() != n {
return Err(format!(
"generated_regressor_correction row mismatch: score_zeta_sensitivity rows={n}, \
z={}, a_block rows={}",
z.len(),
a_block.nrows()
));
}
if a_block.ncols() != self.basis_ncols {
return Err(format!(
"generated_regressor_correction expects {} basis columns, got {}",
self.basis_ncols,
a_block.ncols()
));
}
if vb.nrows() != p_beta || vb.ncols() != p_beta {
return Err(format!(
"generated_regressor_correction: vb must be {p_beta}×{p_beta}, got {}×{}",
vb.nrows(),
vb.ncols()
));
}
let mut g = Array2::<f64>::zeros((p_beta, dim_theta1));
for i in 0..n {
let j_zeta_row = self.zeta_theta1_jacobian_row(z[i], a_block.row(i));
assert_eq!(
j_zeta_row.len(),
dim_theta1,
"J_zeta row width must match the first-stage hyperparameter dimension"
);
let s_i = score_zeta_sensitivity.row(i);
if j_zeta_row.iter().all(|&v| v == 0.0) {
continue;
}
for (b, &s_b) in s_i.iter().enumerate() {
if s_b == 0.0 {
continue;
}
for (k, &jz) in j_zeta_row.iter().enumerate() {
g[[b, k]] += s_b * jz;
}
}
}
let vb_g = vb.dot(&g);
Ok(self.generated_regressor_term(vb_g.view()))
}
}
fn weighted_ridge_sandwich_cov(
basis: ArrayView2<'_, f64>,
residuals: &[f64],
weights: ArrayView1<'_, f64>,
normal_matrix: &Array2<f64>,
) -> Result<Array2<f64>, String> {
let n = basis.nrows();
let p = basis.ncols();
if residuals.len() != n || weights.len() != n {
return Err(format!(
"weighted ridge sandwich length mismatch: rows={n}, residuals={}, weights={}",
residuals.len(),
weights.len()
));
}
if normal_matrix.nrows() != p || normal_matrix.ncols() != p {
return Err(format!(
"weighted ridge sandwich normal-matrix shape mismatch: basis cols={p}, normal {}x{}",
normal_matrix.nrows(),
normal_matrix.ncols()
));
}
let mut b = basis.to_owned();
for i in 0..n {
let wi = weights[i];
let ri = residuals[i];
let scale = wi * ri;
if scale == 0.0 {
b.row_mut(i).fill(0.0);
continue;
}
b.row_mut(i).iter_mut().for_each(|value| *value *= scale);
}
let meat = crate::linalg::faer_ndarray::fast_ata(&b);
let mut m_sym = normal_matrix.clone();
crate::linalg::matrix::symmetrize_in_place(&mut m_sym);
let scale: Vec<f64> = (0..p)
.map(|j| 1.0 / m_sym[[j, j]].max(f64::MIN_POSITIVE).sqrt())
.collect();
let mut m_scaled = m_sym;
let mut meat_scaled = meat;
for i in 0..p {
for j in 0..p {
let s = scale[i] * scale[j];
m_scaled[[i, j]] *= s;
meat_scaled[[i, j]] *= s;
}
}
let (_rank, m_pinv) =
crate::linalg::utils::block_penalty_rank_and_pinv(&m_scaled).map_err(|e| {
format!("conditional latent calibration sandwich pseudo-inverse failed: {e}")
})?;
let mut cov = m_pinv.dot(&meat_scaled).dot(&m_pinv);
for i in 0..p {
for j in 0..p {
cov[[i, j]] *= scale[i] * scale[j];
}
}
if cov.iter().any(|v| !v.is_finite()) {
return Err("conditional latent calibration sandwich covariance is non-finite".to_string());
}
Ok(cov)
}
fn weighted_mean(values: &[f64], weights: ArrayView1<'_, f64>, total_weight: f64) -> f64 {
values
.iter()
.zip(weights.iter())
.map(|(&v, &w)| w * v)
.sum::<f64>()
/ total_weight
}
fn robust_conditional_score_pvalue(
a_centered: ArrayView2<'_, f64>,
u: &[f64],
weights: ArrayView1<'_, f64>,
) -> Result<Option<f64>, String> {
let n = a_centered.nrows();
let r = a_centered.ncols();
if r == 0 || n == 0 {
return Ok(None);
}
if u.len() != n || weights.len() != n {
return Err(format!(
"conditional score test length mismatch: rows={n}, u={}, weights={}",
u.len(),
weights.len()
));
}
let mut s = Array1::<f64>::zeros(r);
let mut omega = Array2::<f64>::zeros((r, r));
for i in 0..n {
let wi = weights[i];
if wi <= 0.0 {
continue;
}
let ui = u[i];
let a_row = a_centered.row(i);
let wu = wi * ui;
for j in 0..r {
s[j] += wu * a_row[j];
}
let w2u2 = wi * wi * ui * ui;
if w2u2 == 0.0 {
continue;
}
for j in 0..r {
let aj = a_row[j];
if aj == 0.0 {
continue;
}
let scaled = w2u2 * aj;
for k in j..r {
let inc = scaled * a_row[k];
omega[[j, k]] += inc;
if k != j {
omega[[k, j]] += inc;
}
}
}
}
if !s.iter().all(|v| v.is_finite()) || !omega.iter().all(|v| v.is_finite()) {
return Ok(None);
}
let (rank, omega_pinv) = crate::linalg::utils::block_penalty_rank_and_pinv(&omega)
.map_err(|e| format!("conditional score test pseudo-inverse failed: {e}"))?;
if rank == 0 {
return Ok(None);
}
let d_stat = s.dot(&omega_pinv.dot(&s));
if !(d_stat.is_finite() && d_stat >= 0.0) {
return Ok(None);
}
let p_lower = statrs::function::gamma::gamma_lr(rank as f64 / 2.0, d_stat / 2.0);
let p_value = (1.0 - p_lower).clamp(0.0, 1.0);
Ok(Some(p_value))
}
fn fit_conditional_latent_calibration_if_needed(
z: &Array1<f64>,
weights: &Array1<f64>,
a_block: ArrayView2<'_, f64>,
) -> Result<Option<LatentZConditionalCalibration>, String> {
let n = z.len();
let p = a_block.ncols();
if n != weights.len() {
return Err(format!(
"conditional latent gate length mismatch: z={n}, weights={}",
weights.len()
));
}
if a_block.nrows() != n {
return Err(format!(
"conditional latent gate row mismatch: z={n}, basis rows={}",
a_block.nrows()
));
}
if p == 0 {
return Ok(None);
}
let total_weight = weights.iter().copied().sum::<f64>();
if !(total_weight.is_finite() && total_weight > 0.0) {
return Ok(None);
}
if z.iter().any(|v| !v.is_finite()) || a_block.iter().any(|v| !v.is_finite()) {
return Ok(None);
}
let z_mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ total_weight;
let global_var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - z_mean) * (zi - z_mean))
.sum::<f64>()
/ total_weight;
if !(global_var.is_finite() && global_var > 0.0) {
return Ok(None);
}
let mut a_centered = a_block.to_owned();
for j in 0..p {
let col = a_block.column(j);
let col_mean = col
.iter()
.zip(weights.iter())
.map(|(&v, &w)| w * v)
.sum::<f64>()
/ total_weight;
a_centered.column_mut(j).mapv_inplace(|v| v - col_mean);
}
let u_mean: Vec<f64> = z.iter().map(|&zi| zi - z_mean).collect();
let p_mean = robust_conditional_score_pvalue(a_centered.view(), &u_mean, weights.view())?;
let u_var: Vec<f64> = u_mean.iter().map(|&e| e * e - global_var).collect();
let p_var = robust_conditional_score_pvalue(a_centered.view(), &u_var, weights.view())?;
let mean_fires = p_mean.is_some_and(|p| p < AUTO_Z_CONDITIONAL_RAO_ALPHA);
let var_fires = p_var.is_some_and(|p| p < AUTO_Z_CONDITIONAL_RAO_ALPHA);
if !mean_fires && !var_fires {
return Ok(None);
}
let basis = build_intercept_basis(a_block);
let mut penalty = Array2::<f64>::zeros((basis.ncols(), basis.ncols()));
for j in 0..basis.ncols() {
let diag_jj = basis
.column(j)
.iter()
.zip(weights.iter())
.map(|(&x, &w)| w * x * x)
.sum::<f64>()
.max(f64::MIN_POSITIVE);
penalty[[j, j]] = diag_jj;
}
let z_col = z.view().insert_axis(ndarray::Axis(1));
let (mean_coeffs_mat, mean_fitted) = crate::linalg::utils::gaussian_weighted_ridge(
basis.view(),
z_col,
penalty.view(),
weights.view(),
AUTO_Z_CONDITIONAL_RIDGE_REL,
)?;
let mean_coeffs: Vec<f64> = mean_coeffs_mat.column(0).to_vec();
let normal_matrix = {
let mut wa = basis.to_owned();
for i in 0..wa.nrows() {
let wi = weights[i];
wa.row_mut(i).iter_mut().for_each(|value| *value *= wi);
}
let mut m = basis.t().dot(&wa);
m += &(penalty.to_owned() * AUTO_Z_CONDITIONAL_RIDGE_REL);
m
};
let mean_residuals: Vec<f64> = z
.iter()
.zip(mean_fitted.column(0).iter())
.map(|(&zi, &mi)| zi - mi)
.collect();
let mean_cov = weighted_ridge_sandwich_cov(
basis.view(),
&mean_residuals,
weights.view(),
&normal_matrix,
)?;
let var_floor = (AUTO_Z_CONDITIONAL_VAR_FLOOR_FRAC * global_var).max(f64::MIN_POSITIVE);
let (var_coeffs, var_cov): (Vec<f64>, Array2<f64>) = if var_fires {
let resid_sq: Array1<f64> = mean_residuals.iter().map(|&e| e * e).collect();
let resid_col = resid_sq.view().insert_axis(ndarray::Axis(1));
let (var_coeffs_mat, var_fitted) = crate::linalg::utils::gaussian_weighted_ridge(
basis.view(),
resid_col,
penalty.view(),
weights.view(),
AUTO_Z_CONDITIONAL_RIDGE_REL,
)?;
let var_residuals: Vec<f64> = resid_sq
.iter()
.zip(var_fitted.column(0).iter())
.map(|(&si, &vi)| si - vi)
.collect();
let cov = weighted_ridge_sandwich_cov(
basis.view(),
&var_residuals,
weights.view(),
&normal_matrix,
)?;
(var_coeffs_mat.column(0).to_vec(), cov)
} else {
(Vec::new(), Array2::<f64>::zeros((0, 0)))
};
let mut calibration = LatentZConditionalCalibration {
mean_coeffs,
var_coeffs,
basis_ncols: p,
var_floor,
global_var,
post_mean: 0.0,
post_sd: 1.0,
mean_cov,
var_cov,
};
let calibrated = calibration.apply(z.view(), a_block)?;
let post_mean = weighted_mean(calibrated.as_slice().unwrap(), weights.view(), total_weight);
let post_var = calibrated
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - post_mean) * (zi - post_mean))
.sum::<f64>()
/ total_weight;
calibration.post_mean = post_mean;
calibration.post_sd = post_var.max(0.0).sqrt();
Ok(Some(calibration))
}
fn build_intercept_basis(a_block: ArrayView2<'_, f64>) -> Array2<f64> {
let n = a_block.nrows();
let p = a_block.ncols();
let mut basis = Array2::<f64>::ones((n, p + 1));
basis.slice_mut(s![.., 1..]).assign(&a_block);
basis
}
fn build_latent_measure_with_geometry(
z: &Array1<f64>,
weights: &Array1<f64>,
policy: &LatentZPolicy,
conditioning: Option<ArrayView2<'_, f64>>,
) -> Result<(LatentMeasureKind, LatentMeasureCalibration), String> {
match policy.latent_measure {
LatentMeasureSpec::Auto { grid_size: _ } => {
if let Some(a_block) = conditioning
&& let Some(cal) =
fit_conditional_latent_calibration_if_needed(z, weights, a_block)?
{
log::info!(
"[BMS latent-z] conditional location-scale calibrated: basis_ncols={} var_active={} post_mean={:.3e} post_sd={:.3e} (E[z|C]/Var(z|C) Rao gate fired)",
cal.basis_ncols,
!cal.var_coeffs.is_empty(),
cal.post_mean,
cal.post_sd,
);
return Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::ConditionalLocationScale(cal),
));
}
if latent_z_is_standard_normal_enough(z, weights, policy)? {
Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::None,
))
} else {
let calibration = LatentZRankIntCalibration::fit(z, weights)?;
log::info!(
"[BMS latent-z] rank-INT calibrated: post_mean={:.3e} post_sd={:.3e} knots={}",
calibration.post_mean,
calibration.post_sd,
calibration.sorted_z.len(),
);
Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::RankInverseNormal(calibration),
))
}
}
LatentMeasureSpec::StandardNormal => Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::None,
)),
LatentMeasureSpec::GlobalEmpirical { grid_size } => {
let kind = build_global_empirical_latent_measure(z, weights, grid_size)?;
Ok((kind, LatentMeasureCalibration::None))
}
}
}
fn latent_z_is_standard_normal_enough(
z: &Array1<f64>,
weights: &Array1<f64>,
policy: &LatentZPolicy,
) -> Result<bool, String> {
if z.len() != weights.len() {
return Err(format!(
"latent-measure auto-detection length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
if !(weight_sum.is_finite()
&& weight_sum > 0.0
&& weight_sq_sum.is_finite()
&& weight_sq_sum > 0.0)
{
return Err("latent-measure auto-detection requires positive finite weights".to_string());
}
let effective_n = weight_sum * weight_sum / weight_sq_sum;
if !(effective_n.is_finite() && effective_n > 1.0) {
return Err(
"latent-measure auto-detection requires at least two effective observations"
.to_string(),
);
}
let mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ weight_sum;
let var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
.sum::<f64>()
/ weight_sum;
let sd = var.sqrt();
if !(mean.is_finite() && sd.is_finite() && sd > 0.0) {
return Ok(false);
}
let skew = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| {
let centered = (zi - mean) / sd;
wi * centered.powi(3)
})
.sum::<f64>()
/ weight_sum;
let excess_kurtosis = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| {
let centered = (zi - mean) / sd;
wi * centered.powi(4)
})
.sum::<f64>()
/ weight_sum
- 3.0;
let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
let ks_to_normal = weighted_ks_to_standard_normal(z, weights, weight_sum)?;
let tail_mass_4 = weighted_tail_mass(z, weights, weight_sum, AUTO_Z_NORMAL_TAIL_SIGMA_INNER);
let tail_mass_6 = weighted_tail_mass(z, weights, weight_sum, AUTO_Z_NORMAL_TAIL_SIGMA_OUTER);
let max_abs_z = z.iter().fold(0.0_f64, |acc, &zi| acc.max(zi.abs()));
let normal_tail_4 = 2.0 * (1.0 - normal_cdf(AUTO_Z_NORMAL_TAIL_SIGMA_INNER));
let normal_tail_6 = 2.0 * (1.0 - normal_cdf(AUTO_Z_NORMAL_TAIL_SIGMA_OUTER));
Ok(mean.abs() <= mean_tol
&& (sd - 1.0).abs() <= sd_tol
&& skew.is_finite()
&& skew.abs() <= policy.max_abs_skew.min(AUTO_Z_NORMAL_SKEW_TOL)
&& excess_kurtosis.is_finite()
&& excess_kurtosis.abs() <= policy.max_abs_excess_kurtosis.min(AUTO_Z_NORMAL_KURT_TOL)
&& ks_to_normal.is_finite()
&& ks_to_normal <= AUTO_Z_NORMAL_KS_TOL
&& tail_mass_4
<= AUTO_Z_NORMAL_TAIL_MASS_SLACK * normal_tail_4 + AUTO_Z_NORMAL_TAIL_FLOOR_INNER
&& tail_mass_6
<= AUTO_Z_NORMAL_TAIL_MASS_SLACK * normal_tail_6 + AUTO_Z_NORMAL_TAIL_FLOOR_OUTER
&& max_abs_z < AUTO_Z_NORMAL_MAX_ABS)
}
fn build_global_empirical_latent_measure(
z: &Array1<f64>,
weights: &Array1<f64>,
grid_size: usize,
) -> Result<LatentMeasureKind, String> {
let grid = build_empirical_z_grid(z, weights, grid_size, "empirical latent measure")?;
let measure = LatentMeasureKind::GlobalEmpirical { grid };
measure.validate("empirical latent measure")?;
Ok(measure)
}
fn weighted_ks_to_standard_normal(
z: &Array1<f64>,
weights: &Array1<f64>,
total_weight: f64,
) -> Result<f64, String> {
let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
for (&zi, &wi) in z.iter().zip(weights.iter()) {
if !zi.is_finite() || !wi.is_finite() || wi < 0.0 {
return Err(
"latent-measure KS diagnostic requires finite z and non-negative finite weights"
.to_string(),
);
}
if wi > 0.0 {
pairs.push((zi, wi));
}
}
pairs.sort_by(|left, right| {
left.0
.partial_cmp(&right.0)
.expect("validated latent z values are finite")
});
let mut prev = 0.0;
let mut ks = 0.0_f64;
for (zi, wi) in pairs {
let cdf = normal_cdf(zi);
let next = prev + wi / total_weight;
ks = ks.max((cdf - prev).abs()).max((cdf - next).abs());
prev = next;
}
Ok(ks)
}
fn weighted_tail_mass(
z: &Array1<f64>,
weights: &Array1<f64>,
total_weight: f64,
cutoff: f64,
) -> f64 {
z.iter()
.zip(weights.iter())
.filter(|&(&zi, _)| zi.abs() > cutoff)
.map(|(_, &wi)| wi)
.sum::<f64>()
/ total_weight
}
fn build_empirical_z_grid(
z: &Array1<f64>,
weights: &Array1<f64>,
grid_size: usize,
context: &str,
) -> Result<EmpiricalZGrid, String> {
if grid_size < 3 {
return Err(format!(
"empirical latent measure grid_size must be at least 3, got {grid_size}"
));
}
if z.len() != weights.len() {
return Err(format!(
"{context} length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
for (idx, (&zi, &wi)) in z.iter().zip(weights.iter()).enumerate() {
if !zi.is_finite() {
return Err(format!(
"{context} z value at row {idx} is non-finite ({zi})"
));
}
if !wi.is_finite() || wi < 0.0 {
return Err(format!(
"{context} weight at row {idx} must be finite and non-negative, got {wi}"
));
}
if wi > 0.0 {
pairs.push((zi, wi));
}
}
if pairs.len() < 2 {
return Err(format!(
"{context} requires at least two positive-weight rows"
));
}
pairs.sort_by(|left, right| {
left.0
.partial_cmp(&right.0)
.expect("validated empirical latent z values are finite")
});
let total_weight = pairs.iter().map(|(_, weight)| *weight).sum::<f64>();
if !(total_weight.is_finite() && total_weight > 0.0) {
return Err(format!("{context} requires positive finite total weight"));
}
let m = grid_size.min(pairs.len());
let mut nodes = Vec::with_capacity(m);
let mut out_weights = Vec::with_capacity(m);
let bin_weight_target = total_weight / (m as f64);
let mut cursor = 0usize;
let mut remaining = pairs[0].1;
for _ in 0..m {
let mut need = bin_weight_target;
let mut bin_weight = 0.0;
let mut bin_sum = 0.0;
while need > EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL * bin_weight_target
&& cursor < pairs.len()
{
let take = remaining.min(need);
bin_sum += take * pairs[cursor].0;
bin_weight += take;
need -= take;
remaining -= take;
if remaining <= EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL * pairs[cursor].1 {
cursor += 1;
if cursor < pairs.len() {
remaining = pairs[cursor].1;
}
}
}
if bin_weight > 0.0 {
nodes.push(bin_sum / bin_weight);
out_weights.push(bin_weight / total_weight);
}
}
if nodes.len() < 2 {
return Err(format!(
"{context} compression produced fewer than two nodes"
));
}
recenter_rescale_empirical_grid(&mut nodes, &out_weights);
let total = out_weights.iter().sum::<f64>();
if total.is_finite() && total > 0.0 {
for weight in &mut out_weights {
*weight /= total;
}
}
validate_empirical_z_grid(&nodes, &out_weights, context)?;
Ok(EmpiricalZGrid {
nodes,
weights: out_weights,
})
}
fn recenter_rescale_empirical_grid(nodes: &mut [f64], weights: &[f64]) {
let total = weights.iter().sum::<f64>();
if !(total.is_finite() && total > 0.0) {
return;
}
let mean = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * node)
.sum::<f64>()
/ total;
let var = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * (node - mean).powi(2))
.sum::<f64>()
/ total;
let sd = var.sqrt();
if sd.is_finite() && sd > BMS_VARIANCE_FLOOR {
for node in nodes {
*node = (*node - mean) / sd;
}
}
}
pub(super) const BMS_AUTO_SUBSAMPLE_PHASE1_BUDGET: usize = 12;
pub(super) const BERNOULLI_LINK_PROBABILITY_EPS: f64 = 1e-12;
pub(super) const BMS_VARIANCE_FLOOR: f64 = 1e-12;
pub(super) const BMS_DERIV_TOL: f64 = 1e-8;
pub(super) const EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL: f64 = 1e-14;
pub(super) const ROW_CHUNK_SIZE: usize = 1024;
pub(super) const EXACT_WORK_LOG_MIN_ROWS: usize = 50_000;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_EXPECTED_REUSE_PASSES: usize = 3;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_MIN_REUSE_PASSES: usize = 2;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_TILE_ROWS: usize = 8192;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_NUM: u64 = 1;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_DEN: u64 = 4;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_NUM: u64 = 1;
pub(super) const BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_DEN: u64 = 2;
pub(super) const BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS: usize = 10_000;
pub(crate) mod block_specs;
pub(crate) mod exact_eval_cache;
pub(crate) mod family;
pub(crate) mod gradient_paths;
pub(crate) mod hessian_paths;
pub(crate) mod install_flex;
pub(crate) mod row_kernel;
#[cfg(test)]
#[path = "../../../tests/src_modules/families_bms_tests.rs"]
mod tests;
pub(crate) mod workspace;
pub use block_specs::fit_bernoulli_marginal_slope_terms;
pub use gradient_paths::{
MarginalSlopeCovariance, MarginalSlopeCovarianceShape, marginal_slope_covariance_from_scores,
marginal_slope_preserving_scale, marginal_slope_probit_eta, padded_deviation_seed,
};
pub use install_flex::CrossBlockIdentifiabilityWarning;
pub(crate) use install_flex::FlexCompileOutcome;
pub(crate) use block_specs::push_deviation_aux_blockspecs;
pub use block_specs::{BmsLogslopeJacobian, BmsMarginalJacobian};
pub(crate) use family::{
BernoulliMarginalLinkMap, bernoulli_marginal_link_map,
build_link_deviation_block_from_knots_design_seed_and_weights,
build_score_warp_deviation_block_from_seed,
};
pub(crate) use gradient_paths::standardize_latent_z_with_policy;
pub(crate) use gradient_paths::{
empirical_intercept_from_marginal, signed_probit_neglog_derivatives_up_to_fourth,
unary_derivatives_log, unary_derivatives_log_normal_pdf, unary_derivatives_neglog_phi,
unary_derivatives_sqrt,
};
pub(crate) use install_flex::{
install_compiled_flex_block_into_runtime, project_monotone_feasible_beta,
validate_monotone_structural_feasible,
};