use crate::custom_family::{
BatchedOuterGradientTerms, BlockWorkingSet, BlockwiseFitOptions, CustomFamily,
CustomFamilyWarmStart, ExactNewtonJointGradientEvaluation, ExactNewtonJointHessianWorkspace,
ExactNewtonJointPsiSecondOrderTerms, ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace,
FamilyEvaluation, ParameterBlockSpec, ParameterBlockState, build_block_spatial_psi_derivatives,
custom_family_outer_derivatives, evaluate_custom_family_joint_hyper_efs_shared,
evaluate_custom_family_joint_hyper_shared, fit_custom_family,
};
use crate::estimate::UnifiedFitResult;
use crate::estimate::reml::unified::{DenseSpectralOperator, HessianOperator, HyperOperator};
use crate::families::gamlss::{ParameterBlockInput, initialize_monotone_wiggle_knots_from_seed};
use crate::families::lognormal_kernel::FrailtySpec;
use crate::families::marginal_slope_shared::{
CoeffSupport, ObservedDenestedCellPartials, SparsePrimaryCoeffJetView, WeightedOuterRow,
add_optional_matrix, add_optional_vector, add_two_surface_psi_outer,
build_denested_partition_cells as shared_denested_partition_cells, chunked_row_reduction,
eval_coeff4_at, is_sigma_aux_index as shared_is_sigma_aux_index,
observed_denested_cell_partials as shared_observed_denested_cell_partials, outer_row_indices,
outer_weighted_rows, probit_frailty_scale, probit_frailty_scale_multi_dir_jet,
psi_derivative_location, scale_coeff4,
};
use crate::families::row_kernel::{
RowKernel, RowKernelHessianWorkspace, build_row_kernel_cache, row_kernel_gradient,
row_kernel_hessian_dense, row_kernel_log_likelihood,
};
use crate::matrix::{DesignMatrix, SymmetricMatrix};
use crate::pirls::LinearInequalityConstraints;
use crate::probability::{
normal_cdf, normal_logcdf, normal_pdf, signed_probit_logcdf_and_mills_ratio,
standard_normal_quantile,
};
use crate::smooth::{
ExactJointHyperSetup, SpatialLengthScaleOptimizationOptions,
SpatialLogKappaCoords, TermCollectionDesign, TermCollectionSpec,
apply_spatial_anisotropy_pilot_initializer, build_term_collection_designs_and_freeze_joint,
optimize_spatial_length_scale_exact_joint, spatial_length_scale_term_indices,
};
use crate::types::{InverseLink, LinkFunction, WigglePenaltyConfig};
use ndarray::{Array1, Array2, ArrayView2, Axis, s};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
pub mod deviation_runtime;
pub(crate) mod exact_kernel;
pub use deviation_runtime::DeviationRuntime;
pub use deviation_runtime::ParametricAnchorBlock;
#[derive(Clone, Debug)]
pub struct DeviationBlockConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_order: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
pub monotonicity_eps: f64,
}
impl Default for DeviationBlockConfig {
fn default() -> Self {
WigglePenaltyConfig::cubic_triple_operator_default().into()
}
}
impl DeviationBlockConfig {
pub fn triple_penalty_default() -> Self {
Self::default()
}
}
impl From<WigglePenaltyConfig> for DeviationBlockConfig {
fn from(cfg: WigglePenaltyConfig) -> Self {
let penalty_order = *cfg.penalty_orders.iter().max().unwrap_or(&2);
Self {
degree: cfg.degree,
num_internal_knots: cfg.num_internal_knots,
penalty_order,
penalty_orders: cfg.penalty_orders,
double_penalty: cfg.double_penalty,
monotonicity_eps: cfg.monotonicity_eps,
}
}
}
#[derive(Clone)]
pub(crate) struct DeviationPrepared {
pub(crate) block: ParameterBlockInput,
pub(crate) runtime: DeviationRuntime,
}
impl std::fmt::Debug for DeviationPrepared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeviationPrepared").finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct BernoulliMarginalSlopeTermSpec {
pub y: Array1<f64>,
pub weights: Array1<f64>,
pub z: Array1<f64>,
pub base_link: InverseLink,
pub marginalspec: TermCollectionSpec,
pub logslopespec: TermCollectionSpec,
pub marginal_offset: Array1<f64>,
pub logslope_offset: Array1<f64>,
pub frailty: FrailtySpec,
pub score_warp: Option<DeviationBlockConfig>,
pub link_dev: Option<DeviationBlockConfig>,
pub latent_z_policy: LatentZPolicy,
}
impl BernoulliMarginalSlopeTermSpec {
pub fn calibrated_probit(
y: Array1<f64>,
weights: Array1<f64>,
z: Array1<f64>,
marginalspec: TermCollectionSpec,
logslopespec: TermCollectionSpec,
marginal_offset: Array1<f64>,
logslope_offset: Array1<f64>,
frailty: FrailtySpec,
protocol: crate::solver::protocol::MarginalSlopeCalibrationProtocol,
) -> Self {
Self {
y,
weights,
z,
base_link: protocol.base_link,
marginalspec,
logslopespec,
marginal_offset,
logslope_offset,
frailty,
score_warp: protocol.score_warp,
link_dev: protocol.link_deviation,
latent_z_policy: protocol.latent_score.into_policy(),
}
}
}
pub struct BernoulliMarginalSlopeFitResult {
pub fit: UnifiedFitResult,
pub marginalspec_resolved: TermCollectionSpec,
pub logslopespec_resolved: TermCollectionSpec,
pub marginal_design: TermCollectionDesign,
pub logslope_design: TermCollectionDesign,
pub baseline_marginal: f64,
pub baseline_logslope: f64,
pub z_normalization: LatentZNormalization,
pub latent_measure: LatentMeasureKind,
pub score_warp_runtime: Option<DeviationRuntime>,
pub link_dev_runtime: Option<DeviationRuntime>,
pub gaussian_frailty_sd: Option<f64>,
pub cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning>,
pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
}
#[derive(Clone, Debug)]
pub enum LatentZCheckMode {
Strict,
WarnOnly,
Off,
}
#[derive(Clone, Debug)]
pub enum LatentZNormalizationMode {
None,
FitWeighted,
Frozen { mean: f64, sd: f64 },
}
pub const DEFAULT_EMPIRICAL_LATENT_GRID_SIZE: usize = 65;
const AUTO_Z_NORMAL_SKEW_TOL: f64 = 0.10;
const AUTO_Z_NORMAL_KURT_TOL: f64 = 0.25;
const AUTO_Z_NORMAL_KS_TOL: f64 = 0.025;
const AUTO_Z_NORMAL_MAX_ABS: f64 = 8.0;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LatentMeasureSpec {
Auto { grid_size: usize },
StandardNormal,
GlobalEmpirical { grid_size: usize },
}
impl LatentMeasureSpec {
pub fn global_empirical_default() -> Self {
Self::GlobalEmpirical {
grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
}
}
pub fn auto_default() -> Self {
Self::Auto {
grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
}
}
}
impl Default for LatentMeasureSpec {
fn default() -> Self {
Self::auto_default()
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EmpiricalZGrid {
pub nodes: Vec<f64>,
pub weights: Vec<f64>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "kebab-case")]
pub enum LatentMeasureKind {
StandardNormal,
GlobalEmpirical {
nodes: Vec<f64>,
weights: Vec<f64>,
},
LocalEmpirical {
feature_cols: Vec<usize>,
#[serde(default)]
input_scales: Option<Vec<f64>>,
centers: Vec<Vec<f64>>,
grids: Vec<EmpiricalZGrid>,
top_k: usize,
bandwidth: f64,
#[serde(skip)]
train_row_mixtures: Arc<Vec<Vec<(usize, f64)>>>,
},
}
impl Default for LatentMeasureKind {
fn default() -> Self {
Self::StandardNormal
}
}
impl LatentMeasureKind {
pub fn validate(&self, context: &str) -> Result<(), String> {
match self {
Self::StandardNormal => Ok(()),
Self::GlobalEmpirical { nodes, weights } => {
validate_empirical_z_grid(nodes, weights, context)
}
Self::LocalEmpirical {
feature_cols,
input_scales,
centers,
grids,
top_k,
bandwidth,
..
} => {
if feature_cols.is_empty() {
return Err(format!(
"{context} local empirical latent measure needs feature columns"
));
}
if centers.is_empty() {
return Err(format!(
"{context} local empirical latent measure needs centers"
));
}
if centers.len() != grids.len() {
return Err(format!(
"{context} local empirical latent measure center/grid length mismatch: centers={}, grids={}",
centers.len(),
grids.len()
));
}
if *top_k == 0 || *top_k > centers.len() {
return Err(format!(
"{context} local empirical latent measure top_k must be in 1..={}, got {top_k}",
centers.len()
));
}
if !(*bandwidth).is_finite() || *bandwidth <= 0.0 {
return Err(format!(
"{context} local empirical latent measure bandwidth must be finite and positive, got {bandwidth}"
));
}
if let Some(scales) = input_scales.as_ref() {
if scales.len() != feature_cols.len() {
return Err(format!(
"{context} local empirical latent measure input scale dimension mismatch: scales={}, features={}",
scales.len(),
feature_cols.len()
));
}
for (scale_idx, scale) in scales.iter().enumerate() {
if !(scale.is_finite() && *scale > 0.0) {
return Err(format!(
"{context} local empirical latent measure input scale {scale_idx} must be finite and positive, got {scale}"
));
}
}
}
for (center_idx, center) in centers.iter().enumerate() {
if center.len() != feature_cols.len() {
return Err(format!(
"{context} local empirical latent center {center_idx} dimension mismatch: got {}, expected {}",
center.len(),
feature_cols.len()
));
}
if center.iter().any(|value| !value.is_finite()) {
return Err(format!(
"{context} local empirical latent center {center_idx} has non-finite coordinates"
));
}
}
for (grid_idx, grid) in grids.iter().enumerate() {
validate_empirical_z_grid(
&grid.nodes,
&grid.weights,
&format!("{context} local empirical grid {grid_idx}"),
)?;
}
Ok(())
}
}
}
fn is_empirical(&self) -> bool {
matches!(
self,
Self::GlobalEmpirical { .. } | Self::LocalEmpirical { .. }
)
}
fn empirical_grid_for_training_row(
&self,
row: usize,
) -> Result<Option<EmpiricalZGrid>, String> {
match self {
Self::StandardNormal => Ok(None),
Self::GlobalEmpirical { nodes, weights } => Ok(Some(EmpiricalZGrid {
nodes: nodes.clone(),
weights: weights.clone(),
})),
Self::LocalEmpirical {
grids,
train_row_mixtures,
..
} => {
let mixture = train_row_mixtures.get(row).ok_or_else(|| {
format!(
"local empirical latent measure is missing training mixture for row {row}"
)
})?;
Ok(Some(combine_empirical_grids(grids, mixture)?))
}
}
}
}
fn validate_empirical_z_grid(nodes: &[f64], weights: &[f64], context: &str) -> Result<(), String> {
if nodes.len() != weights.len() {
return Err(format!(
"{context} empirical latent measure node/weight length mismatch: nodes={}, weights={}",
nodes.len(),
weights.len()
));
}
if nodes.len() < 2 {
return Err(format!(
"{context} empirical latent measure requires at least two nodes"
));
}
let mut total = 0.0;
for (idx, (&node, &weight)) in nodes.iter().zip(weights.iter()).enumerate() {
if !node.is_finite() {
return Err(format!(
"{context} empirical latent measure node {idx} is non-finite ({node})"
));
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"{context} empirical latent measure weight {idx} must be finite and positive, got {weight}"
));
}
total += weight;
}
if !(total.is_finite() && (total - 1.0).abs() <= 1e-8) {
return Err(format!(
"{context} empirical latent measure weights must sum to 1, got {total}"
));
}
Ok(())
}
fn combine_empirical_grids(
grids: &[EmpiricalZGrid],
mixture: &[(usize, f64)],
) -> Result<EmpiricalZGrid, String> {
if mixture.is_empty() {
return Err("local empirical latent measure row mixture is empty".to_string());
}
let mut nodes = Vec::new();
let mut weights = Vec::new();
for &(grid_idx, grid_weight) in mixture {
if !grid_weight.is_finite() || grid_weight <= 0.0 {
return Err(format!(
"local empirical latent mixture weight must be finite and positive, got {grid_weight}"
));
}
let grid = grids.get(grid_idx).ok_or_else(|| {
format!("local empirical latent mixture references missing grid {grid_idx}")
})?;
for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
nodes.push(node);
weights.push(grid_weight * weight);
}
}
let total = weights.iter().copied().sum::<f64>();
if !(total.is_finite() && total > 0.0) {
return Err(
"local empirical latent combined grid has non-positive total weight".to_string(),
);
}
for weight in &mut weights {
*weight /= total;
}
validate_empirical_z_grid(&nodes, &weights, "local empirical latent combined grid")?;
Ok(EmpiricalZGrid { nodes, weights })
}
#[derive(Clone, Debug)]
pub struct LatentZPolicy {
pub check_mode: LatentZCheckMode,
pub normalization: LatentZNormalizationMode,
pub latent_measure: LatentMeasureSpec,
pub mean_tol_multiplier: f64,
pub sd_tol_multiplier: f64,
pub max_abs_skew: f64,
pub max_abs_excess_kurtosis: f64,
}
impl LatentZPolicy {
pub fn frozen_transformation_normal() -> Self {
Self {
check_mode: LatentZCheckMode::WarnOnly,
normalization: LatentZNormalizationMode::Frozen { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureSpec::auto_default(),
mean_tol_multiplier: 4.0,
sd_tol_multiplier: 4.0,
max_abs_skew: 4.0,
max_abs_excess_kurtosis: 20.0,
}
}
pub fn exploratory_fit_weighted() -> Self {
Self {
check_mode: LatentZCheckMode::WarnOnly,
normalization: LatentZNormalizationMode::FitWeighted,
latent_measure: LatentMeasureSpec::auto_default(),
mean_tol_multiplier: 8.0,
sd_tol_multiplier: 8.0,
max_abs_skew: 4.0,
max_abs_excess_kurtosis: 20.0,
}
}
pub fn empirical_fit_weighted() -> Self {
Self {
latent_measure: LatentMeasureSpec::global_empirical_default(),
..Self::exploratory_fit_weighted()
}
}
}
impl Default for LatentZPolicy {
fn default() -> Self {
Self::frozen_transformation_normal()
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct LatentZNormalization {
pub mean: f64,
pub sd: f64,
}
impl LatentZNormalization {
pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, String> {
if !(self.mean.is_finite() && self.sd.is_finite() && self.sd > 1e-12) {
return Err(format!(
"{context} requires finite latent z normalization with sd > 1e-12; got mean={} sd={}",
self.mean, self.sd
));
}
if z.iter().any(|value| !value.is_finite()) {
return Err(format!("{context} requires finite z values"));
}
Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct LatentZRankIntCalibration {
pub sorted_z: Vec<f64>,
pub weighted_cdf: Vec<f64>,
pub post_mean: f64,
pub post_sd: f64,
}
impl LatentZRankIntCalibration {
pub fn fit(z: &Array1<f64>, weights: &Array1<f64>) -> Result<Self, String> {
if z.len() != weights.len() {
return Err(format!(
"rank-INT calibration: z length {} != weights length {}",
z.len(),
weights.len()
));
}
if z.is_empty() {
return Err("rank-INT calibration requires at least one observation".to_string());
}
let w_total = weights.iter().copied().sum::<f64>();
if !(w_total.is_finite() && w_total > 0.0) {
return Err(format!(
"rank-INT calibration requires positive finite total weight, got {w_total}"
));
}
for (idx, value) in z.iter().enumerate() {
if !value.is_finite() {
return Err(format!(
"rank-INT calibration: z[{idx}] = {value} not finite"
));
}
}
for (idx, weight) in weights.iter().enumerate() {
if !(weight.is_finite() && *weight >= 0.0) {
return Err(format!(
"rank-INT calibration: weight[{idx}] = {weight} not finite/non-negative"
));
}
}
let mut order: Vec<usize> = (0..z.len()).collect();
order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap_or(std::cmp::Ordering::Equal));
let mut sorted_z: Vec<f64> = Vec::with_capacity(z.len());
let mut weighted_cdf: Vec<f64> = Vec::with_capacity(z.len());
let denom = w_total + 0.25;
let eps = 0.5 / w_total.max(1.0);
let mut cum_w = 0.0_f64;
let mut last_z: Option<f64> = None;
for &idx in &order {
cum_w += weights[idx];
let zi = z[idx];
if let Some(prev) = last_z {
if zi == prev {
if let Some(slot) = weighted_cdf.last_mut() {
let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
*slot = p;
}
continue;
}
}
let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
sorted_z.push(zi);
weighted_cdf.push(p);
last_z = Some(zi);
}
let mut sum_wz = 0.0_f64;
let mut sum_w = 0.0_f64;
for (i, &idx) in order.iter().enumerate() {
let _ = i;
let zi = z[idx];
let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
sum_wz += weights[idx] * calibrated;
sum_w += weights[idx];
}
let post_mean = if sum_w > 0.0 { sum_wz / sum_w } else { 0.0 };
let mut sum_w_dev = 0.0_f64;
for &idx in &order {
let zi = z[idx];
let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
let d = calibrated - post_mean;
sum_w_dev += weights[idx] * d * d;
}
let post_sd = if sum_w > 0.0 {
(sum_w_dev / sum_w).sqrt()
} else {
1.0
};
Ok(Self {
sorted_z,
weighted_cdf,
post_mean,
post_sd,
})
}
pub fn apply_to_training(&self, z: &Array1<f64>) -> Result<Array1<f64>, String> {
if self.sorted_z.is_empty() {
return Err("rank-INT calibration has no knots".to_string());
}
let mut out = Array1::<f64>::zeros(z.len());
for (idx, &zi) in z.iter().enumerate() {
if !zi.is_finite() {
return Err(format!(
"rank-INT calibration apply: z[{idx}] = {zi} not finite"
));
}
out[idx] = self.apply_at_predict(zi);
}
Ok(out)
}
pub fn apply_at_predict(&self, z: f64) -> f64 {
Self::apply_with_knots(z, &self.sorted_z, &self.weighted_cdf)
}
fn apply_with_knots(z: f64, sorted_z: &[f64], weighted_cdf: &[f64]) -> f64 {
debug_assert_eq!(sorted_z.len(), weighted_cdf.len());
debug_assert!(!sorted_z.is_empty());
let n = sorted_z.len();
let p = if z <= sorted_z[0] {
weighted_cdf[0]
} else if z >= sorted_z[n - 1] {
weighted_cdf[n - 1]
} else {
let mut lo = 0usize;
let mut hi = n - 1;
while hi - lo > 1 {
let mid = (lo + hi) / 2;
if sorted_z[mid] <= z {
lo = mid;
} else {
hi = mid;
}
}
let z_lo = sorted_z[lo];
let z_hi = sorted_z[hi];
let p_lo = weighted_cdf[lo];
let p_hi = weighted_cdf[hi];
if z_hi == z_lo {
p_hi
} else {
let t = (z - z_lo) / (z_hi - z_lo);
p_lo + t * (p_hi - p_lo)
}
};
standard_normal_quantile(p).unwrap_or_else(|_| if p < 0.5 { -8.0 } else { 8.0 })
}
}
#[derive(Clone, Debug)]
pub enum LatentMeasureCalibration {
None,
RankInverseNormal(LatentZRankIntCalibration),
}
fn build_latent_measure_with_geometry(
z: &Array1<f64>,
weights: &Array1<f64>,
policy: &LatentZPolicy,
data: Option<ArrayView2<'_, f64>>,
specs: &[&TermCollectionSpec],
) -> Result<(LatentMeasureKind, LatentMeasureCalibration), String> {
match policy.latent_measure {
LatentMeasureSpec::Auto { grid_size: _ } => {
if latent_z_is_standard_normal_enough(z, weights, policy)? {
Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::None,
))
} else {
let calibration = LatentZRankIntCalibration::fit(z, weights)?;
log::info!(
"[BMS latent-z] rank-INT calibrated: post_mean={:.3e} post_sd={:.3e} knots={}",
calibration.post_mean,
calibration.post_sd,
calibration.sorted_z.len(),
);
let _ = data;
let _ = specs;
Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::RankInverseNormal(calibration),
))
}
}
LatentMeasureSpec::StandardNormal => Ok((
LatentMeasureKind::StandardNormal,
LatentMeasureCalibration::None,
)),
LatentMeasureSpec::GlobalEmpirical { grid_size } => {
let kind = build_global_empirical_latent_measure(z, weights, grid_size)?;
Ok((kind, LatentMeasureCalibration::None))
}
}
}
fn latent_z_is_standard_normal_enough(
z: &Array1<f64>,
weights: &Array1<f64>,
policy: &LatentZPolicy,
) -> Result<bool, String> {
if z.len() != weights.len() {
return Err(format!(
"latent-measure auto-detection length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
if !(weight_sum.is_finite()
&& weight_sum > 0.0
&& weight_sq_sum.is_finite()
&& weight_sq_sum > 0.0)
{
return Err("latent-measure auto-detection requires positive finite weights".to_string());
}
let effective_n = weight_sum * weight_sum / weight_sq_sum;
if !(effective_n.is_finite() && effective_n > 1.0) {
return Err(
"latent-measure auto-detection requires at least two effective observations"
.to_string(),
);
}
let mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ weight_sum;
let var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
.sum::<f64>()
/ weight_sum;
let sd = var.sqrt();
if !(mean.is_finite() && sd.is_finite() && sd > 0.0) {
return Ok(false);
}
let skew = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| {
let centered = (zi - mean) / sd;
wi * centered.powi(3)
})
.sum::<f64>()
/ weight_sum;
let excess_kurtosis = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| {
let centered = (zi - mean) / sd;
wi * centered.powi(4)
})
.sum::<f64>()
/ weight_sum
- 3.0;
let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
let ks_to_normal = weighted_ks_to_standard_normal(z, weights, weight_sum)?;
let tail_mass_4 = weighted_tail_mass(z, weights, weight_sum, 4.0);
let tail_mass_6 = weighted_tail_mass(z, weights, weight_sum, 6.0);
let max_abs_z = z.iter().fold(0.0_f64, |acc, &zi| acc.max(zi.abs()));
let normal_tail_4 = 2.0 * (1.0 - normal_cdf(4.0));
let normal_tail_6 = 2.0 * (1.0 - normal_cdf(6.0));
Ok(mean.abs() <= mean_tol
&& (sd - 1.0).abs() <= sd_tol
&& skew.is_finite()
&& skew.abs() <= policy.max_abs_skew.min(AUTO_Z_NORMAL_SKEW_TOL)
&& excess_kurtosis.is_finite()
&& excess_kurtosis.abs() <= policy.max_abs_excess_kurtosis.min(AUTO_Z_NORMAL_KURT_TOL)
&& ks_to_normal.is_finite()
&& ks_to_normal <= AUTO_Z_NORMAL_KS_TOL
&& tail_mass_4 <= 2.0 * normal_tail_4 + 1e-5
&& tail_mass_6 <= 2.0 * normal_tail_6 + 1e-8
&& max_abs_z < AUTO_Z_NORMAL_MAX_ABS)
}
fn build_global_empirical_latent_measure(
z: &Array1<f64>,
weights: &Array1<f64>,
grid_size: usize,
) -> Result<LatentMeasureKind, String> {
let grid = build_empirical_z_grid(z, weights, grid_size, "empirical latent measure")?;
let measure = LatentMeasureKind::GlobalEmpirical {
nodes: grid.nodes,
weights: grid.weights,
};
measure.validate("empirical latent measure")?;
Ok(measure)
}
fn weighted_ks_to_standard_normal(
z: &Array1<f64>,
weights: &Array1<f64>,
total_weight: f64,
) -> Result<f64, String> {
let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
for (&zi, &wi) in z.iter().zip(weights.iter()) {
if !zi.is_finite() || !wi.is_finite() || wi < 0.0 {
return Err(
"latent-measure KS diagnostic requires finite z and non-negative finite weights"
.to_string(),
);
}
if wi > 0.0 {
pairs.push((zi, wi));
}
}
pairs.sort_by(|left, right| {
left.0
.partial_cmp(&right.0)
.expect("validated latent z values are finite")
});
let mut prev = 0.0;
let mut ks = 0.0_f64;
for (zi, wi) in pairs {
let cdf = normal_cdf(zi);
let next = prev + wi / total_weight;
ks = ks.max((cdf - prev).abs()).max((cdf - next).abs());
prev = next;
}
Ok(ks)
}
fn weighted_tail_mass(
z: &Array1<f64>,
weights: &Array1<f64>,
total_weight: f64,
cutoff: f64,
) -> f64 {
z.iter()
.zip(weights.iter())
.filter(|&(&zi, _)| zi.abs() > cutoff)
.map(|(_, &wi)| wi)
.sum::<f64>()
/ total_weight
}
fn build_empirical_z_grid(
z: &Array1<f64>,
weights: &Array1<f64>,
grid_size: usize,
context: &str,
) -> Result<EmpiricalZGrid, String> {
if grid_size < 3 {
return Err(format!(
"empirical latent measure grid_size must be at least 3, got {grid_size}"
));
}
if z.len() != weights.len() {
return Err(format!(
"{context} length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
for (idx, (&zi, &wi)) in z.iter().zip(weights.iter()).enumerate() {
if !zi.is_finite() {
return Err(format!(
"{context} z value at row {idx} is non-finite ({zi})"
));
}
if !wi.is_finite() || wi < 0.0 {
return Err(format!(
"{context} weight at row {idx} must be finite and non-negative, got {wi}"
));
}
if wi > 0.0 {
pairs.push((zi, wi));
}
}
if pairs.len() < 2 {
return Err(format!(
"{context} requires at least two positive-weight rows"
));
}
pairs.sort_by(|left, right| {
left.0
.partial_cmp(&right.0)
.expect("validated empirical latent z values are finite")
});
let total_weight = pairs.iter().map(|(_, weight)| *weight).sum::<f64>();
if !(total_weight.is_finite() && total_weight > 0.0) {
return Err(format!("{context} requires positive finite total weight"));
}
let m = grid_size.min(pairs.len());
let mut nodes = Vec::with_capacity(m);
let mut out_weights = Vec::with_capacity(m);
let bin_weight_target = total_weight / (m as f64);
let mut cursor = 0usize;
let mut remaining = pairs[0].1;
for _ in 0..m {
let mut need = bin_weight_target;
let mut bin_weight = 0.0;
let mut bin_sum = 0.0;
while need > 1e-14 * bin_weight_target && cursor < pairs.len() {
let take = remaining.min(need);
bin_sum += take * pairs[cursor].0;
bin_weight += take;
need -= take;
remaining -= take;
if remaining <= 1e-14 * pairs[cursor].1 {
cursor += 1;
if cursor < pairs.len() {
remaining = pairs[cursor].1;
}
}
}
if bin_weight > 0.0 {
nodes.push(bin_sum / bin_weight);
out_weights.push(bin_weight / total_weight);
}
}
if nodes.len() < 2 {
return Err(format!(
"{context} compression produced fewer than two nodes"
));
}
recenter_rescale_empirical_grid(&mut nodes, &out_weights);
let total = out_weights.iter().sum::<f64>();
if total.is_finite() && total > 0.0 {
for weight in &mut out_weights {
*weight /= total;
}
}
validate_empirical_z_grid(&nodes, &out_weights, context)?;
Ok(EmpiricalZGrid {
nodes,
weights: out_weights,
})
}
fn recenter_rescale_empirical_grid(nodes: &mut [f64], weights: &[f64]) {
let total = weights.iter().sum::<f64>();
if !(total.is_finite() && total > 0.0) {
return;
}
let mean = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * node)
.sum::<f64>()
/ total;
let var = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * (node - mean).powi(2))
.sum::<f64>()
/ total;
let sd = var.sqrt();
if sd.is_finite() && sd > 1e-12 {
for node in nodes {
*node = (*node - mean) / sd;
}
}
}
#[derive(Clone)]
struct BernoulliMarginalSlopeFamily {
y: Arc<Array1<f64>>,
weights: Arc<Array1<f64>>,
z: Arc<Array1<f64>>,
latent_measure: LatentMeasureKind,
gaussian_frailty_sd: Option<f64>,
base_link: InverseLink,
marginal_design: DesignMatrix,
logslope_design: DesignMatrix,
score_warp: Option<DeviationRuntime>,
link_dev: Option<DeviationRuntime>,
policy: crate::resource::ResourcePolicy,
cell_moment_lru: Arc<exact_kernel::CellMomentLruCache>,
cell_moment_cache_stats: Arc<exact_kernel::CellMomentCacheStats>,
intercept_warm_starts: Option<Arc<BernoulliInterceptWarmStartCache>>,
auto_subsample_phase_counter: Arc<std::sync::atomic::AtomicUsize>,
auto_subsample_last_rho: Arc<Mutex<Option<Array1<f64>>>>,
shared_eval_cache: Arc<Mutex<Option<SharedEvalCacheEntry>>>,
}
#[derive(Clone)]
struct SharedEvalCacheFingerprint {
betas: Vec<Array1<f64>>,
subsample: Option<usize>,
}
impl SharedEvalCacheFingerprint {
fn from_inputs(block_states: &[ParameterBlockState], options: &BlockwiseFitOptions) -> Self {
let betas = block_states.iter().map(|s| s.beta.clone()).collect();
let subsample = options
.outer_score_subsample
.as_ref()
.map(|arc| Arc::as_ptr(arc) as usize);
Self { betas, subsample }
}
fn matches(&self, other: &Self) -> bool {
if self.subsample != other.subsample {
return false;
}
if self.betas.len() != other.betas.len() {
return false;
}
for (a, b) in self.betas.iter().zip(other.betas.iter()) {
if a.len() != b.len() {
return false;
}
if a.iter().zip(b.iter()).any(|(x, y)| x.to_bits() != y.to_bits()) {
return false;
}
}
true
}
}
struct SharedEvalCacheEntry {
fingerprint: SharedEvalCacheFingerprint,
cache: Arc<BernoulliMarginalSlopeExactEvalCache>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct CacheFingerprint {
beta_hash: u64,
block_count: usize,
subsample_mask_hash: u64,
options_hash: u64,
}
impl CacheFingerprint {
fn compute(
block_states: &[ParameterBlockState],
subsample_mask: Option<&[usize]>,
options: Option<&BlockwiseFitOptions>,
) -> Self {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut beta_hasher = DefaultHasher::new();
for st in block_states {
for v in st.beta.iter() {
v.to_bits().hash(&mut beta_hasher);
}
}
let mut mask_hasher = DefaultHasher::new();
if let Some(mask) = subsample_mask {
mask.hash(&mut mask_hasher);
} else {
0u64.hash(&mut mask_hasher);
}
let mut opts_hasher = DefaultHasher::new();
if let Some(opts) = options {
opts.inner_max_cycles.hash(&mut opts_hasher);
opts.inner_tol.to_bits().hash(&mut opts_hasher);
opts.outer_max_iter.hash(&mut opts_hasher);
opts.outer_tol.to_bits().hash(&mut opts_hasher);
opts.minweight.to_bits().hash(&mut opts_hasher);
opts.ridge_floor.to_bits().hash(&mut opts_hasher);
opts.use_remlobjective.hash(&mut opts_hasher);
opts.use_outer_hessian.hash(&mut opts_hasher);
opts.compute_covariance.hash(&mut opts_hasher);
opts.line_search_prefer_workspace.hash(&mut opts_hasher);
opts.early_exit_threshold
.map(|v| v.to_bits())
.hash(&mut opts_hasher);
opts.auto_outer_subsample.hash(&mut opts_hasher);
opts.outer_score_subsample
.as_ref()
.map(|a| Arc::as_ptr(a) as usize)
.hash(&mut opts_hasher);
}
Self {
beta_hash: beta_hasher.finish(),
block_count: block_states.len(),
subsample_mask_hash: mask_hasher.finish(),
options_hash: opts_hasher.finish(),
}
}
}
const BMS_AUTO_SUBSAMPLE_PHASE1_BUDGET: usize = 12;
#[derive(Clone)]
struct BernoulliInterceptPredictorWarmStart {
intercept: f64,
primary_point: Vec<f64>,
intercept_primary_deriv: Vec<f64>,
}
struct BernoulliInterceptWarmStartCache {
intercepts: Vec<AtomicU64>,
predictors: Vec<Mutex<Option<BernoulliInterceptPredictorWarmStart>>>,
}
impl std::ops::Deref for BernoulliInterceptWarmStartCache {
type Target = [AtomicU64];
fn deref(&self) -> &Self::Target {
&self.intercepts
}
}
impl BernoulliInterceptWarmStartCache {
fn predictor_seed(&self, row: usize, current_point: &[f64]) -> Option<f64> {
let warm = self.predictors.get(row)?.lock().ok()?.as_ref().cloned()?;
if warm.primary_point.len() != current_point.len()
|| warm.intercept_primary_deriv.len() != current_point.len()
|| !warm.intercept.is_finite()
{
return None;
}
let correction = warm
.intercept_primary_deriv
.iter()
.zip(current_point.iter().zip(warm.primary_point.iter()))
.map(|(a_u, (new, old))| a_u * (new - old))
.sum::<f64>();
let seed = warm.intercept + correction;
seed.is_finite().then_some(seed)
}
fn store_predictor(
&self,
row: usize,
intercept: f64,
primary_point: Vec<f64>,
intercept_primary_deriv: Vec<f64>,
) {
if !intercept.is_finite()
|| primary_point.iter().any(|value| !value.is_finite())
|| intercept_primary_deriv
.iter()
.any(|value| !value.is_finite())
{
return;
}
let Some(slot) = self.predictors.get(row) else {
return;
};
if let Ok(mut guard) = slot.lock() {
*guard = Some(BernoulliInterceptPredictorWarmStart {
intercept,
primary_point,
intercept_primary_deriv,
});
}
}
}
fn new_intercept_warm_start_cache(n: usize) -> Arc<BernoulliInterceptWarmStartCache> {
Arc::new(BernoulliInterceptWarmStartCache {
intercepts: (0..n).map(|_| AtomicU64::new(f64::NAN.to_bits())).collect(),
predictors: (0..n).map(|_| Mutex::new(None)).collect(),
})
}
#[derive(Clone, Default)]
struct ThetaHints {
marginal_beta: Option<Array1<f64>>,
logslope_beta: Option<Array1<f64>>,
score_warp_beta: Option<Array1<f64>>,
link_dev_beta: Option<Array1<f64>>,
}
pub(crate) fn build_score_warp_deviation_block_from_seed(
seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_deviation_block_from_knots_and_design_seed(seed, seed, cfg)
}
const BERNOULLI_LINK_PROBABILITY_EPS: f64 = 1e-12;
#[derive(Clone, Copy, Debug)]
pub(crate) struct BernoulliMarginalLinkMap {
pub mu: f64,
pub mu1: f64,
pub mu2: f64,
pub mu3: f64,
pub mu4: f64,
pub q: f64,
pub q1: f64,
pub q2: f64,
pub q3: f64,
pub q4: f64,
}
#[inline]
fn clamp_bernoulli_link_probability(probability: f64) -> f64 {
probability.clamp(
BERNOULLI_LINK_PROBABILITY_EPS,
1.0 - BERNOULLI_LINK_PROBABILITY_EPS,
)
}
pub(crate) fn bernoulli_marginal_slope_eta_from_probability(
base_link: &InverseLink,
probability: f64,
context: &str,
) -> Result<f64, String> {
require_probit_marginal_slope_link(base_link, context)?;
let target = clamp_bernoulli_link_probability(probability);
standard_normal_quantile(target)
.map_err(|e| format!("{context} failed to invert probit probability {target}: {e}"))
}
pub(crate) fn bernoulli_marginal_link_map(
base_link: &InverseLink,
eta: f64,
) -> Result<BernoulliMarginalLinkMap, String> {
require_probit_marginal_slope_link(base_link, "bernoulli marginal-slope")?;
let raw_mu = normal_cdf(eta);
let mu = clamp_bernoulli_link_probability(raw_mu);
let q = standard_normal_quantile(mu).map_err(|e| {
format!("bernoulli marginal-slope probit target inversion failed at mu={mu}: {e}")
})?;
if raw_mu <= BERNOULLI_LINK_PROBABILITY_EPS || raw_mu >= 1.0 - BERNOULLI_LINK_PROBABILITY_EPS {
return Ok(BernoulliMarginalLinkMap {
mu,
mu1: 0.0,
mu2: 0.0,
mu3: 0.0,
mu4: 0.0,
q,
q1: 0.0,
q2: 0.0,
q3: 0.0,
q4: 0.0,
});
}
let phi_eta = normal_pdf(eta);
let phi_q = normal_pdf(q);
if !phi_q.is_finite() || phi_q <= 0.0 {
return Err(format!(
"bernoulli marginal-slope internal probit density must be positive, got phi(q)={phi_q} at eta={eta}, q={q}"
));
}
let mu1 = phi_eta;
let mu2 = -eta * phi_eta;
let mu3 = (eta * eta - 1.0) * phi_eta;
let mu4 = -(eta.powi(3) - 3.0 * eta) * phi_eta;
let q1 = mu1 / phi_q;
let q2 = mu2 / phi_q + q * q1 * q1;
let q3 = mu3 / phi_q + 3.0 * q * q1 * q2 - (q * q - 1.0) * q1.powi(3);
let q4 =
mu4 / phi_q + (q.powi(3) - 3.0 * q) * q1.powi(4) + 4.0 * q * q1 * q3 + 3.0 * q * q2 * q2
- 6.0 * (q * q - 1.0) * q1 * q1 * q2;
Ok(BernoulliMarginalLinkMap {
mu,
mu1,
mu2,
mu3,
mu4,
q,
q1,
q2,
q3,
q4,
})
}
fn require_probit_marginal_slope_link(
base_link: &InverseLink,
context: &str,
) -> Result<(), String> {
if matches!(base_link, InverseLink::Standard(LinkFunction::Probit)) {
Ok(())
} else {
Err(format!(
"{context} requires link(type=probit); non-probit marginal-slope base links are not supported by the calibrated de-nested probit kernel"
))
}
}
pub(crate) fn build_link_deviation_block_from_knots_design_seed_and_weights(
knot_seed: &Array1<f64>,
design_seed: &Array1<f64>,
_anchor_weights: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_deviation_block_from_knots_and_design_seed(knot_seed, design_seed, cfg)
}
fn build_deviation_block_from_knots_and_design_seed(
knot_seed: &Array1<f64>,
design_seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
if cfg.degree != 3 {
return Err(format!(
"structural deviation runtime is cubic; degree must be 3, got {}",
cfg.degree
));
}
let penalty_orders = resolve_deviation_operator_orders(cfg)?;
let knots = initialize_monotone_wiggle_knots_from_seed(
knot_seed.view(),
cfg.degree,
cfg.num_internal_knots,
)?;
let max_penalty_order = penalty_orders.iter().copied().max().ok_or_else(|| {
"deviation block requires at least one positive function-penalty derivative order"
.to_string()
})?;
let runtime = DeviationRuntime::try_new(knots, cfg.monotonicity_eps, max_penalty_order)?;
let design = runtime.design(design_seed)?;
let p = design.ncols();
if p == 0 {
return Err("structural deviation basis has no free derivative controls".to_string());
}
let mut block = ParameterBlockInput {
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(design)),
offset: Array1::zeros(design_seed.len()),
penalties: Vec::new(),
nullspace_dims: Vec::new(),
initial_log_lambdas: None,
initial_beta: Some(Array1::zeros(p)),
};
for order in penalty_orders {
append_deviation_function_penalty(&mut block, &runtime, order)?;
}
if cfg.double_penalty {
append_deviation_function_penalty(&mut block, &runtime, 0)?;
}
Ok(DeviationPrepared { block, runtime })
}
fn resolve_deviation_operator_orders(cfg: &DeviationBlockConfig) -> Result<Vec<usize>, String> {
let mut orders = Vec::new();
let requested = if cfg.penalty_orders.is_empty() {
std::slice::from_ref(&cfg.penalty_order)
} else {
cfg.penalty_orders.as_slice()
};
for &order in requested {
if order == 0 {
continue;
}
if order > cfg.degree {
return Err(format!(
"deviation function penalty derivative order {order} exceeds basis degree {}",
cfg.degree
));
}
if !orders.contains(&order) {
orders.push(order);
}
}
if orders.is_empty() {
return Err(
"deviation block requires at least one positive function-penalty derivative order"
.to_string(),
);
}
Ok(orders)
}
fn append_deviation_function_penalty(
block: &mut ParameterBlockInput,
runtime: &DeviationRuntime,
derivative_order: usize,
) -> Result<(), String> {
let (penalty, nullity) =
runtime.integrated_derivative_penalty_with_nullity(derivative_order)?;
block
.penalties
.push(crate::solver::estimate::PenaltySpec::Dense(penalty));
block.nullspace_dims.push(nullity);
Ok(())
}
pub(crate) enum CrossBlockAnchor<'a> {
Parametric(&'a DesignMatrix),
FlexEvaluation(&'a Array2<f64>),
}
#[derive(Debug)]
pub(crate) enum CrossBlockIdentifiabilityOutcome {
Reparameterised {
#[allow(dead_code)]
kept: usize,
#[allow(dead_code)]
dropped: usize,
},
FullyAliased {
reason: String,
},
}
#[derive(Clone, Debug)]
pub struct CrossBlockIdentifiabilityWarning {
pub candidate_label: &'static str,
pub anchor_summary: String,
pub reason: String,
}
pub(crate) fn enforce_cross_block_identifiability_for_flex_block(
candidate: &mut DeviationPrepared,
candidate_arg_at_training_rows: &Array1<f64>,
candidate_cfg: &DeviationBlockConfig,
anchors: &[CrossBlockAnchor<'_>],
parametric_anchor_blocks: &[Option<
crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock,
>],
training_row_weights: &Array1<f64>,
) -> Result<CrossBlockIdentifiabilityOutcome, String> {
use crate::faer_ndarray::FaerEigh;
use crate::families::bernoulli_marginal_slope::deviation_runtime::{
AnchorNullSpaceComponent, AnchorNullSpaceEvaluator, AnchorResidual,
};
if parametric_anchor_blocks.len() != anchors.len() {
return Err(format!(
"cross-block identifiability: parametric_anchor_blocks length {} does not match anchors length {}",
parametric_anchor_blocks.len(),
anchors.len(),
));
}
let candidate_design = candidate.runtime.design(candidate_arg_at_training_rows)?;
let n = candidate_design.nrows();
let p_candidate = candidate_design.ncols();
if p_candidate == 0 {
return Ok(CrossBlockIdentifiabilityOutcome::Reparameterised {
kept: 0,
dropped: 0,
});
}
if training_row_weights.len() != n {
return Err(format!(
"cross-block identifiability: training_row_weights length {} does not match candidate row count {}",
training_row_weights.len(),
n,
));
}
let mut anchor_dense_blocks: Vec<ndarray::Array2<f64>> = Vec::new();
let mut anchor_components: Vec<AnchorNullSpaceComponent> = Vec::new();
let mut total_parametric_cols = 0usize;
for (anchor_idx, anchor) in anchors.iter().enumerate() {
match anchor {
CrossBlockAnchor::FlexEvaluation(a) => {
if a.nrows() != n {
return Err(format!(
"cross-block identifiability: flex anchor has {} rows, candidate has {}",
a.nrows(),
n,
));
}
}
CrossBlockAnchor::Parametric(d) => {
if d.nrows() != n {
return Err(format!(
"cross-block identifiability: parametric anchor has {} rows, candidate has {}",
d.nrows(),
n,
));
}
let p_a = d.ncols();
if p_a == 0 {
continue;
}
let block_tag = parametric_anchor_blocks[anchor_idx].ok_or_else(|| {
format!(
"cross-block identifiability: anchor {anchor_idx} is Parametric but parametric_anchor_blocks tag is None",
)
})?;
let dense = d.try_to_dense_arc("cross-block anchor")?.as_ref().clone();
anchor_dense_blocks.push(dense);
anchor_components.push(AnchorNullSpaceComponent::Parametric {
block: block_tag,
ncols: p_a,
});
total_parametric_cols += p_a;
}
}
}
if total_parametric_cols == 0 {
return Ok(CrossBlockIdentifiabilityOutcome::Reparameterised {
kept: p_candidate,
dropped: 0,
});
}
let d_total = total_parametric_cols;
let mut n_train = Array2::<f64>::zeros((n, d_total));
{
let mut col_offset = 0usize;
for block in &anchor_dense_blocks {
let bc = block.ncols();
n_train
.slice_mut(ndarray::s![.., col_offset..col_offset + bc])
.assign(block);
col_offset += bc;
}
}
let mut sqrt_w = Array1::<f64>::zeros(n);
for (i, &w) in training_row_weights.iter().enumerate() {
if !w.is_finite() || w < 0.0 {
return Err(format!(
"cross-block identifiability: training_row_weights[{i}] = {w} is not finite/non-negative",
));
}
sqrt_w[i] = w.sqrt();
}
let n_train_sqw = {
let mut out = n_train.clone();
for i in 0..n {
let s = sqrt_w[i];
for j in 0..d_total {
out[[i, j]] *= s;
}
}
out
};
let c_sqw = {
let mut out = candidate_design.clone();
for i in 0..n {
let s = sqrt_w[i];
for j in 0..p_candidate {
out[[i, j]] *= s;
}
}
out
};
let g_n = n_train_sqw.t().dot(&n_train_sqw);
let (g_n_evals, g_n_evecs) = g_n
.eigh(faer::Side::Lower)
.map_err(|e| format!("cross-block identifiability G_N eigh failed: {e}"))?;
let g_n_evals_slice = g_n_evals
.as_slice()
.ok_or_else(|| "G_N eigenvalues not contiguous".to_string())?;
let lambda_max_n = g_n_evals_slice.iter().copied().fold(0.0_f64, f64::max);
let n_train_f = n as f64;
let g_n_threshold = lambda_max_n * (64.0 * n_train_f.max(1.0) * f64::EPSILON);
let g_n_kept: Vec<usize> = g_n_evals_slice
.iter()
.enumerate()
.filter_map(|(idx, &v)| (v > g_n_threshold).then_some(idx))
.collect();
let r = g_n_kept.len();
if r == 0 {
return Ok(CrossBlockIdentifiabilityOutcome::Reparameterised {
kept: p_candidate,
dropped: 0,
});
}
let mut rotation_r = Array2::<f64>::zeros((d_total, r));
for (out_col, &in_col) in g_n_kept.iter().enumerate() {
let lam = g_n_evals_slice[in_col];
let inv_sqrt = 1.0 / lam.sqrt();
let evec = g_n_evecs.column(in_col);
for i in 0..d_total {
rotation_r[[i, out_col]] = evec[i] * inv_sqrt;
}
}
let q_w_sqw = n_train_sqw.dot(&rotation_r);
let k_w = q_w_sqw.t().dot(&c_sqw);
let c_tilde_sqw = &c_sqw - &q_w_sqw.dot(&k_w);
let g_c_tilde = c_tilde_sqw.t().dot(&c_tilde_sqw);
let (gc_evals, gc_evecs) = g_c_tilde
.eigh(faer::Side::Lower)
.map_err(|e| format!("cross-block identifiability G_C̃ eigh failed: {e}"))?;
let gc_evals_slice = gc_evals
.as_slice()
.ok_or_else(|| "G_C̃ eigenvalues not contiguous".to_string())?;
let lambda_max_c = gc_evals_slice.iter().copied().fold(0.0_f64, f64::max);
let c_sqw_norm_sq: f64 = c_sqw.iter().map(|v| v * v).sum();
let lambda_max_ref = lambda_max_c.max(c_sqw_norm_sq);
let drop_tol = lambda_max_ref * (64.0 * n_train_f.max(1.0) * f64::EPSILON);
let pos_indices: Vec<usize> = gc_evals_slice
.iter()
.enumerate()
.filter_map(|(idx, &v)| (v > drop_tol).then_some(idx))
.collect();
let k_kept = pos_indices.len();
if k_kept == 0 {
let reason = format!(
"candidate flex basis ({p_c} cols) has zero directions remaining after \
residualisation against the unpenalised null-space of the anchor union \
({d_local} parametric anchor cols, weighted training-row rank {r}). The \
residualised basis (I - P_A) C has numerical rank zero at the {n} training \
rows under the joint-Hessian row metric, so every direction in span(C) is \
reproducible by the parametric anchors up to numerical tolerance \
{drop_tol:.3e}. The candidate flex block carries no information the \
parametric blocks do not already capture in their unpenalised span; \
disable this flex block or drop a parametric term that exactly reproduces \
the flex argument. Knot count is NOT the relevant lever for this failure \
mode.",
p_c = p_candidate,
d_local = d_total,
r = r,
n = n,
drop_tol = drop_tol,
);
return Ok(CrossBlockIdentifiabilityOutcome::FullyAliased { reason });
}
let mut v_selector = Array2::<f64>::zeros((p_candidate, k_kept));
for (out_col, &in_col) in pos_indices.iter().enumerate() {
v_selector
.column_mut(out_col)
.assign(&gc_evecs.column(in_col));
}
let k_w_v = k_w.dot(&v_selector); let residual_coefficients = rotation_r.dot(&k_w_v);
let null_basis_evaluator = AnchorNullSpaceEvaluator::Stacked {
components: anchor_components,
orthonormalising_rotation: Array2::<f64>::eye(d_total),
};
let residual = AnchorResidual {
residual_coefficients,
null_basis_evaluator,
};
candidate
.runtime
.set_anchor_rows_at_training(n_train.clone());
candidate
.runtime
.compose_anchor_orthogonalisation(&v_selector, Some(residual))?;
let new_design = candidate
.runtime
.design_at_training_with_residual(candidate_arg_at_training_rows)?;
let new_p = new_design.ncols();
debug_assert_eq!(new_p, k_kept);
debug_assert_eq!(new_design.nrows(), n);
candidate.block.design =
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(new_design));
candidate.block.penalties.clear();
candidate.block.nullspace_dims.clear();
let penalty_orders = resolve_deviation_operator_orders(candidate_cfg)?;
for order in penalty_orders {
append_deviation_function_penalty(&mut candidate.block, &candidate.runtime, order)?;
}
if candidate_cfg.double_penalty {
append_deviation_function_penalty(&mut candidate.block, &candidate.runtime, 0)?;
}
candidate.block.initial_beta = Some(Array1::zeros(new_p));
log::info!(
"[BMS cross-block identifiability] flex block reparameterised: \
kept {kept}/{p_candidate} directions (parametric anchor cols={d_total}, \
weighted_anchor_null_rank={r}, kept_after_rank_reveal={kept}, \
dropped={dropped}, training rows={n}, λ_max(C̃ᵀWC̃)={lambda_max_c:.3e}, \
drop tol={drop_tol:.3e})",
kept = new_p,
p_candidate = p_candidate,
d_total = d_total,
r = r,
dropped = p_candidate - new_p,
n = n,
lambda_max_c = lambda_max_c,
drop_tol = drop_tol,
);
Ok(CrossBlockIdentifiabilityOutcome::Reparameterised {
kept: new_p,
dropped: p_candidate - new_p,
})
}
pub(crate) fn project_monotone_feasible_beta(
runtime: &DeviationRuntime,
current: &Array1<f64>,
proposed: &Array1<f64>,
label: &str,
) -> Result<Array1<f64>, String> {
if current.len() != runtime.basis_dim() {
return Err(format!(
"{label} monotone projection current length mismatch: current={}, expected={}",
current.len(),
runtime.basis_dim()
));
}
if proposed.len() != runtime.basis_dim() {
return Err(format!(
"{label} monotone projection length mismatch: proposed={}, expected={}",
proposed.len(),
runtime.basis_dim()
));
}
for (idx, value) in current.iter().enumerate() {
if !value.is_finite() {
return Err(format!("{label} current coefficient {idx} is non-finite"));
}
}
for (idx, value) in proposed.iter().enumerate() {
if !value.is_finite() {
return Err(format!("{label} coefficient {idx} is non-finite"));
}
}
runtime.monotonicity_feasible(current, &format!("{label} current beta"))?;
if runtime
.monotonicity_feasible(proposed, &format!("{label} proposed beta"))
.is_ok()
{
return Ok(proposed.clone());
}
let direction = proposed - current;
let mut lo = 0.0_f64;
let mut hi = 1.0_f64;
for _ in 0..64 {
let mid = 0.5 * (lo + hi);
let candidate = current + &direction.mapv(|value| value * mid);
if runtime
.monotonicity_feasible(&candidate, &format!("{label} projected beta"))
.is_ok()
{
lo = mid;
} else {
hi = mid;
}
}
Ok(current + &direction.mapv(|value| value * lo))
}
fn validate_spec(
data: ArrayView2<'_, f64>,
spec: &BernoulliMarginalSlopeTermSpec,
) -> Result<(), String> {
let n = data.nrows();
if spec.y.len() != n
|| spec.weights.len() != n
|| spec.z.len() != n
|| spec.marginal_offset.len() != n
|| spec.logslope_offset.len() != n
{
return Err(format!(
"bernoulli-marginal-slope row mismatch: data={}, y={}, weights={}, z={}, marginal_offset={}, logslope_offset={}",
n,
spec.y.len(),
spec.weights.len(),
spec.z.len(),
spec.marginal_offset.len(),
spec.logslope_offset.len()
));
}
if spec
.y
.iter()
.any(|&yi| !yi.is_finite() || ((yi - 0.0).abs() > 1e-9 && (yi - 1.0).abs() > 1e-9))
{
return Err("bernoulli-marginal-slope requires binary y in {0,1}".to_string());
}
if spec.weights.iter().any(|&w| !w.is_finite() || w < 0.0) {
return Err("bernoulli-marginal-slope requires finite non-negative weights".to_string());
}
if spec.z.iter().any(|&zi| !zi.is_finite()) {
return Err("bernoulli-marginal-slope requires finite z values".to_string());
}
if spec.marginal_offset.iter().any(|&value| !value.is_finite()) {
return Err("bernoulli-marginal-slope requires finite marginal offsets".to_string());
}
if spec.logslope_offset.iter().any(|&value| !value.is_finite()) {
return Err("bernoulli-marginal-slope requires finite logslope offsets".to_string());
}
require_probit_marginal_slope_link(&spec.base_link, "bernoulli-marginal-slope")?;
spec.frailty.validate_for_marginal_slope()?;
match &spec.frailty {
FrailtySpec::None => {}
FrailtySpec::GaussianShift { sigma_fixed } => {
if let Some(sigma) = sigma_fixed
&& (!sigma.is_finite() || *sigma < 0.0)
{
return Err(format!(
"bernoulli-marginal-slope requires GaussianShift sigma >= 0, got {sigma}"
));
}
}
FrailtySpec::HazardMultiplier { .. } => unreachable!(),
}
Ok(())
}
pub(crate) fn standardize_latent_z_with_policy(
z: &Array1<f64>,
weights: &Array1<f64>,
context: &str,
policy: &LatentZPolicy,
) -> Result<(Array1<f64>, LatentZNormalization), String> {
if z.len() != weights.len() {
return Err(format!(
"{context} latent-score normalization length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
if !(weight_sum.is_finite()
&& weight_sum > 0.0
&& weight_sq_sum.is_finite()
&& weight_sq_sum > 0.0)
{
return Err(format!("{context} requires positive finite total weight"));
}
let effective_n = weight_sum * weight_sum / weight_sq_sum;
if !(effective_n.is_finite() && effective_n > 1.0) {
return Err(format!(
"{context} requires at least two effective observations for latent-score normalization"
));
}
let mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ weight_sum;
let var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
.sum::<f64>()
/ weight_sum;
let sd = var.sqrt();
if !(sd.is_finite() && sd > 1e-12) {
return Err(format!(
"{context} requires z with positive finite weighted standard deviation"
));
}
let target_norm = match policy.normalization {
LatentZNormalizationMode::None => LatentZNormalization { mean: 0.0, sd: 1.0 },
LatentZNormalizationMode::FitWeighted => LatentZNormalization { mean, sd },
LatentZNormalizationMode::Frozen {
mean: frozen_mean,
sd: frozen_sd,
} => LatentZNormalization {
mean: frozen_mean,
sd: frozen_sd,
},
};
let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
let check_msg = || {
format!(
"{context} requires z to already be approximately latent N(0,1) before identification normalization; got mean={mean:.6e}, sd={sd:.6e}, effective_n={effective_n:.1}, allowed_mean={mean_tol:.3e}, allowed_sd={sd_tol:.3e}"
)
};
if mean.abs() > mean_tol || (sd - 1.0).abs() > sd_tol {
match policy.check_mode {
LatentZCheckMode::Strict => return Err(check_msg()),
LatentZCheckMode::WarnOnly => log::warn!("{}", check_msg()),
LatentZCheckMode::Off => {}
}
}
let normalization = target_norm;
let z_std = normalization.apply(z, context)?;
let skew = z_std
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi.powi(3))
.sum::<f64>()
/ weight_sum;
let kurt = z_std
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi.powi(4))
.sum::<f64>()
/ weight_sum
- 3.0;
if skew.abs() > policy.max_abs_skew || kurt.abs() > policy.max_abs_excess_kurtosis {
let msg = format!(
"{context} requires z to be approximately Gaussian after identification normalization; got skewness={skew:.3}, excess_kurtosis={kurt:.3}"
);
match policy.check_mode {
LatentZCheckMode::Strict => return Err(msg),
LatentZCheckMode::WarnOnly => log::warn!("{}", msg),
LatentZCheckMode::Off => {}
}
}
if skew.abs() > 0.75 || kurt.abs() > 2.0 {
log::warn!(
"{context}: z has skewness={skew:.3} and excess kurtosis={kurt:.3}; latent-measure auto-selection will use empirical calibration unless stricter diagnostics pass"
);
}
Ok((z_std, normalization))
}
pub fn padded_deviation_seed(seed: &Array1<f64>, min_iqr: f64, pad_fraction: f64) -> Array1<f64> {
let mut sorted = seed.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if sorted.len() < 4 {
return seed.clone();
}
let n = sorted.len();
let q1 = sorted[n / 4];
let q3 = sorted[3 * n / 4];
let iqr = (q3 - q1).max(min_iqr);
let pad = pad_fraction * iqr;
let mut out = seed.to_vec();
out.push(sorted[0] - pad);
out.push(sorted[n - 1] + pad);
Array1::from_vec(out)
}
fn pooled_probit_baseline(
y: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
) -> Result<(f64, f64), String> {
if y.len() != z.len() || y.len() != weights.len() {
return Err(format!(
"pooled bernoulli-marginal-slope pilot length mismatch: y={}, z={}, weights={}",
y.len(),
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
if !weight_sum.is_finite() || weight_sum <= 0.0 {
return Err(
"pooled bernoulli-marginal-slope pilot requires positive finite total weight"
.to_string(),
);
}
let prevalence = y
.iter()
.zip(weights.iter())
.map(|(&yi, &wi)| yi * wi)
.sum::<f64>()
/ weight_sum;
let prevalence = prevalence.clamp(1e-6, 1.0 - 1e-6);
let z_mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| zi * wi)
.sum::<f64>()
/ weight_sum;
let z_var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - z_mean) * (zi - z_mean))
.sum::<f64>()
/ weight_sum;
let yz_cov = y
.iter()
.zip(z.iter())
.zip(weights.iter())
.map(|((&yi, &zi), &wi)| wi * (yi - prevalence) * (zi - z_mean))
.sum::<f64>()
/ weight_sum;
let mut beta0 = standard_normal_quantile(prevalence).map_err(|e| {
format!("failed to initialize pooled bernoulli-marginal-slope pilot intercept: {e}")
})?;
let mut beta1 = if z_var > 1e-12 { yz_cov / z_var } else { 0.0 };
let objective_grad_hess =
|intercept: f64, slope: f64| -> Result<(f64, f64, f64, f64, f64, f64), String> {
let mut obj = 0.0;
let mut g0 = 0.0;
let mut g1 = 0.0;
let mut h00 = 0.0;
let mut h01 = 0.0;
let mut h11 = 0.0;
for ((&yi, &zi), &wi) in y.iter().zip(z.iter()).zip(weights.iter()) {
if wi == 0.0 {
continue;
}
let eta = intercept + slope * zi;
let s = 2.0 * yi - 1.0;
let margin = s * eta;
let (logcdf, lambda) = signed_probit_logcdf_and_mills_ratio(margin);
let g_eta = -wi * s * lambda;
let h_eta = wi * lambda * (margin + lambda);
obj -= wi * logcdf;
g0 += g_eta;
g1 += g_eta * zi;
h00 += h_eta;
h01 += h_eta * zi;
h11 += h_eta * zi * zi;
}
Ok((obj, g0, g1, h00, h01, h11))
};
let mut obj_prev = f64::INFINITY;
for _ in 0..50 {
let (obj, g0, g1, h00, h01, h11) = objective_grad_hess(beta0, beta1)?;
if !obj.is_finite() || !g0.is_finite() || !g1.is_finite() {
return Err(
"pooled bernoulli-marginal-slope pilot produced non-finite objective or gradient"
.to_string(),
);
}
let grad_max = g0.abs().max(g1.abs());
if grad_max < 1e-8 {
break;
}
let mut ridge = 1e-8;
let (step0, step1) = loop {
let h00_r = h00 + ridge;
let h11_r = h11 + ridge;
let det = h00_r * h11_r - h01 * h01;
if det.is_finite() && det.abs() > 1e-18 {
let s0 = (h11_r * g0 - h01 * g1) / det;
let s1 = (-h01 * g0 + h00_r * g1) / det;
if s0.is_finite() && s1.is_finite() {
break (s0, s1);
}
}
ridge *= 10.0;
if ridge > 1e6 {
return Err(
"pooled bernoulli-marginal-slope pilot Hessian solve failed".to_string()
);
}
};
let mut accepted = false;
let mut step_scale = 1.0;
for _ in 0..25 {
let cand0 = beta0 - step_scale * step0;
let cand1 = beta1 - step_scale * step1;
let (cand_obj, _, _, _, _, _) = objective_grad_hess(cand0, cand1)?;
if cand_obj.is_finite() && cand_obj <= obj {
beta0 = cand0;
beta1 = cand1;
obj_prev = cand_obj;
accepted = true;
break;
}
step_scale *= 0.5;
}
if !accepted {
if (obj_prev - obj).abs() < 1e-10 {
break;
}
return Err("pooled bernoulli-marginal-slope pilot line search failed".to_string());
}
}
let a = beta0;
let b = if beta1.abs() < 1e-6 {
if beta1.is_sign_negative() {
-1e-6
} else {
1e-6
}
} else {
beta1
};
Ok((a / (1.0 + b * b).sqrt(), b))
}
fn pilot_eta_for_link_dev_orthogonalisation(
base_link: &InverseLink,
y: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
marginal_design: &DesignMatrix,
marginal_offset: &Array1<f64>,
logslope_offset: &Array1<f64>,
baseline_marginal: f64,
baseline_logslope: f64,
probit_scale: f64,
) -> Result<Array1<f64>, String> {
use crate::faer_ndarray::FaerCholesky;
let n = y.len();
if marginal_design.nrows() != n {
return Err(format!(
"pilot_eta_for_link_dev_orthogonalisation: marginal design has {} rows, expected {}",
marginal_design.nrows(),
n,
));
}
let mut working_eta = Array1::<f64>::zeros(n);
let mut w_irls = Array1::<f64>::zeros(n);
let mut residual = Array1::<f64>::zeros(n);
for i in 0..n {
let a_pre = baseline_marginal + marginal_offset[i];
let b_pre = baseline_logslope + logslope_offset[i];
let q_marg = bernoulli_marginal_link_map(base_link, a_pre)
.map_err(|e| {
format!("pilot_eta_for_link_dev_orthogonalisation marginal link map: {e}")
})?
.q;
let eta = rigid_observed_eta(q_marg, b_pre, z[i], probit_scale);
working_eta[i] = eta;
let mu = clamp_bernoulli_link_probability(normal_cdf(eta));
let phi = normal_pdf(eta).max(1e-300);
let var = (mu * (1.0 - mu)).max(1e-300);
w_irls[i] = weights[i] * (phi * phi) / var;
residual[i] = (y[i] - mu) / phi;
}
let p_marg = marginal_design.ncols();
if p_marg == 0 {
return Ok(working_eta);
}
let xtwr = marginal_design.compute_xtwy(&w_irls, &residual)?;
let mut xtwx = marginal_design.compute_xtwx(&w_irls)?;
let trace_diag: f64 = (0..p_marg).map(|i| xtwx[[i, i]]).sum();
let ridge = (trace_diag / p_marg as f64).max(1e-12) * 1e-6;
for i in 0..p_marg {
xtwx[[i, i]] += ridge;
}
let factor = xtwx
.cholesky(faer::Side::Lower)
.map_err(|e| format!("pilot_eta_for_link_dev_orthogonalisation Cholesky failed: {e}"))?;
let delta_beta_marg = factor.solvevec(&xtwr);
let marg_contrib = marginal_design.dot(&delta_beta_marg);
Ok(&working_eta + &marg_contrib)
}
fn joint_setup(
data: ArrayView2<'_, f64>,
marginalspec: &TermCollectionSpec,
logslopespec: &TermCollectionSpec,
marginal_penalties: usize,
logslope_penalties: usize,
extra_rho0: &[f64],
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> ExactJointHyperSetup {
let marginal_terms = spatial_length_scale_term_indices(marginalspec);
let logslope_terms = spatial_length_scale_term_indices(logslopespec);
let rho_dim = marginal_penalties + logslope_penalties + extra_rho0.len();
let mut rho0vec = Array1::<f64>::zeros(rho_dim);
for (idx, &value) in extra_rho0.iter().enumerate() {
rho0vec[marginal_penalties + logslope_penalties + idx] = value;
}
let rho_lower = Array1::<f64>::from_elem(rho_dim, -12.0);
let rho_upper = Array1::<f64>::from_elem(rho_dim, 12.0);
let marginal_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
marginalspec,
&marginal_terms,
kappa_options,
)
.reseed_from_data(data, marginalspec, &marginal_terms, kappa_options);
let logslope_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
logslopespec,
&logslope_terms,
kappa_options,
)
.reseed_from_data(data, logslopespec, &logslope_terms, kappa_options);
let mut values = marginal_kappa.as_array().to_vec();
values.extend(logslope_kappa.as_array().iter());
let marginal_dims = marginal_kappa.dims_per_term().to_vec();
let logslope_dims = logslope_kappa.dims_per_term().to_vec();
let mut dims = marginal_dims.clone();
dims.extend(logslope_dims.iter().copied());
let log_kappa0 = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(values), dims.clone());
let marginal_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data,
marginalspec,
&marginal_terms,
&marginal_dims,
kappa_options,
);
let logslope_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data,
logslopespec,
&logslope_terms,
&logslope_dims,
kappa_options,
);
let mut lower_vals = marginal_lower.as_array().to_vec();
lower_vals.extend(logslope_lower.as_array().iter());
let log_kappa_lower =
SpatialLogKappaCoords::new_with_dims(Array1::from_vec(lower_vals), dims.clone());
let marginal_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data,
marginalspec,
&marginal_terms,
&marginal_dims,
kappa_options,
);
let logslope_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data,
logslopespec,
&logslope_terms,
&logslope_dims,
kappa_options,
);
let mut upper_vals = marginal_upper.as_array().to_vec();
upper_vals.extend(logslope_upper.as_array().iter());
let log_kappa_upper = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(upper_vals), dims);
let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
ExactJointHyperSetup::new(
rho0vec,
rho_lower,
rho_upper,
log_kappa0,
log_kappa_lower,
log_kappa_upper,
)
}
#[inline]
fn signed_probit_neglog_derivatives_up_to_fourth_numeric(
signed_margin: f64,
weight: f64,
) -> (f64, f64, f64, f64) {
if weight == 0.0 || signed_margin == f64::INFINITY {
return (0.0, 0.0, 0.0, 0.0);
}
if signed_margin == f64::NEG_INFINITY {
return (f64::NEG_INFINITY, weight, 0.0, 0.0);
}
if signed_margin.is_nan() {
return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
}
let (_, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
let k1 = -lambda;
let k2 = lambda * (signed_margin + lambda);
let k3 = lambda
* (1.0
- signed_margin * signed_margin
- 3.0 * signed_margin * lambda
- 2.0 * lambda * lambda);
let k4 = lambda
* ((signed_margin.powi(3) - 3.0 * signed_margin)
+ (7.0 * signed_margin * signed_margin - 4.0) * lambda
+ 12.0 * signed_margin * lambda * lambda
+ 6.0 * lambda.powi(3));
(weight * k1, weight * k2, weight * k3, weight * k4)
}
pub(crate) fn signed_probit_neglog_derivatives_up_to_fourth(
signed_margin: f64,
weight: f64,
) -> Result<(f64, f64, f64, f64), String> {
if weight == 0.0 || signed_margin == f64::INFINITY {
return Ok((0.0, 0.0, 0.0, 0.0));
}
if !signed_margin.is_finite() {
return Err(format!(
"non-finite signed margin in exact probit derivative helper: {signed_margin}"
));
}
Ok(signed_probit_neglog_derivatives_up_to_fourth_numeric(
signed_margin,
weight,
))
}
#[inline]
fn rigid_observed_logslope(logslope: f64, probit_scale: f64) -> f64 {
probit_scale * logslope
}
#[inline]
fn rigid_observed_scale(logslope: f64, probit_scale: f64) -> f64 {
let observed_logslope = rigid_observed_logslope(logslope, probit_scale);
(1.0 + observed_logslope * observed_logslope).sqrt()
}
#[inline]
fn rigid_intercept_from_marginal(marginal_eta: f64, logslope: f64, probit_scale: f64) -> f64 {
marginal_eta * rigid_observed_scale(logslope, probit_scale)
}
#[inline]
fn rigid_prescale_intercept_from_marginal(
marginal_eta: f64,
logslope: f64,
probit_scale: f64,
) -> f64 {
rigid_intercept_from_marginal(marginal_eta, logslope, probit_scale) / probit_scale
}
#[inline]
fn rigid_prescale_intercept_derivative_abs(
marginal_eta: f64,
logslope: f64,
probit_scale: f64,
) -> f64 {
let c = rigid_observed_scale(logslope, probit_scale);
probit_scale * normal_pdf(marginal_eta) / c
}
#[inline]
fn rigid_observed_eta(marginal_eta: f64, logslope: f64, z: f64, probit_scale: f64) -> f64 {
rigid_intercept_from_marginal(marginal_eta, logslope, probit_scale)
+ rigid_observed_logslope(logslope, probit_scale) * z
}
fn unary_derivatives_normal_cdf(x: f64) -> [f64; 5] {
let pdf = normal_pdf(x);
[
normal_cdf(x),
pdf,
-x * pdf,
(x * x - 1.0) * pdf,
(-x.powi(3) + 3.0 * x) * pdf,
]
}
fn unary_derivatives_normal_pdf(x: f64) -> [f64; 5] {
let pdf = normal_pdf(x);
[
pdf,
-x * pdf,
(x * x - 1.0) * pdf,
(-x.powi(3) + 3.0 * x) * pdf,
(x.powi(4) - 6.0 * x * x + 3.0) * pdf,
]
}
fn unary_derivatives_reciprocal(x: f64) -> [f64; 5] {
let x1 = x.max(1e-300);
let x2 = x1 * x1;
let x3 = x2 * x1;
let x4 = x3 * x1;
let x5 = x4 * x1;
[1.0 / x1, -1.0 / x2, 2.0 / x3, -6.0 / x4, 24.0 / x5]
}
#[inline]
fn lse_accumulate(log_max: &mut f64, sum: &mut f64, log_term: f64) {
if !log_term.is_finite() {
return;
}
if log_term > *log_max {
if log_max.is_finite() {
*sum = *sum * (*log_max - log_term).exp() + 1.0;
} else {
*sum = 1.0;
}
*log_max = log_term;
} else {
*sum += (log_term - *log_max).exp();
}
}
fn empirical_rigid_calibration_eval(
intercept: f64,
log_target_mu: f64,
slope: f64,
probit_scale: f64,
nodes: &[f64],
weights: &[f64],
) -> Result<(f64, f64, f64), String> {
if !intercept.is_finite() {
return Err(format!(
"empirical latent calibration: non-finite intercept {intercept}"
));
}
let observed_slope = rigid_observed_logslope(slope, probit_scale);
const HALF_LOG_2PI: f64 = 0.918_938_533_204_672_8;
let mut log_max_phi = f64::NEG_INFINITY;
let mut sum_phi = 0.0_f64;
let mut log_max_cdf = f64::NEG_INFINITY;
let mut sum_cdf = 0.0_f64;
let mut log_max_pos = f64::NEG_INFINITY;
let mut sum_pos = 0.0_f64;
let mut log_max_neg = f64::NEG_INFINITY;
let mut sum_neg = 0.0_f64;
for (&node, &weight) in nodes.iter().zip(weights.iter()) {
if !(weight.is_finite() && weight > 0.0) {
continue;
}
let eta = intercept + observed_slope * node;
if !eta.is_finite() {
return Err(format!(
"empirical latent calibration: non-finite η at intercept={intercept}, slope={slope}, node={node}"
));
}
let log_w = weight.ln();
let log_phi = -0.5 * eta * eta - HALF_LOG_2PI;
let log_term_phi = log_w + log_phi;
let log_term_cdf = log_w + normal_logcdf(eta);
lse_accumulate(&mut log_max_phi, &mut sum_phi, log_term_phi);
lse_accumulate(&mut log_max_cdf, &mut sum_cdf, log_term_cdf);
if eta != 0.0 {
let log_term_eta_phi = log_term_phi + eta.abs().ln();
if eta > 0.0 {
lse_accumulate(&mut log_max_pos, &mut sum_pos, log_term_eta_phi);
} else {
lse_accumulate(&mut log_max_neg, &mut sum_neg, log_term_eta_phi);
}
}
}
if !(sum_phi.is_finite() && sum_cdf.is_finite() && sum_phi > 0.0 && sum_cdf > 0.0) {
return Err(format!(
"empirical latent calibration: log-space accumulation failed (sum_phi={sum_phi}, sum_cdf={sum_cdf}, intercept={intercept})"
));
}
let log_s_phi = log_max_phi + sum_phi.ln();
let log_s_cdf = log_max_cdf + sum_cdf.ln();
let f = log_s_cdf - log_target_mu;
let log_f_prime = log_s_phi - log_s_cdf;
let f_prime = if log_f_prime > -740.0 {
log_f_prime.exp()
} else {
f64::MIN_POSITIVE
};
let exp_safe = |log_x: f64| -> f64 { if log_x > -740.0 { log_x.exp() } else { 0.0 } };
let pos_over_cdf = if sum_pos > 0.0 {
exp_safe(log_max_pos + sum_pos.ln() - log_s_cdf)
} else {
0.0
};
let neg_over_cdf = if sum_neg > 0.0 {
exp_safe(log_max_neg + sum_neg.ln() - log_s_cdf)
} else {
0.0
};
let s_etaphi_over_s_cdf = pos_over_cdf - neg_over_cdf;
let f_double_prime = -s_etaphi_over_s_cdf - f_prime * f_prime;
if !(f.is_finite() && f_prime.is_finite() && f_prime > 0.0 && f_double_prime.is_finite()) {
return Err(format!(
"empirical latent calibration: non-finite log-space state f={f}, f'={f_prime}, f''={f_double_prime} at intercept={intercept}"
));
}
Ok((f, f_prime, f_double_prime))
}
pub(crate) fn empirical_intercept_from_marginal(
target_mu: f64,
target_q: f64,
slope: f64,
probit_scale: f64,
nodes: &[f64],
weights: &[f64],
initial: Option<f64>,
) -> Result<f64, String> {
if !(target_mu.is_finite() && target_mu > 0.0 && target_mu < 1.0) {
return Err(format!(
"empirical latent calibration requires target mu in (0,1), got {target_mu}"
));
}
let log_target_mu = target_mu.ln();
let seed =
initial.unwrap_or_else(|| rigid_intercept_from_marginal(target_q, slope, probit_scale));
let eval = |a: f64| {
empirical_rigid_calibration_eval(a, log_target_mu, slope, probit_scale, nodes, weights)
};
let abs_tol = 1e-13_f64.max(4.0 * f64::EPSILON);
let (root, _, f_best) = super::monotone_root::solve_monotone_root(
eval,
seed,
"empirical latent intercept",
abs_tol,
64,
48,
)?;
if f_best.abs() > abs_tol {
return Err(format!(
"empirical latent intercept solve failed: log-residual={f_best:.3e} at a={root:.6}, target mu={target_mu:.6}"
));
}
Ok(root)
}
struct RigidProbitKernel {
logcdf: f64,
u1: f64,
u2: f64,
u3: f64,
u4: f64,
c1: f64,
c2: f64,
c3: f64,
c4: f64,
eta_q: f64,
eta_g: f64,
}
impl RigidProbitKernel {
#[inline]
fn new(q: f64, g: f64, z: f64, y: f64, w: f64, probit_scale: f64) -> Result<Self, String> {
let s = 2.0 * y - 1.0;
let observed_logslope = rigid_observed_logslope(g, probit_scale);
let g2 = observed_logslope * observed_logslope;
let c = (1.0 + g2).sqrt();
let c1 = probit_scale * observed_logslope / c;
let c_inv3 = 1.0 / (c * c * c);
let c_inv5 = c_inv3 / (c * c);
let c_inv7 = c_inv5 / (c * c);
let eta = q * c + observed_logslope * z;
let m = s * eta;
let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(m);
let (k1, k2, k3, k4) = signed_probit_neglog_derivatives_up_to_fourth(m, w)?;
Ok(Self {
logcdf,
u1: s * k1,
u2: k2,
u3: s * k3,
u4: k4,
c1,
c2: probit_scale * probit_scale * c_inv3,
c3: -3.0 * probit_scale.powi(3) * observed_logslope * c_inv5,
c4: probit_scale.powi(4) * (12.0 * g2 - 3.0) * c_inv7,
eta_q: c,
eta_g: q * c1 + probit_scale * z,
})
}
#[inline]
fn neglog_only(
q: f64,
g: f64,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<f64, String> {
let s = 2.0 * y - 1.0;
let observed_logslope = rigid_observed_logslope(g, probit_scale);
let c = (1.0 + observed_logslope * observed_logslope).sqrt();
let eta = q * c + observed_logslope * z;
let m = s * eta;
let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(m);
if !logcdf.is_finite() {
return Err(format!(
"rigid probit neglog_only: non-finite log Φ at q={q}, g={g}, z={z}, y={y}"
));
}
Ok(-w * logcdf)
}
#[inline]
fn primary_hessian(&self, q: f64) -> [[f64; 2]; 2] {
let h00 = self.u2 * self.eta_q * self.eta_q;
let h01 = self.u2 * self.eta_q * self.eta_g + self.u1 * self.c1;
let h11 = self.u2 * self.eta_g * self.eta_g + self.u1 * q * self.c2;
[[h00, h01], [h01, h11]]
}
#[inline]
fn third_contracted(&self, q: f64, dq: f64, dg: f64) -> [[f64; 2]; 2] {
let dd = self.eta_q * dq + self.eta_g * dg;
let dd_q = self.c1 * dg;
let dd_g = self.c1 * dq + q * self.c2 * dg;
let dd_qg = self.c2 * dg;
let dd_gg = self.c2 * dq + q * self.c3 * dg;
let t00 = self.u3 * self.eta_q * self.eta_q * dd + self.u2 * 2.0 * self.eta_q * dd_q;
let t01 = self.u3 * self.eta_q * self.eta_g * dd
+ self.u2 * (self.c1 * dd + self.eta_q * dd_g + self.eta_g * dd_q)
+ self.u1 * dd_qg;
let t11 = self.u3 * self.eta_g * self.eta_g * dd
+ self.u2 * (q * self.c2 * dd + 2.0 * self.eta_g * dd_g)
+ self.u1 * dd_gg;
[[t00, t01], [t01, t11]]
}
#[inline]
fn fourth_contracted(&self, q: f64, uq: f64, ug: f64, vq: f64, vg: f64) -> [[f64; 2]; 2] {
let du = self.eta_q * uq + self.eta_g * ug;
let dv = self.eta_q * vq + self.eta_g * vg;
let du_a = [self.c1 * ug, self.c1 * uq + q * self.c2 * ug];
let dv_a = [self.c1 * vg, self.c1 * vq + q * self.c2 * vg];
let du_ab = [
[0.0, self.c2 * ug],
[self.c2 * ug, self.c2 * uq + q * self.c3 * ug],
];
let dv_ab = [
[0.0, self.c2 * vg],
[self.c2 * vg, self.c2 * vq + q * self.c3 * vg],
];
let dduv = self.c1 * (uq * vg + ug * vq) + q * self.c2 * ug * vg;
let dduv_a = [
self.c2 * ug * vg,
self.c2 * (uq * vg + ug * vq) + q * self.c3 * ug * vg,
];
let dduv_ab = [
[0.0, self.c3 * ug * vg],
[
self.c3 * ug * vg,
self.c3 * (uq * vg + ug * vq) + q * self.c4 * ug * vg,
],
];
let eta_a = [self.eta_q, self.eta_g];
let eta_ab = [[0.0, self.c1], [self.c1, q * self.c2]];
let mut f = [[0.0f64; 2]; 2];
for a in 0..2 {
for b in a..2 {
let val = self.u4 * eta_a[a] * eta_a[b] * du * dv
+ self.u3
* (eta_ab[a][b] * du * dv
+ du_a[a] * eta_a[b] * dv
+ dv_a[a] * eta_a[b] * du
+ du_a[b] * eta_a[a] * dv
+ dv_a[b] * eta_a[a] * du
+ dduv * eta_a[a] * eta_a[b])
+ self.u2
* (eta_ab[a][b] * dduv
+ du_a[a] * dv_a[b]
+ dv_a[a] * du_a[b]
+ du_ab[a][b] * dv
+ dv_ab[a][b] * du
+ eta_a[b] * dduv_a[a]
+ eta_a[a] * dduv_a[b])
+ self.u1 * dduv_ab[a][b];
f[a][b] = val;
f[b][a] = val;
}
}
f
}
}
#[inline]
fn rigid_transformed_gradient(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
) -> [f64; 2] {
[
kernel.u1 * kernel.eta_q * marginal.q1,
kernel.u1 * kernel.eta_g,
]
}
#[inline]
fn rigid_transformed_hessian(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
) -> [[f64; 2]; 2] {
let h_q = kernel.primary_hessian(marginal.q);
let grad_q = kernel.u1 * kernel.eta_q;
[
[
h_q[0][0] * marginal.q1 * marginal.q1 + grad_q * marginal.q2,
h_q[0][1] * marginal.q1,
],
[h_q[1][0] * marginal.q1, h_q[1][1]],
]
}
#[inline]
fn rigid_internal_third_components(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
) -> (f64, f64, f64, f64) {
let q_dir = kernel.third_contracted(marginal.q, 1.0, 0.0);
let g_dir = kernel.third_contracted(marginal.q, 0.0, 1.0);
(q_dir[0][0], q_dir[0][1], q_dir[1][1], g_dir[1][1])
}
#[inline]
fn rigid_transformed_third_full(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
) -> [[[f64; 2]; 2]; 2] {
let h_q = kernel.primary_hessian(marginal.q);
let grad_q = kernel.u1 * kernel.eta_q;
let (f_qqq, f_qqg, f_qgg, f_ggg) = rigid_internal_third_components(marginal, kernel);
let f_etaetaeta = f_qqq * marginal.q1.powi(3)
+ 3.0 * h_q[0][0] * marginal.q1 * marginal.q2
+ grad_q * marginal.q3;
let f_etaetag = f_qqg * marginal.q1 * marginal.q1 + h_q[0][1] * marginal.q2;
let f_etagg = f_qgg * marginal.q1;
third_full_from_symmetric_components(f_etaetaeta, f_etaetag, f_etagg, f_ggg)
}
#[inline]
fn third_full_from_symmetric_components(
t_qqq: f64,
t_qqg: f64,
t_qgg: f64,
t_ggg: f64,
) -> [[[f64; 2]; 2]; 2] {
let mut t = [[[0.0; 2]; 2]; 2];
t[0][0][0] = t_qqq;
t[0][0][1] = t_qqg;
t[0][1][0] = t_qqg;
t[1][0][0] = t_qqg;
t[0][1][1] = t_qgg;
t[1][0][1] = t_qgg;
t[1][1][0] = t_qgg;
t[1][1][1] = t_ggg;
t
}
#[inline]
fn contract_third_full(t: &[[[f64; 2]; 2]; 2], d_eta: f64, d_g: f64) -> [[f64; 2]; 2] {
[
[
t[0][0][0] * d_eta + t[0][0][1] * d_g,
t[0][1][0] * d_eta + t[0][1][1] * d_g,
],
[
t[1][0][0] * d_eta + t[1][0][1] * d_g,
t[1][1][0] * d_eta + t[1][1][1] * d_g,
],
]
}
fn rigid_transformed_fourth_full(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
) -> [[[[f64; 2]; 2]; 2]; 2] {
let h_q = kernel.primary_hessian(marginal.q);
let grad_q = kernel.u1 * kernel.eta_q;
let (f_qqq, f_qqg, f_qgg, _) = rigid_internal_third_components(marginal, kernel);
let qq = kernel.fourth_contracted(marginal.q, 1.0, 0.0, 1.0, 0.0);
let qg = kernel.fourth_contracted(marginal.q, 1.0, 0.0, 0.0, 1.0);
let gg = kernel.fourth_contracted(marginal.q, 0.0, 1.0, 0.0, 1.0);
let f_qqqq = qq[0][0];
let f_qqqg = qq[0][1];
let f_qqgg = qq[1][1];
let f_qggg = qg[1][1];
let f_gggg = gg[1][1];
let f_eta4 = f_qqqq * marginal.q1.powi(4)
+ 6.0 * f_qqq * marginal.q1 * marginal.q1 * marginal.q2
+ 3.0 * h_q[0][0] * marginal.q2 * marginal.q2
+ 4.0 * h_q[0][0] * marginal.q1 * marginal.q3
+ grad_q * marginal.q4;
let f_eta3g = f_qqqg * marginal.q1.powi(3)
+ 3.0 * f_qqg * marginal.q1 * marginal.q2
+ h_q[0][1] * marginal.q3;
let f_eta2g2 = f_qqgg * marginal.q1 * marginal.q1 + f_qgg * marginal.q2;
let f_etag3 = f_qggg * marginal.q1;
fourth_full_from_symmetric_components(f_eta4, f_eta3g, f_eta2g2, f_etag3, f_gggg)
}
#[inline]
fn fourth_full_from_symmetric_components(
t_qqqq: f64,
t_qqqg: f64,
t_qqgg: f64,
t_qggg: f64,
t_gggg: f64,
) -> [[[[f64; 2]; 2]; 2]; 2] {
let mut t = [[[[0.0; 2]; 2]; 2]; 2];
for a in 0..2 {
for b in 0..2 {
for c in 0..2 {
for d in 0..2 {
let g_count = a + b + c + d; t[a][b][c][d] = match g_count {
0 => t_qqqq,
1 => t_qqqg,
2 => t_qqgg,
3 => t_qggg,
4 => t_gggg,
_ => unreachable!(),
};
}
}
}
}
t
}
#[inline]
fn contract_fourth_full(
t: &[[[[f64; 2]; 2]; 2]; 2],
u_eta: f64,
u_g: f64,
v_eta: f64,
v_g: f64,
) -> [[f64; 2]; 2] {
let mut out = [[0.0; 2]; 2];
for a in 0..2 {
for b in 0..2 {
let mut sum = 0.0;
sum += t[a][b][0][0] * u_eta * v_eta;
sum += t[a][b][0][1] * u_eta * v_g;
sum += t[a][b][1][0] * u_g * v_eta;
sum += t[a][b][1][1] * u_g * v_g;
out[a][b] = sum;
}
}
out
}
pub(crate) fn unary_derivatives_sqrt(x: f64) -> [f64; 5] {
let s = x.max(1e-300).sqrt();
let x1 = x.max(1e-300);
let x2 = x1 * x1;
let x3 = x2 * x1;
[
s,
0.5 / s,
-0.25 / (x1 * s),
3.0 / (8.0 * x2 * s),
-15.0 / (16.0 * x3 * s),
]
}
pub(crate) fn unary_derivatives_neglog_phi(x: f64, weight: f64) -> [f64; 5] {
if weight == 0.0 || x == f64::INFINITY {
return [0.0, 0.0, 0.0, 0.0, 0.0];
}
if x == f64::NEG_INFINITY {
return [f64::INFINITY, f64::NEG_INFINITY, weight, 0.0, 0.0];
}
if x.is_nan() {
return [f64::NAN; 5];
}
let (d1, d2, d3, d4) = signed_probit_neglog_derivatives_up_to_fourth_numeric(x, weight);
let (log_cdf, _) = signed_probit_logcdf_and_mills_ratio(x);
[-weight * log_cdf, d1, d2, d3, d4]
}
pub(crate) fn unary_derivatives_log(x: f64) -> [f64; 5] {
let x1 = x.max(1e-300);
let x2 = x1 * x1;
let x3 = x2 * x1;
let x4 = x3 * x1;
[x1.ln(), 1.0 / x1, -1.0 / x2, 2.0 / x3, -6.0 / x4]
}
pub(crate) fn unary_derivatives_log_normal_pdf(x: f64) -> [f64; 5] {
let c = 0.5 * (2.0 * std::f64::consts::PI).ln();
[-0.5 * x * x - c, -x, -1.0, 0.0, 0.0]
}
struct BlockPsiRow {
block_idx: usize,
range: std::ops::Range<usize>,
local_vec: Array1<f64>,
}
struct PsiAxisSpec {
block_idx: usize,
idx_primary: usize,
psi_map: crate::families::custom_family::PsiDesignMap,
}
#[derive(Clone)]
struct BlockSlices {
marginal: std::ops::Range<usize>,
logslope: std::ops::Range<usize>,
h: Option<std::ops::Range<usize>>,
w: Option<std::ops::Range<usize>>,
total: usize,
}
fn block_slices(family: &BernoulliMarginalSlopeFamily) -> BlockSlices {
let mut cursor = 0usize;
let marginal = cursor..cursor + family.marginal_design.ncols();
cursor = marginal.end;
let logslope = cursor..cursor + family.logslope_design.ncols();
cursor = logslope.end;
let h = family.score_warp.as_ref().map(|runtime| {
let range = cursor..cursor + runtime.basis_dim();
cursor = range.end;
range
});
let w = family.link_dev.as_ref().map(|runtime| {
let range = cursor..cursor + runtime.basis_dim();
cursor = range.end;
range
});
BlockSlices {
marginal,
logslope,
h,
w,
total: cursor,
}
}
#[derive(Clone)]
struct PrimarySlices {
q: usize,
logslope: usize,
h: Option<std::ops::Range<usize>>,
w: Option<std::ops::Range<usize>>,
total: usize,
}
fn primary_slices(slices: &BlockSlices) -> PrimarySlices {
let q = 0usize;
let logslope = 1usize;
let mut cursor = 2usize;
let h = slices.h.as_ref().map(|range| {
let out = cursor..cursor + range.len();
cursor = out.end;
out
});
let w = slices.w.as_ref().map(|range| {
let out = cursor..cursor + range.len();
cursor = out.end;
out
});
PrimarySlices {
q,
logslope,
h,
w,
total: cursor,
}
}
struct BernoulliBlockHessianAccumulator {
h_mm: Array2<f64>,
h_gg: Array2<f64>,
h_mg: Array2<f64>,
dense_correction: Option<Array2<f64>>,
}
impl BernoulliBlockHessianAccumulator {
fn new(slices: &BlockSlices) -> Self {
let p_m = slices.marginal.len();
let p_g = slices.logslope.len();
let has_hw = slices.h.is_some() || slices.w.is_some();
Self {
h_mm: Array2::zeros((p_m, p_m)),
h_gg: Array2::zeros((p_g, p_g)),
h_mg: Array2::zeros((p_m, p_g)),
dense_correction: if has_hw {
Some(Array2::zeros((slices.total, slices.total)))
} else {
None
},
}
}
fn add_pullback(
&mut self,
family: &BernoulliMarginalSlopeFamily,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
primary_hessian: &Array2<f64>,
) {
let h = primary_hessian;
family
.marginal_design
.syr_row_into(row, h[[0, 0]], &mut self.h_mm)
.expect("marginal syr_row_into dimension mismatch");
family
.logslope_design
.syr_row_into(row, h[[1, 1]], &mut self.h_gg)
.expect("logslope syr_row_into dimension mismatch");
if h[[0, 1]] != 0.0 {
family
.marginal_design
.row_outer_into_view(
row,
&family.logslope_design,
h[[0, 1]],
self.h_mg.view_mut(),
)
.expect("marginal-logslope row_outer_into dimension mismatch");
}
if let Some(ref mut dc) = self.dense_correction {
family.add_pullback_primary_hessian_hw_only(dc, row, slices, primary, h.view());
}
}
fn add_hw_pullback_only(
&mut self,
family: &BernoulliMarginalSlopeFamily,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
primary_hessian: &Array2<f64>,
) {
if let Some(ref mut dc) = self.dense_correction {
family.add_pullback_primary_hessian_hw_only(
dc,
row,
slices,
primary,
primary_hessian.view(),
);
}
}
fn add_weighted_design_grams(
&mut self,
family: &BernoulliMarginalSlopeFamily,
rows: std::ops::Range<usize>,
w_mm: &Array1<f64>,
w_mg: &Array1<f64>,
w_gg: &Array1<f64>,
) -> Result<(), String> {
let x = family
.marginal_design
.try_row_chunk(rows.clone())
.map_err(|e| format!("bernoulli marginal_design try_row_chunk: {e}"))?;
let g = family
.logslope_design
.try_row_chunk(rows)
.map_err(|e| format!("bernoulli logslope_design try_row_chunk: {e}"))?;
self.h_mm += &crate::faer_ndarray::fast_xt_diag_x(&x, w_mm);
if w_mg.iter().any(|value| *value != 0.0) {
self.h_mg += &crate::faer_ndarray::fast_xt_diag_y(&x, w_mg, &g);
}
self.h_gg += &crate::faer_ndarray::fast_xt_diag_x(&g, w_gg);
Ok(())
}
fn add_rank1_psi_cross(
&mut self,
family: &BernoulliMarginalSlopeFamily,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
psi_block_idx: usize,
psi_row: &Array1<f64>,
right_primary: &Array1<f64>,
) {
if right_primary[0] != 0.0 {
match psi_block_idx {
0 => {
for (idx, &value) in psi_row.iter().enumerate() {
if value == 0.0 {
continue;
}
let scale = right_primary[0] * value;
{
let mut col = self.h_mm.column_mut(idx);
family
.marginal_design
.axpy_row_into(row, scale, &mut col)
.expect("marginal axpy column mismatch");
}
{
let mut row_view = self.h_mm.row_mut(idx);
family
.marginal_design
.axpy_row_into(row, scale, &mut row_view)
.expect("marginal axpy row mismatch");
}
}
}
1 => {
for (idx, &value) in psi_row.iter().enumerate() {
if value == 0.0 {
continue;
}
let mut col = self.h_mg.column_mut(idx);
family
.marginal_design
.axpy_row_into(row, right_primary[0] * value, &mut col)
.expect("marginal axpy column mismatch");
}
}
_ => {}
}
}
if right_primary[1] != 0.0 {
match psi_block_idx {
0 => {
for (idx, &value) in psi_row.iter().enumerate() {
if value == 0.0 {
continue;
}
let mut row_view = self.h_mg.row_mut(idx);
family
.logslope_design
.axpy_row_into(row, right_primary[1] * value, &mut row_view)
.expect("logslope axpy row mismatch");
}
}
1 => {
for (idx, &value) in psi_row.iter().enumerate() {
if value == 0.0 {
continue;
}
let scale = right_primary[1] * value;
{
let mut col = self.h_gg.column_mut(idx);
family
.logslope_design
.axpy_row_into(row, scale, &mut col)
.expect("logslope axpy column mismatch");
}
{
let mut row_view = self.h_gg.row_mut(idx);
family
.logslope_design
.axpy_row_into(row, scale, &mut row_view)
.expect("logslope axpy row mismatch");
}
}
}
_ => {}
}
}
if let Some(ref mut dc) = self.dense_correction {
let psi_range = if psi_block_idx == 0 {
slices.marginal.clone()
} else {
slices.logslope.clone()
};
if let (Some(ph), Some(bh)) = (primary.h.as_ref(), slices.h.as_ref()) {
let h_part = right_primary.slice(ndarray::s![ph.start..ph.end]);
for (li, gi) in psi_range.clone().enumerate() {
for (lj, gj) in bh.clone().enumerate() {
let val = psi_row[li] * h_part[lj];
dc[[gi, gj]] += val;
dc[[gj, gi]] += val;
}
}
}
if let (Some(pw), Some(bw)) = (primary.w.as_ref(), slices.w.as_ref()) {
let w_part = right_primary.slice(ndarray::s![pw.start..pw.end]);
for (li, gi) in psi_range.enumerate() {
for (lj, gj) in bw.clone().enumerate() {
let val = psi_row[li] * w_part[lj];
dc[[gi, gj]] += val;
dc[[gj, gi]] += val;
}
}
}
}
}
fn add_psi_psi_outer(
&mut self,
block_i: usize,
psi_row_i: &Array1<f64>,
block_j: usize,
psi_row_j: &Array1<f64>,
alpha: f64,
) {
add_two_surface_psi_outer(
block_i,
psi_row_i,
block_j,
psi_row_j,
alpha,
0,
1,
&mut self.h_mm,
&mut self.h_gg,
&mut self.h_mg,
);
}
fn add(&mut self, other: &BernoulliBlockHessianAccumulator) {
self.h_mm += &other.h_mm;
self.h_gg += &other.h_gg;
self.h_mg += &other.h_mg;
if let (Some(ref mut dc), Some(ref odc)) = (
self.dense_correction.as_mut(),
other.dense_correction.as_ref(),
) {
dc.scaled_add(1.0, odc);
}
}
fn to_dense(&self, slices: &BlockSlices) -> Array2<f64> {
let mut out = Array2::zeros((slices.total, slices.total));
out.slice_mut(s![slices.marginal.clone(), slices.marginal.clone()])
.assign(&self.h_mm);
out.slice_mut(s![slices.logslope.clone(), slices.logslope.clone()])
.assign(&self.h_gg);
out.slice_mut(s![slices.marginal.clone(), slices.logslope.clone()])
.assign(&self.h_mg);
out.slice_mut(s![slices.logslope.clone(), slices.marginal.clone()])
.assign(&self.h_mg.t());
if let Some(ref dc) = self.dense_correction {
out += dc;
}
out
}
fn into_operator(self, slices: &BlockSlices) -> BernoulliBlockHessianOperator {
BernoulliBlockHessianOperator {
h_mm: self.h_mm,
h_gg: self.h_gg,
h_mg: self.h_mg,
dense_correction: self.dense_correction,
marginal: slices.marginal.clone(),
logslope: slices.logslope.clone(),
total: slices.total,
}
}
}
struct BernoulliBlockHessianOperator {
h_mm: Array2<f64>,
h_gg: Array2<f64>,
h_mg: Array2<f64>,
dense_correction: Option<Array2<f64>>,
marginal: std::ops::Range<usize>,
logslope: std::ops::Range<usize>,
total: usize,
}
impl HyperOperator for BernoulliBlockHessianOperator {
fn dim(&self) -> usize {
self.total
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let v_m = v.slice(s![self.marginal.clone()]);
let v_g = v.slice(s![self.logslope.clone()]);
let mut out = Array1::zeros(self.total);
{
let mut o_m = out.slice_mut(s![self.marginal.clone()]);
o_m += &self.h_mm.dot(&v_m);
o_m += &self.h_mg.dot(&v_g);
}
{
let mut o_g = out.slice_mut(s![self.logslope.clone()]);
o_g += &self.h_mg.t().dot(&v_m);
o_g += &self.h_gg.dot(&v_g);
}
if let Some(ref dc) = self.dense_correction {
out += &dc.dot(v);
}
out
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let v_m = v.slice(s![self.marginal.clone()]);
let v_g = v.slice(s![self.logslope.clone()]);
let u_m = u.slice(s![self.marginal.clone()]);
let u_g = u.slice(s![self.logslope.clone()]);
let mut total = v_m.dot(&self.h_mm.dot(&u_m));
total += v_g.dot(&self.h_gg.dot(&u_g));
total += v_m.dot(&self.h_mg.dot(&u_g));
total += v_g.dot(&self.h_mg.t().dot(&u_m));
if let Some(ref dc) = self.dense_correction {
total += v.dot(&dc.dot(u));
}
total
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::zeros((self.total, self.total));
out.slice_mut(s![self.marginal.clone(), self.marginal.clone()])
.assign(&self.h_mm);
out.slice_mut(s![self.logslope.clone(), self.logslope.clone()])
.assign(&self.h_gg);
out.slice_mut(s![self.marginal.clone(), self.logslope.clone()])
.assign(&self.h_mg);
out.slice_mut(s![self.logslope.clone(), self.marginal.clone()])
.assign(&self.h_mg.t());
if let Some(ref dc) = self.dense_correction {
out += dc;
}
out
}
fn is_implicit(&self) -> bool {
false
}
}
#[derive(Clone)]
struct CachedDenestedCellMoments {
partition_cell: exact_kernel::DenestedPartitionCell,
state: exact_kernel::CellDerivativeMomentState,
}
#[derive(Clone)]
struct RowCellMomentsBundle {
max_degree: usize,
rows: Vec<Option<Vec<CachedDenestedCellMoments>>>,
}
impl RowCellMomentsBundle {
#[inline]
fn row(&self, row: usize, required_degree: usize) -> Option<&[CachedDenestedCellMoments]> {
debug_assert!(
self.max_degree >= required_degree,
"row cell moments bundle max_degree={} required_degree={}",
self.max_degree,
required_degree
);
(self.max_degree >= required_degree)
.then(|| {
self.rows
.get(row)
.and_then(Option::as_ref)
.map(Vec::as_slice)
})
.flatten()
}
fn estimated_resident_bytes(n_rows: usize, n_cells: usize, max_degree: usize) -> usize {
let row_vecs =
n_rows.saturating_mul(std::mem::size_of::<Option<Vec<CachedDenestedCellMoments>>>());
let cell_records = n_cells.saturating_mul(std::mem::size_of::<CachedDenestedCellMoments>());
let required_moments = max_degree.saturating_add(1);
let moment_payload = if required_moments > exact_kernel::CELL_MOMENT_INLINE_CAPACITY {
n_cells
.saturating_mul(required_moments)
.saturating_mul(std::mem::size_of::<f64>())
} else {
0
};
row_vecs
.saturating_add(cell_records)
.saturating_add(moment_payload)
}
}
#[derive(Clone)]
struct BernoulliMarginalSlopeRowExactContext {
intercept: f64,
m_a: f64,
intercept_fast_path: bool,
degree9_cells: Option<Vec<CachedDenestedCellMoments>>,
}
struct BernoulliMarginalSlopeFlexRowScratch {
m_u: Array1<f64>,
m_au: Array1<f64>,
m_uv: Array2<f64>,
a_u: Array1<f64>,
a_uv: Array2<f64>,
rho: Array1<f64>,
tau: Array1<f64>,
du: Array1<f64>,
grad: Array1<f64>,
hess: Array2<f64>,
}
impl BernoulliMarginalSlopeFlexRowScratch {
fn new(primary_dim: usize) -> Self {
Self {
m_u: Array1::zeros(primary_dim),
m_au: Array1::zeros(primary_dim),
m_uv: Array2::zeros((primary_dim, primary_dim)),
a_u: Array1::zeros(primary_dim),
a_uv: Array2::zeros((primary_dim, primary_dim)),
rho: Array1::zeros(primary_dim),
tau: Array1::zeros(primary_dim),
du: Array1::zeros(primary_dim),
grad: Array1::zeros(primary_dim),
hess: Array2::zeros((primary_dim, primary_dim)),
}
}
fn reset(&mut self, need_hessian: bool) {
self.m_u.fill(0.0);
self.a_u.fill(0.0);
self.rho.fill(0.0);
self.tau.fill(0.0);
self.du.fill(0.0);
self.grad.fill(0.0);
if need_hessian {
self.m_au.fill(0.0);
self.m_uv.fill(0.0);
self.a_uv.fill(0.0);
self.hess.fill(0.0);
}
}
}
#[inline]
fn accumulate_flex_block_grad_hess(
primary_range: &std::ops::Range<usize>,
scratch: &BernoulliMarginalSlopeFlexRowScratch,
grad: &mut Array1<f64>,
hess: &mut Array2<f64>,
) {
let start = primary_range.start;
let end = primary_range.end;
let src_g = scratch.grad.slice(s![start..end]);
grad.scaled_add(-1.0, &src_g);
let src_h = scratch.hess.slice(s![start..end, start..end]);
*hess += &src_h;
}
use crate::families::jet_partitions::MultiDirJet;
const COEFF_SUPPORT_BHW: CoeffSupport = CoeffSupport {
include_primary: true,
include_h: true,
include_w: true,
};
const COEFF_SUPPORT_BW: CoeffSupport = CoeffSupport {
include_primary: true,
include_h: false,
include_w: true,
};
const COEFF_SUPPORT_W: CoeffSupport = CoeffSupport {
include_primary: false,
include_h: false,
include_w: true,
};
struct BernoulliExactNewtonAccumulator {
ll: f64,
grad_marginal: Array1<f64>,
grad_logslope: Array1<f64>,
hess_marginal: Array2<f64>,
hess_logslope: Array2<f64>,
grad_h: Option<Array1<f64>>,
grad_w: Option<Array1<f64>>,
hess_h: Option<Array2<f64>>,
hess_w: Option<Array2<f64>>,
}
impl BernoulliExactNewtonAccumulator {
fn new(slices: &BlockSlices) -> Self {
Self {
ll: 0.0,
grad_marginal: Array1::zeros(slices.marginal.len()),
grad_logslope: Array1::zeros(slices.logslope.len()),
hess_marginal: Array2::zeros((slices.marginal.len(), slices.marginal.len())),
hess_logslope: Array2::zeros((slices.logslope.len(), slices.logslope.len())),
grad_h: slices.h.as_ref().map(|range| Array1::zeros(range.len())),
grad_w: slices.w.as_ref().map(|range| Array1::zeros(range.len())),
hess_h: slices
.h
.as_ref()
.map(|range| Array2::zeros((range.len(), range.len()))),
hess_w: slices
.w
.as_ref()
.map(|range| Array2::zeros((range.len(), range.len()))),
}
}
fn add_pullback_block_diagonals(
&mut self,
family: &BernoulliMarginalSlopeFamily,
row: usize,
primary: &PrimarySlices,
row_neglog: f64,
scratch: &BernoulliMarginalSlopeFlexRowScratch,
) -> Result<(), String> {
self.ll -= row_neglog;
{
let mut marginal = self.grad_marginal.view_mut();
family.marginal_design.axpy_row_into(
row,
BernoulliMarginalSlopeFamily::exact_newton_score_component_from_objective_gradient(
scratch.grad[0],
),
&mut marginal,
)?;
}
{
let mut logslope = self.grad_logslope.view_mut();
family.logslope_design.axpy_row_into(
row,
BernoulliMarginalSlopeFamily::exact_newton_score_component_from_objective_gradient(
scratch.grad[1],
),
&mut logslope,
)?;
}
family
.marginal_design
.syr_row_into(row, scratch.hess[[0, 0]], &mut self.hess_marginal)?;
family
.logslope_design
.syr_row_into(row, scratch.hess[[1, 1]], &mut self.hess_logslope)?;
if let (Some(primary_h), Some(grad_h), Some(hess_h)) = (
primary.h.as_ref(),
self.grad_h.as_mut(),
self.hess_h.as_mut(),
) {
accumulate_flex_block_grad_hess(primary_h, scratch, grad_h, hess_h);
}
if let (Some(primary_w), Some(grad_w), Some(hess_w)) = (
primary.w.as_ref(),
self.grad_w.as_mut(),
self.hess_w.as_mut(),
) {
accumulate_flex_block_grad_hess(primary_w, scratch, grad_w, hess_w);
}
Ok(())
}
fn add(&mut self, other: &Self) {
self.ll += other.ll;
self.grad_marginal += &other.grad_marginal;
self.grad_logslope += &other.grad_logslope;
self.hess_marginal += &other.hess_marginal;
self.hess_logslope += &other.hess_logslope;
add_optional_vector(&mut self.grad_h, &other.grad_h);
add_optional_vector(&mut self.grad_w, &other.grad_w);
add_optional_matrix(&mut self.hess_h, &other.hess_h);
add_optional_matrix(&mut self.hess_w, &other.hess_w);
}
}
fn add_weighted_chunk_gradient(
chunk: &Array2<f64>,
weights: &Array1<f64>,
target: &mut Array1<f64>,
) {
*target += &crate::faer_ndarray::fast_atv(chunk, weights);
}
fn new_cell_moment_lru_cache(
policy: &crate::resource::ResourcePolicy,
) -> Arc<exact_kernel::CellMomentLruCache> {
let budget = policy.max_single_materialization_bytes;
Arc::new(exact_kernel::CellMomentLruCache::new(budget))
}
fn new_cell_moment_cache_stats() -> Arc<exact_kernel::CellMomentCacheStats> {
Arc::new(exact_kernel::CellMomentCacheStats::default())
}
fn add_weighted_chunk_gram(chunk: &Array2<f64>, weights: &Array1<f64>, target: &mut Array2<f64>) {
*target += &crate::faer_ndarray::fast_xt_diag_x(chunk, weights);
}
const ROW_CHUNK_SIZE: usize = 1024;
const EXACT_WORK_LOG_MIN_ROWS: usize = 50_000;
#[inline]
fn log_exact_work(n: usize) -> bool {
n >= EXACT_WORK_LOG_MIN_ROWS
}
#[derive(Clone)]
struct BernoulliMarginalSlopeExactEvalCache {
slices: BlockSlices,
primary: PrimarySlices,
row_contexts: Vec<BernoulliMarginalSlopeRowExactContext>,
row_cell_moments: Option<RowCellMomentsBundle>,
row_primary_hessians: Option<Array2<f64>>,
rigid_third_full: crate::resource::RayonSafeOnce<Result<Vec<[[[f64; 2]; 2]; 2]>, String>>,
fingerprint: CacheFingerprint,
rigid_fourth_full:
crate::resource::RayonSafeOnce<Result<Vec<[[[[f64; 2]; 2]; 2]; 2]>, String>>,
}
struct BernoulliRigidRowKernel {
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
slices: BlockSlices,
third_full_cache: crate::resource::RayonSafeOnce<Vec<[[[f64; 2]; 2]; 2]>>,
fourth_full_cache: crate::resource::RayonSafeOnce<Vec<[[[[f64; 2]; 2]; 2]; 2]>>,
}
impl BernoulliRigidRowKernel {
fn new(family: BernoulliMarginalSlopeFamily, block_states: Vec<ParameterBlockState>) -> Self {
let slices = block_slices(&family);
Self {
family,
block_states,
slices,
third_full_cache: crate::resource::RayonSafeOnce::new(),
fourth_full_cache: crate::resource::RayonSafeOnce::new(),
}
}
fn third_full_cache(&self) -> &[[[[f64; 2]; 2]; 2]] {
self.third_full_cache
.get_or_init(|| {
(0..self.family.y.len())
.into_par_iter()
.map(|row| {
let marginal_eta = self.block_states[0].eta[row];
let marginal = self.family.marginal_link_map(marginal_eta)?;
let slope = self.block_states[1].eta[row];
self.family
.rigid_row_third_full(row, marginal_eta, marginal, slope)
})
.collect::<Result<Vec<_>, String>>()
.expect(
"BernoulliRigidRowKernel third-full cache build failed; \
per-row jet should not error at the converged β snapshot",
)
})
.as_slice()
}
fn fourth_full_cache(&self) -> &[[[[[f64; 2]; 2]; 2]; 2]] {
self.fourth_full_cache
.get_or_init(|| {
(0..self.family.y.len())
.into_par_iter()
.map(|row| {
let marginal_eta = self.block_states[0].eta[row];
let marginal = self.family.marginal_link_map(marginal_eta)?;
let slope = self.block_states[1].eta[row];
self.family
.rigid_row_fourth_full(row, marginal_eta, marginal, slope)
})
.collect::<Result<Vec<_>, String>>()
.expect(
"BernoulliRigidRowKernel fourth-full cache build failed; \
per-row jet should not error at the converged β snapshot",
)
})
.as_slice()
}
}
impl RowKernel<2> for BernoulliRigidRowKernel {
fn n_rows(&self) -> usize {
self.family.y.len()
}
fn n_coefficients(&self) -> usize {
self.slices.total
}
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; 2], [[f64; 2]; 2]), String> {
let marginal_eta = self.block_states[0].eta[row];
let marginal = self.family.marginal_link_map(marginal_eta)?;
let g = self.block_states[1].eta[row];
self.family
.rigid_row_kernel_eval(row, marginal_eta, marginal, g)
}
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; 2] {
let d_beta = ndarray::ArrayView1::from(d_beta);
[
self.family
.marginal_design
.dot_row_view(row, d_beta.slice(s![self.slices.marginal.clone()])),
self.family
.logslope_design
.dot_row_view(row, d_beta.slice(s![self.slices.logslope.clone()])),
]
}
fn jacobian_transpose_action(&self, row: usize, v: &[f64; 2], out: &mut [f64]) {
{
let mut m = ndarray::ArrayViewMut1::from(&mut out[self.slices.marginal.clone()]);
self.family
.marginal_design
.axpy_row_into(row, v[0], &mut m)
.expect("marginal axpy dim mismatch");
}
{
let mut g = ndarray::ArrayViewMut1::from(&mut out[self.slices.logslope.clone()]);
self.family
.logslope_design
.axpy_row_into(row, v[1], &mut g)
.expect("logslope axpy dim mismatch");
}
}
fn add_pullback_hessian(&self, row: usize, h: &[[f64; 2]; 2], target: &mut Array2<f64>) {
self.family
.marginal_design
.syr_row_into_view(
row,
h[0][0],
target.slice_mut(s![
self.slices.marginal.clone(),
self.slices.marginal.clone()
]),
)
.expect("marginal syr dim mismatch");
if h[0][1] != 0.0 {
self.family
.marginal_design
.row_outer_into_view(
row,
&self.family.logslope_design,
h[0][1],
target.slice_mut(s![
self.slices.marginal.clone(),
self.slices.logslope.clone()
]),
)
.expect("marginal-logslope outer dim mismatch");
self.family
.logslope_design
.row_outer_into_view(
row,
&self.family.marginal_design,
h[0][1],
target.slice_mut(s![
self.slices.logslope.clone(),
self.slices.marginal.clone()
]),
)
.expect("logslope-marginal outer dim mismatch");
}
self.family
.logslope_design
.syr_row_into_view(
row,
h[1][1],
target.slice_mut(s![
self.slices.logslope.clone(),
self.slices.logslope.clone()
]),
)
.expect("logslope syr dim mismatch");
}
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; 2]; 2], diag: &mut [f64]) {
{
let mut md = ndarray::ArrayViewMut1::from(&mut diag[self.slices.marginal.clone()]);
self.family
.marginal_design
.squared_axpy_row_into(row, h[0][0], &mut md)
.expect("marginal squared_axpy dim mismatch");
}
{
let mut gd = ndarray::ArrayViewMut1::from(&mut diag[self.slices.logslope.clone()]);
self.family
.logslope_design
.squared_axpy_row_into(row, h[1][1], &mut gd)
.expect("logslope squared_axpy dim mismatch");
}
}
fn row_third_contracted(&self, row: usize, dir: &[f64; 2]) -> Result<[[f64; 2]; 2], String> {
let cache = self.third_full_cache();
Ok(contract_third_full(&cache[row], dir[0], dir[1]))
}
fn warm_up_directional_caches(&self) -> Result<(), String> {
let _ = self.third_full_cache();
let _ = self.fourth_full_cache();
Ok(())
}
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; 2],
dir_v: &[f64; 2],
) -> Result<[[f64; 2]; 2], String> {
let cache = self.fourth_full_cache();
Ok(contract_fourth_full(
&cache[row],
dir_u[0],
dir_u[1],
dir_v[0],
dir_v[1],
))
}
}
struct BernoulliMarginalSlopeExactNewtonJointHessianWorkspace {
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
cache: Arc<BernoulliMarginalSlopeExactEvalCache>,
matvec_calls: AtomicUsize,
options: BlockwiseFitOptions,
}
struct BernoulliMarginalSlopeLineSearchWorkspace {
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
cache: BernoulliMarginalSlopeExactEvalCache,
options: BlockwiseFitOptions,
log_likelihood: f64,
full_workspace: crate::resource::RayonSafeOnce<
Result<Arc<BernoulliMarginalSlopeExactNewtonJointHessianWorkspace>, String>,
>,
}
struct BernoulliMarginalSlopeExactNewtonJointPsiWorkspace {
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
specs: Vec<ParameterBlockSpec>,
derivative_blocks: Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
cache: Arc<BernoulliMarginalSlopeExactEvalCache>,
options: BlockwiseFitOptions,
}
const BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS: usize = 10_000;
fn bernoulli_margslope_line_search_ll_with_early_exit<F>(
weighted_rows: &[WeightedOuterRow],
threshold: f64,
row_ll: F,
) -> Result<f64, String>
where
F: Fn(usize) -> Result<f64, String> + Sync,
{
if !threshold.is_finite() {
return Err(format!(
"bernoulli marginal-slope early-exit threshold must be finite, got {threshold}"
));
}
let mut total_ll = 0.0;
for chunk in weighted_rows.chunks(BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS) {
let chunk_ll: f64 = chunk
.into_par_iter()
.try_fold(
|| 0.0,
|mut acc, wr| -> Result<_, String> {
acc += wr.weight * row_ll(wr.index)?;
Ok(acc)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
)?;
total_ll += chunk_ll;
if -total_ll > threshold {
return Err(format!(
"bernoulli marginal-slope line-search rejected early: partial_nll={} threshold={}",
-total_ll, threshold
));
}
}
Ok(total_ll)
}
impl BernoulliMarginalSlopeFamily {
#[inline]
fn probit_frailty_scale(&self) -> f64 {
probit_frailty_scale(self.gaussian_frailty_sd)
}
fn empirical_rigid_intercept_for_row(
&self,
row: usize,
marginal: BernoulliMarginalLinkMap,
slope: f64,
nodes: &[f64],
measure_weights: &[f64],
) -> Result<f64, String> {
let cached = self.intercept_warm_starts.as_ref().and_then(|cache| {
let value = f64::from_bits(cache.get(row)?.load(Ordering::Relaxed));
value.is_finite().then_some(value)
});
let root = empirical_intercept_from_marginal(
marginal.mu,
marginal.q,
slope,
self.probit_frailty_scale(),
nodes,
measure_weights,
cached,
)?;
if let Some(cache) = self.intercept_warm_starts.as_ref()
&& let Some(slot) = cache.get(row)
{
slot.store(root.to_bits(), Ordering::Relaxed);
}
Ok(root)
}
fn empirical_rigid_calibration_jets(
&self,
intercept: &MultiDirJet,
mu: &MultiDirJet,
slope: &MultiDirJet,
nodes: &[f64],
measure_weights: &[f64],
) -> (MultiDirJet, MultiDirJet) {
let n_dirs = intercept.coeffs.len().trailing_zeros() as usize;
let observed_slope = slope.scale(self.probit_frailty_scale());
let mut f = mu.scale(-1.0);
let mut f_a = MultiDirJet::zero(n_dirs);
for (&node, &weight) in nodes.iter().zip(measure_weights.iter()) {
let eta = intercept.add(&observed_slope.scale(node));
let cdf = eta.compose_unary(unary_derivatives_normal_cdf(eta.coeff(0)));
let pdf = eta.compose_unary(unary_derivatives_normal_pdf(eta.coeff(0)));
f = f.add(&cdf.scale(weight));
f_a = f_a.add(&pdf.scale(weight));
}
(f, f_a)
}
fn empirical_rigid_neglog_only(
&self,
row: usize,
marginal: BernoulliMarginalLinkMap,
slope: f64,
nodes: &[f64],
measure_weights: &[f64],
) -> Result<f64, String> {
let intercept =
self.empirical_rigid_intercept_for_row(row, marginal, slope, nodes, measure_weights)?;
let observed_slope = slope * self.probit_frailty_scale();
let observed_eta = intercept + observed_slope * self.z[row];
let signed = (2.0 * self.y[row] - 1.0) * observed_eta;
let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(signed);
if !logcdf.is_finite() {
return Err(format!(
"empirical rigid neglog_only: non-finite log Φ at row {row}"
));
}
Ok(-self.weights[row] * logcdf)
}
fn rigid_row_neglog_only(
&self,
row: usize,
marginal: BernoulliMarginalLinkMap,
slope: f64,
) -> Result<f64, String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => RigidProbitKernel::neglog_only(
marginal.q,
slope,
self.z[row],
self.y[row],
self.weights[row],
self.probit_frailty_scale(),
),
Some(grid) => {
self.empirical_rigid_neglog_only(row, marginal, slope, &grid.nodes, &grid.weights)
}
}
}
fn empirical_rigid_neglog_jet(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
directions: &[[f64; 2]],
nodes: &[f64],
measure_weights: &[f64],
) -> Result<MultiDirJet, String> {
let n_dirs = directions.len();
let marginal_first = directions.iter().map(|dir| dir[0]).collect::<Vec<_>>();
let slope_first = directions.iter().map(|dir| dir[1]).collect::<Vec<_>>();
let marginal_eta_jet = MultiDirJet::linear(n_dirs, marginal_eta, &marginal_first);
let mu_jet = marginal_eta_jet.compose_unary([
marginal.mu,
marginal.mu1,
marginal.mu2,
marginal.mu3,
marginal.mu4,
]);
let slope_jet = MultiDirJet::linear(n_dirs, slope, &slope_first);
let intercept_root =
self.empirical_rigid_intercept_for_row(row, marginal, slope, nodes, measure_weights)?;
let mut intercept_jet = MultiDirJet::constant(n_dirs, intercept_root);
for _ in 0..6 {
let (f, f_a) = self.empirical_rigid_calibration_jets(
&intercept_jet,
&mu_jet,
&slope_jet,
nodes,
measure_weights,
);
let inv_f_a = f_a.compose_unary(unary_derivatives_reciprocal(f_a.coeff(0)));
intercept_jet = intercept_jet.add(&f.mul(&inv_f_a).scale(-1.0));
intercept_jet.coeffs[0] = intercept_root;
}
let observed_slope = slope_jet.scale(self.probit_frailty_scale());
let observed_eta = intercept_jet.add(&observed_slope.scale(self.z[row]));
let signed = observed_eta.scale(2.0 * self.y[row] - 1.0);
Ok(signed.compose_unary(unary_derivatives_neglog_phi(
signed.coeff(0),
self.weights[row],
)))
}
fn primary_component_jet(
n_dirs: usize,
base: f64,
directions: &[Array1<f64>],
idx: usize,
) -> Result<MultiDirJet, String> {
let first = directions
.iter()
.map(|dir| {
dir.get(idx).copied().ok_or_else(|| {
format!(
"bernoulli empirical flex direction length {} is too short for primary index {idx}",
dir.len()
)
})
})
.collect::<Result<Vec<_>, String>>()?;
Ok(MultiDirJet::linear(n_dirs, base, &first))
}
fn local_cubic_value_jet(cubic: exact_kernel::LocalSpanCubic, x: &MultiDirJet) -> MultiDirJet {
let n_dirs = x.coeffs.len().trailing_zeros() as usize;
let t = x.add(&MultiDirJet::constant(n_dirs, -cubic.left));
let t2 = t.mul(&t);
let t3 = t2.mul(&t);
MultiDirJet::constant(n_dirs, cubic.c0)
.add(&t.scale(cubic.c1))
.add(&t2.scale(cubic.c2))
.add(&t3.scale(cubic.c3))
}
fn local_cubic_first_derivative_jet(
cubic: exact_kernel::LocalSpanCubic,
x: &MultiDirJet,
) -> MultiDirJet {
let n_dirs = x.coeffs.len().trailing_zeros() as usize;
let t = x.add(&MultiDirJet::constant(n_dirs, -cubic.left));
let t2 = t.mul(&t);
MultiDirJet::constant(n_dirs, cubic.c1)
.add(&t.scale(2.0 * cubic.c2))
.add(&t2.scale(3.0 * cubic.c3))
}
fn empirical_flex_eta_and_eta_a_jet_at_z(
&self,
primary: &PrimarySlices,
a_jet: &MultiDirJet,
b_jet: &MultiDirJet,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
directions: &[Array1<f64>],
z: f64,
) -> Result<(MultiDirJet, MultiDirJet), String> {
let n_dirs = directions.len();
let mut inside = a_jet.add(&b_jet.scale(z));
if let Some(h_range) = primary.h.as_ref() {
let runtime = self.score_warp.as_ref().ok_or_else(|| {
"empirical flex score-warp primary range without runtime".to_string()
})?;
let beta_h = beta_h.ok_or_else(|| {
"empirical flex score-warp primary range without beta".to_string()
})?;
let mut h_jet = MultiDirJet::zero(n_dirs);
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z,
"empirical flex score-warp",
|local_idx, idx, basis_span| {
let basis_value = basis_span.evaluate(z);
let beta_jet =
Self::primary_component_jet(n_dirs, beta_h[local_idx], directions, idx)?;
h_jet = h_jet.add(&beta_jet.scale(basis_value));
Ok(())
},
)?;
inside = inside.add(&b_jet.mul(&h_jet));
}
let u_jet = a_jet.add(&b_jet.scale(z));
let mut w_jet = MultiDirJet::zero(n_dirs);
let mut w_prime_jet = MultiDirJet::zero(n_dirs);
if let Some(w_range) = primary.w.as_ref() {
let runtime = self.link_dev.as_ref().ok_or_else(|| {
"empirical flex link-deviation primary range without runtime".to_string()
})?;
let beta_w = beta_w.ok_or_else(|| {
"empirical flex link-deviation primary range without beta".to_string()
})?;
let u0 = u_jet.coeff(0);
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u0,
"empirical flex link-deviation",
|local_idx, idx, basis_span| {
let beta_jet =
Self::primary_component_jet(n_dirs, beta_w[local_idx], directions, idx)?;
let basis_value = Self::local_cubic_value_jet(basis_span, &u_jet);
let basis_derivative =
Self::local_cubic_first_derivative_jet(basis_span, &u_jet);
w_jet = w_jet.add(&beta_jet.mul(&basis_value));
w_prime_jet = w_prime_jet.add(&beta_jet.mul(&basis_derivative));
Ok(())
},
)?;
}
let scale = self.probit_frailty_scale();
let eta = inside.add(&w_jet).scale(scale);
let eta_a = MultiDirJet::constant(n_dirs, 1.0)
.add(&w_prime_jet)
.scale(scale);
Ok((eta, eta_a))
}
fn empirical_flex_calibration_jets(
&self,
primary: &PrimarySlices,
a_jet: &MultiDirJet,
mu_jet: &MultiDirJet,
b_jet: &MultiDirJet,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
directions: &[Array1<f64>],
grid: &EmpiricalZGrid,
) -> Result<(MultiDirJet, MultiDirJet), String> {
let n_dirs = directions.len();
let mut f = mu_jet.scale(-1.0);
let mut f_a = MultiDirJet::zero(n_dirs);
for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
let (eta, eta_a) = self.empirical_flex_eta_and_eta_a_jet_at_z(
primary, a_jet, b_jet, beta_h, beta_w, directions, node,
)?;
let cdf = eta.compose_unary(unary_derivatives_normal_cdf(eta.coeff(0)));
let pdf = eta.compose_unary(unary_derivatives_normal_pdf(eta.coeff(0)));
f = f.add(&cdf.scale(weight));
f_a = f_a.add(&pdf.mul(&eta_a).scale(weight));
}
Ok((f, f_a))
}
fn empirical_flex_neglog_jet(
&self,
row: usize,
primary: &PrimarySlices,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
directions: &[Array1<f64>],
grid: &EmpiricalZGrid,
) -> Result<MultiDirJet, String> {
let n_dirs = directions.len();
if n_dirs > 6 {
return Err(format!(
"bernoulli empirical flex jet supports at most 6 directions, got {n_dirs}"
));
}
for dir in directions {
if dir.len() != primary.total {
return Err(format!(
"bernoulli empirical flex direction length {} != primary dimension {}",
dir.len(),
primary.total
));
}
}
if !(row_ctx.intercept.is_finite() && row_ctx.m_a.is_finite() && row_ctx.m_a > 0.0) {
return Err("non-finite empirical flexible row context in jet contraction".to_string());
}
let marginal = self.marginal_link_map(q)?;
let q_jet = Self::primary_component_jet(n_dirs, q, directions, primary.q)?;
let mu_jet = q_jet.compose_unary([
marginal.mu,
marginal.mu1,
marginal.mu2,
marginal.mu3,
marginal.mu4,
]);
let b_jet = Self::primary_component_jet(n_dirs, b, directions, primary.logslope)?;
let intercept_root = row_ctx.intercept;
let mut a_jet = MultiDirJet::constant(n_dirs, intercept_root);
for _ in 0..6 {
let (f, f_a) = self.empirical_flex_calibration_jets(
primary, &a_jet, &mu_jet, &b_jet, beta_h, beta_w, directions, grid,
)?;
if !(f_a.coeff(0).is_finite() && f_a.coeff(0) > 0.0) {
return Err(format!(
"empirical flex calibration jet has invalid F_a={}",
f_a.coeff(0)
));
}
let inv_f_a = f_a.compose_unary(unary_derivatives_reciprocal(f_a.coeff(0)));
a_jet = a_jet.add(&f.mul(&inv_f_a).scale(-1.0));
a_jet.coeffs[0] = intercept_root;
}
let (eta_observed, _) = self.empirical_flex_eta_and_eta_a_jet_at_z(
primary,
&a_jet,
&b_jet,
beta_h,
beta_w,
directions,
self.z[row],
)?;
let signed = eta_observed.scale(2.0 * self.y[row] - 1.0);
Ok(signed.compose_unary(unary_derivatives_neglog_phi(
signed.coeff(0),
self.weights[row],
)))
}
fn unit_primary_direction(dim: usize, idx: usize) -> Array1<f64> {
let mut direction = Array1::<f64>::zeros(dim);
direction[idx] = 1.0;
direction
}
fn empirical_flex_row_third_contracted_recompute(
&self,
row: usize,
primary: &PrimarySlices,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir: &Array1<f64>,
grid: &EmpiricalZGrid,
) -> Result<Array2<f64>, String> {
let r = primary.total;
if dir.len() != r {
return Err(format!(
"bernoulli empirical flex third contraction direction length {} != primary dimension {r}",
dir.len()
));
}
if dir.iter().all(|value| *value == 0.0) {
return Ok(Array2::<f64>::zeros((r, r)));
}
let basis_dirs = (0..r)
.map(|idx| Self::unit_primary_direction(r, idx))
.collect::<Vec<_>>();
let dir_owned = dir.to_owned();
let mut out = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let directions = vec![
basis_dirs[u].clone(),
basis_dirs[v].clone(),
dir_owned.clone(),
];
let jet = self.empirical_flex_neglog_jet(
row,
primary,
q,
b,
beta_h,
beta_w,
row_ctx,
&directions,
grid,
)?;
let val = jet.coeff(1 | 2 | 4);
out[[u, v]] = val;
out[[v, u]] = val;
}
}
Ok(out)
}
fn empirical_flex_row_fourth_contracted_recompute(
&self,
row: usize,
primary: &PrimarySlices,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir_u: &Array1<f64>,
dir_v: &Array1<f64>,
grid: &EmpiricalZGrid,
) -> Result<Array2<f64>, String> {
let r = primary.total;
if dir_u.len() != r || dir_v.len() != r {
return Err(format!(
"bernoulli empirical flex fourth contraction direction lengths ({},{}) != primary dimension {r}",
dir_u.len(),
dir_v.len()
));
}
if dir_u.iter().all(|value| *value == 0.0) || dir_v.iter().all(|value| *value == 0.0) {
return Ok(Array2::<f64>::zeros((r, r)));
}
let basis_dirs = (0..r)
.map(|idx| Self::unit_primary_direction(r, idx))
.collect::<Vec<_>>();
let dir_u_owned = dir_u.to_owned();
let dir_v_owned = dir_v.to_owned();
let mut out = Array2::<f64>::zeros((r, r));
for p in 0..r {
for q_idx in p..r {
let directions = vec![
basis_dirs[p].clone(),
basis_dirs[q_idx].clone(),
dir_u_owned.clone(),
dir_v_owned.clone(),
];
let jet = self.empirical_flex_neglog_jet(
row,
primary,
q,
b,
beta_h,
beta_w,
row_ctx,
&directions,
grid,
)?;
let val = jet.coeff(1 | 2 | 4 | 8);
out[[p, q_idx]] = val;
out[[q_idx, p]] = val;
}
}
Ok(out)
}
fn rigid_row_kernel_eval(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
) -> Result<(f64, [f64; 2], [[f64; 2]; 2]), String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
let kernel = RigidProbitKernel::new(
marginal.q,
slope,
self.z[row],
self.y[row],
self.weights[row],
self.probit_frailty_scale(),
)?;
Ok((
-self.weights[row] * kernel.logcdf,
rigid_transformed_gradient(marginal, &kernel),
rigid_transformed_hessian(marginal, &kernel),
))
}
Some(grid) => {
let jet = self.empirical_rigid_neglog_jet(
row,
marginal_eta,
marginal,
slope,
&[[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]],
&grid.nodes,
&grid.weights,
)?;
Ok((
jet.coeff(0),
[jet.coeff(1), jet.coeff(2)],
[
[jet.coeff(1 | 4), jet.coeff(1 | 2)],
[jet.coeff(1 | 2), jet.coeff(2 | 8)],
],
))
}
}
}
fn rigid_row_third_contracted(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
dir_q: f64,
dir_g: f64,
) -> Result<[[f64; 2]; 2], String> {
let full = self.rigid_row_third_full(row, marginal_eta, marginal, slope)?;
Ok(contract_third_full(&full, dir_q, dir_g))
}
fn rigid_third_full_cached<'a>(
&self,
block_states: &[ParameterBlockState],
cache: &'a BernoulliMarginalSlopeExactEvalCache,
row: usize,
) -> Result<&'a [[[f64; 2]; 2]; 2], String> {
let stored = cache.rigid_third_full.get_or_init(|| {
(0..self.y.len())
.into_par_iter()
.map(|r| {
let marginal_eta = block_states[0].eta[r];
let marginal = self.marginal_link_map(marginal_eta)?;
let slope = block_states[1].eta[r];
self.rigid_row_third_full(r, marginal_eta, marginal, slope)
})
.collect::<Result<Vec<_>, String>>()
});
let table = stored.as_ref().map_err(|err| err.clone())?;
Ok(&table[row])
}
fn rigid_fourth_full_cached<'a>(
&self,
block_states: &[ParameterBlockState],
cache: &'a BernoulliMarginalSlopeExactEvalCache,
row: usize,
) -> Result<&'a [[[[f64; 2]; 2]; 2]; 2], String> {
let stored = cache.rigid_fourth_full.get_or_init(|| {
(0..self.y.len())
.into_par_iter()
.map(|r| {
let marginal_eta = block_states[0].eta[r];
let marginal = self.marginal_link_map(marginal_eta)?;
let slope = block_states[1].eta[r];
self.rigid_row_fourth_full(r, marginal_eta, marginal, slope)
})
.collect::<Result<Vec<_>, String>>()
});
let table = stored.as_ref().map_err(|err| err.clone())?;
Ok(&table[row])
}
fn rigid_row_third_full(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
) -> Result<[[[f64; 2]; 2]; 2], String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
let kernel = RigidProbitKernel::new(
marginal.q,
slope,
self.z[row],
self.y[row],
self.weights[row],
self.probit_frailty_scale(),
)?;
Ok(rigid_transformed_third_full(marginal, &kernel))
}
Some(grid) => {
let jet = self.empirical_rigid_neglog_jet(
row,
marginal_eta,
marginal,
slope,
&[
[1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
],
&grid.nodes,
&grid.weights,
)?;
let t_qqq = jet.coeff(1 | 4 | 16); let t_qqg = jet.coeff(1 | 4 | 2); let t_qgg = jet.coeff(1 | 2 | 8); let t_ggg = jet.coeff(2 | 8 | 32); Ok(third_full_from_symmetric_components(
t_qqq, t_qqg, t_qgg, t_ggg,
))
}
}
}
fn rigid_row_fourth_full(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
) -> Result<[[[[f64; 2]; 2]; 2]; 2], String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
let kernel = RigidProbitKernel::new(
marginal.q,
slope,
self.z[row],
self.y[row],
self.weights[row],
self.probit_frailty_scale(),
)?;
Ok(rigid_transformed_fourth_full(marginal, &kernel))
}
Some(grid) => {
let jet = self.empirical_rigid_neglog_jet(
row,
marginal_eta,
marginal,
slope,
&[
[1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
],
&grid.nodes,
&grid.weights,
)?;
let t_qqqq = jet.coeff(1 | 4 | 16 | 64); let t_qqqg = jet.coeff(1 | 4 | 16 | 2); let t_qqgg = jet.coeff(1 | 4 | 2 | 8); let t_qggg = jet.coeff(1 | 2 | 8 | 32); let t_gggg = jet.coeff(2 | 8 | 32 | 128); Ok(fourth_full_from_symmetric_components(
t_qqqq, t_qqqg, t_qqgg, t_qggg, t_gggg,
))
}
}
}
pub(crate) const AUTO_LINE_SEARCH_SUBSAMPLE_N: usize = 30_000;
pub(crate) fn log_likelihood_only_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<f64, String> {
self.validate_exact_monotonicity(block_states)?;
let flex_active = self.effective_flex_active(block_states)?;
let n = self.y.len();
let (effective_options, trial_subsample_installed) =
if options.early_exit_threshold.is_some()
&& options.outer_score_subsample.is_none()
&& n >= Self::AUTO_LINE_SEARCH_SUBSAMPLE_N
{
let stratum_secondary: Vec<u8> = self
.y
.iter()
.map(|v| if *v > 0.5 { 1u8 } else { 0u8 })
.collect();
let z_slice = self
.z
.as_slice()
.expect("BMS family z must be contiguous for line-search subsample");
let auto_opts =
crate::families::marginal_slope_shared::AutoOuterSubsampleOptions::default();
match crate::families::marginal_slope_shared::auto_outer_score_subsample(
z_slice,
Some(&stratum_secondary),
&auto_opts,
) {
Some(subsample) => {
let mut cloned = options.clone();
cloned.outer_score_subsample = Some(std::sync::Arc::new(subsample));
(std::borrow::Cow::Owned(cloned), true)
}
None => (std::borrow::Cow::Borrowed(options), false),
}
} else {
(std::borrow::Cow::Borrowed(options), false)
};
let options: &BlockwiseFitOptions = &effective_options;
let weighted_rows = outer_weighted_rows(options, n);
if !flex_active {
let b = &block_states[1].eta;
let row_ll = |i: usize| -> Result<f64, String> {
let marginal_eta = block_states[0].eta[i];
let marginal = self.marginal_link_map(marginal_eta)?;
let neglog = self.rigid_row_neglog_only(i, marginal, b[i])?;
Ok(-neglog)
};
if let Some(threshold) = options.early_exit_threshold {
let trial_result = bernoulli_margslope_line_search_ll_with_early_exit(
&weighted_rows,
threshold,
row_ll,
);
if trial_subsample_installed {
if let Ok(_subsample_ll) = trial_result {
let full_total: Result<f64, String> = (0..n)
.into_par_iter()
.try_fold(
|| 0.0,
|mut ll, i| -> Result<_, String> {
ll += row_ll(i)?;
Ok(ll)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
);
return full_total;
}
}
return trial_result;
}
let total: Result<f64, String> = weighted_rows
.into_par_iter()
.try_fold(
|| 0.0,
|mut ll, wr| -> Result<_, String> {
ll += wr.weight * row_ll(wr.index)?;
Ok(ll)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
);
return total;
}
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let row_ll = |row: usize| -> Result<f64, String> {
let intercept = self
.solve_row_intercept_base(
row,
block_states[0].eta[row],
block_states[1].eta[row],
beta_h,
beta_w,
None,
)?
.0;
let slope = block_states[1].eta[row];
let obs =
self.observed_denested_cell_partials(row, intercept, slope, beta_h, beta_w)?;
let s_i = eval_coeff4_at(&obs.coeff, self.z[row]);
let signed = (2.0 * self.y[row] - 1.0) * s_i;
let (log_cdf, _) = signed_probit_logcdf_and_mills_ratio(signed);
Ok(self.weights[row] * log_cdf)
};
if let Some(threshold) = options.early_exit_threshold {
let trial_result = bernoulli_margslope_line_search_ll_with_early_exit(
&weighted_rows,
threshold,
row_ll,
);
if trial_subsample_installed {
if let Ok(_subsample_ll) = trial_result {
let full_total: Result<f64, String> = (0..n)
.into_par_iter()
.try_fold(
|| 0.0,
|mut ll, i| -> Result<_, String> {
ll += row_ll(i)?;
Ok(ll)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
);
return full_total;
}
}
return trial_result;
}
let total: Result<f64, String> = weighted_rows
.into_par_iter()
.try_fold(
|| 0.0,
|mut ll, wr| -> Result<_, String> {
ll += wr.weight * row_ll(wr.index)?;
Ok(ll)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
);
total
}
fn line_search_log_likelihood_workspace(
&self,
block_states: &[ParameterBlockState],
line_search_options: &BlockwiseFitOptions,
workspace_options: &BlockwiseFitOptions,
) -> Result<Option<(f64, Arc<dyn ExactNewtonJointHessianWorkspace>)>, String> {
self.validate_exact_monotonicity(block_states)?;
self.validate_exact_block_state_shapes(block_states)?;
if !self.effective_flex_active(block_states)? {
return Ok(None);
}
let n = self.y.len();
let slices = block_slices(self);
let primary = primary_slices(&slices);
let p_total = slices.total;
let started = std::time::Instant::now();
let subsample_active = line_search_options.outer_score_subsample.is_some();
if log_exact_work(n) {
log::info!(
"[BMS line-search cache] build start n={} p={} flex=true subsample={}",
n,
p_total,
subsample_active
);
}
self.preseed_intercept_warm_starts(block_states)?;
exact_kernel::reset_tail_cell_moment_cache();
let stats = BernoulliInterceptSolveStats::default();
let cell_cache_before = self.cell_moment_cache_stats.snapshot();
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let mut row_contexts = Vec::with_capacity(n);
let mut log_likelihood = 0.0_f64;
let ll_row_weights: Vec<f64> = match line_search_options.outer_score_subsample.as_ref() {
Some(s) => {
let mut weights = vec![0.0; n];
for r in s.rows.iter() {
if r.index < n {
weights[r.index] = r.weight;
}
}
weights
}
None => vec![1.0; n],
};
for start in (0..n).step_by(BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS) {
let end = (start + BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS).min(n);
let ll_weights_slice = ll_row_weights.as_slice();
let mut chunk_rows = (start..end)
.into_par_iter()
.map(|row| -> Result<_, String> {
let row_ctx = self.build_row_exact_context_with_stats_and_cell_cache(
row,
block_states,
Some(&stats),
false,
)?;
let ll_weight = ll_weights_slice[row];
let row_ll = if ll_weight == 0.0 {
0.0
} else {
let slope = block_states[1].eta[row];
let obs = self.observed_denested_cell_partials(
row,
row_ctx.intercept,
slope,
beta_h,
beta_w,
)?;
let s_i = eval_coeff4_at(&obs.coeff, self.z[row]);
let signed = (2.0 * self.y[row] - 1.0) * s_i;
let (log_cdf, _) = signed_probit_logcdf_and_mills_ratio(signed);
ll_weight * self.weights[row] * log_cdf
};
Ok((row, row_ctx, row_ll))
})
.collect::<Result<Vec<_>, String>>()?;
chunk_rows.sort_unstable_by_key(|(row, _, _)| *row);
for (_, row_ctx, row_ll) in chunk_rows {
log_likelihood += row_ll;
row_contexts.push(row_ctx);
}
if let Some(threshold) = line_search_options.early_exit_threshold
&& -log_likelihood > threshold
{
return Err(format!(
"bernoulli marginal-slope line-search rejected early: partial_nll={} threshold={}",
-log_likelihood, threshold
));
}
}
if log_exact_work(n) {
let (cell_hits, cell_misses, cell_hit_rate) = self
.cell_moment_cache_stats
.hit_rate_delta(cell_cache_before);
log::info!(
"[BMS line-search cache] cell moments hits={} misses={} hit_rate={:.1}% entries={} resident_mib={:.1}/{:.1}",
cell_hits,
cell_misses,
100.0 * cell_hit_rate,
self.cell_moment_lru.len(),
self.cell_moment_lru.resident_bytes() as f64 / (1024.0 * 1024.0),
self.cell_moment_lru.max_bytes() as f64 / (1024.0 * 1024.0),
);
}
let workspace: Arc<dyn ExactNewtonJointHessianWorkspace> =
Arc::new(BernoulliMarginalSlopeLineSearchWorkspace {
family: self.clone(),
block_states: block_states.to_vec(),
cache: BernoulliMarginalSlopeExactEvalCache {
slices,
primary,
row_contexts,
row_cell_moments: None,
row_primary_hessians: None,
rigid_third_full: crate::resource::RayonSafeOnce::new(),
fingerprint: CacheFingerprint::compute(
block_states,
workspace_options
.outer_score_subsample
.as_ref()
.map(|s| s.mask.as_slice()),
Some(workspace_options),
),
rigid_fourth_full: crate::resource::RayonSafeOnce::new(),
},
options: workspace_options.clone(),
log_likelihood,
full_workspace: crate::resource::RayonSafeOnce::new(),
});
if log_exact_work(n) {
log::info!(
"[BMS line-search cache] build done n={} p={} elapsed={:.3}s",
n,
p_total,
started.elapsed().as_secs_f64()
);
}
Ok(Some((log_likelihood, workspace)))
}
fn is_sigma_aux_index(
&self,
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
) -> bool {
shared_is_sigma_aux_index(self.gaussian_frailty_sd, derivative_blocks, psi_index)
}
fn sigma_scale_jet(
&self,
n_dirs: usize,
first_masks: &[usize],
second_masks: &[usize],
) -> Result<MultiDirJet, String> {
probit_frailty_scale_multi_dir_jet(
self.gaussian_frailty_sd,
"bernoulli marginal-slope log-sigma auxiliary requested without GaussianShift sigma",
n_dirs,
first_masks,
second_masks,
)
}
fn row_neglog_directional_with_scale_jet(
&self,
row: usize,
block_states: &[ParameterBlockState],
dirs: &[Array1<f64>],
scale_jet: &MultiDirJet,
) -> Result<f64, String> {
let k = dirs.len();
if k > 4 {
return Err(format!(
"bernoulli marginal-slope sigma row directional expects 0..=4 directions, got {k}"
));
}
if scale_jet.coeffs.len() != (1usize << k) {
return Err(format!(
"bernoulli marginal-slope sigma scale jet dimension mismatch: coeffs={}, dirs={k}",
scale_jet.coeffs.len()
));
}
let first = |idx: usize| -> Vec<f64> { dirs.iter().map(|dir| dir[idx]).collect() };
let marginal = self.marginal_link_map(block_states[0].eta[row])?;
let eta_jet = MultiDirJet::linear(k, block_states[0].eta[row], &first(0));
let q_jet = eta_jet.compose_unary([
marginal.q,
marginal.q1,
marginal.q2,
marginal.q3,
marginal.q4,
]);
let g_jet = MultiDirJet::linear(k, block_states[1].eta[row], &first(1));
let observed_g_jet = g_jet.mul(scale_jet);
let one_plus_b2 = MultiDirJet::constant(k, 1.0).add(&observed_g_jet.mul(&observed_g_jet));
let c_jet = one_plus_b2.compose_unary(unary_derivatives_sqrt(one_plus_b2.coeff(0)));
let z_jet = MultiDirJet::constant(k, self.z[row]);
let eta_observed_jet = q_jet.mul(&c_jet).add(&observed_g_jet.mul(&z_jet));
let signed_jet = eta_observed_jet.scale(2.0 * self.y[row] - 1.0);
Ok(signed_jet
.compose_unary(unary_derivatives_neglog_phi(
signed_jet.coeff(0),
self.weights[row],
))
.coeff((1usize << k) - 1))
}
fn row_sigma_primary_terms(
&self,
row: usize,
block_states: &[ParameterBlockState],
second_sigma: bool,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let primary_dim = 2usize;
let zero = Array1::<f64>::zeros(primary_dim);
let objective = if second_sigma {
let scale = self.sigma_scale_jet(2, &[1, 2], &[3])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), zero.clone()],
&scale,
)?
} else {
let scale = self.sigma_scale_jet(1, &[1], &[])?;
self.row_neglog_directional_with_scale_jet(row, block_states, &[zero.clone()], &scale)?
};
let mut grad = Array1::<f64>::zeros(primary_dim);
for a in 0..primary_dim {
let mut da = Array1::<f64>::zeros(primary_dim);
da[a] = 1.0;
grad[a] = if second_sigma {
let scale = self.sigma_scale_jet(3, &[1, 2], &[3])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), zero.clone(), da],
&scale,
)?
} else {
let scale = self.sigma_scale_jet(2, &[1], &[])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), da],
&scale,
)?
};
}
let mut hess = Array2::<f64>::zeros((primary_dim, primary_dim));
for a in 0..primary_dim {
let mut da = Array1::<f64>::zeros(primary_dim);
da[a] = 1.0;
for b in a..primary_dim {
let mut db = Array1::<f64>::zeros(primary_dim);
db[b] = 1.0;
let value = if second_sigma {
let scale = self.sigma_scale_jet(4, &[1, 2], &[3])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), zero.clone(), da.clone(), db],
&scale,
)?
} else {
let scale = self.sigma_scale_jet(3, &[1], &[])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), da.clone(), db],
&scale,
)?
};
hess[[a, b]] = value;
hess[[b, a]] = value;
}
}
Ok((objective, grad, hess))
}
fn accumulate_rigid_sigma_pullback(
&self,
row: usize,
slices: &BlockSlices,
primary_grad: &Array1<f64>,
primary_hessian: &Array2<f64>,
score: &mut Array1<f64>,
hessian: &mut BernoulliBlockHessianAccumulator,
) -> Result<(), String> {
{
let mut marginal = score.slice_mut(s![slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, primary_grad[0], &mut marginal)?;
}
{
let mut logslope = score.slice_mut(s![slices.logslope.clone()]);
self.logslope_design
.axpy_row_into(row, primary_grad[1], &mut logslope)?;
}
hessian.add_pullback(self, row, slices, &primary_slices(slices), primary_hessian);
Ok(())
}
fn sigma_exact_joint_psi_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
self.sigma_exact_joint_psi_terms_with_options(
block_states,
specs,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn sigma_exact_joint_psi_terms_with_options(
&self,
block_states: &[ParameterBlockState],
_specs: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if self.effective_flex_active(block_states)? {
return Err(
"bernoulli marginal-slope log-sigma hyperderivatives are implemented for the rigid probit marginal-slope kernel; flexible score/link kernels require the analytic denested cell-tensor sigma path"
.to_string(),
);
}
if self.gaussian_frailty_sd.is_none() {
return Ok(None);
}
let slices = block_slices(self);
let n = self.y.len();
let row_iter = outer_row_indices(options, n).to_vec();
let row_weights =
crate::families::marginal_slope_shared::outer_row_weights_by_index(options, n);
let (objective_psi, score_psi, acc) = chunked_row_reduction(
row_iter.as_slice(),
|| {
(
0.0,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(&slices),
)
},
|row, acc| -> Result<(), String> {
let (mut obj, mut grad, mut hess) =
self.row_sigma_primary_terms(row, block_states, false)?;
let w = row_weights[row];
if w != 1.0 {
obj *= w;
grad.mapv_inplace(|v| v * w);
hess.mapv_inplace(|v| v * w);
}
acc.0 += obj;
self.accumulate_rigid_sigma_pullback(
row, &slices, &grad, &hess, &mut acc.1, &mut acc.2,
)?;
Ok(())
},
|total, chunk| {
total.0 += chunk.0;
total.1 += &chunk.1;
total.2.add(&chunk.2);
},
)?;
Ok(Some(ExactNewtonJointPsiTerms {
objective_psi,
score_psi,
hessian_psi: Array2::zeros((0, 0)),
hessian_psi_operator: Some(Arc::new(acc.into_operator(&slices))),
}))
}
fn sigma_exact_joint_psisecond_order_terms(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
self.sigma_exact_joint_psisecond_order_terms_with_options(
block_states,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn sigma_exact_joint_psisecond_order_terms_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
if self.effective_flex_active(block_states)? {
return Err(
"bernoulli marginal-slope second log-sigma hyperderivatives are implemented for the rigid probit marginal-slope kernel; flexible score/link kernels require the analytic denested cell-tensor sigma path"
.to_string(),
);
}
if self.gaussian_frailty_sd.is_none() {
return Ok(None);
}
let slices = block_slices(self);
let n = self.y.len();
let row_iter = outer_row_indices(options, n).to_vec();
let row_weights =
crate::families::marginal_slope_shared::outer_row_weights_by_index(options, n);
let (objective_psi_psi, score_psi_psi, acc) = chunked_row_reduction(
row_iter.as_slice(),
|| {
(
0.0,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(&slices),
)
},
|row, acc| -> Result<(), String> {
let (mut obj, mut grad, mut hess) =
self.row_sigma_primary_terms(row, block_states, true)?;
let w = row_weights[row];
if w != 1.0 {
obj *= w;
grad.mapv_inplace(|v| v * w);
hess.mapv_inplace(|v| v * w);
}
acc.0 += obj;
self.accumulate_rigid_sigma_pullback(
row, &slices, &grad, &hess, &mut acc.1, &mut acc.2,
)?;
Ok(())
},
|total, chunk| {
total.0 += chunk.0;
total.1 += &chunk.1;
total.2.add(&chunk.2);
},
)?;
Ok(Some(ExactNewtonJointPsiSecondOrderTerms {
objective_psi_psi,
score_psi_psi,
hessian_psi_psi: Array2::zeros((0, 0)),
hessian_psi_psi_operator: Some(Box::new(acc.into_operator(&slices))),
}))
}
fn sigma_exact_joint_psihessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.sigma_exact_joint_psihessian_directional_derivative_with_options(
block_states,
d_beta_flat,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn sigma_exact_joint_psihessian_directional_derivative_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
options: &BlockwiseFitOptions,
) -> Result<Option<Array2<f64>>, String> {
if self.effective_flex_active(block_states)? {
return Err(
"bernoulli marginal-slope log-sigma Hessian directional derivatives are implemented for the rigid probit marginal-slope kernel; flexible score/link kernels require the analytic denested cell-tensor sigma path"
.to_string(),
);
}
if self.gaussian_frailty_sd.is_none() {
return Ok(None);
}
let slices = block_slices(self);
if d_beta_flat.len() != slices.total {
return Err(format!(
"bernoulli marginal-slope d_beta length mismatch for sigma Hessian derivative: got {}, expected {}",
d_beta_flat.len(),
slices.total
));
}
let n = self.y.len();
let primary = primary_slices(&slices);
let row_iter = outer_row_indices(options, n).to_vec();
let row_weights =
crate::families::marginal_slope_shared::outer_row_weights_by_index(options, n);
let acc = chunked_row_reduction(
row_iter.as_slice(),
|| BernoulliBlockHessianAccumulator::new(&slices),
|row, acc| -> Result<(), String> {
let row_dir =
self.row_primary_direction_from_flat(row, &slices, &primary, d_beta_flat)?;
let zero = Array1::<f64>::zeros(primary.total);
let mut grad = Array1::<f64>::zeros(primary.total);
for a in 0..primary.total {
let mut da = Array1::<f64>::zeros(primary.total);
da[a] = 1.0;
let scale = self.sigma_scale_jet(3, &[1], &[])?;
grad[a] = self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), row_dir.clone(), da],
&scale,
)?;
}
let mut hess = Array2::<f64>::zeros((primary.total, primary.total));
for a in 0..primary.total {
let mut da = Array1::<f64>::zeros(primary.total);
da[a] = 1.0;
for b in a..primary.total {
let mut db = Array1::<f64>::zeros(primary.total);
db[b] = 1.0;
let scale = self.sigma_scale_jet(4, &[1], &[])?;
let value = self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), row_dir.clone(), da.clone(), db],
&scale,
)?;
hess[[a, b]] = value;
hess[[b, a]] = value;
}
}
let w = row_weights[row];
if w != 1.0 {
hess.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, &slices, &primary, &hess);
Ok(())
},
|total, chunk| {
total.add(&chunk);
},
)?;
Ok(Some(acc.into_operator(&slices).to_dense()))
}
#[inline]
fn marginal_link_map(&self, eta: f64) -> Result<BernoulliMarginalLinkMap, String> {
bernoulli_marginal_link_map(&self.base_link, eta)
}
#[inline]
fn exact_newton_score_component_from_objective_gradient(
objective_gradient_component: f64,
) -> f64 {
-objective_gradient_component
}
#[inline]
fn exact_newton_score_from_objective_gradient(objective_gradient: Array1<f64>) -> Array1<f64> {
-objective_gradient
}
#[inline]
fn exact_newton_observed_information_from_objective_hessian(
objective_hessian: Array2<f64>,
) -> Array2<f64> {
objective_hessian
}
#[inline]
fn score_block_index(&self) -> Option<usize> {
self.score_warp.as_ref().map(|_| 2)
}
#[inline]
fn link_block_index(&self) -> Option<usize> {
self.link_dev
.as_ref()
.map(|_| 2 + usize::from(self.score_warp.is_some()))
}
fn optional_exact_block_state<'a>(
&self,
block_states: &'a [ParameterBlockState],
block_idx: Option<usize>,
label: &str,
) -> Result<Option<&'a ParameterBlockState>, String> {
match block_idx {
Some(idx) => block_states
.get(idx)
.map(Some)
.ok_or_else(|| format!("missing {label} block state")),
None => Ok(None),
}
}
fn score_block_state<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a ParameterBlockState>, String> {
self.optional_exact_block_state(block_states, self.score_block_index(), "score-warp")
}
fn link_block_state<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a ParameterBlockState>, String> {
self.optional_exact_block_state(block_states, self.link_block_index(), "link deviation")
}
fn score_beta<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a Array1<f64>>, String> {
Ok(self
.score_block_state(block_states)?
.map(|state| &state.beta))
}
fn link_beta<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a Array1<f64>>, String> {
Ok(self
.link_block_state(block_states)?
.map(|state| &state.beta))
}
fn validate_exact_block_state_shapes(
&self,
block_states: &[ParameterBlockState],
) -> Result<(), String> {
let expected_blocks =
2usize + usize::from(self.score_warp.is_some()) + usize::from(self.link_dev.is_some());
if block_states.len() != expected_blocks {
return Err(format!(
"bernoulli marginal-slope block count mismatch: got {}, expected {}",
block_states.len(),
expected_blocks
));
}
let n_rows = self.y.len();
let marginal = &block_states[0];
let marginal_ncols = self.marginal_design.ncols();
if marginal_ncols > 0 && marginal.beta.len() != marginal_ncols {
return Err(format!(
"bernoulli marginal-slope marginal beta length mismatch: got {}, expected {}",
marginal.beta.len(),
marginal_ncols
));
}
if marginal.eta.len() != n_rows {
return Err(format!(
"bernoulli marginal-slope marginal eta length mismatch: got {}, expected {}",
marginal.eta.len(),
n_rows
));
}
let logslope = &block_states[1];
let logslope_ncols = self.logslope_design.ncols();
if logslope_ncols > 0 && logslope.beta.len() != logslope_ncols {
return Err(format!(
"bernoulli marginal-slope logslope beta length mismatch: got {}, expected {}",
logslope.beta.len(),
logslope_ncols
));
}
if logslope.eta.len() != n_rows {
return Err(format!(
"bernoulli marginal-slope logslope eta length mismatch: got {}, expected {}",
logslope.eta.len(),
n_rows
));
}
if let Some(runtime) = &self.score_warp {
let score = self
.score_block_state(block_states)?
.expect("score-warp block should exist when runtime is present");
if score.beta.len() != runtime.basis_dim() {
return Err(format!(
"bernoulli marginal-slope score-warp beta length mismatch: got {}, expected {}",
score.beta.len(),
runtime.basis_dim()
));
}
if score.eta.len() != n_rows {
return Err(format!(
"bernoulli marginal-slope score-warp eta length mismatch: got {}, expected {}",
score.eta.len(),
n_rows
));
}
}
if let Some(runtime) = &self.link_dev {
let link = self
.link_block_state(block_states)?
.expect("link-deviation block should exist when runtime is present");
if link.beta.len() != runtime.basis_dim() {
return Err(format!(
"bernoulli marginal-slope link-deviation beta length mismatch: got {}, expected {}",
link.beta.len(),
runtime.basis_dim()
));
}
if link.eta.len() != n_rows {
return Err(format!(
"bernoulli marginal-slope link-deviation eta length mismatch: got {}, expected {}",
link.eta.len(),
n_rows
));
}
}
Ok(())
}
fn denested_partition_cells(
&self,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<Vec<exact_kernel::DenestedPartitionCell>, String> {
shared_denested_partition_cells(
a,
b,
self.score_warp.as_ref(),
beta_h,
self.link_dev.as_ref(),
beta_w,
self.probit_frailty_scale(),
)
}
#[inline]
fn evaluate_cell_moments_lru(
&self,
cell: exact_kernel::DenestedCubicCell,
max_degree: usize,
) -> Result<exact_kernel::CellMomentState, String> {
self.cell_moment_cache_stats.record_miss();
exact_kernel::evaluate_cell_moments_uncached(cell, max_degree)
}
#[inline]
fn evaluate_cell_derivative_moments_lru(
&self,
cell: exact_kernel::DenestedCubicCell,
max_degree: usize,
) -> Result<exact_kernel::CellDerivativeMomentState, String> {
self.cell_moment_cache_stats.record_miss();
exact_kernel::evaluate_cell_derivative_moments_uncached(cell, max_degree)
}
#[inline]
fn for_each_deviation_basis_cubic_at<F>(
runtime: &DeviationRuntime,
primary_range: &std::ops::Range<usize>,
value: f64,
label: &str,
mut visit: F,
) -> Result<(), String>
where
F: FnMut(usize, usize, exact_kernel::LocalSpanCubic) -> Result<(), String>,
{
if primary_range.len() != runtime.basis_dim() {
return Err(format!(
"{label} primary range length {} does not match deviation basis dimension {}",
primary_range.len(),
runtime.basis_dim()
));
}
runtime.for_each_basis_cubic_at(value, |local_idx, basis_span| {
visit(local_idx, primary_range.start + local_idx, basis_span)
})
}
fn evaluate_denested_calibration_newton(
&self,
a: f64,
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<(f64, f64, f64), String> {
let marginal = self.marginal_link_map(marginal_eta)?;
let cells = self.denested_partition_cells(a, slope, beta_h, beta_w)?;
let scale = self.probit_frailty_scale();
let mut f = -marginal.mu;
let mut f_a = 0.0;
for partition_cell in cells {
let cell = partition_cell.cell;
let state = self.evaluate_cell_moments_lru(cell, 4)?;
f += state.value;
let (dc_da_raw, _) = exact_kernel::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
slope,
);
let dc_da = scale_coeff4(dc_da_raw, scale);
f_a += exact_kernel::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
}
Ok((f, f_a, 0.0))
}
fn evaluate_empirical_grid_calibration_newton(
&self,
a: f64,
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
grid: &EmpiricalZGrid,
) -> Result<(f64, f64, f64), String> {
let marginal = self.marginal_link_map(marginal_eta)?;
let mut f = -marginal.mu;
let mut f_a = 0.0;
let mut f_aa = 0.0;
for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
let obs = self.observed_denested_cell_partials_at_z(node, a, slope, beta_h, beta_w)?;
let eta = eval_coeff4_at(&obs.coeff, node);
let eta_a = eval_coeff4_at(&obs.dc_da, node);
let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
let pdf = normal_pdf(eta);
f += weight * normal_cdf(eta);
f_a += weight * pdf * eta_a;
f_aa += weight * pdf * (eta_aa - eta * eta_a * eta_a);
}
if !(f.is_finite() && f_a.is_finite() && f_a > 0.0 && f_aa.is_finite()) {
return Err(format!(
"empirical latent denested calibration produced invalid root state: f={f}, f_a={f_a}, f_aa={f_aa}"
));
}
Ok((f, f_a, f_aa))
}
fn evaluate_calibration_newton(
&self,
row: usize,
a: f64,
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<(f64, f64, f64), String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
self.evaluate_denested_calibration_newton(a, marginal_eta, slope, beta_h, beta_w)
}
Some(grid) => self.evaluate_empirical_grid_calibration_newton(
a,
marginal_eta,
slope,
beta_h,
beta_w,
&grid,
),
}
}
fn flex_active(&self) -> bool {
self.score_warp.is_some() || self.link_dev.is_some()
}
fn effective_flex_active(&self, block_states: &[ParameterBlockState]) -> Result<bool, String> {
if self.score_warp.is_some() && self.score_beta(block_states)?.is_none() {
return Err("missing bernoulli score-warp block state".to_string());
}
if self.link_dev.is_some() && self.link_beta(block_states)?.is_none() {
return Err("missing bernoulli link-deviation block state".to_string());
}
Ok(self.flex_active())
}
fn validate_exact_monotonicity(
&self,
block_states: &[ParameterBlockState],
) -> Result<(), String> {
self.validate_exact_block_state_shapes(block_states)?;
if let (Some(runtime), Some(score)) =
(&self.score_warp, self.score_block_state(block_states)?)
{
runtime.monotonicity_feasible(
&score.beta,
"bernoulli marginal-slope score-warp deviation",
)?;
}
if let (Some(runtime), Some(beta_w)) = (&self.link_dev, self.link_beta(block_states)?) {
runtime.monotonicity_feasible(beta_w, "bernoulli marginal-slope link deviation")?;
}
Ok(())
}
fn link_terms_value_d1(
&self,
eta0: &Array1<f64>,
beta_w: Option<&Array1<f64>>,
) -> Result<(Array1<f64>, Array1<f64>), String> {
if let (Some(runtime), Some(beta)) = (&self.link_dev, beta_w) {
let basis = runtime.design(eta0)?;
let d1 = runtime.first_derivative_design(eta0)?;
Ok((eta0 + &basis.dot(beta), d1.dot(beta) + 1.0))
} else {
Ok((eta0.clone(), Array1::ones(eta0.len())))
}
}
fn row_intercept_closed_form_seed(
&self,
marginal: BernoulliMarginalLinkMap,
slope: f64,
beta_w: Option<&Array1<f64>>,
) -> Result<f64, String> {
let probit_scale = self.probit_frailty_scale();
let a_rigid_pre_scale =
rigid_intercept_from_marginal(marginal.q, slope, probit_scale) / probit_scale;
if beta_w.is_some() {
let v = Array1::from_vec(vec![a_rigid_pre_scale]);
let (l_val, l_d1) = self.link_terms_value_d1(&v, beta_w)?;
let ell1 = l_d1[0];
if ell1 > 1e-8 {
let ell0 = l_val[0] - ell1 * a_rigid_pre_scale;
let observed_logslope = probit_scale * ell1 * slope;
return Ok(
(marginal.q * (1.0 + observed_logslope * observed_logslope).sqrt()
/ probit_scale
- ell0)
/ ell1,
);
}
}
Ok(a_rigid_pre_scale)
}
fn preseed_intercept_warm_starts(
&self,
block_states: &[ParameterBlockState],
) -> Result<(), String> {
if !self.effective_flex_active(block_states)? {
return Ok(());
}
let Some(cache) = self.intercept_warm_starts.as_ref() else {
return Ok(());
};
let beta_w = self.link_beta(block_states)?;
let n = self.y.len();
if cache.len() != n {
return Ok(());
}
let marginal_eta = &block_states[0].eta;
let slope_eta = &block_states[1].eta;
let seeds: Result<Vec<f64>, String> = (0..n)
.into_par_iter()
.map(|row| {
let marginal = self.marginal_link_map(marginal_eta[row])?;
self.row_intercept_closed_form_seed(marginal, slope_eta[row], beta_w)
})
.collect();
let seeds = seeds?;
let nan_bits = f64::NAN.to_bits();
let mut preseeded = 0usize;
let mut kept_warm = 0usize;
for (row, seed) in seeds.iter().enumerate() {
let slot = &cache[row];
if !seed.is_finite() {
continue;
}
match slot.compare_exchange(
nan_bits,
seed.to_bits(),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => preseeded += 1,
Err(prev) => {
if f64::from_bits(prev).is_finite() {
kept_warm += 1;
}
}
}
}
log::info!(
"[bernoulli intercept warm-start] preseeded={} (cold), kept_warm={} (carried over from previous PIRLS)",
preseeded,
kept_warm,
);
Ok(())
}
#[inline]
fn row_intercept_newton_is_converged(a: f64, f: f64, f_a: f64, abs_tol: f64) -> bool {
if !a.is_finite() || !f.is_finite() || !f_a.is_finite() || f_a == 0.0 {
return false;
}
let correction = (f / f_a).abs();
f.abs() <= abs_tol || correction <= 1e-10 * (1.0 + a.abs())
}
}
#[derive(Default)]
struct BernoulliInterceptSolveStats {
cached_short_circuit: AtomicUsize,
closed_form_short_circuit: AtomicUsize,
full_solver: AtomicUsize,
seed_residual_le_1e12: AtomicUsize,
seed_residual_le_1e10: AtomicUsize,
seed_residual_le_1e8: AtomicUsize,
seed_residual_le_abs_tol: AtomicUsize,
seed_residual_gt_abs_tol: AtomicUsize,
max_full_solver_iters: AtomicUsize,
}
impl BernoulliInterceptSolveStats {
fn record_seed_residual(&self, residual: f64, abs_tol: f64) {
let abs = residual.abs();
if abs <= 1e-12 {
self.seed_residual_le_1e12.fetch_add(1, Ordering::Relaxed);
} else if abs <= 1e-10 {
self.seed_residual_le_1e10.fetch_add(1, Ordering::Relaxed);
} else if abs <= 1e-8 {
self.seed_residual_le_1e8.fetch_add(1, Ordering::Relaxed);
} else if abs <= abs_tol {
self.seed_residual_le_abs_tol
.fetch_add(1, Ordering::Relaxed);
} else {
self.seed_residual_gt_abs_tol
.fetch_add(1, Ordering::Relaxed);
}
}
fn record_full_solver(&self, refine_iters: usize) {
self.full_solver.fetch_add(1, Ordering::Relaxed);
let mut current = self.max_full_solver_iters.load(Ordering::Relaxed);
while refine_iters > current {
match self.max_full_solver_iters.compare_exchange_weak(
current,
refine_iters,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(next) => current = next,
}
}
}
}
impl BernoulliMarginalSlopeFamily {
fn intercept_primary_point(
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Vec<f64> {
let mut point = Vec::with_capacity(
2 + beta_h.map(|beta| beta.len()).unwrap_or(0)
+ beta_w.map(|beta| beta.len()).unwrap_or(0),
);
point.push(q);
point.push(b);
if let Some(beta) = beta_h {
point.extend(beta.iter().copied());
}
if let Some(beta) = beta_w {
point.extend(beta.iter().copied());
}
point
}
#[inline]
fn cache_row_intercept(&self, row: usize, a: f64) {
if let Some(cache) = self.intercept_warm_starts.as_ref()
&& let Some(slot) = cache.get(row)
{
slot.store(a.to_bits(), Ordering::Relaxed);
}
}
fn cache_row_intercept_predictor(
&self,
row: usize,
a: f64,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
a_u: &Array1<f64>,
) {
let Some(cache) = self.intercept_warm_starts.as_ref() else {
return;
};
let primary_point = Self::intercept_primary_point(q, b, beta_h, beta_w);
if primary_point.len() != a_u.len() {
return;
}
cache.store_predictor(row, a, primary_point, a_u.iter().copied().collect());
}
#[inline]
fn beta_linf(beta: Option<&Array1<f64>>) -> f64 {
beta.map(|b| b.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs())))
.unwrap_or(0.0)
}
fn near_zero_deviation_residual_bound(
&self,
slope: f64,
beta_h_linf: f64,
beta_w_linf: f64,
) -> f64 {
let score_basis_sup = self
.score_warp
.as_ref()
.map(|runtime| runtime.value_basis_l1_sup_norm())
.unwrap_or(0.0);
let link_basis_sup = self
.link_dev
.as_ref()
.map(|runtime| runtime.value_basis_l1_sup_norm())
.unwrap_or(0.0);
normal_pdf(0.0)
* self.probit_frailty_scale()
* (slope.abs() * score_basis_sup * beta_h_linf + link_basis_sup * beta_w_linf)
}
fn solve_row_intercept_base(
&self,
row: usize,
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
stats: Option<&BernoulliInterceptSolveStats>,
) -> Result<(f64, f64, bool), String> {
let marginal = self.marginal_link_map(marginal_eta)?;
let probit_scale = self.probit_frailty_scale();
let target = marginal.mu;
let abs_tol = 1e-8_f64.max(1e-4 * target.abs());
let rigid_a = rigid_prescale_intercept_from_marginal(marginal.q, slope, probit_scale);
let rigid_abs_deriv =
rigid_prescale_intercept_derivative_abs(marginal.q, slope, probit_scale);
let beta_h_linf = Self::beta_linf(beta_h);
let beta_w_linf = Self::beta_linf(beta_w);
let exact_zero_deviation = beta_h_linf == 0.0 && beta_w_linf == 0.0;
let standard_normal_law = matches!(self.latent_measure, LatentMeasureKind::StandardNormal);
if exact_zero_deviation && standard_normal_law {
self.cache_row_intercept(row, rigid_a);
return Ok((rigid_a, rigid_abs_deriv, true));
}
let near_zero_bound =
self.near_zero_deviation_residual_bound(slope, beta_h_linf, beta_w_linf);
let beta_linf_max = beta_h_linf.max(beta_w_linf);
if standard_normal_law && near_zero_bound <= abs_tol && beta_linf_max <= f64::EPSILON.sqrt()
{
let (f_rigid, _, _) = self.evaluate_calibration_newton(
row,
rigid_a,
marginal_eta,
slope,
beta_h,
beta_w,
)?;
if f_rigid.abs() <= abs_tol {
self.cache_row_intercept(row, rigid_a);
return Ok((rigid_a, rigid_abs_deriv, true));
}
}
let eval = |a: f64| -> Result<(f64, f64, f64), String> {
self.evaluate_calibration_newton(row, a, marginal_eta, slope, beta_h, beta_w)
};
let a_closed_form = self.row_intercept_closed_form_seed(marginal, slope, beta_w)?;
let current_primary_point =
Self::intercept_primary_point(marginal_eta, slope, beta_h, beta_w);
let predictor_a = self
.intercept_warm_starts
.as_ref()
.and_then(|cache| cache.predictor_seed(row, ¤t_primary_point));
let cached_a = self.intercept_warm_starts.as_ref().and_then(|cache| {
let value = f64::from_bits(cache.get(row)?.load(Ordering::Relaxed));
value.is_finite().then_some(value)
});
let a_init = predictor_a.or(cached_a).unwrap_or(a_closed_form);
let probe_result = (|| -> Result<(Option<(f64, f64, f64)>, f64), String> {
let mut a = a_init;
let mut seed_residual = None;
for _ in 0..6 {
let (f, f_a, _) = eval(a)?;
if seed_residual.is_none() {
seed_residual = Some(f);
}
if Self::row_intercept_newton_is_converged(a, f, f_a, abs_tol) {
return Ok((Some((a, f_a.abs(), f)), seed_residual.unwrap_or(f)));
}
if !(f_a.is_finite() && f_a != 0.0) {
break;
}
let next_a = a - f / f_a;
if !next_a.is_finite() {
break;
}
a = next_a;
}
Ok((None, seed_residual.unwrap_or(f64::INFINITY)))
})();
if let Ok((accepted, seed_residual)) = &probe_result {
if let Some(stats) = stats {
stats.record_seed_residual(*seed_residual, abs_tol);
}
if let Some((a, abs_deriv, _)) = accepted {
if let Some(stats) = stats {
if predictor_a.is_some() || cached_a.is_some() {
stats.cached_short_circuit.fetch_add(1, Ordering::Relaxed);
} else {
stats
.closed_form_short_circuit
.fetch_add(1, Ordering::Relaxed);
}
}
self.cache_row_intercept(row, *a);
return Ok((*a, *abs_deriv, false));
}
}
let mut solve_result = super::monotone_root::solve_monotone_root_detailed(
&eval,
a_init,
"bernoulli intercept",
abs_tol,
64,
48,
);
if (predictor_a.is_some() || cached_a.is_some()) && solve_result.is_err() {
solve_result = super::monotone_root::solve_monotone_root_detailed(
&eval,
a_closed_form,
"bernoulli intercept",
abs_tol,
64,
48,
);
}
let solve_solution = solve_result?;
if let Some(stats) = stats {
stats.record_full_solver(solve_solution.refine_iters);
}
let (a, abs_deriv, f_best) = (
solve_solution.root,
solve_solution.abs_deriv,
solve_solution.residual,
);
if f_best.abs() > abs_tol {
return Err(format!(
"bernoulli marginal-slope intercept solve failed: \
residual={f_best:.3e} at a={a:.6}, target mu={target:.6}"
));
}
self.cache_row_intercept(row, a);
Ok((a, abs_deriv, false))
}
#[cfg(test)]
fn build_row_exact_context(
&self,
row: usize,
block_states: &[ParameterBlockState],
) -> Result<BernoulliMarginalSlopeRowExactContext, String> {
self.build_row_exact_context_with_stats(row, block_states, None)
}
fn build_row_exact_context_with_stats(
&self,
row: usize,
block_states: &[ParameterBlockState],
stats: Option<&BernoulliInterceptSolveStats>,
) -> Result<BernoulliMarginalSlopeRowExactContext, String> {
self.build_row_exact_context_with_stats_and_cell_cache(row, block_states, stats, true)
}
fn build_row_exact_context_with_stats_and_cell_cache(
&self,
row: usize,
block_states: &[ParameterBlockState],
stats: Option<&BernoulliInterceptSolveStats>,
cache_degree9_cells: bool,
) -> Result<BernoulliMarginalSlopeRowExactContext, String> {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let slope = block_states[1].eta[row];
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let (intercept, m_a, intercept_fast_path) = if self.effective_flex_active(block_states)? {
self.solve_row_intercept_base(row, marginal_eta, slope, beta_h, beta_w, stats)?
} else {
let intercept = match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
rigid_intercept_from_marginal(marginal.q, slope, self.probit_frailty_scale())
}
Some(grid) => self.empirical_rigid_intercept_for_row(
row,
marginal,
slope,
&grid.nodes,
&grid.weights,
)?,
};
(intercept, f64::NAN, false)
};
let degree9_cells = if cache_degree9_cells
&& self.effective_flex_active(block_states)?
&& matches!(self.latent_measure, LatentMeasureKind::StandardNormal)
{
let cells = self.denested_partition_cells(intercept, slope, beta_h, beta_w)?;
let mut dedup: HashMap<
exact_kernel::CellFingerprint,
exact_kernel::CellDerivativeMomentState,
> = HashMap::new();
let mut out: Vec<CachedDenestedCellMoments> = Vec::with_capacity(cells.len());
for partition_cell in cells.into_iter() {
let key = exact_kernel::CellFingerprint::new(partition_cell.cell);
let state: exact_kernel::CellDerivativeMomentState =
if let Some(existing) = dedup.get(&key) {
existing.clone()
} else {
let computed =
self.evaluate_cell_derivative_moments_lru(partition_cell.cell, 9)?;
dedup.insert(key, computed.clone());
computed
};
out.push(CachedDenestedCellMoments {
partition_cell,
state,
});
}
Some(out)
} else {
None
};
Ok(BernoulliMarginalSlopeRowExactContext {
intercept,
m_a,
intercept_fast_path,
degree9_cells,
})
}
#[inline]
fn row_ctx(
cache: &BernoulliMarginalSlopeExactEvalCache,
row: usize,
) -> &BernoulliMarginalSlopeRowExactContext {
&cache.row_contexts[row]
}
fn build_exact_eval_cache_with_order(
&self,
block_states: &[ParameterBlockState],
) -> Result<BernoulliMarginalSlopeExactEvalCache, String> {
self.build_exact_eval_cache_with_options(block_states, None)
}
fn build_exact_eval_cache_with_options(
&self,
block_states: &[ParameterBlockState],
options: Option<&BlockwiseFitOptions>,
) -> Result<BernoulliMarginalSlopeExactEvalCache, String> {
self.validate_exact_block_state_shapes(block_states)?;
let slices = block_slices(self);
let primary = primary_slices(&slices);
let n = self.y.len();
let flex_active = self.effective_flex_active(block_states)?;
let started = std::time::Instant::now();
if log_exact_work(n) {
log::info!(
"[BMS exact-cache] build start n={} p={} flex={}",
n,
slices.total,
flex_active
);
}
self.preseed_intercept_warm_starts(block_states)?;
if flex_active {
exact_kernel::reset_tail_cell_moment_cache();
}
let stats = BernoulliInterceptSolveStats::default();
let cell_cache_before = self.cell_moment_cache_stats.snapshot();
let row_contexts: Result<Vec<_>, String> = (0..n)
.into_par_iter()
.map(|row| self.build_row_exact_context_with_stats(row, block_states, Some(&stats)))
.collect();
let row_contexts = row_contexts?;
let fast_path_rows = row_contexts
.iter()
.filter(|ctx| ctx.intercept_fast_path)
.count();
log::debug!(
"[BMS exact-cache] row-intercept zero-deviation fast path rows={}/{}",
fast_path_rows,
n
);
if flex_active {
log::debug!(
"bernoulli marginal-slope intercept seed short-circuit: cached={}, closed_form={}, full_solver={}, max_full_solver_iters={}, seed_residual_bins={{<=1e-12:{}, <=1e-10:{}, <=1e-8:{}, <=abs_tol:{}, >abs_tol:{}}}",
stats.cached_short_circuit.load(Ordering::Relaxed),
stats.closed_form_short_circuit.load(Ordering::Relaxed),
stats.full_solver.load(Ordering::Relaxed),
stats.max_full_solver_iters.load(Ordering::Relaxed),
stats.seed_residual_le_1e12.load(Ordering::Relaxed),
stats.seed_residual_le_1e10.load(Ordering::Relaxed),
stats.seed_residual_le_1e8.load(Ordering::Relaxed),
stats.seed_residual_le_abs_tol.load(Ordering::Relaxed),
stats.seed_residual_gt_abs_tol.load(Ordering::Relaxed),
);
}
if flex_active {
let (cell_hits, cell_misses, cell_hit_rate) = self
.cell_moment_cache_stats
.hit_rate_delta(cell_cache_before);
log::info!(
"[BMS cell-moment LRU] cycle hits={} misses={} hit_rate={:.1}% entries={} resident_mib={:.1}/{:.1}",
cell_hits,
cell_misses,
100.0 * cell_hit_rate,
self.cell_moment_lru.len(),
self.cell_moment_lru.resident_bytes() as f64 / (1024.0 * 1024.0),
self.cell_moment_lru.max_bytes() as f64 / (1024.0 * 1024.0),
);
let tail_stats = exact_kernel::tail_cell_moment_cache_stats();
log::info!(
"[BMS exact-cache] affine tail-cell memo: hits={} misses={} entries={} hit_rate={:.3}%",
tail_stats.hits,
tail_stats.misses,
tail_stats.entries,
100.0 * tail_stats.hit_rate(),
);
}
if log_exact_work(n) {
log::info!(
"[BMS exact-cache] build done n={} p={} flex={} elapsed={:.3}s",
n,
slices.total,
flex_active,
started.elapsed().as_secs_f64()
);
}
let row_cell_mask = options
.and_then(|opts| opts.outer_score_subsample.as_ref())
.map(|subsample| subsample.mask.as_slice());
let row_cell_moments =
self.build_row_cell_moments_bundle(block_states, &row_contexts, 21, row_cell_mask)?;
Ok(BernoulliMarginalSlopeExactEvalCache {
slices,
primary,
row_contexts,
row_cell_moments,
row_primary_hessians: None,
rigid_third_full: crate::resource::RayonSafeOnce::new(),
fingerprint: CacheFingerprint::compute(block_states, row_cell_mask, options),
rigid_fourth_full: crate::resource::RayonSafeOnce::new(),
})
}
fn build_row_cell_moments_bundle(
&self,
block_states: &[ParameterBlockState],
row_contexts: &[BernoulliMarginalSlopeRowExactContext],
max_degree: usize,
row_mask: Option<&[usize]>,
) -> Result<Option<RowCellMomentsBundle>, String> {
if !self.effective_flex_active(block_states)? {
return Ok(None);
}
if !matches!(self.latent_measure, LatentMeasureKind::StandardNormal) {
return Ok(None);
}
let n = self.y.len();
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let selected_rows: Vec<usize> = match row_mask {
Some(mask) => mask.iter().copied().filter(|&row| row < n).collect(),
None => (0..n).collect(),
};
if selected_rows.is_empty() {
return Ok(None);
}
let partitions: Vec<(usize, Vec<exact_kernel::DenestedPartitionCell>)> = selected_rows
.into_par_iter()
.map(|row| {
self.denested_partition_cells(
row_contexts[row].intercept,
block_states[1].eta[row],
beta_h,
beta_w,
)
.map(|cells| (row, cells))
})
.collect::<Result<Vec<_>, String>>()?;
let selected_n = partitions.len();
let n_cells = partitions
.iter()
.map(|(_, cells)| cells.len())
.sum::<usize>();
let estimated_bytes =
RowCellMomentsBundle::estimated_resident_bytes(n, n_cells, max_degree);
let limit_bytes = self.policy.max_operator_cache_bytes;
if estimated_bytes > limit_bytes {
log::warn!(
"[BMS row-cell-moments] skip precompute n={} selected_rows={} cells={} degree={} estimated_bytes={} limit_bytes={}",
n,
selected_n,
n_cells,
max_degree,
estimated_bytes,
limit_bytes
);
return Ok(None);
}
let started = std::time::Instant::now();
let computed_rows = partitions
.into_par_iter()
.map(|(row, cells)| {
let moments = cells
.into_iter()
.map(|partition_cell| {
self.evaluate_cell_derivative_moments_lru(partition_cell.cell, max_degree)
.map(|state| CachedDenestedCellMoments {
partition_cell,
state,
})
})
.collect::<Result<Vec<_>, String>>()?;
Ok((row, moments))
})
.collect::<Result<Vec<_>, String>>()?;
let mut rows = vec![None; n];
for (row, moments) in computed_rows {
rows[row] = Some(moments);
}
if log_exact_work(n) {
log::info!(
"[BMS row-cell-moments] precomputed n={} selected_rows={} cells={} degree={} estimated_bytes={} elapsed={:.3}s",
n,
selected_n,
n_cells,
max_degree,
estimated_bytes,
started.elapsed().as_secs_f64()
);
}
Ok(Some(RowCellMomentsBundle { max_degree, rows }))
}
fn build_row_primary_hessian_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<Array2<f64>>, String> {
if !self.effective_flex_active(block_states)? {
return Ok(None);
}
let n = self.y.len();
let primary = &cache.primary;
let r = primary.total;
let cache_bytes = n
.saturating_mul(r)
.saturating_mul(r)
.saturating_mul(std::mem::size_of::<f64>());
if cache_bytes > self.policy.max_single_materialization_bytes {
if log_exact_work(n) {
log::info!(
"[BMS row-primary-hessian-cache] stream rows n={} r={} bytes={} limit_bytes={}",
n,
r,
cache_bytes,
self.policy.max_single_materialization_bytes
);
}
return Ok(None);
}
let mut packed = Array2::<f64>::zeros((n, r * r));
packed
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.try_for_each(|(row, mut packed_row)| -> Result<(), String> {
let row_ctx = Self::row_ctx(cache, row);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(r);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
for (dst, src) in packed_row.iter_mut().zip(scratch.hess.iter()) {
*dst = *src;
}
Ok(())
})?;
Ok(Some(packed))
}
#[inline]
fn cached_row_primary_hessian<'a>(
cache: &'a BernoulliMarginalSlopeExactEvalCache,
row: usize,
) -> Option<ArrayView2<'a, f64>> {
let rows = cache.row_primary_hessians.as_ref()?;
let r = cache.primary.total;
if row >= rows.nrows() {
return None;
}
let width = r.checked_mul(r)?;
let start = row.checked_mul(width)?;
let end = start.checked_add(width)?;
ArrayView2::from_shape((r, r), rows.as_slice()?.get(start..end)?).ok()
}
fn build_exact_eval_cache(
&self,
block_states: &[ParameterBlockState],
) -> Result<BernoulliMarginalSlopeExactEvalCache, String> {
self.build_exact_eval_cache_with_order(block_states)
}
fn build_shared_eval_cache_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<Arc<BernoulliMarginalSlopeExactEvalCache>, String> {
let fingerprint = SharedEvalCacheFingerprint::from_inputs(block_states, options);
{
let guard = self.shared_eval_cache.lock().map_err(|e| e.to_string())?;
if let Some(entry) = guard.as_ref() {
if entry.fingerprint.matches(&fingerprint) {
return Ok(Arc::clone(&entry.cache));
}
}
}
let mut cache = self.build_exact_eval_cache_with_options(block_states, Some(options))?;
cache.row_primary_hessians =
self.build_row_primary_hessian_cache(block_states, &cache)?;
if !self.effective_flex_active(block_states)? {
let _ = self.rigid_third_full_cached(block_states, &cache, 0)?;
let _ = self.rigid_fourth_full_cached(block_states, &cache, 0)?;
}
let arc = Arc::new(cache);
let mut guard = self.shared_eval_cache.lock().map_err(|e| e.to_string())?;
*guard = Some(SharedEvalCacheEntry {
fingerprint,
cache: Arc::clone(&arc),
});
Ok(arc)
}
fn row_primary_direction_from_flat(
&self,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
d_beta_flat: &Array1<f64>,
) -> Result<Array1<f64>, String> {
if d_beta_flat.len() != slices.total {
return Err(format!(
"bernoulli marginal-slope d_beta length mismatch: got {}, expected {}",
d_beta_flat.len(),
slices.total
));
}
let mut out = Array1::<f64>::zeros(primary.total);
out[primary.q] = self
.marginal_design
.dot_row_view(row, d_beta_flat.slice(s![slices.marginal.clone()]));
out[primary.logslope] = self
.logslope_design
.dot_row_view(row, d_beta_flat.slice(s![slices.logslope.clone()]));
if let (Some(block_range), Some(primary_range)) = (slices.h.as_ref(), primary.h.as_ref()) {
out.slice_mut(s![primary_range.start..primary_range.end])
.assign(&d_beta_flat.slice(s![block_range.clone()]).to_owned());
}
if let (Some(block_range), Some(primary_range)) = (slices.w.as_ref(), primary.w.as_ref()) {
out.slice_mut(s![primary_range.start..primary_range.end])
.assign(&d_beta_flat.slice(s![block_range.clone()]).to_owned());
}
Ok(out)
}
fn row_primary_psi_direction_from_map(
&self,
row: usize,
block_idx: usize,
psi_map: &crate::families::custom_family::PsiDesignMap,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(primary.total);
match block_idx {
0 => {
let x_row = psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.q] = x_row.dot(&block_states[0].beta);
}
1 => {
let x_row = psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.logslope] = x_row.dot(&block_states[1].beta);
}
_ => {
return Err(format!(
"bernoulli marginal-slope psi direction only supports spatial marginal/logslope blocks, got block {block_idx}"
));
}
}
Ok(out)
}
fn row_primary_psi_action_on_direction_from_map(
&self,
row: usize,
block_idx: usize,
psi_map: &crate::families::custom_family::PsiDesignMap,
slices: &BlockSlices,
d_beta_flat: &Array1<f64>,
primary: &PrimarySlices,
) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(primary.total);
match block_idx {
0 => {
let x_row = psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.q] =
x_row.dot(&d_beta_flat.slice(s![slices.marginal.clone()]).to_owned())
}
1 => {
let x_row = psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.logslope] =
x_row.dot(&d_beta_flat.slice(s![slices.logslope.clone()]).to_owned())
}
_ => {
return Err(format!(
"bernoulli marginal-slope psi action only supports marginal/logslope blocks, got block {block_idx}"
));
}
}
Ok(out)
}
fn row_primary_psi_second_direction_from_map(
&self,
row: usize,
block_i: usize,
block_j: usize,
psi_map_ij: Option<&crate::families::custom_family::PsiDesignMap>,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
) -> Result<Array1<f64>, String> {
if block_i != block_j {
return Ok(Array1::<f64>::zeros(primary.total));
}
let psi_map_ij = psi_map_ij.expect("psi_map_ij must be provided when block_i == block_j");
let mut out = Array1::<f64>::zeros(primary.total);
match block_i {
0 => {
let x_row = psi_map_ij
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.q] = x_row.dot(&block_states[0].beta);
}
1 => {
let x_row = psi_map_ij
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.logslope] = x_row.dot(&block_states[1].beta);
}
_ => {
return Err(format!(
"bernoulli marginal-slope psi second direction only supports marginal/logslope blocks, got block {block_i}"
));
}
}
Ok(out)
}
fn pullback_primary_vector(
&self,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
primary_vec: &Array1<f64>,
) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(slices.total);
{
let mut marginal = out.slice_mut(s![slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, primary_vec[primary.q], &mut marginal)?;
}
{
let mut logslope = out.slice_mut(s![slices.logslope.clone()]);
self.logslope_design.axpy_row_into(
row,
primary_vec[primary.logslope],
&mut logslope,
)?;
}
if let Some(primary_h) = primary.h.as_ref() {
if let Some(block_h) = slices.h.as_ref() {
out.slice_mut(s![block_h.clone()]).assign(
&primary_vec
.slice(s![primary_h.start..primary_h.end])
.to_owned(),
);
}
}
if let Some(primary_w) = primary.w.as_ref() {
if let Some(block_w) = slices.w.as_ref() {
out.slice_mut(s![block_w.clone()]).assign(
&primary_vec
.slice(s![primary_w.start..primary_w.end])
.to_owned(),
);
}
}
Ok(out)
}
fn block_psi_row_from_map(
&self,
row: usize,
block_idx: usize,
psi_map: &crate::families::custom_family::PsiDesignMap,
slices: &BlockSlices,
) -> Result<BlockPsiRow, String> {
let (local_vec, range) = match block_idx {
0 => (
psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?,
slices.marginal.clone(),
),
1 => (
psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?,
slices.logslope.clone(),
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi embedding only supports marginal/logslope blocks, got block {block_idx}"
));
}
};
Ok(BlockPsiRow {
block_idx,
range,
local_vec,
})
}
fn block_psi_second_row_from_map(
&self,
row: usize,
block_i: usize,
block_j: usize,
psi_map_ij: Option<&crate::families::custom_family::PsiDesignMap>,
slices: &BlockSlices,
) -> Result<Option<BlockPsiRow>, String> {
if block_i != block_j {
return Ok(None);
}
let psi_map_ij = psi_map_ij.expect("psi_map_ij must be provided when block_i == block_j");
let (local_vec, range) = match block_i {
0 => (
psi_map_ij
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?,
slices.marginal.clone(),
),
1 => (
psi_map_ij
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?,
slices.logslope.clone(),
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi second embedding only supports marginal/logslope blocks, got block {block_i}"
));
}
};
Ok(Some(BlockPsiRow {
block_idx: block_i,
range,
local_vec,
}))
}
fn compute_row_primary_gradient_hessian(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
if self.effective_flex_active(block_states)? {
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
let neglog = self.compute_row_analytic_flex_into(
row,
block_states,
primary,
row_ctx,
true,
&mut scratch,
)?;
return Ok((neglog, scratch.grad, scratch.hess));
}
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (neglog, grad_pair, h) = self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
let mut grad = Array1::<f64>::zeros(2);
grad[0] = grad_pair[0];
grad[1] = grad_pair[1];
let mut hess = Array2::<f64>::zeros((2, 2));
hess[[0, 0]] = h[0][0];
hess[[0, 1]] = h[0][1];
hess[[1, 0]] = h[1][0];
hess[[1, 1]] = h[1][1];
Ok((neglog, grad, hess))
}
fn compute_row_primary_gradient_hessian_reusing_cache(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<(Array1<f64>, Array2<f64>), String> {
if self.effective_flex_active(block_states)?
&& let Some(cached_hessian) = Self::cached_row_primary_hessian(cache, row)
{
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 3));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
false,
&mut scratch,
)?;
return Ok((scratch.grad, cached_hessian.to_owned()));
}
let (_, grad, hess) =
self.compute_row_primary_gradient_hessian(row, block_states, primary, row_ctx)?;
Ok((grad, hess))
}
fn compute_row_analytic_flex_into(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
need_hessian: bool,
scratch: &mut BernoulliMarginalSlopeFlexRowScratch,
) -> Result<f64, String> {
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
None,
need_hessian,
scratch,
)
}
fn compute_row_analytic_flex_into_with_moments(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
row_cell_moments: Option<&[CachedDenestedCellMoments]>,
need_hessian: bool,
scratch: &mut BernoulliMarginalSlopeFlexRowScratch,
) -> Result<f64, String> {
let q = block_states[0].eta[row];
let b = block_states[1].eta[row];
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
self.compute_row_analytic_flex_from_parts_into(
row,
primary,
q,
b,
beta_h,
beta_w,
row_ctx,
row_cell_moments,
need_hessian,
scratch,
)
}
fn compute_row_analytic_flex_from_parts_into(
&self,
row: usize,
primary: &PrimarySlices,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
row_cell_moments: Option<&[CachedDenestedCellMoments]>,
need_hessian: bool,
scratch: &mut BernoulliMarginalSlopeFlexRowScratch,
) -> Result<f64, String> {
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let r = primary.total;
scratch.reset(need_hessian);
let a = row_ctx.intercept;
let f_a = row_ctx.m_a;
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let marginal = self.marginal_link_map(q)?;
let inv_ma = 1.0 / f_a;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let zero_family = vec![[0.0; 4]; r];
let f_u = &mut scratch.m_u;
let f_au = &mut scratch.m_au;
let f_uv = &mut scratch.m_uv;
let mut f_aa = 0.0f64;
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
if let Some(empirical_grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
for (&node, &weight) in empirical_grid
.nodes
.iter()
.zip(empirical_grid.weights.iter())
{
coeff_u.fill([0.0; 4]);
coeff_au.fill([0.0; 4]);
coeff_bu.fill([0.0; 4]);
let obs = self.observed_denested_cell_partials_at_z(node, a, b, beta_h, beta_w)?;
let eta = eval_coeff4_at(&obs.coeff, node);
let eta_a = eval_coeff4_at(&obs.dc_da, node);
let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
let phi = normal_pdf(eta);
if need_hessian {
f_aa += weight * phi * (eta_aa - eta * eta_a * eta_a);
}
coeff_u[1] = obs.dc_db;
if need_hessian {
coeff_au[1] = obs.dc_dab;
coeff_bu[1] = obs.dc_dbb;
}
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
node,
"score-warp",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
if need_hessian {
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
}
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
let u_node = a + b * node;
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_node,
"link-wiggle",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
if need_hessian {
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
}
Ok(())
},
)?;
}
let eta_u = (0..r)
.map(|idx| eval_coeff4_at(&coeff_u[idx], node))
.collect::<Vec<_>>();
for u in 1..r {
f_u[u] += weight * phi * eta_u[u];
if need_hessian {
let eta_au = eval_coeff4_at(&coeff_au[u], node);
f_au[u] += weight * phi * (eta_au - eta * eta_a * eta_u[u]);
}
}
if need_hessian {
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
for u in 1..r {
for v in u..r {
let second_coeff = coeff_jet.pair_from_b_family(
coeff_jet.b_first,
u,
v,
COEFF_SUPPORT_BHW,
);
let eta_uv = eval_coeff4_at(&second_coeff, node);
let val = weight * phi * (eta_uv - eta * eta_u[u] * eta_u[v]);
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
}
}
}
}
} else {
let owned_cells;
let cached_cells: Vec<(
exact::DenestedPartitionCell,
std::borrow::Cow<'_, exact::CellDerivativeMomentState>,
)> = if let Some(cached) = row_cell_moments {
debug_assert!(
!cached.is_empty(),
"row cell moments bundle was selected but row {row} has no cells"
);
cached
.iter()
.map(|entry| {
(
entry.partition_cell,
std::borrow::Cow::Borrowed(&entry.state),
)
})
.collect()
} else if let Some(cached) = row_ctx.degree9_cells.as_ref() {
cached
.iter()
.map(|entry| {
(
entry.partition_cell,
std::borrow::Cow::Borrowed(&entry.state),
)
})
.collect()
} else {
owned_cells = self.denested_partition_cells(a, b, beta_h, beta_w)?;
owned_cells
.into_iter()
.map(|partition_cell| {
let degree = if need_hessian { 9 } else { 3 };
self.evaluate_cell_derivative_moments_lru(partition_cell.cell, degree)
.map(|state| (partition_cell, std::borrow::Cow::Owned(state)))
})
.collect::<Result<Vec<_>, String>>()?
};
for (partition_cell, state) in cached_cells {
coeff_u.fill([0.0; 4]);
coeff_au.fill([0.0; 4]);
coeff_bu.fill([0.0; 4]);
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state: &exact::CellDerivativeMomentState = &state;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
coeff_u[1] = dc_db;
coeff_au[1] = [0.0; 4];
coeff_bu[1] = [0.0; 4];
if need_hessian {
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
}
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
if need_hessian {
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
}
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
if need_hessian {
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
}
Ok(())
},
)?;
}
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_u[u], &state.moments)?;
if need_hessian {
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_u[u],
&coeff_au[u],
&state.moments,
)?;
}
}
if need_hessian {
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
for u in 1..r {
for v in u..r {
let second_coeff = coeff_jet.pair_from_b_family(
coeff_jet.b_first,
u,
v,
COEFF_SUPPORT_BHW,
);
let val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
}
}
}
}
}
f_u[0] = -marginal.mu1;
if need_hessian {
f_uv[[0, 0]] = -marginal.mu2;
}
let a_u = &mut scratch.a_u;
for u in 0..r {
a_u[u] = -f_u[u] * inv_ma;
}
self.cache_row_intercept_predictor(row, a, q, b, beta_h, beta_w, a_u);
let a_uv = &mut scratch.a_uv;
if need_hessian {
for u in 0..r {
for v in u..r {
let val = -(f_uv[[u, v]]
+ f_au[u] * a_u[v]
+ f_au[v] * a_u[u]
+ f_aa * a_u[u] * a_u[v])
* inv_ma;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let chi_obs = eval_coeff4_at(&obs.dc_da, z_obs);
let eta_aa_obs = eval_coeff4_at(&obs.dc_daa, z_obs);
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
let rho = &mut scratch.rho;
let tau = &mut scratch.tau;
rho.fill(0.0);
tau.fill(0.0);
for u in 1..r {
rho[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
tau[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
}
let eta_u = &mut scratch.grad;
for u in 0..r {
eta_u[u] = chi_obs * a_u[u] + rho[u];
}
let signed_margin = s_y * eta_val;
let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
let neglog_val = -w_i * log_cdf;
let d1_m = -w_i * lambda;
let d2_m = w_i * lambda * (signed_margin + lambda);
if need_hessian {
let hess = &mut scratch.hess;
hess.fill(0.0);
for u in 0..r {
for v in u..r {
let r_uv = eval_coeff4_at(
&g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW),
z_obs,
);
let eta_uv = chi_obs * a_uv[[u, v]]
+ eta_aa_obs * a_u[u] * a_u[v]
+ tau[u] * a_u[v]
+ a_u[u] * tau[v]
+ r_uv;
let val = d2_m * eta_u[u] * eta_u[v] + d1_m * s_y * eta_uv;
hess[[u, v]] = val;
hess[[v, u]] = val;
}
}
}
eta_u.mapv_inplace(|eu| d1_m * s_y * eu);
Ok(neglog_val)
}
fn primary_point_from_block_states(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
) -> Result<Array1<f64>, String> {
let mut point = Array1::<f64>::zeros(primary.total);
point[primary.q] = block_states[0].eta[row];
point[primary.logslope] = block_states[1].eta[row];
if let Some(h_range) = primary.h.as_ref() {
let score = self
.score_block_state(block_states)?
.ok_or_else(|| "missing score-warp beta".to_string())?;
point
.slice_mut(s![h_range.start..h_range.end])
.assign(&score.beta);
}
if let Some(w_range) = primary.w.as_ref() {
let beta_w = self
.link_block_state(block_states)?
.ok_or_else(|| "missing link deviation beta".to_string())?;
point
.slice_mut(s![w_range.start..w_range.end])
.assign(&beta_w.beta);
}
Ok(point)
}
fn primary_point_components(
&self,
point: &Array1<f64>,
primary: &PrimarySlices,
) -> (f64, f64, Option<Array1<f64>>, Option<Array1<f64>>) {
let beta_h = primary
.h
.as_ref()
.map(|range| point.slice(s![range.start..range.end]).to_owned());
let beta_w = primary
.w
.as_ref()
.map(|range| point.slice(s![range.start..range.end]).to_owned());
(point[primary.q], point[primary.logslope], beta_h, beta_w)
}
fn observed_denested_cell_partials(
&self,
row: usize,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<ObservedDenestedCellPartials, String> {
shared_observed_denested_cell_partials(
self.z[row],
a,
b,
self.score_warp.as_ref(),
beta_h,
self.link_dev.as_ref(),
beta_w,
self.probit_frailty_scale(),
)
}
fn observed_denested_cell_partials_at_z(
&self,
z_value: f64,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<ObservedDenestedCellPartials, String> {
shared_observed_denested_cell_partials(
z_value,
a,
b,
self.score_warp.as_ref(),
beta_h,
self.link_dev.as_ref(),
beta_w,
self.probit_frailty_scale(),
)
}
fn row_primary_third_contracted_recompute(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir: &Array1<f64>,
) -> Result<Array2<f64>, String> {
self.row_primary_third_contracted_recompute_with_moments(
row,
block_states,
cache,
row_ctx,
dir,
)
}
fn row_primary_third_contracted_recompute_with_moments(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir: &Array1<f64>,
) -> Result<Array2<f64>, String> {
if !self.effective_flex_active(block_states)? {
let t = self.rigid_third_full_cached(block_states, cache, row)?;
let m = contract_third_full(t, dir[0], dir[1]);
let mut out = Array2::<f64>::zeros((2, 2));
out[[0, 0]] = m[0][0];
out[[0, 1]] = m[0][1];
out[[1, 0]] = m[1][0];
out[[1, 1]] = m[1][1];
return Ok(out);
}
if dir.iter().all(|value| value.abs() <= 0.0) {
return Ok(Array2::<f64>::zeros((
cache.primary.total,
cache.primary.total,
)));
}
if !row_ctx.intercept.is_finite() || !row_ctx.m_a.is_finite() || row_ctx.m_a <= 0.0 {
return Err(
"non-finite flexible row context in third-order directional contraction"
.to_string(),
);
}
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let primary = &cache.primary;
let point = self.primary_point_from_block_states(row, block_states, primary)?;
let (q, b, beta_h_owned, beta_w_owned) = self.primary_point_components(&point, primary);
let beta_h = beta_h_owned.as_ref();
let beta_w = beta_w_owned.as_ref();
if let Some(grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
return self.empirical_flex_row_third_contracted_recompute(
row, primary, q, b, beta_h, beta_w, row_ctx, dir, &grid,
);
}
let a = row_ctx.intercept;
let r = primary.total;
let marginal = self.marginal_link_map(q)?;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let zero_family = vec![[0.0; 4]; r];
let mut f_a = 0.0;
let mut f_aa = 0.0;
let mut f_a_dir = 0.0;
let mut f_aa_dir = 0.0;
let mut f_u = Array1::<f64>::zeros(r);
let mut f_au = Array1::<f64>::zeros(r);
let mut f_au_dir = Array1::<f64>::zeros(r);
let mut f_uv = Array2::<f64>::zeros((r, r));
let mut f_uv_dir = Array2::<f64>::zeros((r, r));
let owned_cells;
let cells: &[CachedDenestedCellMoments] = if let Some(cached) = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 15))
{
cached
} else {
let partitions = self.denested_partition_cells(a, b, beta_h, beta_w)?;
owned_cells = partitions
.into_iter()
.map(|partition_cell| {
self.evaluate_cell_derivative_moments_lru(partition_cell.cell, 15)
.map(|state| CachedDenestedCellMoments {
partition_cell,
state,
})
})
.collect::<Result<Vec<_>, String>>()?;
&owned_cells
};
for entry in cells {
let partition_cell = entry.partition_cell;
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = &entry.state;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let denested_third = exact::denested_cell_third_partials(partition_cell.link_span);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
let dc_daab = scale_coeff4(denested_third.1, scale);
let dc_dabb = scale_coeff4(denested_third.2, scale);
let dc_dbbb = scale_coeff4(denested_third.3, scale);
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
coeff_u[1] = dc_db;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
coeff_aau[1] = dc_daab;
coeff_abu[1] = dc_dabb;
coeff_bbu[1] = dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp third-direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle third-direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw_raw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw_raw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&coeff_aau,
&coeff_abu,
&coeff_bbu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
f_a += exact::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_jet.first[u], &state.moments)?;
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_jet.a_first[u],
&state.moments,
)?;
}
let coeff_dir = coeff_jet.directional_family(coeff_jet.first, dir, COEFF_SUPPORT_BHW);
let coeff_a_dir =
coeff_jet.directional_family(coeff_jet.a_first, dir, COEFF_SUPPORT_BW);
let coeff_aa_dir =
coeff_jet.directional_family(coeff_jet.aa_first, dir, COEFF_SUPPORT_BW);
f_a_dir += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_dir,
&coeff_a_dir,
&state.moments,
)?;
f_aa_dir += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir,
&dc_daa,
&coeff_a_dir,
&coeff_a_dir,
&coeff_aa_dir,
&state.moments,
)?;
let mut coeff_u_dir = vec![[0.0; 4]; r];
let mut coeff_au_dir = vec![[0.0; 4]; r];
for u in 1..r {
coeff_u_dir[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.b_first,
u,
dir,
COEFF_SUPPORT_BHW,
);
coeff_au_dir[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.ab_first,
u,
dir,
COEFF_SUPPORT_BW,
);
}
for u in 1..r {
f_au_dir[u] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_dir,
&coeff_jet.a_first[u],
&coeff_a_dir,
&coeff_u_dir[u],
&coeff_au_dir[u],
&state.moments,
)?;
}
for u in 1..r {
for v in u..r {
let second_coeff =
coeff_jet.pair_from_b_family(coeff_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
let third_coeff = coeff_jet.pair_directional_from_bb_family(
coeff_jet.bb_first,
u,
v,
dir,
COEFF_SUPPORT_BW,
);
let dir_val = exact::cell_third_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir,
&second_coeff,
&coeff_u_dir[u],
&coeff_u_dir[v],
&third_coeff,
&state.moments,
)?;
f_uv_dir[[u, v]] += dir_val;
if u != v {
f_uv_dir[[v, u]] += dir_val;
}
}
}
}
f_u[0] = -marginal.mu1;
f_uv[[0, 0]] = -marginal.mu2;
f_uv_dir[[0, 0]] = -dir[0] * marginal.mu3;
let inv_f_a = 1.0 / f_a;
let mut a_u = Array1::<f64>::zeros(r);
for u in 0..r {
a_u[u] = -f_u[u] * inv_f_a;
}
let mut a_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val =
-(f_uv[[u, v]] + f_au[u] * a_u[v] + f_au[v] * a_u[u] + f_aa * a_u[u] * a_u[v])
* inv_f_a;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
let a_dir = a_u.dot(dir);
let a_u_dir = a_uv.dot(dir);
let mut a_uv_dir = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let n_dir = f_uv_dir[[u, v]]
+ f_au_dir[u] * a_u[v]
+ f_au[u] * a_u_dir[v]
+ f_au_dir[v] * a_u[u]
+ f_au[v] * a_u_dir[u]
+ f_aa_dir * a_u[u] * a_u[v]
+ f_aa * (a_u_dir[u] * a_u[v] + a_u[u] * a_u_dir[v]);
let val = -(n_dir + f_a_dir * a_uv[[u, v]]) * inv_f_a;
a_uv_dir[[u, v]] = val;
a_uv_dir[[v, u]] = val;
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
let mut g_aau_fixed = vec![[0.0; 4]; r];
let mut g_abu_fixed = vec![[0.0; 4]; r];
let mut g_bbu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
g_aau_fixed[1] = obs.dc_daab;
g_abu_fixed[1] = obs.dc_dabb;
g_bbu_fixed[1] = obs.dc_dbbb;
let scale = self.probit_frailty_scale();
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp third-direction observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle third-direction observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
g_aau_fixed[idx] = scale_coeff4(dc_aaw_raw, scale);
g_abu_fixed[idx] = scale_coeff4(dc_abw_raw, scale);
g_bbu_fixed[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&g_aau_fixed,
&g_abu_fixed,
&g_bbu_fixed,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
let g_a = eval_coeff4_at(&obs.dc_da, z_obs);
let g_aa = eval_coeff4_at(&obs.dc_daa, z_obs);
let g_aaa = eval_coeff4_at(&obs.dc_daaa, z_obs);
let mut g_u = Array1::<f64>::zeros(r);
let mut g_au = Array1::<f64>::zeros(r);
let mut g_aau = Array1::<f64>::zeros(r);
let mut g_uv = Array2::<f64>::zeros((r, r));
let mut g_auv = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
g_au[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
g_aau[u] = eval_coeff4_at(&g_jet.aa_first[u], z_obs);
}
for u in 1..r {
for v in u..r {
let second_coeff = g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = eval_coeff4_at(&second_coeff, z_obs);
g_uv[[u, v]] = val;
g_uv[[v, u]] = val;
let third_coeff = g_jet.pair_from_b_family(g_jet.ab_first, u, v, COEFF_SUPPORT_BW);
let third_val = eval_coeff4_at(&third_coeff, z_obs);
g_auv[[u, v]] = third_val;
g_auv[[v, u]] = third_val;
}
}
let mut g_u_dir_fixed = vec![[0.0; 4]; r];
let mut g_au_dir_fixed = vec![[0.0; 4]; r];
let g_dir_fixed = g_jet.directional_family(g_jet.first, dir, COEFF_SUPPORT_BHW);
let g_a_dir_fixed = g_jet.directional_family(g_jet.a_first, dir, COEFF_SUPPORT_BW);
let g_aa_dir_fixed = g_jet.directional_family(g_jet.aa_first, dir, COEFF_SUPPORT_BW);
let g_dir = eval_coeff4_at(&g_dir_fixed, z_obs);
let g_a_dir = eval_coeff4_at(&g_a_dir_fixed, z_obs);
let g_aa_dir = eval_coeff4_at(&g_aa_dir_fixed, z_obs);
for u in 1..r {
g_u_dir_fixed[u] =
g_jet.param_directional_from_b_family(g_jet.b_first, u, dir, COEFF_SUPPORT_BHW);
g_au_dir_fixed[u] =
g_jet.param_directional_from_b_family(g_jet.ab_first, u, dir, COEFF_SUPPORT_BW);
}
let mut g_u_dir = Array1::<f64>::zeros(r);
let mut g_uv_dir = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u_dir[u] = eval_coeff4_at(&g_u_dir_fixed[u], z_obs);
}
for u in 1..r {
for v in u..r {
let third_coeff = g_jet.pair_directional_from_bb_family(
g_jet.bb_first,
u,
v,
dir,
COEFF_SUPPORT_BW,
);
let val = eval_coeff4_at(&third_coeff, z_obs);
g_uv_dir[[u, v]] = val;
g_uv_dir[[v, u]] = val;
}
}
let eta_u = g_a * &a_u + &g_u;
let mut eta_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = g_a * a_uv[[u, v]]
+ g_aa * a_u[u] * a_u[v]
+ g_au[u] * a_u[v]
+ g_au[v] * a_u[u]
+ g_uv[[u, v]];
eta_uv[[u, v]] = val;
eta_uv[[v, u]] = val;
}
}
let eta_dir = g_a * a_dir + g_dir;
let eta_u_dir = eta_uv.dot(dir);
let dg_a_dir = g_aa * a_dir + g_a_dir;
let dg_aa_dir = g_aaa * a_dir + g_aa_dir;
let mut dg_au_dir = Array1::<f64>::zeros(r);
let mut dg_uv_dir = Array2::<f64>::zeros((r, r));
for u in 0..r {
dg_au_dir[u] = g_aau[u] * a_dir + eval_coeff4_at(&g_au_dir_fixed[u], z_obs);
}
for u in 0..r {
for v in u..r {
let val = g_auv[[u, v]] * a_dir + g_uv_dir[[u, v]];
dg_uv_dir[[u, v]] = val;
dg_uv_dir[[v, u]] = val;
}
}
let mut eta_uv_dir = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = dg_a_dir * a_uv[[u, v]]
+ g_a * a_uv_dir[[u, v]]
+ dg_aa_dir * a_u[u] * a_u[v]
+ g_aa * (a_u_dir[u] * a_u[v] + a_u[u] * a_u_dir[v])
+ dg_au_dir[u] * a_u[v]
+ g_au[u] * a_u_dir[v]
+ dg_au_dir[v] * a_u[u]
+ g_au[v] * a_u_dir[u]
+ dg_uv_dir[[u, v]];
eta_uv_dir[[u, v]] = val;
eta_uv_dir[[v, u]] = val;
}
}
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let m = s_y * eta_val;
let (k1, k2, k3, _) = signed_probit_neglog_derivatives_up_to_fourth(m, w_i)?;
let u1 = s_y * k1;
let u3 = s_y * k3;
let mut out = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = u3 * eta_u[u] * eta_u[v] * eta_dir
+ k2 * (eta_uv[[u, v]] * eta_dir
+ eta_u[u] * eta_u_dir[v]
+ eta_u[v] * eta_u_dir[u])
+ u1 * eta_uv_dir[[u, v]];
out[[u, v]] = val;
out[[v, u]] = val;
}
}
Ok(out)
}
#[inline]
fn coeff4_eval_adjoint(z: f64, scalar_adjoint: f64) -> [f64; 4] {
let z2 = z * z;
[
scalar_adjoint,
scalar_adjoint * z,
scalar_adjoint * z2,
scalar_adjoint * z2 * z,
]
}
#[inline]
fn add_coeff4_adjoint(target: &mut [f64; 4], source: &[f64; 4]) {
for idx in 0..4 {
target[idx] += source[idx];
}
}
#[inline]
fn add_eval_directional_family_adjoint(
jet: &SparsePrimaryCoeffJetView<'_>,
family: &[[f64; 4]],
support: CoeffSupport,
z: f64,
scalar_adjoint: f64,
direction_adjoint: &mut [f64],
) {
let coeff_adjoint = Self::coeff4_eval_adjoint(z, scalar_adjoint);
jet.add_directional_family_adjoint(family, &coeff_adjoint, support, direction_adjoint);
}
#[inline]
fn add_eval_param_directional_adjoint(
jet: &SparsePrimaryCoeffJetView<'_>,
family: &[[f64; 4]],
param: usize,
support: CoeffSupport,
z: f64,
scalar_adjoint: f64,
direction_adjoint: &mut [f64],
) {
let coeff_adjoint = Self::coeff4_eval_adjoint(z, scalar_adjoint);
jet.add_param_directional_from_b_family_adjoint(
family,
param,
&coeff_adjoint,
support,
direction_adjoint,
);
}
#[inline]
fn add_eval_pair_directional_adjoint(
jet: &SparsePrimaryCoeffJetView<'_>,
family: &[[f64; 4]],
u: usize,
v: usize,
support: CoeffSupport,
z: f64,
scalar_adjoint: f64,
direction_adjoint: &mut [f64],
) {
let coeff_adjoint = Self::coeff4_eval_adjoint(z, scalar_adjoint);
jet.add_pair_directional_from_bb_family_adjoint(
family,
u,
v,
&coeff_adjoint,
support,
direction_adjoint,
);
}
fn add_cell_second_direction_adjoint(
cell: exact_kernel::DenestedCubicCell,
first_r: &[f64; 4],
moments: &[f64],
scalar_adjoint: f64,
first_s_adjoint: &mut [f64; 4],
second_adjoint: &mut [f64; 4],
) -> Result<(), String> {
if moments.len() < 10 {
return Err(format!(
"insufficient reduced moments for second-derivative adjoint: need 10, have {}",
moments.len()
));
}
let scale = scalar_adjoint / std::f64::consts::TAU;
let eta = [cell.c0, cell.c1, cell.c2, cell.c3];
for k in 0..4 {
second_adjoint[k] += scale * moments[k];
}
for s_idx in 0..4 {
let mut eta_r_moment = 0.0;
for (eta_idx, &eta_value) in eta.iter().enumerate() {
for (r_idx, &r_value) in first_r.iter().enumerate() {
eta_r_moment += eta_value * r_value * moments[eta_idx + r_idx + s_idx];
}
}
first_s_adjoint[s_idx] -= scale * eta_r_moment;
}
Ok(())
}
fn add_cell_third_direction_adjoint(
cell: exact_kernel::DenestedCubicCell,
first_r: &[f64; 4],
first_s: &[f64; 4],
second_rs: &[f64; 4],
moments: &[f64],
scalar_adjoint: f64,
first_t_adjoint: &mut [f64; 4],
second_rt_adjoint: &mut [f64; 4],
second_st_adjoint: &mut [f64; 4],
third_rst_adjoint: &mut [f64; 4],
) -> Result<(), String> {
if moments.len() < 16 {
return Err(format!(
"insufficient reduced moments for third-derivative adjoint: need 16, have {}",
moments.len()
));
}
let scale = scalar_adjoint / std::f64::consts::TAU;
let eta = [cell.c0, cell.c1, cell.c2, cell.c3];
let mut eta_sq_minus_one = [0.0; 7];
for (i, &eta_i) in eta.iter().enumerate() {
for (j, &eta_j) in eta.iter().enumerate() {
eta_sq_minus_one[i + j] += eta_i * eta_j;
}
}
eta_sq_minus_one[0] -= 1.0;
for k in 0..4 {
third_rst_adjoint[k] += scale * moments[k];
}
for coeff_idx in 0..4 {
let mut eta_s_moment = 0.0;
let mut eta_r_moment = 0.0;
for (eta_idx, &eta_value) in eta.iter().enumerate() {
for basis_idx in 0..4 {
eta_s_moment +=
eta_value * first_s[basis_idx] * moments[eta_idx + coeff_idx + basis_idx];
eta_r_moment +=
eta_value * first_r[basis_idx] * moments[eta_idx + coeff_idx + basis_idx];
}
}
second_rt_adjoint[coeff_idx] -= scale * eta_s_moment;
second_st_adjoint[coeff_idx] -= scale * eta_r_moment;
}
for t_idx in 0..4 {
let mut linear_second = 0.0;
for (eta_idx, &eta_value) in eta.iter().enumerate() {
for (second_idx, &second_value) in second_rs.iter().enumerate() {
linear_second +=
eta_value * second_value * moments[eta_idx + second_idx + t_idx];
}
}
let mut cubic_product = 0.0;
for (eta_idx, &eta_value) in eta_sq_minus_one.iter().enumerate() {
for (r_idx, &r_value) in first_r.iter().enumerate() {
for (s_idx, &s_value) in first_s.iter().enumerate() {
cubic_product += eta_value
* r_value
* s_value
* moments[eta_idx + r_idx + s_idx + t_idx];
}
}
}
first_t_adjoint[t_idx] += scale * (cubic_product - linear_second);
}
Ok(())
}
fn row_primary_third_trace_gradient_with_moments(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
gram: &[f64],
) -> Result<Array1<f64>, String> {
let primary = &cache.primary;
let r = primary.total;
if gram.len() != r * r {
return Err(format!(
"bernoulli marginal-slope row trace gram length {} != {}",
gram.len(),
r * r
));
}
if !self.effective_flex_active(block_states)? {
let tensor = self.rigid_third_full_cached(block_states, cache, row)?;
let mut grad = Array1::<f64>::zeros(r);
for a_idx in 0..2 {
for b_idx in 0..2 {
let weight = gram[a_idx * r + b_idx];
for dir_idx in 0..2 {
grad[dir_idx] += weight * tensor[a_idx][b_idx][dir_idx];
}
}
}
return Ok(grad);
}
if !row_ctx.intercept.is_finite() || !row_ctx.m_a.is_finite() || row_ctx.m_a <= 0.0 {
return Err(
"non-finite flexible row context in third-order trace-gradient contraction"
.to_string(),
);
}
let point = self.primary_point_from_block_states(row, block_states, primary)?;
let (q, b, beta_h_owned, beta_w_owned) = self.primary_point_components(&point, primary);
let beta_h = beta_h_owned.as_ref();
let beta_w = beta_w_owned.as_ref();
if let Some(grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
let mut grad = Array1::<f64>::zeros(r);
for dir_idx in 0..r {
let mut basis = Array1::<f64>::zeros(r);
basis[dir_idx] = 1.0;
let third = self.empirical_flex_row_third_contracted_recompute(
row, primary, q, b, beta_h, beta_w, row_ctx, &basis, &grid,
)?;
grad[dir_idx] = Self::row_primary_trace_contract(&third, gram);
}
return Ok(grad);
}
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let a = row_ctx.intercept;
let marginal = self.marginal_link_map(q)?;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let zero_family = vec![[0.0; 4]; r];
let mut f_a = 0.0;
let mut f_aa = 0.0;
let mut f_u = Array1::<f64>::zeros(r);
let mut f_au = Array1::<f64>::zeros(r);
let mut f_uv = Array2::<f64>::zeros((r, r));
let owned_cells;
let cells: &[CachedDenestedCellMoments] = if let Some(cached) = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 15))
{
cached
} else {
let partitions = self.denested_partition_cells(a, b, beta_h, beta_w)?;
owned_cells = partitions
.into_iter()
.map(|partition_cell| {
self.evaluate_cell_derivative_moments_lru(partition_cell.cell, 15)
.map(|state| CachedDenestedCellMoments {
partition_cell,
state,
})
})
.collect::<Result<Vec<_>, String>>()?;
&owned_cells
};
for entry in cells {
let partition_cell = entry.partition_cell;
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = &entry.state;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let denested_third = exact::denested_cell_third_partials(partition_cell.link_span);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
let dc_daab = scale_coeff4(denested_third.1, scale);
let dc_dabb = scale_coeff4(denested_third.2, scale);
let dc_dbbb = scale_coeff4(denested_third.3, scale);
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
coeff_u[1] = dc_db;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
coeff_aau[1] = dc_daab;
coeff_abu[1] = dc_dabb;
coeff_bbu[1] = dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp trace-gradient base",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle trace-gradient base",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw_raw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw_raw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&coeff_aau,
&coeff_abu,
&coeff_bbu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
f_a += exact::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_jet.first[u], &state.moments)?;
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_jet.a_first[u],
&state.moments,
)?;
}
for u in 1..r {
for v in u..r {
let second_coeff =
coeff_jet.pair_from_b_family(coeff_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
}
}
}
f_u[0] = -marginal.mu1;
f_uv[[0, 0]] = -marginal.mu2;
let inv_f_a = 1.0 / f_a;
let mut a_u = Array1::<f64>::zeros(r);
for u in 0..r {
a_u[u] = -f_u[u] * inv_f_a;
}
let mut a_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val =
-(f_uv[[u, v]] + f_au[u] * a_u[v] + f_au[v] * a_u[u] + f_aa * a_u[u] * a_u[v])
* inv_f_a;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
let mut g_aau_fixed = vec![[0.0; 4]; r];
let mut g_abu_fixed = vec![[0.0; 4]; r];
let mut g_bbu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
g_aau_fixed[1] = obs.dc_daab;
g_abu_fixed[1] = obs.dc_dabb;
g_bbu_fixed[1] = obs.dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp trace-gradient observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle trace-gradient observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
g_aau_fixed[idx] = scale_coeff4(dc_aaw_raw, scale);
g_abu_fixed[idx] = scale_coeff4(dc_abw_raw, scale);
g_bbu_fixed[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&g_aau_fixed,
&g_abu_fixed,
&g_bbu_fixed,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
let g_a = eval_coeff4_at(&obs.dc_da, z_obs);
let g_aa = eval_coeff4_at(&obs.dc_daa, z_obs);
let g_aaa = eval_coeff4_at(&obs.dc_daaa, z_obs);
let mut g_u = Array1::<f64>::zeros(r);
let mut g_au = Array1::<f64>::zeros(r);
let mut g_aau = Array1::<f64>::zeros(r);
let mut g_uv = Array2::<f64>::zeros((r, r));
let mut g_auv = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
g_au[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
g_aau[u] = eval_coeff4_at(&g_jet.aa_first[u], z_obs);
}
for u in 1..r {
for v in u..r {
let second_coeff = g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = eval_coeff4_at(&second_coeff, z_obs);
g_uv[[u, v]] = val;
g_uv[[v, u]] = val;
let third_coeff = g_jet.pair_from_b_family(g_jet.ab_first, u, v, COEFF_SUPPORT_BW);
let third_val = eval_coeff4_at(&third_coeff, z_obs);
g_auv[[u, v]] = third_val;
g_auv[[v, u]] = third_val;
}
}
let eta_u = g_a * &a_u + &g_u;
let mut eta_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = g_a * a_uv[[u, v]]
+ g_aa * a_u[u] * a_u[v]
+ g_au[u] * a_u[v]
+ g_au[v] * a_u[u]
+ g_uv[[u, v]];
eta_uv[[u, v]] = val;
eta_uv[[v, u]] = val;
}
}
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let m = s_y * eta_val;
let (k1, k2, k3, _) = signed_probit_neglog_derivatives_up_to_fourth(m, w_i)?;
let u1 = s_y * k1;
let u3 = s_y * k3;
let mut direction_adjoint = vec![0.0; r];
let mut adj_eta_dir = 0.0;
let mut adj_eta_u_dir = vec![0.0; r];
let mut adj_a_u_dir = vec![0.0; r];
let mut adj_a_uv_dir = Array2::<f64>::zeros((r, r));
let mut adj_dg_a_dir = 0.0;
let mut adj_dg_aa_dir = 0.0;
let mut adj_dg_au_dir = vec![0.0; r];
let mut adj_a_dir = 0.0;
for u in 0..r {
for v in u..r {
let weight = if u == v {
gram[u * r + v]
} else {
gram[u * r + v] + gram[v * r + u]
};
if weight == 0.0 {
continue;
}
adj_eta_dir += weight * (u3 * eta_u[u] * eta_u[v] + k2 * eta_uv[[u, v]]);
adj_eta_u_dir[v] += weight * k2 * eta_u[u];
adj_eta_u_dir[u] += weight * k2 * eta_u[v];
let adj_eta_uv_dir = weight * u1;
adj_dg_a_dir += adj_eta_uv_dir * a_uv[[u, v]];
adj_a_uv_dir[[u, v]] += adj_eta_uv_dir * g_a;
adj_dg_aa_dir += adj_eta_uv_dir * a_u[u] * a_u[v];
adj_a_u_dir[u] += adj_eta_uv_dir * g_aa * a_u[v];
adj_a_u_dir[v] += adj_eta_uv_dir * g_aa * a_u[u];
adj_dg_au_dir[u] += adj_eta_uv_dir * a_u[v];
adj_a_u_dir[v] += adj_eta_uv_dir * g_au[u];
adj_dg_au_dir[v] += adj_eta_uv_dir * a_u[u];
adj_a_u_dir[u] += adj_eta_uv_dir * g_au[v];
adj_a_dir += adj_eta_uv_dir * g_auv[[u, v]];
Self::add_eval_pair_directional_adjoint(
&g_jet,
g_jet.bb_first,
u,
v,
COEFF_SUPPORT_BW,
z_obs,
adj_eta_uv_dir,
&mut direction_adjoint,
);
}
}
for u in 0..r {
let adj = adj_dg_au_dir[u];
if adj != 0.0 {
adj_a_dir += adj * g_aau[u];
Self::add_eval_param_directional_adjoint(
&g_jet,
g_jet.ab_first,
u,
COEFF_SUPPORT_BW,
z_obs,
adj,
&mut direction_adjoint,
);
}
}
adj_a_dir += adj_eta_dir * g_a + adj_dg_a_dir * g_aa + adj_dg_aa_dir * g_aaa;
Self::add_eval_directional_family_adjoint(
&g_jet,
g_jet.first,
COEFF_SUPPORT_BHW,
z_obs,
adj_eta_dir,
&mut direction_adjoint,
);
Self::add_eval_directional_family_adjoint(
&g_jet,
g_jet.a_first,
COEFF_SUPPORT_BW,
z_obs,
adj_dg_a_dir,
&mut direction_adjoint,
);
Self::add_eval_directional_family_adjoint(
&g_jet,
g_jet.aa_first,
COEFF_SUPPORT_BW,
z_obs,
adj_dg_aa_dir,
&mut direction_adjoint,
);
for u in 0..r {
let adj = adj_eta_u_dir[u];
if adj != 0.0 {
for dir_idx in 0..r {
direction_adjoint[dir_idx] += adj * eta_uv[[u, dir_idx]];
}
}
}
let mut adj_f_a_dir = 0.0;
let mut adj_f_aa_dir = 0.0;
let mut adj_f_au_dir = vec![0.0; r];
let mut adj_f_uv_dir = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let adj = adj_a_uv_dir[[u, v]];
if adj == 0.0 {
continue;
}
let adj_n = -adj * inv_f_a;
adj_f_uv_dir[[u, v]] += adj_n;
adj_f_au_dir[u] += adj_n * a_u[v];
adj_a_u_dir[v] += adj_n * f_au[u];
adj_f_au_dir[v] += adj_n * a_u[u];
adj_a_u_dir[u] += adj_n * f_au[v];
adj_f_aa_dir += adj_n * a_u[u] * a_u[v];
adj_a_u_dir[u] += adj_n * f_aa * a_u[v];
adj_a_u_dir[v] += adj_n * f_aa * a_u[u];
adj_f_a_dir += adj_n * a_uv[[u, v]];
}
}
direction_adjoint[0] -= adj_f_uv_dir[[0, 0]] * marginal.mu3;
for u in 0..r {
let adj = adj_a_u_dir[u];
if adj != 0.0 {
for dir_idx in 0..r {
direction_adjoint[dir_idx] += adj * a_uv[[u, dir_idx]];
}
}
}
if adj_a_dir != 0.0 {
for dir_idx in 0..r {
direction_adjoint[dir_idx] += adj_a_dir * a_u[dir_idx];
}
}
for entry in cells {
let partition_cell = entry.partition_cell;
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = &entry.state;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let denested_third = exact::denested_cell_third_partials(partition_cell.link_span);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
let dc_daab = scale_coeff4(denested_third.1, scale);
let dc_dabb = scale_coeff4(denested_third.2, scale);
let dc_dbbb = scale_coeff4(denested_third.3, scale);
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
coeff_u[1] = dc_db;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
coeff_aau[1] = dc_daab;
coeff_abu[1] = dc_dabb;
coeff_bbu[1] = dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp trace-gradient adjoint",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle trace-gradient adjoint",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw_raw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw_raw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&coeff_aau,
&coeff_abu,
&coeff_bbu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
let mut coeff_dir_adj = [0.0; 4];
let mut coeff_a_dir_adj = [0.0; 4];
let mut coeff_aa_dir_adj = [0.0; 4];
let mut coeff_u_dir_adj = vec![[0.0; 4]; r];
let mut coeff_au_dir_adj = vec![[0.0; 4]; r];
if adj_f_a_dir != 0.0 {
Self::add_cell_second_direction_adjoint(
cell,
&dc_da,
&state.moments,
adj_f_a_dir,
&mut coeff_dir_adj,
&mut coeff_a_dir_adj,
)?;
}
if adj_f_aa_dir != 0.0 {
let mut a_rt_adj = [0.0; 4];
let mut a_st_adj = [0.0; 4];
Self::add_cell_third_direction_adjoint(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
adj_f_aa_dir,
&mut coeff_dir_adj,
&mut a_rt_adj,
&mut a_st_adj,
&mut coeff_aa_dir_adj,
)?;
Self::add_coeff4_adjoint(&mut coeff_a_dir_adj, &a_rt_adj);
Self::add_coeff4_adjoint(&mut coeff_a_dir_adj, &a_st_adj);
}
for u in 1..r {
let adj = adj_f_au_dir[u];
if adj == 0.0 {
continue;
}
Self::add_cell_third_direction_adjoint(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_jet.a_first[u],
&state.moments,
adj,
&mut coeff_dir_adj,
&mut coeff_a_dir_adj,
&mut coeff_u_dir_adj[u],
&mut coeff_au_dir_adj[u],
)?;
}
for u in 1..r {
for v in u..r {
let adj = adj_f_uv_dir[[u, v]];
if adj == 0.0 {
continue;
}
let second_coeff =
coeff_jet.pair_from_b_family(coeff_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let mut u_dir_adj = [0.0; 4];
let mut v_dir_adj = [0.0; 4];
let mut third_coeff_adj = [0.0; 4];
Self::add_cell_third_direction_adjoint(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
adj,
&mut coeff_dir_adj,
&mut u_dir_adj,
&mut v_dir_adj,
&mut third_coeff_adj,
)?;
Self::add_coeff4_adjoint(&mut coeff_u_dir_adj[u], &u_dir_adj);
Self::add_coeff4_adjoint(&mut coeff_u_dir_adj[v], &v_dir_adj);
coeff_jet.add_pair_directional_from_bb_family_adjoint(
coeff_jet.bb_first,
u,
v,
&third_coeff_adj,
COEFF_SUPPORT_BW,
&mut direction_adjoint,
);
}
}
coeff_jet.add_directional_family_adjoint(
coeff_jet.first,
&coeff_dir_adj,
COEFF_SUPPORT_BHW,
&mut direction_adjoint,
);
coeff_jet.add_directional_family_adjoint(
coeff_jet.a_first,
&coeff_a_dir_adj,
COEFF_SUPPORT_BW,
&mut direction_adjoint,
);
coeff_jet.add_directional_family_adjoint(
coeff_jet.aa_first,
&coeff_aa_dir_adj,
COEFF_SUPPORT_BW,
&mut direction_adjoint,
);
for u in 1..r {
coeff_jet.add_param_directional_from_b_family_adjoint(
coeff_jet.b_first,
u,
&coeff_u_dir_adj[u],
COEFF_SUPPORT_BHW,
&mut direction_adjoint,
);
coeff_jet.add_param_directional_from_b_family_adjoint(
coeff_jet.ab_first,
u,
&coeff_au_dir_adj[u],
COEFF_SUPPORT_BW,
&mut direction_adjoint,
);
}
}
Ok(Array1::from_vec(direction_adjoint))
}
fn row_primary_third_trace_many_with_moments(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
row_dirs: &[Array1<f64>],
gram: &[f64],
) -> Result<Vec<f64>, String> {
let primary = &cache.primary;
let r = primary.total;
if row_dirs.is_empty() {
return Ok(Vec::new());
}
if gram.len() != r * r {
return Err(format!(
"bernoulli marginal-slope row trace gram length {} != {}",
gram.len(),
r * r
));
}
if let Some((idx, dir)) = row_dirs.iter().enumerate().find(|(_, dir)| dir.len() != r) {
return Err(format!(
"bernoulli marginal-slope row trace direction {idx} length {} != {r}",
dir.len()
));
}
if row_dirs.len() > 1 {
let trace_gradient = self.row_primary_third_trace_gradient_with_moments(
row,
block_states,
cache,
row_ctx,
gram,
)?;
let traces = row_dirs
.iter()
.map(|dir| trace_gradient.dot(dir))
.collect::<Vec<_>>();
return Ok(traces);
}
if !self.effective_flex_active(block_states)? {
let t = self.rigid_third_full_cached(block_states, cache, row)?;
let mut traces = vec![0.0; row_dirs.len()];
for (dir_idx, dir) in row_dirs.iter().enumerate() {
let m = contract_third_full(t, dir[0], dir[1]);
traces[dir_idx] = m[0][0] * gram[0]
+ m[0][1] * gram[1]
+ m[1][0] * gram[r]
+ m[1][1] * gram[r + 1];
}
return Ok(traces);
}
if !row_ctx.intercept.is_finite() || !row_ctx.m_a.is_finite() || row_ctx.m_a <= 0.0 {
return Err(
"non-finite flexible row context in batched third-order trace contraction"
.to_string(),
);
}
let point = self.primary_point_from_block_states(row, block_states, primary)?;
let (q, b, beta_h_owned, beta_w_owned) = self.primary_point_components(&point, primary);
let beta_h = beta_h_owned.as_ref();
let beta_w = beta_w_owned.as_ref();
let a = row_ctx.intercept;
if let Some(grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
let mut traces = vec![0.0; row_dirs.len()];
for (dir_idx, dir) in row_dirs.iter().enumerate() {
let third = self.empirical_flex_row_third_contracted_recompute(
row, primary, q, b, beta_h, beta_w, row_ctx, dir, &grid,
)?;
traces[dir_idx] = Self::row_primary_trace_contract(&third, gram);
}
return Ok(traces);
}
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let marginal = self.marginal_link_map(q)?;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let zero_family = vec![[0.0; 4]; r];
let n_dirs = row_dirs.len();
let mut f_a = 0.0;
let mut f_aa = 0.0;
let mut f_u = Array1::<f64>::zeros(r);
let mut f_au = Array1::<f64>::zeros(r);
let mut f_uv = Array2::<f64>::zeros((r, r));
let mut f_a_dir = vec![0.0; n_dirs];
let mut f_aa_dir = vec![0.0; n_dirs];
let mut f_au_dir = vec![0.0; n_dirs * r];
let mut f_uv_dir = vec![0.0; n_dirs * r * r];
let owned_cells;
let cells: &[CachedDenestedCellMoments] = if let Some(cached) = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 15))
{
cached
} else {
let partitions = self.denested_partition_cells(a, b, beta_h, beta_w)?;
owned_cells = partitions
.into_iter()
.map(|partition_cell| {
self.evaluate_cell_derivative_moments_lru(partition_cell.cell, 15)
.map(|state| CachedDenestedCellMoments {
partition_cell,
state,
})
})
.collect::<Result<Vec<_>, String>>()?;
&owned_cells
};
for entry in cells {
let partition_cell = entry.partition_cell;
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = &entry.state;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let denested_third = exact::denested_cell_third_partials(partition_cell.link_span);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
let dc_daab = scale_coeff4(denested_third.1, scale);
let dc_dabb = scale_coeff4(denested_third.2, scale);
let dc_dbbb = scale_coeff4(denested_third.3, scale);
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
coeff_u[1] = dc_db;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
coeff_aau[1] = dc_daab;
coeff_abu[1] = dc_dabb;
coeff_bbu[1] = dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp batched third-trace direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle batched third-trace direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw_raw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw_raw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&coeff_aau,
&coeff_abu,
&coeff_bbu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
f_a += exact::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_jet.first[u], &state.moments)?;
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_jet.a_first[u],
&state.moments,
)?;
}
for u in 1..r {
for v in u..r {
let second_coeff =
coeff_jet.pair_from_b_family(coeff_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
}
}
for (dir_idx, dir) in row_dirs.iter().enumerate() {
let coeff_dir =
coeff_jet.directional_family(coeff_jet.first, dir, COEFF_SUPPORT_BHW);
let coeff_a_dir =
coeff_jet.directional_family(coeff_jet.a_first, dir, COEFF_SUPPORT_BW);
let coeff_aa_dir =
coeff_jet.directional_family(coeff_jet.aa_first, dir, COEFF_SUPPORT_BW);
f_a_dir[dir_idx] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_dir,
&coeff_a_dir,
&state.moments,
)?;
f_aa_dir[dir_idx] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir,
&dc_daa,
&coeff_a_dir,
&coeff_a_dir,
&coeff_aa_dir,
&state.moments,
)?;
let mut coeff_u_dir = vec![[0.0; 4]; r];
let mut coeff_au_dir = vec![[0.0; 4]; r];
for u in 1..r {
coeff_u_dir[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.b_first,
u,
dir,
COEFF_SUPPORT_BHW,
);
coeff_au_dir[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.ab_first,
u,
dir,
COEFF_SUPPORT_BW,
);
}
for u in 1..r {
f_au_dir[dir_idx * r + u] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_dir,
&coeff_jet.a_first[u],
&coeff_a_dir,
&coeff_u_dir[u],
&coeff_au_dir[u],
&state.moments,
)?;
}
let dir_base = dir_idx * r * r;
for u in 1..r {
for v in u..r {
let second_coeff = coeff_jet.pair_from_b_family(
coeff_jet.b_first,
u,
v,
COEFF_SUPPORT_BHW,
);
let third_coeff = coeff_jet.pair_directional_from_bb_family(
coeff_jet.bb_first,
u,
v,
dir,
COEFF_SUPPORT_BW,
);
let dir_val = exact::cell_third_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir,
&second_coeff,
&coeff_u_dir[u],
&coeff_u_dir[v],
&third_coeff,
&state.moments,
)?;
f_uv_dir[dir_base + u * r + v] += dir_val;
if u != v {
f_uv_dir[dir_base + v * r + u] += dir_val;
}
}
}
}
}
f_u[0] = -marginal.mu1;
f_uv[[0, 0]] = -marginal.mu2;
let inv_f_a = 1.0 / f_a;
let mut a_u = Array1::<f64>::zeros(r);
for u in 0..r {
a_u[u] = -f_u[u] * inv_f_a;
}
let mut a_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val =
-(f_uv[[u, v]] + f_au[u] * a_u[v] + f_au[v] * a_u[u] + f_aa * a_u[u] * a_u[v])
* inv_f_a;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
let mut g_aau_fixed = vec![[0.0; 4]; r];
let mut g_abu_fixed = vec![[0.0; 4]; r];
let mut g_bbu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
g_aau_fixed[1] = obs.dc_daab;
g_abu_fixed[1] = obs.dc_dabb;
g_bbu_fixed[1] = obs.dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp batched third-trace observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle batched third-trace observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
g_aau_fixed[idx] = scale_coeff4(dc_aaw_raw, scale);
g_abu_fixed[idx] = scale_coeff4(dc_abw_raw, scale);
g_bbu_fixed[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&g_aau_fixed,
&g_abu_fixed,
&g_bbu_fixed,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
let g_a = eval_coeff4_at(&obs.dc_da, z_obs);
let g_aa = eval_coeff4_at(&obs.dc_daa, z_obs);
let g_aaa = eval_coeff4_at(&obs.dc_daaa, z_obs);
let mut g_u = Array1::<f64>::zeros(r);
let mut g_au = Array1::<f64>::zeros(r);
let mut g_aau = Array1::<f64>::zeros(r);
let mut g_uv = Array2::<f64>::zeros((r, r));
let mut g_auv = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
g_au[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
g_aau[u] = eval_coeff4_at(&g_jet.aa_first[u], z_obs);
}
for u in 1..r {
for v in u..r {
let second_coeff = g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = eval_coeff4_at(&second_coeff, z_obs);
g_uv[[u, v]] = val;
g_uv[[v, u]] = val;
let third_coeff = g_jet.pair_from_b_family(g_jet.ab_first, u, v, COEFF_SUPPORT_BW);
let third_val = eval_coeff4_at(&third_coeff, z_obs);
g_auv[[u, v]] = third_val;
g_auv[[v, u]] = third_val;
}
}
let eta_u = g_a * &a_u + &g_u;
let mut eta_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = g_a * a_uv[[u, v]]
+ g_aa * a_u[u] * a_u[v]
+ g_au[u] * a_u[v]
+ g_au[v] * a_u[u]
+ g_uv[[u, v]];
eta_uv[[u, v]] = val;
eta_uv[[v, u]] = val;
}
}
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let m = s_y * eta_val;
let (k1, k2, k3, _) = signed_probit_neglog_derivatives_up_to_fourth(m, w_i)?;
let u1 = s_y * k1;
let u3 = s_y * k3;
let mut traces = vec![0.0; n_dirs];
for (dir_idx, dir) in row_dirs.iter().enumerate() {
let dir_base = dir_idx * r * r;
f_uv_dir[dir_base] = -dir[0] * marginal.mu3;
let a_dir = a_u.dot(dir);
let a_u_dir = a_uv.dot(dir);
let g_dir_fixed = g_jet.directional_family(g_jet.first, dir, COEFF_SUPPORT_BHW);
let g_a_dir_fixed = g_jet.directional_family(g_jet.a_first, dir, COEFF_SUPPORT_BW);
let g_aa_dir_fixed = g_jet.directional_family(g_jet.aa_first, dir, COEFF_SUPPORT_BW);
let g_dir = eval_coeff4_at(&g_dir_fixed, z_obs);
let g_a_dir = eval_coeff4_at(&g_a_dir_fixed, z_obs);
let g_aa_dir = eval_coeff4_at(&g_aa_dir_fixed, z_obs);
let eta_dir = g_a * a_dir + g_dir;
let eta_u_dir = eta_uv.dot(dir);
let dg_a_dir = g_aa * a_dir + g_a_dir;
let dg_aa_dir = g_aaa * a_dir + g_aa_dir;
let mut dg_au_dir = Array1::<f64>::zeros(r);
for u in 0..r {
let coeff =
g_jet.param_directional_from_b_family(g_jet.ab_first, u, dir, COEFF_SUPPORT_BW);
dg_au_dir[u] = g_aau[u] * a_dir + eval_coeff4_at(&coeff, z_obs);
}
let mut trace = 0.0;
for u in 0..r {
for v in u..r {
let fuvd = f_uv_dir[dir_base + u * r + v];
let n_dir = fuvd
+ f_au_dir[dir_idx * r + u] * a_u[v]
+ f_au[u] * a_u_dir[v]
+ f_au_dir[dir_idx * r + v] * a_u[u]
+ f_au[v] * a_u_dir[u]
+ f_aa_dir[dir_idx] * a_u[u] * a_u[v]
+ f_aa * (a_u_dir[u] * a_u[v] + a_u[u] * a_u_dir[v]);
let a_uv_dir = -(n_dir + f_a_dir[dir_idx] * a_uv[[u, v]]) * inv_f_a;
let third_coeff = g_jet.pair_directional_from_bb_family(
g_jet.bb_first,
u,
v,
dir,
COEFF_SUPPORT_BW,
);
let dg_uv_dir = g_auv[[u, v]] * a_dir + eval_coeff4_at(&third_coeff, z_obs);
let eta_uv_dir = dg_a_dir * a_uv[[u, v]]
+ g_a * a_uv_dir
+ dg_aa_dir * a_u[u] * a_u[v]
+ g_aa * (a_u_dir[u] * a_u[v] + a_u[u] * a_u_dir[v])
+ dg_au_dir[u] * a_u[v]
+ g_au[u] * a_u_dir[v]
+ dg_au_dir[v] * a_u[u]
+ g_au[v] * a_u_dir[u]
+ dg_uv_dir;
let val = u3 * eta_u[u] * eta_u[v] * eta_dir
+ k2 * (eta_uv[[u, v]] * eta_dir
+ eta_u[u] * eta_u_dir[v]
+ eta_u[v] * eta_u_dir[u])
+ u1 * eta_uv_dir;
if u == v {
trace += val * gram[u * r + v];
} else {
trace += val * (gram[u * r + v] + gram[v * r + u]);
}
}
}
traces[dir_idx] = trace;
}
Ok(traces)
}
fn row_primary_fourth_contracted_recompute_ordered(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir_u: &Array1<f64>,
dir_v: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let flex_active = self.effective_flex_active(block_states)?;
let expected_dir_len = if flex_active { cache.primary.total } else { 2 };
if dir_u.len() != expected_dir_len || dir_v.len() != expected_dir_len {
return Err(format!(
"bernoulli fourth contracted row {row}: direction lengths ({},{}) != {expected_dir_len}",
dir_u.len(),
dir_v.len()
));
}
if !flex_active {
let t = self.rigid_fourth_full_cached(block_states, cache, row)?;
let f = contract_fourth_full(t, dir_u[0], dir_u[1], dir_v[0], dir_v[1]);
let mut out = Array2::<f64>::zeros((2, 2));
out[[0, 0]] = f[0][0];
out[[0, 1]] = f[0][1];
out[[1, 0]] = f[1][0];
out[[1, 1]] = f[1][1];
return Ok(out);
}
if dir_u.iter().all(|value| *value == 0.0) || dir_v.iter().all(|value| *value == 0.0) {
return Ok(Array2::<f64>::zeros((
cache.primary.total,
cache.primary.total,
)));
}
if !row_ctx.intercept.is_finite() || !row_ctx.m_a.is_finite() || row_ctx.m_a <= 0.0 {
return Err(
"non-finite flexible row context in fourth-order directional contraction"
.to_string(),
);
}
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let primary = &cache.primary;
let point = self.primary_point_from_block_states(row, block_states, primary)?;
let (q, b, beta_h_owned, beta_w_owned) = self.primary_point_components(&point, primary);
let beta_h = beta_h_owned.as_ref();
let beta_w = beta_w_owned.as_ref();
if let Some(grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
return self.empirical_flex_row_fourth_contracted_recompute(
row, primary, q, b, beta_h, beta_w, row_ctx, dir_u, dir_v, &grid,
);
}
let a = row_ctx.intercept;
let r = primary.total;
let marginal = self.marginal_link_map(q)?;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let mut f_a = 0.0;
let mut f_aa = 0.0;
let mut f_u = Array1::<f64>::zeros(r);
let mut f_au = Array1::<f64>::zeros(r);
let mut f_uv = Array2::<f64>::zeros((r, r));
let mut f_a_u = 0.0;
let mut f_aa_u = 0.0;
let mut f_au_u = Array1::<f64>::zeros(r);
let mut f_uv_u = Array2::<f64>::zeros((r, r));
let mut f_a_v = 0.0;
let mut f_aa_v = 0.0;
let mut f_au_v = Array1::<f64>::zeros(r);
let mut f_uv_v = Array2::<f64>::zeros((r, r));
let mut f_a_uv = 0.0;
let mut f_aa_uv = 0.0;
let mut f_au_uv = Array1::<f64>::zeros(r);
let mut f_uv_uv = Array2::<f64>::zeros((r, r));
let owned_cells;
let cells: &[CachedDenestedCellMoments] = if let Some(cached) = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 21))
{
cached
} else {
let partitions = self.denested_partition_cells(a, b, beta_h, beta_w)?;
owned_cells = partitions
.into_iter()
.map(|partition_cell| {
self.evaluate_cell_derivative_moments_lru(partition_cell.cell, 21)
.map(|state| CachedDenestedCellMoments {
partition_cell,
state,
})
})
.collect::<Result<Vec<_>, String>>()?;
&owned_cells
};
for entry in cells {
let partition_cell = entry.partition_cell;
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = &entry.state;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let denested_third = exact::denested_cell_third_partials(partition_cell.link_span);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
let dc_daab = scale_coeff4(denested_third.1, scale);
let dc_dabb = scale_coeff4(denested_third.2, scale);
let dc_dbbb = scale_coeff4(denested_third.3, scale);
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
let mut coeff_aaau = vec![[0.0; 4]; r];
let mut coeff_aabu = vec![[0.0; 4]; r];
let mut coeff_abbu = vec![[0.0; 4]; r];
let mut coeff_bbbu = vec![[0.0; 4]; r];
coeff_u[1] = dc_db;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
coeff_aau[1] = dc_daab;
coeff_abu[1] = dc_dabb;
coeff_bbu[1] = dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp fourth-direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle fourth-direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
let (dc_aaaw, dc_aabw, dc_abbw, dc_bbbw) =
exact::link_basis_cell_third_partials(basis_span);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw_raw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw_raw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw_raw, scale);
coeff_aaau[idx] = scale_coeff4(dc_aaaw, scale);
coeff_aabu[idx] = scale_coeff4(dc_aabw, scale);
coeff_abbu[idx] = scale_coeff4(dc_abbw, scale);
coeff_bbbu[idx] = scale_coeff4(dc_bbbw, scale);
Ok(())
},
)?;
}
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&coeff_aau,
&coeff_abu,
&coeff_bbu,
&coeff_aaau,
&coeff_aabu,
&coeff_abbu,
&coeff_bbbu,
);
f_a += exact::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_jet.first[u], &state.moments)?;
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_jet.a_first[u],
&state.moments,
)?;
}
let coeff_dir_u =
coeff_jet.directional_family(coeff_jet.first, dir_u, COEFF_SUPPORT_BHW);
let coeff_dir_v =
coeff_jet.directional_family(coeff_jet.first, dir_v, COEFF_SUPPORT_BHW);
let coeff_a_dir_u =
coeff_jet.directional_family(coeff_jet.a_first, dir_u, COEFF_SUPPORT_BW);
let coeff_a_dir_v =
coeff_jet.directional_family(coeff_jet.a_first, dir_v, COEFF_SUPPORT_BW);
let coeff_aa_dir_u =
coeff_jet.directional_family(coeff_jet.aa_first, dir_u, COEFF_SUPPORT_BW);
let coeff_aa_dir_v =
coeff_jet.directional_family(coeff_jet.aa_first, dir_v, COEFF_SUPPORT_BW);
f_a_u += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_dir_u,
&coeff_a_dir_u,
&state.moments,
)?;
f_a_v += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_dir_v,
&coeff_a_dir_v,
&state.moments,
)?;
f_aa_u += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir_u,
&dc_daa,
&coeff_a_dir_u,
&coeff_a_dir_u,
&coeff_aa_dir_u,
&state.moments,
)?;
f_aa_v += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir_v,
&dc_daa,
&coeff_a_dir_v,
&coeff_a_dir_v,
&coeff_aa_dir_v,
&state.moments,
)?;
let coeff_dir_uv = coeff_jet.mixed_directional_from_b_family(
coeff_jet.b_first,
dir_u,
dir_v,
COEFF_SUPPORT_BHW,
);
let coeff_a_dir_uv = coeff_jet.mixed_directional_from_b_family(
coeff_jet.ab_first,
dir_u,
dir_v,
COEFF_SUPPORT_BW,
);
let coeff_aa_dir_uv = coeff_jet.mixed_directional_from_b_family(
coeff_jet.aab_first,
dir_u,
dir_v,
COEFF_SUPPORT_W,
);
f_a_uv += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_dir_u,
&coeff_dir_v,
&coeff_a_dir_u,
&coeff_a_dir_v,
&coeff_dir_uv,
&coeff_a_dir_uv,
&state.moments,
)?;
f_aa_uv += exact::cell_fourth_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir_u,
&coeff_dir_v,
&dc_daa,
&coeff_a_dir_u,
&coeff_a_dir_v,
&coeff_a_dir_u,
&coeff_a_dir_v,
&coeff_dir_uv,
&coeff_aa_dir_u,
&coeff_aa_dir_v,
&coeff_a_dir_uv,
&coeff_a_dir_uv,
&coeff_aa_dir_uv,
&state.moments,
)?;
let mut coeff_u_dir_u = vec![[0.0; 4]; r];
let mut coeff_u_dir_v = vec![[0.0; 4]; r];
let mut coeff_u_dir_uv = vec![[0.0; 4]; r];
let mut coeff_au_dir_u = vec![[0.0; 4]; r];
let mut coeff_au_dir_v = vec![[0.0; 4]; r];
let mut coeff_au_dir_uv = vec![[0.0; 4]; r];
for u in 1..r {
coeff_u_dir_u[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.b_first,
u,
dir_u,
COEFF_SUPPORT_BHW,
);
coeff_u_dir_v[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.b_first,
u,
dir_v,
COEFF_SUPPORT_BHW,
);
coeff_u_dir_uv[u] = coeff_jet.param_mixed_from_bb_family(
coeff_jet.bb_first,
u,
dir_u,
dir_v,
COEFF_SUPPORT_BW,
);
coeff_au_dir_u[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.ab_first,
u,
dir_u,
COEFF_SUPPORT_BW,
);
coeff_au_dir_v[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.ab_first,
u,
dir_v,
COEFF_SUPPORT_BW,
);
coeff_au_dir_uv[u] = coeff_jet.param_mixed_from_bb_family(
coeff_jet.abb_first,
u,
dir_u,
dir_v,
COEFF_SUPPORT_W,
);
}
for u in 1..r {
f_au_u[u] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_u[u],
&coeff_dir_u,
&coeff_au[u],
&coeff_a_dir_u,
&coeff_u_dir_u[u],
&coeff_au_dir_u[u],
&state.moments,
)?;
f_au_v[u] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_u[u],
&coeff_dir_v,
&coeff_au[u],
&coeff_a_dir_v,
&coeff_u_dir_v[u],
&coeff_au_dir_v[u],
&state.moments,
)?;
f_au_uv[u] += exact::cell_fourth_derivative_from_moments(
cell,
&dc_da,
&coeff_u[u],
&coeff_dir_u,
&coeff_dir_v,
&coeff_au[u],
&coeff_a_dir_u,
&coeff_a_dir_v,
&coeff_u_dir_u[u],
&coeff_u_dir_v[u],
&coeff_dir_uv,
&coeff_au_dir_u[u],
&coeff_au_dir_v[u],
&coeff_a_dir_uv,
&coeff_u_dir_uv[u],
&coeff_au_dir_uv[u],
&state.moments,
)?;
}
for u in 1..r {
for v in u..r {
let second_coeff =
coeff_jet.pair_from_b_family(coeff_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let base_val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += base_val;
if u != v {
f_uv[[v, u]] += base_val;
}
let third_u = coeff_jet.pair_directional_from_bb_family(
coeff_jet.bb_first,
u,
v,
dir_u,
COEFF_SUPPORT_BW,
);
let third_v = coeff_jet.pair_directional_from_bb_family(
coeff_jet.bb_first,
u,
v,
dir_v,
COEFF_SUPPORT_BW,
);
let fourth_uv = coeff_jet.pair_mixed_from_bbb_family(
coeff_jet.bbb_first,
u,
v,
dir_u,
dir_v,
COEFF_SUPPORT_W,
);
let dir_u_val = exact::cell_third_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir_u,
&second_coeff,
&coeff_u_dir_u[u],
&coeff_u_dir_u[v],
&third_u,
&state.moments,
)?;
let dir_v_val = exact::cell_third_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir_v,
&second_coeff,
&coeff_u_dir_v[u],
&coeff_u_dir_v[v],
&third_v,
&state.moments,
)?;
let mix_val = exact::cell_fourth_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir_u,
&coeff_dir_v,
&second_coeff,
&coeff_u_dir_u[u],
&coeff_u_dir_v[u],
&coeff_u_dir_u[v],
&coeff_u_dir_v[v],
&coeff_dir_uv,
&third_u,
&third_v,
&coeff_u_dir_uv[u],
&coeff_u_dir_uv[v],
&fourth_uv,
&state.moments,
)?;
f_uv_u[[u, v]] += dir_u_val;
f_uv_v[[u, v]] += dir_v_val;
f_uv_uv[[u, v]] += mix_val;
if u != v {
f_uv_u[[v, u]] += dir_u_val;
f_uv_v[[v, u]] += dir_v_val;
f_uv_uv[[v, u]] += mix_val;
}
}
}
}
f_u[0] = -marginal.mu1;
f_uv[[0, 0]] = -marginal.mu2;
f_uv_u[[0, 0]] = -dir_u[0] * marginal.mu3;
f_uv_v[[0, 0]] = -dir_v[0] * marginal.mu3;
f_uv_uv[[0, 0]] = -dir_u[0] * dir_v[0] * marginal.mu4;
let inv_f_a = 1.0 / f_a;
let mut a_u = Array1::<f64>::zeros(r);
for u in 0..r {
a_u[u] = -f_u[u] * inv_f_a;
}
let mut a_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val =
-(f_uv[[u, v]] + f_au[u] * a_u[v] + f_au[v] * a_u[u] + f_aa * a_u[u] * a_u[v])
* inv_f_a;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
let a_u_dir_u = a_uv.dot(dir_u);
let a_u_dir_v = a_uv.dot(dir_v);
let mut a_uv_u = Array2::<f64>::zeros((r, r));
let mut a_uv_v = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let n_u = f_uv_u[[u, v]]
+ f_au_u[u] * a_u[v]
+ f_au[u] * a_u_dir_u[v]
+ f_au_u[v] * a_u[u]
+ f_au[v] * a_u_dir_u[u]
+ f_aa_u * a_u[u] * a_u[v]
+ f_aa * (a_u_dir_u[u] * a_u[v] + a_u[u] * a_u_dir_u[v]);
let val_u = -(n_u + f_a_u * a_uv[[u, v]]) * inv_f_a;
a_uv_u[[u, v]] = val_u;
a_uv_u[[v, u]] = val_u;
let n_v = f_uv_v[[u, v]]
+ f_au_v[u] * a_u[v]
+ f_au[u] * a_u_dir_v[v]
+ f_au_v[v] * a_u[u]
+ f_au[v] * a_u_dir_v[u]
+ f_aa_v * a_u[u] * a_u[v]
+ f_aa * (a_u_dir_v[u] * a_u[v] + a_u[u] * a_u_dir_v[v]);
let val_v = -(n_v + f_a_v * a_uv[[u, v]]) * inv_f_a;
a_uv_v[[u, v]] = val_v;
a_uv_v[[v, u]] = val_v;
}
}
let a_u_uv = a_uv_u.dot(dir_v);
let mut a_uv_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let n_uv = f_uv_uv[[u, v]]
+ f_au_uv[u] * a_u[v]
+ f_au_u[u] * a_u_dir_v[v]
+ f_au_v[u] * a_u_dir_u[v]
+ f_au[u] * a_u_uv[v]
+ f_au_uv[v] * a_u[u]
+ f_au_u[v] * a_u_dir_v[u]
+ f_au_v[v] * a_u_dir_u[u]
+ f_au[v] * a_u_uv[u]
+ f_aa_uv * a_u[u] * a_u[v]
+ f_aa_u * (a_u_dir_v[u] * a_u[v] + a_u[u] * a_u_dir_v[v])
+ f_aa_v * (a_u_dir_u[u] * a_u[v] + a_u[u] * a_u_dir_u[v])
+ f_aa
* (a_u_uv[u] * a_u[v]
+ a_u_dir_u[u] * a_u_dir_v[v]
+ a_u_dir_v[u] * a_u_dir_u[v]
+ a_u[u] * a_u_uv[v]);
let val = -(n_uv
+ f_a_v * a_uv_u[[u, v]]
+ f_a_u * a_uv_v[[u, v]]
+ f_a_uv * a_uv[[u, v]])
* inv_f_a;
a_uv_uv[[u, v]] = val;
a_uv_uv[[v, u]] = val;
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
let mut g_aau_fixed = vec![[0.0; 4]; r];
let mut g_abu_fixed = vec![[0.0; 4]; r];
let mut g_bbu_fixed = vec![[0.0; 4]; r];
let mut g_aaau_fixed = vec![[0.0; 4]; r];
let mut g_aabu_fixed = vec![[0.0; 4]; r];
let mut g_abbu_fixed = vec![[0.0; 4]; r];
let mut g_bbbu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
g_aau_fixed[1] = obs.dc_daab;
g_abu_fixed[1] = obs.dc_dabb;
g_bbu_fixed[1] = obs.dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp fourth-direction observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle fourth-direction observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
let (dc_aaaw, dc_aabw, dc_abbw, dc_bbbw) =
exact::link_basis_cell_third_partials(basis_span);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
g_aau_fixed[idx] = scale_coeff4(dc_aaw_raw, scale);
g_abu_fixed[idx] = scale_coeff4(dc_abw_raw, scale);
g_bbu_fixed[idx] = scale_coeff4(dc_bbw_raw, scale);
g_aaau_fixed[idx] = scale_coeff4(dc_aaaw, scale);
g_aabu_fixed[idx] = scale_coeff4(dc_aabw, scale);
g_abbu_fixed[idx] = scale_coeff4(dc_abbw, scale);
g_bbbu_fixed[idx] = scale_coeff4(dc_bbbw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&g_aau_fixed,
&g_abu_fixed,
&g_bbu_fixed,
&g_aaau_fixed,
&g_aabu_fixed,
&g_abbu_fixed,
&g_bbbu_fixed,
);
let g_a = eval_coeff4_at(&obs.dc_da, z_obs);
let g_aa = eval_coeff4_at(&obs.dc_daa, z_obs);
let g_aaa = eval_coeff4_at(&obs.dc_daaa, z_obs);
let mut g_u = Array1::<f64>::zeros(r);
let mut g_au = Array1::<f64>::zeros(r);
let mut g_aau = Array1::<f64>::zeros(r);
let mut g_aaau = Array1::<f64>::zeros(r);
let mut g_uv = Array2::<f64>::zeros((r, r));
let mut g_auv = Array2::<f64>::zeros((r, r));
let mut g_aauv = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
g_au[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
g_aau[u] = eval_coeff4_at(&g_jet.aa_first[u], z_obs);
g_aaau[u] = eval_coeff4_at(&g_jet.aaa_first[u], z_obs);
}
for u in 1..r {
for v in u..r {
let second_coeff = g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = eval_coeff4_at(&second_coeff, z_obs);
g_uv[[u, v]] = val;
g_uv[[v, u]] = val;
let third_coeff = g_jet.pair_from_b_family(g_jet.ab_first, u, v, COEFF_SUPPORT_BW);
let fourth_coeff = g_jet.pair_from_b_family(g_jet.aab_first, u, v, COEFF_SUPPORT_W);
let third_val = eval_coeff4_at(&third_coeff, z_obs);
let fourth_val = eval_coeff4_at(&fourth_coeff, z_obs);
g_auv[[u, v]] = third_val;
g_auv[[v, u]] = third_val;
g_aauv[[u, v]] = fourth_val;
g_aauv[[v, u]] = fourth_val;
}
}
let g_dir_u_fixed = g_jet.directional_family(g_jet.first, dir_u, COEFF_SUPPORT_BHW);
let g_dir_v_fixed = g_jet.directional_family(g_jet.first, dir_v, COEFF_SUPPORT_BHW);
let g_a_dir_u_fixed = g_jet.directional_family(g_jet.a_first, dir_u, COEFF_SUPPORT_BW);
let g_a_dir_v_fixed = g_jet.directional_family(g_jet.a_first, dir_v, COEFF_SUPPORT_BW);
let g_aa_dir_u_fixed = g_jet.directional_family(g_jet.aa_first, dir_u, COEFF_SUPPORT_BW);
let g_aa_dir_v_fixed = g_jet.directional_family(g_jet.aa_first, dir_v, COEFF_SUPPORT_BW);
let g_dir_uv_fixed =
g_jet.mixed_directional_from_b_family(g_jet.b_first, dir_u, dir_v, COEFF_SUPPORT_BHW);
let g_a_dir_uv_fixed =
g_jet.mixed_directional_from_b_family(g_jet.ab_first, dir_u, dir_v, COEFF_SUPPORT_BW);
let g_aa_dir_uv_fixed =
g_jet.mixed_directional_from_b_family(g_jet.aab_first, dir_u, dir_v, COEFF_SUPPORT_W);
let g_dir_u = eval_coeff4_at(&g_dir_u_fixed, z_obs);
let g_dir_v = eval_coeff4_at(&g_dir_v_fixed, z_obs);
let g_dir_uv = eval_coeff4_at(&g_dir_uv_fixed, z_obs);
let g_a_u_fixed = eval_coeff4_at(&g_a_dir_u_fixed, z_obs);
let g_a_v_fixed = eval_coeff4_at(&g_a_dir_v_fixed, z_obs);
let g_aa_u_fixed = eval_coeff4_at(&g_aa_dir_u_fixed, z_obs);
let g_aa_v_fixed = eval_coeff4_at(&g_aa_dir_v_fixed, z_obs);
let g_a_uv_fixed = eval_coeff4_at(&g_a_dir_uv_fixed, z_obs);
let g_aa_uv_fixed = eval_coeff4_at(&g_aa_dir_uv_fixed, z_obs);
let mut g_u_u_fixed = Array1::<f64>::zeros(r);
let mut g_u_v_fixed = Array1::<f64>::zeros(r);
let mut g_u_uv_fixed = Array1::<f64>::zeros(r);
let mut g_au_u_fixed = Array1::<f64>::zeros(r);
let mut g_au_v_fixed = Array1::<f64>::zeros(r);
let mut g_au_uv_fixed = Array1::<f64>::zeros(r);
let mut g_uv_u_fixed = Array2::<f64>::zeros((r, r));
let mut g_uv_v_fixed = Array2::<f64>::zeros((r, r));
let mut g_uv_uv_fixed = Array2::<f64>::zeros((r, r));
let mut g_auv_u_fixed = Array2::<f64>::zeros((r, r));
let mut g_auv_v_fixed = Array2::<f64>::zeros((r, r));
for u in 1..r {
let tmp_u =
g_jet.param_directional_from_b_family(g_jet.b_first, u, dir_u, COEFF_SUPPORT_BHW);
let tmp_v =
g_jet.param_directional_from_b_family(g_jet.b_first, u, dir_v, COEFF_SUPPORT_BHW);
let tmp_uv =
g_jet.param_mixed_from_bb_family(g_jet.bb_first, u, dir_u, dir_v, COEFF_SUPPORT_BW);
let tmp_au_u =
g_jet.param_directional_from_b_family(g_jet.ab_first, u, dir_u, COEFF_SUPPORT_BW);
let tmp_au_v =
g_jet.param_directional_from_b_family(g_jet.ab_first, u, dir_v, COEFF_SUPPORT_BW);
let tmp_au_uv =
g_jet.param_mixed_from_bb_family(g_jet.abb_first, u, dir_u, dir_v, COEFF_SUPPORT_W);
g_u_u_fixed[u] = eval_coeff4_at(&tmp_u, z_obs);
g_u_v_fixed[u] = eval_coeff4_at(&tmp_v, z_obs);
g_u_uv_fixed[u] = eval_coeff4_at(&tmp_uv, z_obs);
g_au_u_fixed[u] = eval_coeff4_at(&tmp_au_u, z_obs);
g_au_v_fixed[u] = eval_coeff4_at(&tmp_au_v, z_obs);
g_au_uv_fixed[u] = eval_coeff4_at(&tmp_au_uv, z_obs);
}
for u in 1..r {
for v in u..r {
let third_u = g_jet.pair_directional_from_bb_family(
g_jet.bb_first,
u,
v,
dir_u,
COEFF_SUPPORT_BW,
);
let third_v = g_jet.pair_directional_from_bb_family(
g_jet.bb_first,
u,
v,
dir_v,
COEFF_SUPPORT_BW,
);
let fourth_uv = g_jet.pair_mixed_from_bbb_family(
g_jet.bbb_first,
u,
v,
dir_u,
dir_v,
COEFF_SUPPORT_W,
);
let a_third_u = g_jet.pair_directional_from_bb_family(
g_jet.abb_first,
u,
v,
dir_u,
COEFF_SUPPORT_W,
);
let a_third_v = g_jet.pair_directional_from_bb_family(
g_jet.abb_first,
u,
v,
dir_v,
COEFF_SUPPORT_W,
);
let vu = eval_coeff4_at(&third_u, z_obs);
let vv = eval_coeff4_at(&third_v, z_obs);
let vuv = eval_coeff4_at(&fourth_uv, z_obs);
g_uv_u_fixed[[u, v]] = vu;
g_uv_v_fixed[[u, v]] = vv;
g_uv_uv_fixed[[u, v]] = vuv;
g_uv_u_fixed[[v, u]] = vu;
g_uv_v_fixed[[v, u]] = vv;
g_uv_uv_fixed[[v, u]] = vuv;
let atu = eval_coeff4_at(&a_third_u, z_obs);
let atv = eval_coeff4_at(&a_third_v, z_obs);
g_auv_u_fixed[[u, v]] = atu;
g_auv_v_fixed[[u, v]] = atv;
g_auv_u_fixed[[v, u]] = atu;
g_auv_v_fixed[[v, u]] = atv;
}
}
let eta_u = g_a * &a_u + &g_u;
let mut eta_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = g_a * a_uv[[u, v]]
+ g_aa * a_u[u] * a_u[v]
+ g_au[u] * a_u[v]
+ g_au[v] * a_u[u]
+ g_uv[[u, v]];
eta_uv[[u, v]] = val;
eta_uv[[v, u]] = val;
}
}
let a_dir_u = a_u.dot(dir_u);
let a_dir_v = a_u.dot(dir_v);
let g_a_u = g_aa * a_dir_u + g_a_u_fixed;
let g_a_v = g_aa * a_dir_v + g_a_v_fixed;
let g_aa_u = g_aaa * a_dir_u + g_aa_u_fixed;
let g_aa_v = g_aaa * a_dir_v + g_aa_v_fixed;
let mut g_u_u = Array1::<f64>::zeros(r);
let mut g_u_v = Array1::<f64>::zeros(r);
let mut g_au_u = Array1::<f64>::zeros(r);
let mut g_au_v = Array1::<f64>::zeros(r);
for u in 0..r {
g_u_u[u] = g_au[u] * a_dir_u + g_u_u_fixed[u];
g_u_v[u] = g_au[u] * a_dir_v + g_u_v_fixed[u];
g_au_u[u] = g_aau[u] * a_dir_u + g_au_u_fixed[u];
g_au_v[u] = g_aau[u] * a_dir_v + g_au_v_fixed[u];
}
let mut eta_uv_u = Array2::<f64>::zeros((r, r));
let mut eta_uv_v = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let g_uv_u = g_auv[[u, v]] * a_dir_u + g_uv_u_fixed[[u, v]];
let g_uv_v = g_auv[[u, v]] * a_dir_v + g_uv_v_fixed[[u, v]];
let val_u = g_a_u * a_uv[[u, v]]
+ g_a * a_uv_u[[u, v]]
+ g_aa_u * a_u[u] * a_u[v]
+ g_aa * (a_u_dir_u[u] * a_u[v] + a_u[u] * a_u_dir_u[v])
+ g_au_u[u] * a_u[v]
+ g_au[u] * a_u_dir_u[v]
+ g_au_u[v] * a_u[u]
+ g_au[v] * a_u_dir_u[u]
+ g_uv_u;
eta_uv_u[[u, v]] = val_u;
eta_uv_u[[v, u]] = val_u;
let val_v = g_a_v * a_uv[[u, v]]
+ g_a * a_uv_v[[u, v]]
+ g_aa_v * a_u[u] * a_u[v]
+ g_aa * (a_u_dir_v[u] * a_u[v] + a_u[u] * a_u_dir_v[v])
+ g_au_v[u] * a_u[v]
+ g_au[u] * a_u_dir_v[v]
+ g_au_v[v] * a_u[u]
+ g_au[v] * a_u_dir_v[u]
+ g_uv_v;
eta_uv_v[[u, v]] = val_v;
eta_uv_v[[v, u]] = val_v;
}
}
let a_dir_uv = a_u_dir_u.dot(dir_v);
let g_a_uv = g_aaa * a_dir_u * a_dir_v
+ g_aa * a_dir_uv
+ g_aa_u_fixed * a_dir_v
+ g_aa_v_fixed * a_dir_u
+ g_a_uv_fixed;
let g_aa_uv = g_aaau.dot(dir_u) * a_dir_v
+ g_aaau.dot(dir_v) * a_dir_u
+ g_aaa * a_dir_uv
+ g_aa_uv_fixed;
let mut g_u_uv = Array1::<f64>::zeros(r);
let mut g_au_uv = Array1::<f64>::zeros(r);
for u in 0..r {
g_u_uv[u] = g_aau[u] * a_dir_u * a_dir_v
+ g_au[u] * a_dir_uv
+ g_au_u_fixed[u] * a_dir_v
+ g_au_v_fixed[u] * a_dir_u
+ g_u_uv_fixed[u];
let row_u_u = g_aauv.row(u).dot(dir_u);
let row_u_v = g_aauv.row(u).dot(dir_v);
g_au_uv[u] = g_aaau[u] * a_dir_u * a_dir_v
+ g_aau[u] * a_dir_uv
+ row_u_u * a_dir_v
+ row_u_v * a_dir_u
+ g_au_uv_fixed[u];
}
let mut eta_uv_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let g_uv_uv = g_aauv[[u, v]] * a_dir_u * a_dir_v
+ g_auv[[u, v]] * a_dir_uv
+ g_auv_u_fixed[[u, v]] * a_dir_v
+ g_auv_v_fixed[[u, v]] * a_dir_u
+ g_uv_uv_fixed[[u, v]];
let val = g_a_uv * a_uv[[u, v]]
+ g_a_u * a_uv_v[[u, v]]
+ g_a_v * a_uv_u[[u, v]]
+ g_a * a_uv_uv[[u, v]]
+ g_aa_uv * a_u[u] * a_u[v]
+ g_aa_u * (a_u_dir_v[u] * a_u[v] + a_u[u] * a_u_dir_v[v])
+ g_aa_v * (a_u_dir_u[u] * a_u[v] + a_u[u] * a_u_dir_u[v])
+ g_aa
* (a_u_uv[u] * a_u[v]
+ a_u_dir_u[u] * a_u_dir_v[v]
+ a_u_dir_v[u] * a_u_dir_u[v]
+ a_u[u] * a_u_uv[v])
+ g_au_uv[u] * a_u[v]
+ g_au_u[u] * a_u_dir_v[v]
+ g_au_v[u] * a_u_dir_u[v]
+ g_au[u] * a_u_uv[v]
+ g_au_uv[v] * a_u[u]
+ g_au_u[v] * a_u_dir_v[u]
+ g_au_v[v] * a_u_dir_u[u]
+ g_au[v] * a_u_uv[u]
+ g_uv_uv;
eta_uv_uv[[u, v]] = val;
eta_uv_uv[[v, u]] = val;
}
}
let eta_dir_u = g_a * a_dir_u + g_dir_u;
let eta_dir_v = g_a * a_dir_v + g_dir_v;
let eta_u_dir_u = eta_uv.dot(dir_u);
let eta_u_dir_v = eta_uv.dot(dir_v);
let eta_dir_uv = g_a_v * a_dir_u + g_a_u_fixed * a_dir_v + g_a * a_dir_uv + g_dir_uv;
let eta_u_uv = eta_uv_u.dot(dir_v);
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let m = s_y * eta_val;
let (k1, k2, k3, k4) = signed_probit_neglog_derivatives_up_to_fourth(m, w_i)?;
let u1 = s_y * k1;
let u3 = s_y * k3;
let mut out = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let a_term = eta_u[u] * eta_u[v] * eta_dir_u;
let a_term_v = eta_u_dir_v[u] * eta_u[v] * eta_dir_u
+ eta_u[u] * eta_u_dir_v[v] * eta_dir_u
+ eta_u[u] * eta_u[v] * eta_dir_uv;
let b_term = eta_uv_u[[u, v]];
let b_term_v = eta_uv_uv[[u, v]];
let c_term = eta_uv[[u, v]] * eta_dir_u
+ eta_u[u] * eta_u_dir_u[v]
+ eta_u[v] * eta_u_dir_u[u];
let c_term_v = eta_uv_v[[u, v]] * eta_dir_u
+ eta_uv[[u, v]] * eta_dir_uv
+ eta_u_dir_v[u] * eta_u_dir_u[v]
+ eta_u[u] * eta_u_uv[v]
+ eta_u_dir_v[v] * eta_u_dir_u[u]
+ eta_u[v] * eta_u_uv[u];
let val = k4 * eta_dir_v * a_term
+ u3 * a_term_v
+ u3 * eta_dir_v * c_term
+ k2 * c_term_v
+ k2 * eta_dir_v * b_term
+ u1 * b_term_v;
out[[u, v]] = val;
out[[v, u]] = val;
}
}
Ok(out)
}
fn row_primary_fourth_contracted_recompute(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir_u: &Array1<f64>,
dir_v: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let ordered = self.row_primary_fourth_contracted_recompute_ordered(
row,
block_states,
cache,
row_ctx,
dir_u,
dir_v,
)?;
if !self.effective_flex_active(block_states)? {
return Ok(ordered);
}
let swapped = self.row_primary_fourth_contracted_recompute_ordered(
row,
block_states,
cache,
row_ctx,
dir_v,
dir_u,
)?;
let mut sym = ordered;
for i in 0..sym.nrows() {
for j in 0..sym.ncols() {
sym[[i, j]] = 0.5 * (sym[[i, j]] + swapped[[i, j]]);
}
}
Ok(sym)
}
fn add_pullback_primary_hessian_hw_only(
&self,
target: &mut Array2<f64>,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
primary_hessian: ArrayView2<'_, f64>,
) {
let h = primary_hessian;
if let (Some(primary_h), Some(block_h)) = (primary.h.as_ref(), slices.h.as_ref()) {
for (local_idx, global_idx) in block_h.clone().enumerate() {
let h_q = h[[0, primary_h.start + local_idx]];
if h_q != 0.0 {
{
let mut col = target.slice_mut(s![slices.marginal.clone(), global_idx]);
self.marginal_design
.axpy_row_into(row, h_q, &mut col)
.expect("marginal axpy column mismatch");
}
{
let mut row_view =
target.slice_mut(s![global_idx, slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, h_q, &mut row_view)
.expect("marginal axpy row mismatch");
}
}
let h_g = h[[1, primary_h.start + local_idx]];
if h_g != 0.0 {
{
let mut col = target.slice_mut(s![slices.logslope.clone(), global_idx]);
self.logslope_design
.axpy_row_into(row, h_g, &mut col)
.expect("logslope axpy column mismatch");
}
{
let mut row_view =
target.slice_mut(s![global_idx, slices.logslope.clone()]);
self.logslope_design
.axpy_row_into(row, h_g, &mut row_view)
.expect("logslope axpy row mismatch");
}
}
}
target
.slice_mut(s![block_h.clone(), block_h.clone()])
.scaled_add(
1.0,
&h.slice(s![
primary_h.start..primary_h.end,
primary_h.start..primary_h.end
]),
);
}
if let (Some(primary_w), Some(block_w)) = (primary.w.as_ref(), slices.w.as_ref()) {
for (local_idx, global_idx) in block_w.clone().enumerate() {
let w_q = h[[0, primary_w.start + local_idx]];
if w_q != 0.0 {
{
let mut col = target.slice_mut(s![slices.marginal.clone(), global_idx]);
self.marginal_design
.axpy_row_into(row, w_q, &mut col)
.expect("marginal axpy column mismatch");
}
{
let mut row_view =
target.slice_mut(s![global_idx, slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, w_q, &mut row_view)
.expect("marginal axpy row mismatch");
}
}
let w_g = h[[1, primary_w.start + local_idx]];
if w_g != 0.0 {
{
let mut col = target.slice_mut(s![slices.logslope.clone(), global_idx]);
self.logslope_design
.axpy_row_into(row, w_g, &mut col)
.expect("logslope axpy column mismatch");
}
{
let mut row_view =
target.slice_mut(s![global_idx, slices.logslope.clone()]);
self.logslope_design
.axpy_row_into(row, w_g, &mut row_view)
.expect("logslope axpy row mismatch");
}
}
}
if let (Some(primary_h), Some(block_h)) = (primary.h.as_ref(), slices.h.as_ref()) {
target
.slice_mut(s![block_h.clone(), block_w.clone()])
.scaled_add(
1.0,
&h.slice(s![
primary_h.start..primary_h.end,
primary_w.start..primary_w.end
]),
);
target
.slice_mut(s![block_w.clone(), block_h.clone()])
.scaled_add(
1.0,
&h.slice(s![
primary_w.start..primary_w.end,
primary_h.start..primary_h.end
]),
);
}
target
.slice_mut(s![block_w.clone(), block_w.clone()])
.scaled_add(
1.0,
&h.slice(s![
primary_w.start..primary_w.end,
primary_w.start..primary_w.end
]),
);
}
}
fn exact_newton_joint_hessian_dense_from_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Array2<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let started = std::time::Instant::now();
if log_exact_work(n) {
log::info!(
"[BMS dense-H] build start n={} p={} source=cache",
n,
slices.total
);
}
let acc = (0..n.div_ceil(ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let chunk_len = end - start;
let mut w_mm = Array1::<f64>::zeros(chunk_len);
let mut w_mg = Array1::<f64>::zeros(chunk_len);
let mut w_gg = Array1::<f64>::zeros(chunk_len);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for (local, row) in (start..end).enumerate() {
let hess_view =
if let Some(cached) = Self::cached_row_primary_hessian(cache, row) {
cached
} else {
let row_ctx = Self::row_ctx(cache, row);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
scratch.hess.view()
};
w_mm[local] = hess_view[[0, 0]];
w_mg[local] = hess_view[[0, 1]];
w_gg[local] = hess_view[[1, 1]];
if let Some(ref mut dc) = acc.dense_correction {
self.add_pullback_primary_hessian_hw_only(
dc, row, slices, primary, hess_view,
);
}
}
acc.add_weighted_design_grams(self, start..end, &w_mm, &w_mg, &w_gg)?;
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
let dense = acc.to_dense(slices);
if log_exact_work(n) {
log::info!(
"[BMS dense-H] build done n={} p={} source=cache elapsed={:.3}s",
n,
slices.total,
started.elapsed().as_secs_f64()
);
}
Ok(dense)
}
fn log_likelihood_from_exact_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<f64, String> {
if !self.effective_flex_active(block_states)? {
return self
.log_likelihood_only_with_options(block_states, &BlockwiseFitOptions::default());
}
let n = self.y.len();
let started = std::time::Instant::now();
if log_exact_work(n) {
log::info!(
"[BMS exact-loglik] eval start n={} p={} source=cache",
n,
cache.slices.total
);
}
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let total: Result<f64, String> = (0..n)
.into_par_iter()
.try_fold(
|| 0.0,
|mut log_likelihood, row| -> Result<_, String> {
let row_ctx = Self::row_ctx(cache, row);
let slope = block_states[1].eta[row];
let obs = self.observed_denested_cell_partials(
row,
row_ctx.intercept,
slope,
beta_h,
beta_w,
)?;
let s_i = eval_coeff4_at(&obs.coeff, self.z[row]);
let signed = (2.0 * self.y[row] - 1.0) * s_i;
let (log_cdf, _) = signed_probit_logcdf_and_mills_ratio(signed);
log_likelihood += self.weights[row] * log_cdf;
Ok(log_likelihood)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
);
let log_likelihood = total?;
if log_exact_work(n) {
log::info!(
"[BMS exact-loglik] eval done n={} p={} source=cache elapsed={:.3}s",
n,
cache.slices.total,
started.elapsed().as_secs_f64()
);
}
Ok(log_likelihood)
}
fn exact_newton_joint_gradient_evaluation_from_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<ExactNewtonJointGradientEvaluation, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let started = std::time::Instant::now();
if log_exact_work(n) {
log::info!(
"[BMS exact-gradient] eval start n={} p={} source=cache",
n,
slices.total
);
}
let make_acc = || {
(
0.0_f64,
Array1::<f64>::zeros(slices.marginal.len()),
Array1::<f64>::zeros(slices.logslope.len()),
slices
.h
.as_ref()
.map(|range| Array1::<f64>::zeros(range.len())),
slices
.w
.as_ref()
.map(|range| Array1::<f64>::zeros(range.len())),
)
};
let (log_likelihood, grad_marginal, grad_logslope, grad_h, grad_w) = (0..n
.div_ceil(ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(make_acc, |mut acc, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in start..end {
let row_ctx = Self::row_ctx(cache, row);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 3));
let neglog = self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
false,
&mut scratch,
)?;
acc.0 -= neglog;
{
let mut marginal = acc.1.view_mut();
self.marginal_design.axpy_row_into(
row,
Self::exact_newton_score_component_from_objective_gradient(
scratch.grad[0],
),
&mut marginal,
)?;
}
{
let mut logslope = acc.2.view_mut();
self.logslope_design.axpy_row_into(
row,
Self::exact_newton_score_component_from_objective_gradient(
scratch.grad[1],
),
&mut logslope,
)?;
}
if let (Some(primary_h), Some(grad_h)) = (primary.h.as_ref(), acc.3.as_mut()) {
for idx in 0..primary_h.len() {
grad_h[idx] +=
Self::exact_newton_score_component_from_objective_gradient(
scratch.grad[primary_h.start + idx],
);
}
}
if let (Some(primary_w), Some(grad_w)) = (primary.w.as_ref(), acc.4.as_mut()) {
for idx in 0..primary_w.len() {
grad_w[idx] +=
Self::exact_newton_score_component_from_objective_gradient(
scratch.grad[primary_w.start + idx],
);
}
}
}
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.0 += right.0;
left.1 += &right.1;
left.2 += &right.2;
if let (Some(lhs), Some(rhs)) = (left.3.as_mut(), right.3.as_ref()) {
*lhs += rhs;
}
if let (Some(lhs), Some(rhs)) = (left.4.as_mut(), right.4.as_ref()) {
*lhs += rhs;
}
Ok(left)
})?;
let mut gradient = Array1::<f64>::zeros(slices.total);
gradient
.slice_mut(s![slices.marginal.clone()])
.assign(&grad_marginal);
gradient
.slice_mut(s![slices.logslope.clone()])
.assign(&grad_logslope);
if let (Some(range), Some(grad_h)) = (slices.h.as_ref(), grad_h.as_ref()) {
gradient.slice_mut(s![range.clone()]).assign(grad_h);
}
if let (Some(range), Some(grad_w)) = (slices.w.as_ref(), grad_w.as_ref()) {
gradient.slice_mut(s![range.clone()]).assign(grad_w);
}
if log_exact_work(n) {
log::info!(
"[BMS exact-gradient] eval done n={} p={} source=cache elapsed={:.3}s",
n,
slices.total,
started.elapsed().as_secs_f64()
);
}
Ok(ExactNewtonJointGradientEvaluation {
log_likelihood,
gradient,
})
}
fn exact_newton_joint_hessian_matvec_from_cache(
&self,
direction: &Array1<f64>,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Array1<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
if !self.effective_flex_active(block_states)? {
let out = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| Array1::<f64>::zeros(slices.total),
|mut chunk_out, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
for row in start..end {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (_, _, h) =
self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
let v_q = self
.marginal_design
.dot_row_view(row, direction.slice(s![slices.marginal.clone()]));
let v_g = self
.logslope_design
.dot_row_view(row, direction.slice(s![slices.logslope.clone()]));
let a_q = h[0][0] * v_q + h[0][1] * v_g;
let a_g = h[1][0] * v_q + h[1][1] * v_g;
{
let mut m = chunk_out.slice_mut(s![slices.marginal.clone()]);
self.marginal_design.axpy_row_into(row, a_q, &mut m)?;
}
{
let mut l = chunk_out.slice_mut(s![slices.logslope.clone()]);
self.logslope_design.axpy_row_into(row, a_g, &mut l)?;
}
}
Ok(chunk_out)
},
)
.try_reduce(
|| Array1::<f64>::zeros(slices.total),
|mut left, right| -> Result<_, String> {
left += &right;
Ok(left)
},
)?;
return Ok(out);
}
let out = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| Array1::<f64>::zeros(slices.total),
|mut chunk_out, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in start..end {
let row_ctx = Self::row_ctx(cache, row);
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, direction)?;
let row_action =
if let Some(row_hess) = Self::cached_row_primary_hessian(cache, row) {
row_hess.dot(&row_dir)
} else {
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
scratch.hess.dot(&row_dir)
};
chunk_out +=
&self.pullback_primary_vector(row, slices, primary, &row_action)?;
}
Ok(chunk_out)
},
)
.try_reduce(
|| Array1::<f64>::zeros(slices.total),
|mut left, right| -> Result<_, String> {
left += &right;
Ok(left)
},
)?;
Ok(out)
}
#[cfg(test)]
fn exact_newton_joint_hessian_matvec_from_cache_serial_reference(
&self,
direction: &Array1<f64>,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Array1<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let mut out = Array1::<f64>::zeros(slices.total);
if !self.effective_flex_active(block_states)? {
for row in 0..self.y.len() {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (_, _, h) = self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
let v_q = self
.marginal_design
.dot_row_view(row, direction.slice(s![slices.marginal.clone()]));
let v_g = self
.logslope_design
.dot_row_view(row, direction.slice(s![slices.logslope.clone()]));
let a_q = h[0][0] * v_q + h[0][1] * v_g;
let a_g = h[1][0] * v_q + h[1][1] * v_g;
{
let mut marginal_out = out.slice_mut(s![slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, a_q, &mut marginal_out)?;
}
{
let mut logslope_out = out.slice_mut(s![slices.logslope.clone()]);
self.logslope_design
.axpy_row_into(row, a_g, &mut logslope_out)?;
}
}
return Ok(out);
}
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in 0..self.y.len() {
let row_ctx = Self::row_ctx(cache, row);
let row_dir = self.row_primary_direction_from_flat(row, slices, primary, direction)?;
let row_action = if let Some(row_hess) = Self::cached_row_primary_hessian(cache, row) {
row_hess.dot(&row_dir)
} else {
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
scratch.hess.dot(&row_dir)
};
out += &self.pullback_primary_vector(row, slices, primary, &row_action)?;
}
Ok(out)
}
fn exact_newton_joint_hessian_diagonal_from_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Array1<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
if !self.effective_flex_active(block_states)? {
let diagonal = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| Array1::<f64>::zeros(slices.total),
|mut chunk_diag, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
for row in start..end {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (_, _, h) =
self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
{
let mut m = chunk_diag.slice_mut(s![slices.marginal.clone()]);
self.marginal_design
.squared_axpy_row_into(row, h[0][0], &mut m)?;
}
{
let mut l = chunk_diag.slice_mut(s![slices.logslope.clone()]);
self.logslope_design
.squared_axpy_row_into(row, h[1][1], &mut l)?;
}
}
Ok(chunk_diag)
},
)
.try_reduce(
|| Array1::<f64>::zeros(slices.total),
|mut left, right| -> Result<_, String> {
left += &right;
Ok(left)
},
)?;
return Ok(diagonal);
}
let diagonal = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| Array1::<f64>::zeros(slices.total),
|mut chunk_diag, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in start..end {
let row_ctx = Self::row_ctx(cache, row);
let cached_hess = Self::cached_row_primary_hessian(cache, row);
if cached_hess.is_none() {
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
}
let h00 = if let Some(row_hess) = cached_hess {
row_hess[[0, 0]]
} else {
scratch.hess[[0, 0]]
};
let h11 = if let Some(row_hess) = cached_hess {
row_hess[[1, 1]]
} else {
scratch.hess[[1, 1]]
};
{
let mut marginal_diag =
chunk_diag.slice_mut(s![slices.marginal.clone()]);
self.marginal_design.squared_axpy_row_into(
row,
h00,
&mut marginal_diag,
)?;
}
{
let mut logslope_diag =
chunk_diag.slice_mut(s![slices.logslope.clone()]);
self.logslope_design.squared_axpy_row_into(
row,
h11,
&mut logslope_diag,
)?;
}
if let (Some(primary_h), Some(block_h)) =
(primary.h.as_ref(), slices.h.as_ref())
{
for (local_idx, global_idx) in block_h.clone().enumerate() {
let ii = primary_h.start + local_idx;
chunk_diag[global_idx] += if let Some(row_hess) = cached_hess {
row_hess[[ii, ii]]
} else {
scratch.hess[[ii, ii]]
};
}
}
if let (Some(primary_w), Some(block_w)) =
(primary.w.as_ref(), slices.w.as_ref())
{
for (local_idx, global_idx) in block_w.clone().enumerate() {
let ii = primary_w.start + local_idx;
chunk_diag[global_idx] += if let Some(row_hess) = cached_hess {
row_hess[[ii, ii]]
} else {
scratch.hess[[ii, ii]]
};
}
}
}
Ok(chunk_diag)
},
)
.try_reduce(
|| Array1::<f64>::zeros(slices.total),
|mut left, right| -> Result<_, String> {
left += &right;
Ok(left)
},
)?;
Ok(diagonal)
}
fn exact_newton_joint_psi_terms_from_cache(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
self.exact_newton_joint_psi_terms_from_cache_with_options(
block_states,
derivative_blocks,
psi_index,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_psi_terms_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
let Some((block_idx, local_idx)) = psi_derivative_location(derivative_blocks, psi_index)
else {
return Ok(None);
};
let axis = self.resolve_psi_axis_spec(derivative_blocks, block_idx, local_idx)?;
let mut results =
self.run_psi_row_pass_for_axes(block_states, cache, options, &[axis])?;
Ok(Some(results.remove(0)))
}
fn resolve_psi_axis_spec(
&self,
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
block_idx: usize,
local_idx: usize,
) -> Result<PsiAxisSpec, String> {
let n = self.y.len();
let deriv = &derivative_blocks[block_idx][local_idx];
let (p_psi, psi_label) = match block_idx {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi terms only support marginal/logslope blocks, got block {block_idx}"
));
}
};
let psi_map = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv,
n,
p_psi,
0..n,
psi_label,
&self.policy,
)?;
Ok(PsiAxisSpec {
block_idx,
idx_primary: if block_idx == 0 { 0 } else { 1 },
psi_map,
})
}
fn run_psi_row_pass_for_axes(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
axes: &[PsiAxisSpec],
) -> Result<Vec<ExactNewtonJointPsiTerms>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let k = axes.len();
if !self.effective_flex_active(block_states)? {
let _ = self.rigid_third_full_cached(block_states, cache, 0)?;
}
let weighted_rows = outer_weighted_rows(options, n);
let make_acc = || -> Vec<(f64, Array1<f64>, BernoulliBlockHessianAccumulator)> {
(0..k)
.map(|_| {
(
0.0f64,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(slices),
)
})
.collect()
};
let folded = weighted_rows
.into_par_iter()
.try_fold(
make_acc,
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let row_ctx = Self::row_ctx(cache, row);
let (f_pi, f_pipi_base) = self
.compute_row_primary_gradient_hessian_reusing_cache(
row,
block_states,
primary,
row_ctx,
cache,
)?;
for (axis_idx, axis) in axes.iter().enumerate() {
let dir = self.row_primary_psi_direction_from_map(
row,
axis.block_idx,
&axis.psi_map,
block_states,
primary,
)?;
let mut f_pipi = f_pipi_base.clone();
let mut third = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir,
)?;
let psi_row = self.block_psi_row_from_map(
row,
axis.block_idx,
&axis.psi_map,
slices,
)?;
let mut f_pipi_dir = f_pipi.dot(&dir);
if w != 1.0 {
f_pipi.mapv_inplace(|v| v * w);
third.mapv_inplace(|v| v * w);
f_pipi_dir.mapv_inplace(|v| v * w);
}
let slot = &mut acc[axis_idx];
slot.0 += w * f_pi.dot(&dir);
slot.1
.slice_mut(s![psi_row.range.clone()])
.scaled_add(w * f_pi[axis.idx_primary], &psi_row.local_vec);
slot.1 +=
&self.pullback_primary_vector(row, slices, primary, &f_pipi_dir)?;
let right_primary = f_pipi.row(axis.idx_primary).to_owned();
slot.2.add_rank1_psi_cross(
self,
row,
slices,
primary,
axis.block_idx,
&psi_row.local_vec,
&right_primary,
);
slot.2.add_pullback(self, row, slices, primary, &third);
}
Ok(acc)
},
)
.try_reduce(
make_acc,
|mut left, right| -> Result<_, String> {
for (l, r) in left.iter_mut().zip(right.into_iter()) {
l.0 += r.0;
l.1 += &r.1;
l.2.add(&r.2);
}
Ok(left)
},
)?;
let mut out = Vec::with_capacity(k);
for (objective_psi, score_psi, block_acc) in folded.into_iter() {
out.push(ExactNewtonJointPsiTerms {
objective_psi,
score_psi,
hessian_psi: Array2::zeros((0, 0)),
hessian_psi_operator: Some(std::sync::Arc::new(block_acc.into_operator(slices))),
});
}
Ok(out)
}
fn exact_newton_joint_psisecond_order_terms_from_cache(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
self.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
block_states,
derivative_blocks,
psi_i,
psi_j,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let Some((block_i, local_i)) = psi_derivative_location(derivative_blocks, psi_i) else {
return Ok(None);
};
let Some((block_j, local_j)) = psi_derivative_location(derivative_blocks, psi_j) else {
return Ok(None);
};
let idx_i = if block_i == 0 { 0 } else { 1 };
let idx_j = if block_j == 0 { 0 } else { 1 };
let n = self.y.len();
let deriv_i = &derivative_blocks[block_i][local_i];
let deriv_j = &derivative_blocks[block_j][local_j];
let (p_psi_i, label_i) = match block_i {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi second-order only supports marginal/logslope blocks, got block {block_i}"
));
}
};
let (p_psi_j, label_j) = match block_j {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi second-order only supports marginal/logslope blocks, got block {block_j}"
));
}
};
let psi_map_i = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv_i,
n,
p_psi_i,
0..n,
label_i,
&self.policy,
)?;
let psi_map_j = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv_j,
n,
p_psi_j,
0..n,
label_j,
&self.policy,
)?;
let psi_map_ij = if block_i == block_j {
Some(
crate::families::custom_family::resolve_custom_family_x_psi_psi_map(
deriv_i,
deriv_j,
local_j,
n,
p_psi_i,
0..n,
label_i,
&self.policy,
)?,
)
} else {
None
};
let weighted_rows = outer_weighted_rows(options, n);
let (objective_psi_psi, score_psi_psi, block_acc) = weighted_rows
.into_par_iter()
.try_fold(
|| {
(
0.0f64,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(slices),
)
},
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
{
let dir_i = self.row_primary_psi_direction_from_map(
row,
block_i,
&psi_map_i,
block_states,
primary,
)?;
let dir_j = self.row_primary_psi_direction_from_map(
row,
block_j,
&psi_map_j,
block_states,
primary,
)?;
let dir_ij = self.row_primary_psi_second_direction_from_map(
row,
block_i,
block_j,
psi_map_ij.as_ref(),
block_states,
primary,
)?;
let row_ctx = Self::row_ctx(cache, row);
let (mut f_pi, mut f_pipi) = self
.compute_row_primary_gradient_hessian_reusing_cache(
row,
block_states,
primary,
row_ctx,
cache,
)?;
let mut third_i = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir_i,
)?;
let mut third_j = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir_j,
)?;
let mut fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir_i,
&dir_j,
)?;
if w != 1.0 {
f_pi.mapv_inplace(|v| v * w);
f_pipi.mapv_inplace(|v| v * w);
third_i.mapv_inplace(|v| v * w);
third_j.mapv_inplace(|v| v * w);
fourth.mapv_inplace(|v| v * w);
}
let br_i = self.block_psi_row_from_map(row, block_i, &psi_map_i, slices)?;
let br_j = self.block_psi_row_from_map(row, block_j, &psi_map_j, slices)?;
let br_ij = self.block_psi_second_row_from_map(
row,
block_i,
block_j,
psi_map_ij.as_ref(),
slices,
)?;
acc.0 += dir_i.dot(&f_pipi.dot(&dir_j)) + f_pi.dot(&dir_ij);
if let Some(ref bij) = br_ij {
let idx_ij = if bij.block_idx == 0 { 0 } else { 1 };
acc.1
.slice_mut(s![bij.range.clone()])
.scaled_add(f_pi[idx_ij], &bij.local_vec);
}
acc.1
.slice_mut(s![br_i.range.clone()])
.scaled_add(f_pipi.row(idx_i).dot(&dir_j), &br_i.local_vec);
acc.1
.slice_mut(s![br_j.range.clone()])
.scaled_add(f_pipi.row(idx_j).dot(&dir_i), &br_j.local_vec);
acc.1 += &self.pullback_primary_vector(
row,
slices,
primary,
&f_pipi.dot(&dir_ij),
)?;
acc.1 += &self.pullback_primary_vector(
row,
slices,
primary,
&third_i.dot(&dir_j),
)?;
if let Some(ref bij) = br_ij {
let idx_ij = if bij.block_idx == 0 { 0 } else { 1 };
let right_primary_ij = f_pipi.row(idx_ij).to_owned();
acc.2.add_rank1_psi_cross(
self,
row,
slices,
primary,
bij.block_idx,
&bij.local_vec,
&right_primary_ij,
);
}
let scalar_ij = f_pipi[[idx_i, idx_j]];
acc.2.add_psi_psi_outer(
block_i,
&br_i.local_vec,
block_j,
&br_j.local_vec,
scalar_ij,
);
let right_primary_i = third_j.row(idx_i).to_owned();
acc.2.add_rank1_psi_cross(
self,
row,
slices,
primary,
block_i,
&br_i.local_vec,
&right_primary_i,
);
let right_primary_j = third_i.row(idx_j).to_owned();
acc.2.add_rank1_psi_cross(
self,
row,
slices,
primary,
block_j,
&br_j.local_vec,
&right_primary_j,
);
acc.2.add_pullback(self, row, slices, primary, &fourth);
let mut third_ij = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir_ij,
)?;
if w != 1.0 {
third_ij.mapv_inplace(|v| v * w);
}
acc.2.add_pullback(self, row, slices, primary, &third_ij);
}
Ok(acc)
},
)
.try_reduce(
|| {
(
0.0f64,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(slices),
)
},
|mut left, right| -> Result<_, String> {
left.0 += right.0;
left.1 += &right.1;
left.2.add(&right.2);
Ok(left)
},
)?;
Ok(Some(ExactNewtonJointPsiSecondOrderTerms {
objective_psi_psi,
score_psi_psi,
hessian_psi_psi: Array2::zeros((0, 0)),
hessian_psi_psi_operator: Some(Box::new(block_acc.into_operator(slices))),
}))
}
fn exact_newton_joint_psihessian_directional_derivative_from_cache(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<Array2<f64>>, String> {
self.exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
block_states,
derivative_blocks,
psi_index,
d_beta_flat,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Array2<f64>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let Some((block_idx, local_idx)) = psi_derivative_location(derivative_blocks, psi_index)
else {
return Ok(None);
};
let idx_primary = if block_idx == 0 { 0 } else { 1 };
let n = self.y.len();
let deriv = &derivative_blocks[block_idx][local_idx];
let (p_psi, psi_label) = match block_idx {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi hessian only supports marginal/logslope blocks, got block {block_idx}"
));
}
};
let psi_map = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv,
n,
p_psi,
0..n,
psi_label,
&self.policy,
)?;
let weighted_rows = outer_weighted_rows(options, n);
let block_acc = weighted_rows
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)?;
let psi_dir = self.row_primary_psi_direction_from_map(
row,
block_idx,
&psi_map,
block_states,
primary,
)?;
let psi_action = self.row_primary_psi_action_on_direction_from_map(
row,
block_idx,
&psi_map,
slices,
d_beta_flat,
primary,
)?;
let row_ctx = Self::row_ctx(cache, row);
let mut third_beta = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
)?;
let mut fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
&psi_dir,
)?;
if w != 1.0 {
third_beta.mapv_inplace(|v| v * w);
fourth.mapv_inplace(|v| v * w);
}
let psi_row = self.block_psi_row_from_map(row, block_idx, &psi_map, slices)?;
let right_primary = third_beta.row(idx_primary).to_owned();
acc.add_rank1_psi_cross(
self,
row,
slices,
primary,
psi_row.block_idx,
&psi_row.local_vec,
&right_primary,
);
acc.add_pullback(self, row, slices, primary, &fourth);
let mut third_action = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&psi_action,
)?;
if w != 1.0 {
third_action.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &third_action);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
Ok(Some(block_acc.to_dense(slices)))
}
pub(crate) fn exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let Some((block_idx, local_idx)) = psi_derivative_location(derivative_blocks, psi_index)
else {
return Ok(None);
};
let idx_primary = if block_idx == 0 { 0 } else { 1 };
let n = self.y.len();
let deriv = &derivative_blocks[block_idx][local_idx];
let (p_psi, psi_label) = match block_idx {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi hessian operator only supports marginal/logslope blocks, got block {block_idx}"
));
}
};
let psi_map = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv,
n,
p_psi,
0..n,
psi_label,
&self.policy,
)?;
let weighted_rows = outer_weighted_rows(options, n);
let block_acc = weighted_rows
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)?;
let psi_dir = self.row_primary_psi_direction_from_map(
row,
block_idx,
&psi_map,
block_states,
primary,
)?;
let psi_action = self.row_primary_psi_action_on_direction_from_map(
row,
block_idx,
&psi_map,
slices,
d_beta_flat,
primary,
)?;
let row_ctx = Self::row_ctx(cache, row);
let mut third_beta = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
)?;
let mut fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
&psi_dir,
)?;
if w != 1.0 {
third_beta.mapv_inplace(|v| v * w);
fourth.mapv_inplace(|v| v * w);
}
let psi_row = self.block_psi_row_from_map(row, block_idx, &psi_map, slices)?;
let right_primary = third_beta.row(idx_primary).to_owned();
acc.add_rank1_psi_cross(
self,
row,
slices,
primary,
psi_row.block_idx,
&psi_row.local_vec,
&right_primary,
);
acc.add_pullback(self, row, slices, primary, &fourth);
let mut third_action = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&psi_action,
)?;
if w != 1.0 {
third_action.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &third_action);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
))
}
fn exact_newton_joint_hessian_directional_derivative_from_cache(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<Array2<f64>>, String> {
self.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
block_states,
d_beta_flat,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Array2<f64>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let weighted_rows = outer_weighted_rows(options, n);
if !self.effective_flex_active(block_states)? {
let block_acc = weighted_rows
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let dq = self
.marginal_design
.dot_row_view(row, d_beta_flat.slice(s![slices.marginal.clone()]));
let dg = self
.logslope_design
.dot_row_view(row, d_beta_flat.slice(s![slices.logslope.clone()]));
let t = self.rigid_row_third_contracted(
row,
marginal_eta,
marginal,
g,
dq,
dg,
)?;
let mut t_arr = Array2::from_shape_fn((2, 2), |(a, b)| t[a][b]);
if w != 1.0 {
t_arr.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &t_arr);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| {
left.add(&right);
Ok(left)
},
)?;
return Ok(Some(block_acc.to_dense(slices)));
}
let block_acc = weighted_rows
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)?;
let row_ctx = Self::row_ctx(cache, row);
let mut third = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
)?;
if w != 1.0 {
third.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &third);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
Ok(Some(block_acc.to_dense(slices)))
}
pub(crate) fn exact_newton_joint_hessian_directional_derivative_operator_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let weighted_rows = outer_weighted_rows(options, n);
if !self.effective_flex_active(block_states)? {
let block_acc = weighted_rows
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let dq = self
.marginal_design
.dot_row_view(row, d_beta_flat.slice(s![slices.marginal.clone()]));
let dg = self
.logslope_design
.dot_row_view(row, d_beta_flat.slice(s![slices.logslope.clone()]));
let t = self.rigid_row_third_contracted(
row,
marginal_eta,
marginal,
g,
dq,
dg,
)?;
let mut t_arr = Array2::from_shape_fn((2, 2), |(a, b)| t[a][b]);
if w != 1.0 {
t_arr.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &t_arr);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
return Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
));
}
let block_acc = weighted_rows
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)?;
let row_ctx = Self::row_ctx(cache, row);
let mut third = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
)?;
if w != 1.0 {
third.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &third);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
))
}
pub(crate) fn exact_newton_joint_hessian_directional_derivative_operators_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_flats: &[Array1<f64>],
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Vec<Option<Arc<dyn HyperOperator>>>, String> {
if d_beta_flats.is_empty() {
return Ok(Vec::new());
}
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let weighted_rows = outer_weighted_rows(options, n);
let make_accs = || {
(0..d_beta_flats.len())
.map(|_| BernoulliBlockHessianAccumulator::new(slices))
.collect::<Vec<_>>()
};
let started = std::time::Instant::now();
let n_rows = weighted_rows.len();
let n_dirs = d_beta_flats.len();
let flex_active = self.effective_flex_active(block_states)?;
let bundle_present = cache.row_cell_moments.is_some();
log::info!(
"[BMS batched dH start] n={} rows={} p={} dirs={} flex={} cell_moments_bundle={}",
n,
n_rows,
slices.total,
n_dirs,
flex_active,
bundle_present,
);
let progress = Arc::new(AtomicUsize::new(0));
let progress_step = (n_rows / 8).max(1);
let progress_started = started;
let bump_progress = |progress: &AtomicUsize| {
let now = progress.fetch_add(1, Ordering::Relaxed) + 1;
if now == n_rows || now % progress_step == 0 {
log::info!(
"[BMS batched dH progress] rows={}/{} dirs={} elapsed={:.3}s",
now,
n_rows,
n_dirs,
progress_started.elapsed().as_secs_f64(),
);
}
};
let dense_contiguous_rows = weighted_rows.len() == n
&& weighted_rows
.iter()
.enumerate()
.all(|(row, wr)| wr.index == row && wr.weight == 1.0);
let mut accs = if !flex_active {
weighted_rows
.clone()
.into_par_iter()
.try_fold(make_accs, |mut accs, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
for (idx, d_beta_flat) in d_beta_flats.iter().enumerate() {
let dq = self
.marginal_design
.dot_row_view(row, d_beta_flat.slice(s![slices.marginal.clone()]));
let dg = self
.logslope_design
.dot_row_view(row, d_beta_flat.slice(s![slices.logslope.clone()]));
let t = self.rigid_row_third_contracted(
row,
marginal_eta,
marginal,
g,
dq,
dg,
)?;
let mut t_arr = Array2::from_shape_fn((2, 2), |(a, b)| t[a][b]);
if w != 1.0 {
t_arr.mapv_inplace(|v| v * w);
}
accs[idx].add_pullback(self, row, slices, primary, &t_arr);
}
bump_progress(&progress);
Ok(accs)
})
.try_reduce(make_accs, |mut left, right| -> Result<_, String> {
for (l, r) in left.iter_mut().zip(right.iter()) {
l.add(r);
}
Ok(left)
})?
} else if dense_contiguous_rows {
let chunk_rows = {
const TARGET_CHUNK_FLOATS: usize = 1 << 17;
(TARGET_CHUNK_FLOATS / (3 * d_beta_flats.len()).max(1)).clamp(1024, n.max(1))
};
let chunks = (0..n)
.step_by(chunk_rows)
.map(|start| (start, (start + chunk_rows).min(n)))
.collect::<Vec<_>>();
chunks
.into_par_iter()
.map(
|(start, end)| -> Result<Vec<BernoulliBlockHessianAccumulator>, String> {
let n_dirs = d_beta_flats.len();
let len = end - start;
let mut accs = make_accs();
let mut w_mm = (0..n_dirs)
.map(|_| Array1::<f64>::zeros(len))
.collect::<Vec<_>>();
let mut w_mg = (0..n_dirs)
.map(|_| Array1::<f64>::zeros(len))
.collect::<Vec<_>>();
let mut w_gg = (0..n_dirs)
.map(|_| Array1::<f64>::zeros(len))
.collect::<Vec<_>>();
for row in start..end {
let local = row - start;
let row_ctx = Self::row_ctx(cache, row);
let row_dirs = d_beta_flats
.iter()
.map(|d_beta_flat| {
self.row_primary_direction_from_flat(
row,
slices,
primary,
d_beta_flat,
)
})
.collect::<Result<Vec<_>, String>>()?;
let thirds = self.row_primary_third_contracted_many_with_moments(
row,
block_states,
cache,
row_ctx,
&row_dirs,
)?;
for (idx, third) in thirds.iter().enumerate() {
w_mm[idx][local] = third[[0, 0]];
w_mg[idx][local] = third[[0, 1]];
w_gg[idx][local] = third[[1, 1]];
accs[idx].add_hw_pullback_only(self, row, slices, primary, third);
}
bump_progress(&progress);
}
for idx in 0..n_dirs {
accs[idx].add_weighted_design_grams(
self,
start..end,
&w_mm[idx],
&w_mg[idx],
&w_gg[idx],
)?;
}
Ok(accs)
},
)
.try_reduce(make_accs, |mut left, right| -> Result<_, String> {
for (l, r) in left.iter_mut().zip(right.iter()) {
l.add(r);
}
Ok(left)
})?
} else {
weighted_rows
.into_par_iter()
.try_fold(make_accs, |mut accs, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let row_ctx = Self::row_ctx(cache, row);
let row_dirs = d_beta_flats
.iter()
.map(|d_beta_flat| {
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)
})
.collect::<Result<Vec<_>, String>>()?;
let mut thirds = self.row_primary_third_contracted_many_with_moments(
row,
block_states,
cache,
row_ctx,
&row_dirs,
)?;
for (idx, third) in thirds.iter_mut().enumerate() {
if w != 1.0 {
third.mapv_inplace(|v| v * w);
}
accs[idx].add_pullback(self, row, slices, primary, &third);
}
bump_progress(&progress);
Ok(accs)
})
.try_reduce(make_accs, |mut left, right| -> Result<_, String> {
for (l, r) in left.iter_mut().zip(right.iter()) {
l.add(r);
}
Ok(left)
})?
};
let elapsed = started.elapsed().as_secs_f64();
log::info!(
"[BMS batched dH] n={} rows={} p={} dirs={} elapsed={:.3}s",
n,
n_rows,
slices.total,
n_dirs,
elapsed
);
Ok(accs
.drain(..)
.map(|acc| Some(Arc::new(acc.into_operator(slices)) as Arc<dyn HyperOperator>))
.collect())
}
fn exact_newton_joint_hessiansecond_directional_derivative_from_cache(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<Array2<f64>>, String> {
self.exact_newton_joint_hessiansecond_directional_derivative_from_cache_with_options(
block_states,
d_beta_u_flat,
d_beta_v_flat,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_hessiansecond_directional_derivative_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Array2<f64>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let make_acc = || BernoulliBlockHessianAccumulator::new(slices);
let weighted_rows = outer_weighted_rows(options, n);
if !self.effective_flex_active(block_states)? {
let _ = self.rigid_fourth_full_cached(block_states, cache, 0)?;
}
if !self.effective_flex_active(block_states)? {
let block_acc = weighted_rows
.into_par_iter()
.try_fold(make_acc, |mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let uq = self
.marginal_design
.dot_row_view(row, d_beta_u_flat.slice(s![slices.marginal.clone()]));
let ug = self
.logslope_design
.dot_row_view(row, d_beta_u_flat.slice(s![slices.logslope.clone()]));
let vq = self
.marginal_design
.dot_row_view(row, d_beta_v_flat.slice(s![slices.marginal.clone()]));
let vg = self
.logslope_design
.dot_row_view(row, d_beta_v_flat.slice(s![slices.logslope.clone()]));
let t = self.rigid_fourth_full_cached(block_states, cache, row)?;
let f = contract_fourth_full(t, uq, ug, vq, vg);
let mut f_arr = Array2::from_shape_fn((2, 2), |(a, b)| f[a][b]);
if w != 1.0 {
f_arr.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &f_arr);
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
})?;
return Ok(Some(block_acc.to_dense(slices)));
}
let block_acc = weighted_rows
.into_par_iter()
.try_fold(make_acc, |mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let row_u =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_u_flat)?;
let row_v =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_v_flat)?;
let row_ctx = Self::row_ctx(cache, row);
let mut fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_u,
&row_v,
)?;
if w != 1.0 {
fourth.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &fourth);
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
})?;
Ok(Some(block_acc.to_dense(slices)))
}
pub(crate) fn exact_newton_joint_hessiansecond_directional_derivative_operator_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let make_acc = || BernoulliBlockHessianAccumulator::new(slices);
let weighted_rows = outer_weighted_rows(options, n);
if !self.effective_flex_active(block_states)? {
let _ = self.rigid_fourth_full_cached(block_states, cache, 0)?;
}
if !self.effective_flex_active(block_states)? {
let block_acc = weighted_rows
.into_par_iter()
.try_fold(make_acc, |mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let uq = self
.marginal_design
.dot_row_view(row, d_beta_u_flat.slice(s![slices.marginal.clone()]));
let ug = self
.logslope_design
.dot_row_view(row, d_beta_u_flat.slice(s![slices.logslope.clone()]));
let vq = self
.marginal_design
.dot_row_view(row, d_beta_v_flat.slice(s![slices.marginal.clone()]));
let vg = self
.logslope_design
.dot_row_view(row, d_beta_v_flat.slice(s![slices.logslope.clone()]));
let t = self.rigid_fourth_full_cached(block_states, cache, row)?;
let f = contract_fourth_full(t, uq, ug, vq, vg);
let mut f_arr = Array2::from_shape_fn((2, 2), |(a, b)| f[a][b]);
if w != 1.0 {
f_arr.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &f_arr);
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
})?;
return Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
));
}
let block_acc = weighted_rows
.into_par_iter()
.try_fold(make_acc, |mut acc, wr| -> Result<_, String> {
let row = wr.index;
let w = wr.weight;
let row_u =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_u_flat)?;
let row_v =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_v_flat)?;
let row_ctx = Self::row_ctx(cache, row);
let mut fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_u,
&row_v,
)?;
if w != 1.0 {
fourth.mapv_inplace(|v| v * w);
}
acc.add_pullback(self, row, slices, primary, &fourth);
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
})?;
Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
))
}
fn evaluate_flex_block_diagonals_from_cache(
&self,
block_states: &[ParameterBlockState],
slices: &BlockSlices,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<FamilyEvaluation, String> {
let primary = cache.primary.clone();
let n = self.y.len();
let n_chunks = n.div_ceil(ROW_CHUNK_SIZE);
let reduced = (0..n_chunks)
.into_par_iter()
.try_fold(
|| BernoulliExactNewtonAccumulator::new(slices),
|mut acc, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in start..end {
let row_ctx = Self::row_ctx(cache, row);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
let row_neglog = self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
&primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
acc.add_pullback_block_diagonals(
self, row, &primary, row_neglog, &scratch,
)?;
}
Ok(acc)
},
)
.try_reduce(
|| BernoulliExactNewtonAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
let BernoulliExactNewtonAccumulator {
ll,
grad_marginal,
grad_logslope,
hess_marginal,
hess_logslope,
grad_h,
grad_w,
hess_h,
hess_w,
} = reduced;
let mut blockworking_sets = vec![
BlockWorkingSet::ExactNewton {
gradient: grad_marginal,
hessian: SymmetricMatrix::Dense(hess_marginal),
},
BlockWorkingSet::ExactNewton {
gradient: grad_logslope,
hessian: SymmetricMatrix::Dense(hess_logslope),
},
];
if let (Some(gradient), Some(hessian)) = (grad_h, hess_h) {
blockworking_sets.push(BlockWorkingSet::ExactNewton {
gradient,
hessian: SymmetricMatrix::Dense(hessian),
});
}
if let (Some(gradient), Some(hessian)) = (grad_w, hess_w) {
blockworking_sets.push(BlockWorkingSet::ExactNewton {
gradient,
hessian: SymmetricMatrix::Dense(hessian),
});
}
Ok(FamilyEvaluation {
log_likelihood: ll,
blockworking_sets,
})
}
fn evaluate_blockwise_exact_newton(
&self,
block_states: &[ParameterBlockState],
) -> Result<FamilyEvaluation, String> {
let slices = block_slices(self);
let flex_active = self.effective_flex_active(block_states)?;
if !flex_active && slices.total < 512 {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let cache = build_row_kernel_cache(&kern, &crate::families::row_kernel::RowSet::All)?;
let ll = row_kernel_log_likelihood(&cache, &crate::families::row_kernel::RowSet::All);
let joint_gradient = Self::exact_newton_score_from_objective_gradient(
row_kernel_gradient(&kern, &cache, &crate::families::row_kernel::RowSet::All),
);
let n = cache.n;
let p_marginal = slices.marginal.len();
let p_logslope = slices.logslope.len();
let make_pair = || {
(
Array2::<f64>::zeros((p_marginal, p_marginal)),
Array2::<f64>::zeros((p_logslope, p_logslope)),
)
};
let (hess_marginal, hess_logslope) = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
make_pair,
|(mut hm, mut hl), chunk_idx| -> Result<(Array2<f64>, Array2<f64>), String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let rows = end - start;
let marginal_chunk = self
.marginal_design
.try_row_chunk(start..end)
.map_err(|e| format!("bernoulli marginal_design try_row_chunk: {e}"))?;
let logslope_chunk = self
.logslope_design
.try_row_chunk(start..end)
.map_err(|e| format!("bernoulli logslope_design try_row_chunk: {e}"))?;
let mut hm_w = Array1::<f64>::zeros(rows);
let mut hl_w = Array1::<f64>::zeros(rows);
for local_row in 0..rows {
let h = &cache.hessians[start + local_row];
hm_w[local_row] = h[0][0];
hl_w[local_row] = h[1][1];
}
add_weighted_chunk_gram(&marginal_chunk, &hm_w, &mut hm);
add_weighted_chunk_gram(&logslope_chunk, &hl_w, &mut hl);
Ok((hm, hl))
},
)
.try_reduce(
make_pair,
|(mut lhm, mut lhl),
(rhm, rhl)|
-> Result<(Array2<f64>, Array2<f64>), String> {
lhm += &rhm;
lhl += &rhl;
Ok((lhm, lhl))
},
)?;
let hess_marginal =
Self::exact_newton_observed_information_from_objective_hessian(hess_marginal);
let hess_logslope =
Self::exact_newton_observed_information_from_objective_hessian(hess_logslope);
let grad_marginal = joint_gradient.slice(s![slices.marginal.clone()]).to_owned();
let grad_logslope = joint_gradient.slice(s![slices.logslope.clone()]).to_owned();
let mut sets = vec![
BlockWorkingSet::ExactNewton {
gradient: grad_marginal,
hessian: SymmetricMatrix::Dense(hess_marginal),
},
BlockWorkingSet::ExactNewton {
gradient: grad_logslope,
hessian: SymmetricMatrix::Dense(hess_logslope),
},
];
if let Some(range) = slices.h.as_ref() {
sets.push(BlockWorkingSet::ExactNewton {
gradient: Array1::zeros(range.len()),
hessian: SymmetricMatrix::Dense(Array2::zeros((range.len(), range.len()))),
});
}
if let Some(range) = slices.w.as_ref() {
sets.push(BlockWorkingSet::ExactNewton {
gradient: Array1::zeros(range.len()),
hessian: SymmetricMatrix::Dense(Array2::zeros((range.len(), range.len()))),
});
}
return Ok(FamilyEvaluation {
log_likelihood: ll,
blockworking_sets: sets,
});
}
if flex_active {
let cache = self.build_exact_eval_cache_with_order(block_states)?;
return self.evaluate_flex_block_diagonals_from_cache(block_states, &slices, &cache);
}
let n = self.y.len();
let p_marginal = slices.marginal.len();
let p_logslope = slices.logslope.len();
let make_acc = || {
(
0.0_f64,
Array1::<f64>::zeros(p_marginal),
Array1::<f64>::zeros(p_logslope),
Array2::<f64>::zeros((p_marginal, p_marginal)),
Array2::<f64>::zeros((p_logslope, p_logslope)),
)
};
let (ll, grad_marginal, grad_logslope, hess_marginal, hess_logslope) =
(0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
make_acc,
|(mut ll, mut gm, mut gl, mut hm, mut hl), chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let rows = end - start;
let marginal_chunk = self
.marginal_design
.try_row_chunk(start..end)
.map_err(|e| format!("bernoulli marginal_design try_row_chunk: {e}"))?;
let logslope_chunk = self
.logslope_design
.try_row_chunk(start..end)
.map_err(|e| format!("bernoulli logslope_design try_row_chunk: {e}"))?;
let mut gm_w = Array1::<f64>::zeros(rows);
let mut gl_w = Array1::<f64>::zeros(rows);
let mut hm_w = Array1::<f64>::zeros(rows);
let mut hl_w = Array1::<f64>::zeros(rows);
for local_row in 0..rows {
let row = start + local_row;
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (neglog, grad, h) =
self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
ll -= neglog;
gm_w[local_row] =
Self::exact_newton_score_component_from_objective_gradient(grad[0]);
gl_w[local_row] =
Self::exact_newton_score_component_from_objective_gradient(grad[1]);
hm_w[local_row] = h[0][0];
hl_w[local_row] = h[1][1];
}
add_weighted_chunk_gradient(&marginal_chunk, &gm_w, &mut gm);
add_weighted_chunk_gradient(&logslope_chunk, &gl_w, &mut gl);
add_weighted_chunk_gram(&marginal_chunk, &hm_w, &mut hm);
add_weighted_chunk_gram(&logslope_chunk, &hl_w, &mut hl);
Ok((ll, gm, gl, hm, hl))
},
)
.try_reduce(
make_acc,
|(lll, mut lgm, mut lgl, mut lhm, mut lhl),
(rll, rgm, rgl, rhm, rhl)|
-> Result<_, String> {
lgm += &rgm;
lgl += &rgl;
lhm += &rhm;
lhl += &rhl;
Ok((lll + rll, lgm, lgl, lhm, lhl))
},
)?;
Ok(FamilyEvaluation {
log_likelihood: ll,
blockworking_sets: {
let mut sets = vec![
BlockWorkingSet::ExactNewton {
gradient: grad_marginal,
hessian: SymmetricMatrix::Dense(hess_marginal),
},
BlockWorkingSet::ExactNewton {
gradient: grad_logslope,
hessian: SymmetricMatrix::Dense(hess_logslope),
},
];
if let Some(range) = slices.h.as_ref() {
sets.push(BlockWorkingSet::ExactNewton {
gradient: Array1::zeros(range.len()),
hessian: SymmetricMatrix::Dense(Array2::zeros((range.len(), range.len()))),
});
}
if let Some(range) = slices.w.as_ref() {
sets.push(BlockWorkingSet::ExactNewton {
gradient: Array1::zeros(range.len()),
hessian: SymmetricMatrix::Dense(Array2::zeros((range.len(), range.len()))),
});
}
sets
},
})
}
}
impl CustomFamily for BernoulliMarginalSlopeFamily {
fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
true
}
fn pseudo_logdet_mode(&self) -> crate::custom_family::PseudoLogdetMode {
crate::custom_family::PseudoLogdetMode::HardPseudo
}
fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
let n = self.y.len() as u64;
let p_total: u64 = specs
.iter()
.map(|s| s.design.ncols() as u64)
.fold(0u64, |a, p| a.saturating_add(p));
if crate::custom_family::use_joint_matrix_free_path(p_total as usize, n as usize) {
n.saturating_mul(p_total)
} else {
crate::custom_family::joint_coupled_coefficient_hessian_cost(n, specs)
}
}
fn exact_outer_derivative_order(
&self,
specs: &[ParameterBlockSpec],
_: &BlockwiseFitOptions,
) -> crate::custom_family::ExactOuterDerivativeOrder {
use crate::custom_family::ExactOuterDerivativeOrder;
let coefficient_work = self
.coefficient_hessian_cost(specs)
.max(self.coefficient_gradient_cost(specs));
if !self.outer_hyper_hessian_dense_available(specs)
&& !self.outer_hyper_hessian_hvp_available(specs)
{
return ExactOuterDerivativeOrder::First;
}
crate::custom_family::exact_outer_order_with_outer_hvp(
specs,
coefficient_work,
self.outer_hyper_hessian_hvp_available(specs),
)
}
fn outer_derivative_policy(
&self,
specs: &[ParameterBlockSpec],
psi_dim: usize,
options: &BlockwiseFitOptions,
) -> crate::custom_family::OuterDerivativePolicy {
use crate::custom_family::OuterDerivativePolicy;
let capability = self.exact_outer_derivative_order(specs, options);
let n = self.y.len() as u128;
let rho_dim_from_specs = specs
.iter()
.map(|spec| spec.penalties.len() as u128)
.sum::<u128>();
let k = rho_dim_from_specs.saturating_add(psi_dim as u128).max(1);
let predicted_hessian_work = if self.flex_active() {
let primary_total = 2u128
+ self
.score_warp
.as_ref()
.map(|runtime| runtime.basis_dim() as u128)
.unwrap_or(0)
+ self
.link_dev
.as_ref()
.map(|runtime| runtime.basis_dim() as u128)
.unwrap_or(0);
n.saturating_mul(k)
.saturating_mul(primary_total.saturating_mul(primary_total))
} else {
let p_total = specs
.iter()
.map(|spec| spec.design.ncols() as u128)
.sum::<u128>();
n.saturating_mul(k)
.saturating_mul(p_total.saturating_add(1))
};
let predicted_gradient_work = predicted_hessian_work / 2;
OuterDerivativePolicy {
capability,
predicted_gradient_work,
predicted_hessian_work,
}
}
fn outer_seed_config(&self, n_params: usize) -> crate::seeding::SeedConfig {
let mut config = crate::seeding::SeedConfig::default();
if n_params == 0 {
return config;
}
config.max_seeds = 1;
config.seed_budget = 1;
config.screen_max_inner_iterations = 2;
config
}
fn exact_newton_joint_psi_workspace_for_first_order_terms(&self) -> bool {
true
}
fn batched_outer_gradient_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
rho: &Array1<f64>,
options: &BlockwiseFitOptions,
hessian_workspace: Option<Arc<dyn ExactNewtonJointHessianWorkspace>>,
) -> Result<Option<BatchedOuterGradientTerms>, String> {
let psi_dim: usize = derivative_blocks.iter().map(Vec::len).sum();
if psi_dim != 0 {
return Ok(None);
}
if block_states.len() != specs.len() {
return Ok(None);
}
let stratum_secondary: Vec<u8> = self
.y
.iter()
.map(|v| if *v > 0.5 { 1u8 } else { 0u8 })
.collect();
let owned_options;
let options: &BlockwiseFitOptions =
match crate::families::marginal_slope_shared::maybe_install_auto_outer_subsample(
options,
self.z.as_slice().expect("z must be contiguous"),
Some(&stratum_secondary),
rho,
&self.auto_subsample_phase_counter,
&self.auto_subsample_last_rho,
BMS_AUTO_SUBSAMPLE_PHASE1_BUDGET,
"BMS",
) {
Some(cloned) => {
owned_options = cloned;
&owned_options
}
None => options,
};
let ranges = Self::block_ranges_from_specs(specs);
let total = ranges.last().map(|(_, end)| *end).unwrap_or(0);
if total == 0 {
return Ok(Some(BatchedOuterGradientTerms {
objective_theta: Array1::zeros(0),
trace_h_inv_hdot: Array1::zeros(0),
trace_s_pinv_sdot: Array1::zeros(0),
}));
}
if rho.len() != specs.iter().map(|spec| spec.penalties.len()).sum::<usize>() {
return Ok(None);
}
if total >= 512 {
return Ok(None);
}
let beta = Self::flatten_block_state_betas_for_specs(block_states, specs)?;
let mut h = if let Some(workspace) = hessian_workspace.as_ref() {
workspace.hessian_dense()?.ok_or_else(|| {
"bernoulli marginal-slope batched gradient requires dense exact joint Hessian below p=512"
.to_string()
})?
} else {
self.exact_newton_joint_hessian(block_states)?.ok_or_else(|| {
"bernoulli marginal-slope batched gradient could not build dense exact joint Hessian"
.to_string()
})?
};
if h.nrows() != total || h.ncols() != total {
return Err(format!(
"bernoulli marginal-slope batched gradient Hessian shape {}x{} != {total}x{total}",
h.nrows(),
h.ncols()
));
}
let ridge = options.ridge_floor.max(1e-15);
let trace_diagonal_ridge = if options.ridge_policy.include_quadratic_penalty
|| options.ridge_policy.include_penalty_logdet
{
ridge
} else {
0.0
};
let mut objective_theta = Array1::<f64>::zeros(rho.len());
let mut trace_s_pinv_sdot = Array1::<f64>::zeros(rho.len());
let mut penalty_cursor = 0usize;
let mut per_block_rho: Vec<Array1<f64>> = Vec::with_capacity(specs.len());
let mut penalties_dense: Vec<Vec<Array2<f64>>> = Vec::with_capacity(specs.len());
for (block_idx, spec) in specs.iter().enumerate() {
let count = spec.penalties.len();
let block_rho = rho
.slice(s![penalty_cursor..penalty_cursor + count])
.to_owned();
let lambdas = block_rho.mapv(f64::exp);
per_block_rho.push(block_rho);
let (start, end) = ranges[block_idx];
let beta_block = beta.slice(s![start..end]);
let mut s_lambda = Array2::<f64>::zeros((end - start, end - start));
let mut block_penalties = Vec::with_capacity(count);
for (local_idx, penalty) in spec.penalties.iter().enumerate() {
let dense = penalty.to_dense();
let lambda = lambdas[local_idx];
let s_beta = dense.dot(&beta_block);
objective_theta[penalty_cursor + local_idx] =
0.5 * lambda * beta_block.dot(&s_beta);
s_lambda.scaled_add(lambda, &dense);
block_penalties.push(dense);
}
h.slice_mut(s![start..end, start..end])
.scaled_add(1.0, &s_lambda);
penalties_dense.push(block_penalties);
penalty_cursor += count;
}
if trace_diagonal_ridge != 0.0 {
for diag in 0..total {
h[[diag, diag]] += trace_diagonal_ridge;
}
}
let penalty_logdet_ridge = if options.ridge_policy.include_penalty_logdet {
ridge
} else {
0.0
};
let per_block_penalty_refs: Vec<&[Array2<f64>]> =
penalties_dense.iter().map(Vec::as_slice).collect();
let per_block_nullspace_dims: Vec<&[usize]> = specs
.iter()
.map(|spec| spec.nullspace_dims.as_slice())
.collect();
let penalty_logdet = crate::estimate::reml::unified::compute_block_penalty_logdet_derivs(
&per_block_rho,
&per_block_penalty_refs,
&per_block_nullspace_dims,
penalty_logdet_ridge,
)?;
trace_s_pinv_sdot.assign(&penalty_logdet.first);
let spectral =
DenseSpectralOperator::from_symmetric_with_mode(&h, self.pseudo_logdet_mode())?;
let factor = spectral.logdet_gradient_factor();
let mut trace_h_inv_hdot = Array1::<f64>::zeros(rho.len());
let mut directions = Array2::<f64>::zeros((total, rho.len()));
penalty_cursor = 0;
for (block_idx, spec) in specs.iter().enumerate() {
let (start, end) = ranges[block_idx];
let beta_block = beta.slice(s![start..end]);
for (local_idx, _penalty) in spec.penalties.iter().enumerate() {
let idx = penalty_cursor + local_idx;
let lambda = rho[idx].exp();
let dense = &penalties_dense[block_idx][local_idx];
trace_h_inv_hdot[idx] +=
spectral.trace_logdet_block_local(dense, lambda, start, end);
let curvature_rhs = dense.dot(&beta_block).mapv(|value| lambda * value);
let mut rhs = Array1::<f64>::zeros(total);
rhs.slice_mut(s![start..end]).assign(&curvature_rhs);
let v = spectral.solve(&rhs);
directions.column_mut(idx).assign(&(-&v));
}
penalty_cursor += spec.penalties.len();
}
let started = std::time::Instant::now();
let workspace_traces = if options.outer_score_subsample.is_some() {
None
} else if let Some(workspace) = hessian_workspace.as_ref() {
workspace.projected_directional_derivative_traces(factor, &directions)?
} else {
None
};
let correction_traces = if let Some(traces) = workspace_traces {
traces
} else {
let owned_cache =
self.build_exact_eval_cache_with_options(block_states, Some(options))?;
if options.outer_score_subsample.is_some() {
let weighted_rows = outer_weighted_rows(options, self.y.len());
self.batched_rho_correction_logdet_traces_for_rows(
block_states,
&owned_cache,
factor,
&directions,
&weighted_rows,
)?
} else {
self.batched_rho_correction_logdet_traces_full_rows(
block_states,
&owned_cache,
factor,
&directions,
)?
}
};
trace_h_inv_hdot += &correction_traces;
if log_exact_work(self.y.len()) {
log::info!(
"[BMS batched outer-gradient] n={} p={} rho={} trace_elapsed={:.3}s",
self.y.len(),
total,
rho.len(),
started.elapsed().as_secs_f64()
);
}
Ok(Some(BatchedOuterGradientTerms {
objective_theta,
trace_h_inv_hdot,
trace_s_pinv_sdot,
}))
}
fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
self.validate_exact_monotonicity(block_states)?;
self.evaluate_blockwise_exact_newton(block_states)
}
fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
Self::log_likelihood_only_with_options(self, block_states, &BlockwiseFitOptions::default())
}
fn log_likelihood_only_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<f64, String> {
Self::log_likelihood_only_with_options(self, block_states, options)
}
fn supports_log_likelihood_early_exit(&self) -> bool {
true
}
fn joint_line_search_log_likelihood_workspace(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
line_search_options: &BlockwiseFitOptions,
workspace_options: &BlockwiseFitOptions,
) -> Result<Option<(f64, Arc<dyn ExactNewtonJointHessianWorkspace>)>, String> {
let Some(workspace) = self.exact_newton_joint_hessian_workspace_with_options(
block_states,
specs,
workspace_options,
)?
else {
return Ok(None);
};
let log_likelihood = match workspace.joint_log_likelihood_evaluation()? {
Some(value) => value,
None => {
Self::log_likelihood_only_with_options(self, block_states, line_search_options)?
}
};
Ok(Some((log_likelihood, workspace)))
}
fn joint_line_search_log_likelihood_evaluation(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
line_search_options: &BlockwiseFitOptions,
workspace_options: &BlockwiseFitOptions,
) -> Result<Option<(f64, Option<Arc<dyn ExactNewtonJointHessianWorkspace>>)>, String> {
self.line_search_log_likelihood_workspace(
block_states,
line_search_options,
workspace_options,
)
.map(|maybe| maybe.map(|(log_likelihood, workspace)| (log_likelihood, Some(workspace))))
}
fn max_feasible_step_size(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
delta: &Array1<f64>,
) -> Result<Option<f64>, String> {
self.validate_exact_block_state_shapes(block_states)?;
let beta = block_states.get(block_idx).ok_or_else(|| {
format!(
"bernoulli marginal-slope block index {block_idx} is out of bounds for {} states",
block_states.len()
)
})?;
if delta.len() != beta.beta.len() {
return Err(format!(
"bernoulli marginal-slope step length mismatch for block {block_idx}: delta={}, beta={}",
delta.len(),
beta.beta.len()
));
}
let beta_norm_sq: f64 = beta.beta.iter().map(|v| v * v).sum();
let delta_norm_sq: f64 = delta.iter().map(|v| v * v).sum();
if delta_norm_sq <= 0.0 || !delta_norm_sq.is_finite() {
return Ok(None);
}
let delta_norm = delta_norm_sq.sqrt();
let beta_norm = beta_norm_sq.sqrt();
const BMS_STEP_CAP_FRACTION: f64 = 0.5;
const BMS_STEP_CAP_BETA_FLOOR: f64 = 1.0e-6;
const BMS_STEP_CAP_FALLBACK: f64 = 1.0;
let cap = if beta_norm < BMS_STEP_CAP_BETA_FLOOR {
BMS_STEP_CAP_FALLBACK / delta_norm
} else {
(BMS_STEP_CAP_FRACTION * beta_norm) / delta_norm
};
if !cap.is_finite() || cap <= 0.0 {
return Ok(None);
}
Ok(Some(cap))
}
fn has_explicit_joint_hessian(&self) -> bool {
true
}
fn exact_newton_joint_hessian(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
let slices = block_slices(self);
if slices.total >= 512 {
return Ok(None);
}
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let cache = build_row_kernel_cache(&kern, &crate::families::row_kernel::RowSet::All)?;
return Ok(Some(row_kernel_hessian_dense(
&kern,
&cache,
&crate::families::row_kernel::RowSet::All,
)));
}
let cache = self.build_exact_eval_cache_with_order(block_states)?;
self.exact_newton_joint_hessian_dense_from_cache(block_states, &cache)
.map(Some)
}
fn exact_newton_joint_gradient_evaluation(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
self.validate_exact_monotonicity(block_states)?;
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let cache = build_row_kernel_cache(&kern, &crate::families::row_kernel::RowSet::All)?;
return Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood: row_kernel_log_likelihood(
&cache,
&crate::families::row_kernel::RowSet::All,
),
gradient: Self::exact_newton_score_from_objective_gradient(row_kernel_gradient(
&kern,
&cache,
&crate::families::row_kernel::RowSet::All,
)),
}));
}
let cache = self.build_exact_eval_cache_with_order(block_states)?;
self.exact_newton_joint_gradient_evaluation_from_cache(block_states, &cache)
.map(Some)
}
fn requires_joint_outer_hyper_path(&self) -> bool {
true
}
fn exact_newton_joint_hessian_workspace(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
Ok(Some(Arc::new(RowKernelHessianWorkspace::new(kern)?)))
} else {
Ok(Some(Arc::new(
BernoulliMarginalSlopeExactNewtonJointHessianWorkspace::new(
self.clone(),
block_states.to_vec(),
BlockwiseFitOptions::default(),
)?,
)))
}
}
fn exact_newton_joint_hessian_workspace_with_options(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let rows = crate::families::row_kernel::RowSet::from_options(options, self.y.len());
Ok(Some(Arc::new(RowKernelHessianWorkspace::with_rows(
kern, rows,
)?)))
} else {
Ok(Some(Arc::new(
BernoulliMarginalSlopeExactNewtonJointHessianWorkspace::new(
self.clone(),
block_states.to_vec(),
options.clone(),
)?,
)))
}
}
fn inner_coefficient_hessian_hvp_available(&self, _specs: &[ParameterBlockSpec]) -> bool {
true
}
fn inner_joint_workspace_gradient_available(&self, _specs: &[ParameterBlockSpec]) -> bool {
true
}
fn inner_joint_workspace_log_likelihood_available(
&self,
_specs: &[ParameterBlockSpec],
) -> bool {
true
}
fn exact_newton_joint_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let sl = d_beta_flat.as_slice().ok_or("non-contiguous d_beta")?;
crate::families::row_kernel::row_kernel_directional_derivative(
&kern,
&crate::families::row_kernel::RowSet::All,
sl,
)
.map(Some)
} else {
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_hessian_directional_derivative_from_cache(
block_states,
d_beta_flat,
&cache,
)
}
}
fn exact_newton_joint_hessiansecond_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let su = d_beta_u_flat.as_slice().ok_or("non-contiguous d_beta_u")?;
let sv = d_beta_v_flat.as_slice().ok_or("non-contiguous d_beta_v")?;
crate::families::row_kernel::row_kernel_second_directional_derivative(
&kern,
&crate::families::row_kernel::RowSet::All,
su,
sv,
)
.map(Some)
} else {
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_hessiansecond_directional_derivative_from_cache(
block_states,
d_beta_u_flat,
d_beta_v_flat,
&cache,
)
}
}
fn exact_newton_joint_psi_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_index) {
return self.sigma_exact_joint_psi_terms(block_states, specs);
}
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_psi_terms_from_cache(
block_states,
derivative_blocks,
psi_index,
&cache,
)
}
fn exact_newton_joint_psisecond_order_terms(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_i)
|| self.is_sigma_aux_index(derivative_blocks, psi_j)
{
if self.is_sigma_aux_index(derivative_blocks, psi_i)
&& self.is_sigma_aux_index(derivative_blocks, psi_j)
{
return self.sigma_exact_joint_psisecond_order_terms(block_states);
}
return Err(
"bernoulli marginal-slope mixed log-sigma/spatial psi second derivatives require cross auxiliary terms; only pure log-sigma second derivatives are supported"
.to_string(),
);
}
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_psisecond_order_terms_from_cache(
block_states,
derivative_blocks,
psi_i,
psi_j,
&cache,
)
}
fn exact_newton_joint_psihessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_index) {
return self
.sigma_exact_joint_psihessian_directional_derivative(block_states, d_beta_flat);
}
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_psihessian_directional_derivative_from_cache(
block_states,
derivative_blocks,
psi_index,
d_beta_flat,
&cache,
)
}
fn exact_newton_joint_psi_workspace(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
Ok(Some(Arc::new(
BernoulliMarginalSlopeExactNewtonJointPsiWorkspace::new(
self.clone(),
block_states.to_vec(),
specs.to_vec(),
derivative_blocks.to_vec(),
BlockwiseFitOptions::default(),
)?,
)))
}
fn exact_newton_joint_psi_workspace_with_options(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
Ok(Some(Arc::new(
BernoulliMarginalSlopeExactNewtonJointPsiWorkspace::new(
self.clone(),
block_states.to_vec(),
specs.to_vec(),
derivative_blocks.to_vec(),
options.clone(),
)?,
)))
}
fn block_linear_constraints(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
spec: &ParameterBlockSpec,
) -> Result<Option<LinearInequalityConstraints>, String> {
if block_states.len() == usize::MAX
|| block_idx == usize::MAX
|| spec.design.ncols() == usize::MAX
{
return Err("unreachable bernoulli marginal-slope constraint state".to_string());
}
if self.score_block_index().is_some_and(|idx| block_idx == idx) {
return Ok(self
.score_warp
.as_ref()
.map(DeviationRuntime::structural_monotonicity_constraints));
}
if self.link_block_index().is_some_and(|idx| block_idx == idx) {
return Ok(self
.link_dev
.as_ref()
.map(DeviationRuntime::structural_monotonicity_constraints));
}
Ok(None)
}
fn post_update_block_beta(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
_: &ParameterBlockSpec,
beta: Array1<f64>,
) -> Result<Array1<f64>, String> {
self.validate_exact_block_state_shapes(block_states)?;
if block_idx >= block_states.len() {
return Err(format!(
"post-update block index {} out of range for {} blocks",
block_idx,
block_states.len()
));
}
if self.score_block_index().is_some_and(|idx| block_idx == idx) {
if let (Some(runtime), Some(score)) =
(&self.score_warp, self.score_block_state(block_states)?)
{
let current = &score.beta;
if current.len() != beta.len() {
return Err(format!(
"score-warp post-update beta length mismatch: current={}, proposed={}",
current.len(),
beta.len()
));
}
return project_monotone_feasible_beta(runtime, current, &beta, "score_warp_dev");
}
}
if self.link_block_index().is_some_and(|idx| block_idx == idx) {
if let (Some(runtime), Some(link)) =
(&self.link_dev, self.link_block_state(block_states)?)
{
let current = &link.beta;
if current.len() != beta.len() {
return Err(format!(
"link-deviation post-update beta length mismatch: current={}, proposed={}",
current.len(),
beta.len()
));
}
return project_monotone_feasible_beta(runtime, current, &beta, "link_dev");
}
}
Ok(beta)
}
}
impl BernoulliMarginalSlopeExactNewtonJointHessianWorkspace {
fn new(
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
options: BlockwiseFitOptions,
) -> Result<Self, String> {
let cache = family.build_shared_eval_cache_with_options(&block_states, &options)?;
Self::from_arc_cache(family, block_states, cache, options)
}
fn from_cache(
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
mut cache: BernoulliMarginalSlopeExactEvalCache,
options: BlockwiseFitOptions,
) -> Result<Self, String> {
let expected = CacheFingerprint::compute(
&block_states,
options
.outer_score_subsample
.as_ref()
.map(|s| s.mask.as_slice()),
Some(&options),
);
if cache.fingerprint != expected {
return Err(format!(
"BernoulliMarginalSlopeExactEvalCache fingerprint mismatch in from_cache: \
cache was built for a different (β, options) than the workspace being constructed"
));
}
cache.row_primary_hessians =
family.build_row_primary_hessian_cache(&block_states, &cache)?;
Self::from_arc_cache(family, block_states, Arc::new(cache), options)
}
fn from_arc_cache(
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
cache: Arc<BernoulliMarginalSlopeExactEvalCache>,
options: BlockwiseFitOptions,
) -> Result<Self, String> {
Ok(Self {
family,
block_states,
cache,
matvec_calls: AtomicUsize::new(0),
options,
})
}
}
impl BernoulliMarginalSlopeLineSearchWorkspace {
fn materialized(
&self,
) -> Result<&Arc<BernoulliMarginalSlopeExactNewtonJointHessianWorkspace>, String> {
match self.full_workspace.get_or_init(|| {
let mut cache = self.cache.clone();
let row_cell_mask = self
.options
.outer_score_subsample
.as_ref()
.map(|subsample| subsample.mask.as_slice());
cache.row_cell_moments = match row_cell_mask {
Some(mask) => self.family.build_row_cell_moments_bundle(
&self.block_states,
&cache.row_contexts,
21,
Some(mask),
)?,
None => None,
};
BernoulliMarginalSlopeExactNewtonJointHessianWorkspace::from_cache(
self.family.clone(),
self.block_states.clone(),
cache,
self.options.clone(),
)
.map(Arc::new)
}) {
Ok(workspace) => Ok(workspace),
Err(err) => Err(err.clone()),
}
}
}
impl ExactNewtonJointHessianWorkspace for BernoulliMarginalSlopeExactNewtonJointHessianWorkspace {
fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
if self.cache.slices.total >= 512 {
return Ok(None);
}
self.family
.exact_newton_joint_hessian_dense_from_cache(&self.block_states, &self.cache)
.map(Some)
}
fn joint_log_likelihood_evaluation(&self) -> Result<Option<f64>, String> {
self.family
.log_likelihood_from_exact_cache(&self.block_states, &self.cache)
.map(Some)
}
fn joint_gradient_evaluation(
&self,
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
self.family
.exact_newton_joint_gradient_evaluation_from_cache(&self.block_states, &self.cache)
.map(Some)
}
fn hessian_matvec(&self, beta_flat: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
let call = self.matvec_calls.fetch_add(1, Ordering::Relaxed) + 1;
let started = std::time::Instant::now();
let result = self
.family
.exact_newton_joint_hessian_matvec_from_cache(
beta_flat,
&self.block_states,
&self.cache,
)
.map(Some);
if log_exact_work(self.family.y.len()) && (call <= 3 || call.is_power_of_two()) {
log::info!(
"[BMS Hessian-Hv] call={} n={} p={} elapsed={:.3}s",
call,
self.family.y.len(),
self.cache.slices.total,
started.elapsed().as_secs_f64()
);
}
result
}
fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
self.family
.exact_newton_joint_hessian_diagonal_from_cache(&self.block_states, &self.cache)
.map(Some)
}
fn projected_directional_derivative_traces(
&self,
factor: &Array2<f64>,
directions: &Array2<f64>,
) -> Result<Option<Array1<f64>>, String> {
let traces = if self.options.outer_score_subsample.is_some() {
let weighted_rows = outer_weighted_rows(&self.options, self.family.y.len());
self.family.batched_rho_correction_logdet_traces_for_rows(
&self.block_states,
&self.cache,
factor,
directions,
&weighted_rows,
)?
} else {
self.family.batched_rho_correction_logdet_traces_full_rows(
&self.block_states,
&self.cache,
factor,
directions,
)?
};
Ok(Some(traces))
}
fn directional_derivative_operator(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
self.family
.exact_newton_joint_hessian_directional_derivative_operator_from_cache_with_options(
&self.block_states,
d_beta_flat,
&self.cache,
&self.options,
)
}
fn directional_derivative_operators(
&self,
d_beta_flats: &[Array1<f64>],
) -> Result<Vec<Option<Arc<dyn HyperOperator>>>, String> {
self.family
.exact_newton_joint_hessian_directional_derivative_operators_from_cache_with_options(
&self.block_states,
d_beta_flats,
&self.cache,
&self.options,
)
}
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&self.block_states,
d_beta_flat,
&self.cache,
&self.options,
)
}
fn second_directional_derivative_operator(
&self,
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
self.family
.exact_newton_joint_hessiansecond_directional_derivative_operator_from_cache_with_options(
&self.block_states,
d_beta_u_flat,
d_beta_v_flat,
&self.cache,
&self.options,
)
}
fn second_directional_derivative(
&self,
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessiansecond_directional_derivative_from_cache_with_options(
&self.block_states,
d_beta_u_flat,
d_beta_v_flat,
&self.cache,
&self.options,
)
}
}
impl ExactNewtonJointHessianWorkspace for BernoulliMarginalSlopeLineSearchWorkspace {
fn warm_up_outer_caches(&self) -> Result<(), String> {
self.materialized()?;
Ok(())
}
fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
self.materialized()?.hessian_dense()
}
fn joint_log_likelihood_evaluation(&self) -> Result<Option<f64>, String> {
Ok(Some(self.log_likelihood))
}
fn joint_gradient_evaluation(
&self,
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
self.materialized()?.joint_gradient_evaluation()
}
fn hessian_matvec(&self, beta_flat: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
self.materialized()?.hessian_matvec(beta_flat)
}
fn hessian_matvec_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<bool, String> {
self.materialized()?.hessian_matvec_into(v, out)
}
fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
self.materialized()?.hessian_diagonal()
}
fn projected_directional_derivative_traces(
&self,
factor: &Array2<f64>,
directions: &Array2<f64>,
) -> Result<Option<Array1<f64>>, String> {
self.materialized()?
.projected_directional_derivative_traces(factor, directions)
}
fn directional_derivative_operator(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
self.materialized()?
.directional_derivative_operator(d_beta_flat)
}
fn directional_derivative_operators(
&self,
d_beta_flats: &[Array1<f64>],
) -> Result<Vec<Option<Arc<dyn HyperOperator>>>, String> {
self.materialized()?
.directional_derivative_operators(d_beta_flats)
}
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.materialized()?.directional_derivative(d_beta_flat)
}
fn second_directional_derivative_operator(
&self,
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
self.materialized()?
.second_directional_derivative_operator(d_beta_u_flat, d_beta_v_flat)
}
fn second_directional_derivative(
&self,
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.materialized()?
.second_directional_derivative(d_beta_u_flat, d_beta_v_flat)
}
}
impl BernoulliMarginalSlopeFamily {
fn block_ranges_from_specs(specs: &[ParameterBlockSpec]) -> Vec<(usize, usize)> {
let mut cursor = 0usize;
specs
.iter()
.map(|spec| {
let start = cursor;
cursor += spec.design.ncols();
(start, cursor)
})
.collect()
}
fn flatten_block_state_betas_for_specs(
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
) -> Result<Array1<f64>, String> {
if block_states.len() != specs.len() {
return Err(format!(
"bernoulli marginal-slope batched gradient state/spec mismatch: states={}, specs={}",
block_states.len(),
specs.len()
));
}
let total = specs.iter().map(|spec| spec.design.ncols()).sum::<usize>();
let mut beta = Array1::<f64>::zeros(total);
let mut cursor = 0usize;
for (idx, (state, spec)) in block_states.iter().zip(specs.iter()).enumerate() {
let width = spec.design.ncols();
if state.beta.len() != width {
return Err(format!(
"bernoulli marginal-slope batched gradient block {idx} beta length {} != spec width {}",
state.beta.len(),
width
));
}
beta.slice_mut(s![cursor..cursor + width])
.assign(&state.beta);
cursor += width;
}
Ok(beta)
}
fn row_factor_primary_projection(
&self,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
factor: &Array2<f64>,
out: &mut [f64],
) -> Result<(), String> {
let rank = factor.ncols();
if out.len() != primary.total * rank {
return Err(format!(
"primary projection scratch length {} != {}",
out.len(),
primary.total * rank
));
}
out.fill(0.0);
for col in 0..rank {
out[primary.q * rank + col] = self
.marginal_design
.dot_row_view(row, factor.slice(s![slices.marginal.clone(), col]));
out[primary.logslope * rank + col] = self
.logslope_design
.dot_row_view(row, factor.slice(s![slices.logslope.clone(), col]));
}
if let (Some(block_range), Some(primary_range)) = (slices.h.as_ref(), primary.h.as_ref()) {
for (local, block_idx) in block_range.clone().enumerate() {
let primary_idx = primary_range.start + local;
for col in 0..rank {
out[primary_idx * rank + col] = factor[[block_idx, col]];
}
}
}
if let (Some(block_range), Some(primary_range)) = (slices.w.as_ref(), primary.w.as_ref()) {
for (local, block_idx) in block_range.clone().enumerate() {
let primary_idx = primary_range.start + local;
for col in 0..rank {
out[primary_idx * rank + col] = factor[[block_idx, col]];
}
}
}
Ok(())
}
fn row_primary_gram_from_projection(
primary_total: usize,
rank: usize,
projection: &[f64],
) -> Vec<f64> {
let mut gram = vec![0.0; primary_total * primary_total];
for a in 0..primary_total {
for b in 0..=a {
let mut sum = 0.0;
let a_base = a * rank;
let b_base = b * rank;
for col in 0..rank {
sum += projection[a_base + col] * projection[b_base + col];
}
gram[a * primary_total + b] = sum;
gram[b * primary_total + a] = sum;
}
}
gram
}
fn primary_tail_block_pairs(
slices: &BlockSlices,
primary: &PrimarySlices,
) -> Vec<(usize, usize)> {
let mut out = Vec::new();
if let (Some(block_range), Some(primary_range)) = (slices.h.as_ref(), primary.h.as_ref()) {
out.extend(
block_range
.clone()
.enumerate()
.map(|(offset, block_idx)| (primary_range.start + offset, block_idx)),
);
}
if let (Some(block_range), Some(primary_range)) = (slices.w.as_ref(), primary.w.as_ref()) {
out.extend(
block_range
.clone()
.enumerate()
.map(|(offset, block_idx)| (primary_range.start + offset, block_idx)),
);
}
out
}
fn primary_tail_tail_gram(
primary_total: usize,
rank: usize,
factor: &Array2<f64>,
tail_pairs: &[(usize, usize)],
) -> Vec<f64> {
let mut gram = vec![0.0; primary_total * primary_total];
for (a_pos, &(primary_a, block_a)) in tail_pairs.iter().enumerate() {
for &(primary_b, block_b) in tail_pairs.iter().take(a_pos + 1) {
let mut sum = 0.0;
for col in 0..rank {
sum += factor[[block_a, col]] * factor[[block_b, col]];
}
gram[primary_a * primary_total + primary_b] = sum;
gram[primary_b * primary_total + primary_a] = sum;
}
}
gram
}
fn row_primary_trace_contract(third: &Array2<f64>, gram: &[f64]) -> f64 {
let r = third.nrows();
debug_assert_eq!(third.ncols(), r);
debug_assert_eq!(gram.len(), r * r);
let mut total = 0.0;
for a in 0..r {
for b in 0..r {
total += third[[a, b]] * gram[a * r + b];
}
}
total
}
fn row_primary_third_contracted_many_with_moments(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
row_dirs: &[Array1<f64>],
) -> Result<Vec<Array2<f64>>, String> {
let primary = &cache.primary;
let r = primary.total;
if row_dirs.is_empty() {
return Ok(Vec::new());
}
if let Some((idx, dir)) = row_dirs.iter().enumerate().find(|(_, dir)| dir.len() != r) {
return Err(format!(
"bernoulli marginal-slope row third direction {idx} length {} != {r}",
dir.len()
));
}
if row_dirs.len() == 1 {
return Ok(vec![
self.row_primary_third_contracted_recompute_with_moments(
row,
block_states,
cache,
row_ctx,
&row_dirs[0],
)?,
]);
}
if !self.effective_flex_active(block_states)? {
let t = self.rigid_third_full_cached(block_states, cache, row)?;
return row_dirs
.iter()
.map(|dir| {
let m = contract_third_full(t, dir[0], dir[1]);
let mut out = Array2::<f64>::zeros((2, 2));
out[[0, 0]] = m[0][0];
out[[0, 1]] = m[0][1];
out[[1, 0]] = m[1][0];
out[[1, 1]] = m[1][1];
Ok(out)
})
.collect();
}
if !row_ctx.intercept.is_finite() || !row_ctx.m_a.is_finite() || row_ctx.m_a <= 0.0 {
return Err(
"non-finite flexible row context in batched third-order contraction".to_string(),
);
}
let point = self.primary_point_from_block_states(row, block_states, primary)?;
let (q, b, beta_h_owned, beta_w_owned) = self.primary_point_components(&point, primary);
let beta_h = beta_h_owned.as_ref();
let beta_w = beta_w_owned.as_ref();
if let Some(grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
return row_dirs
.iter()
.map(|dir| {
self.empirical_flex_row_third_contracted_recompute(
row, primary, q, b, beta_h, beta_w, row_ctx, dir, &grid,
)
})
.collect();
}
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let a = row_ctx.intercept;
let marginal = self.marginal_link_map(q)?;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let zero_family = vec![[0.0; 4]; r];
let n_dirs = row_dirs.len();
let mut f_a = 0.0;
let mut f_aa = 0.0;
let mut f_u = Array1::<f64>::zeros(r);
let mut f_au = Array1::<f64>::zeros(r);
let mut f_uv = Array2::<f64>::zeros((r, r));
let mut f_a_dir = vec![0.0; n_dirs];
let mut f_aa_dir = vec![0.0; n_dirs];
let mut f_au_dir = vec![0.0; n_dirs * r];
let mut f_uv_dir = vec![0.0; n_dirs * r * r];
let owned_cells;
let cells: &[CachedDenestedCellMoments] = if let Some(cached) = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 15))
{
cached
} else {
let partitions = self.denested_partition_cells(a, b, beta_h, beta_w)?;
owned_cells = partitions
.into_iter()
.map(|partition_cell| {
self.evaluate_cell_derivative_moments_lru(partition_cell.cell, 15)
.map(|state| CachedDenestedCellMoments {
partition_cell,
state,
})
})
.collect::<Result<Vec<_>, String>>()?;
&owned_cells
};
for entry in cells {
let partition_cell = entry.partition_cell;
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = &entry.state;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let denested_third = exact::denested_cell_third_partials(partition_cell.link_span);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
let dc_daab = scale_coeff4(denested_third.1, scale);
let dc_dabb = scale_coeff4(denested_third.2, scale);
let dc_dbbb = scale_coeff4(denested_third.3, scale);
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
coeff_u[1] = dc_db;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
coeff_aau[1] = dc_daab;
coeff_abu[1] = dc_dabb;
coeff_bbu[1] = dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp batched third direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle batched third direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw_raw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw_raw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&coeff_aau,
&coeff_abu,
&coeff_bbu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
f_a += exact::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_jet.first[u], &state.moments)?;
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_jet.a_first[u],
&state.moments,
)?;
}
for u in 1..r {
for v in u..r {
let second_coeff =
coeff_jet.pair_from_b_family(coeff_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
}
}
for (dir_idx, dir) in row_dirs.iter().enumerate() {
let coeff_dir =
coeff_jet.directional_family(coeff_jet.first, dir, COEFF_SUPPORT_BHW);
let coeff_a_dir =
coeff_jet.directional_family(coeff_jet.a_first, dir, COEFF_SUPPORT_BW);
let coeff_aa_dir =
coeff_jet.directional_family(coeff_jet.aa_first, dir, COEFF_SUPPORT_BW);
f_a_dir[dir_idx] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_dir,
&coeff_a_dir,
&state.moments,
)?;
f_aa_dir[dir_idx] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir,
&dc_daa,
&coeff_a_dir,
&coeff_a_dir,
&coeff_aa_dir,
&state.moments,
)?;
let mut coeff_u_dir = vec![[0.0; 4]; r];
let mut coeff_au_dir = vec![[0.0; 4]; r];
for u in 1..r {
coeff_u_dir[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.b_first,
u,
dir,
COEFF_SUPPORT_BHW,
);
coeff_au_dir[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.ab_first,
u,
dir,
COEFF_SUPPORT_BW,
);
}
for u in 1..r {
f_au_dir[dir_idx * r + u] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_dir,
&coeff_jet.a_first[u],
&coeff_a_dir,
&coeff_u_dir[u],
&coeff_au_dir[u],
&state.moments,
)?;
}
let dir_base = dir_idx * r * r;
for u in 1..r {
for v in u..r {
let second_coeff = coeff_jet.pair_from_b_family(
coeff_jet.b_first,
u,
v,
COEFF_SUPPORT_BHW,
);
let third_coeff = coeff_jet.pair_directional_from_bb_family(
coeff_jet.bb_first,
u,
v,
dir,
COEFF_SUPPORT_BW,
);
let dir_val = exact::cell_third_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir,
&second_coeff,
&coeff_u_dir[u],
&coeff_u_dir[v],
&third_coeff,
&state.moments,
)?;
f_uv_dir[dir_base + u * r + v] += dir_val;
if u != v {
f_uv_dir[dir_base + v * r + u] += dir_val;
}
}
}
}
}
f_u[0] = -marginal.mu1;
f_uv[[0, 0]] = -marginal.mu2;
let inv_f_a = 1.0 / f_a;
let mut a_u = Array1::<f64>::zeros(r);
for u in 0..r {
a_u[u] = -f_u[u] * inv_f_a;
}
let mut a_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val =
-(f_uv[[u, v]] + f_au[u] * a_u[v] + f_au[v] * a_u[u] + f_aa * a_u[u] * a_u[v])
* inv_f_a;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
let mut g_aau_fixed = vec![[0.0; 4]; r];
let mut g_abu_fixed = vec![[0.0; 4]; r];
let mut g_bbu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
g_aau_fixed[1] = obs.dc_daab;
g_abu_fixed[1] = obs.dc_dabb;
g_bbu_fixed[1] = obs.dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp batched third observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle batched third observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
g_aau_fixed[idx] = scale_coeff4(dc_aaw_raw, scale);
g_abu_fixed[idx] = scale_coeff4(dc_abw_raw, scale);
g_bbu_fixed[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&g_aau_fixed,
&g_abu_fixed,
&g_bbu_fixed,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
let g_a = eval_coeff4_at(&obs.dc_da, z_obs);
let g_aa = eval_coeff4_at(&obs.dc_daa, z_obs);
let g_aaa = eval_coeff4_at(&obs.dc_daaa, z_obs);
let mut g_u = Array1::<f64>::zeros(r);
let mut g_au = Array1::<f64>::zeros(r);
let mut g_aau = Array1::<f64>::zeros(r);
let mut g_uv = Array2::<f64>::zeros((r, r));
let mut g_auv = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
g_au[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
g_aau[u] = eval_coeff4_at(&g_jet.aa_first[u], z_obs);
}
for u in 1..r {
for v in u..r {
let second_coeff = g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = eval_coeff4_at(&second_coeff, z_obs);
g_uv[[u, v]] = val;
g_uv[[v, u]] = val;
let third_coeff = g_jet.pair_from_b_family(g_jet.ab_first, u, v, COEFF_SUPPORT_BW);
let third_val = eval_coeff4_at(&third_coeff, z_obs);
g_auv[[u, v]] = third_val;
g_auv[[v, u]] = third_val;
}
}
let eta_u = g_a * &a_u + &g_u;
let mut eta_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = g_a * a_uv[[u, v]]
+ g_aa * a_u[u] * a_u[v]
+ g_au[u] * a_u[v]
+ g_au[v] * a_u[u]
+ g_uv[[u, v]];
eta_uv[[u, v]] = val;
eta_uv[[v, u]] = val;
}
}
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let m = s_y * eta_val;
let (k1, k2, k3, _) = signed_probit_neglog_derivatives_up_to_fourth(m, w_i)?;
let u1 = s_y * k1;
let u3 = s_y * k3;
let mut out = (0..n_dirs)
.map(|_| Array2::<f64>::zeros((r, r)))
.collect::<Vec<_>>();
for (dir_idx, dir) in row_dirs.iter().enumerate() {
let dir_base = dir_idx * r * r;
f_uv_dir[dir_base] = -dir[0] * marginal.mu3;
let a_dir = a_u.dot(dir);
let a_u_dir = a_uv.dot(dir);
let g_dir_fixed = g_jet.directional_family(g_jet.first, dir, COEFF_SUPPORT_BHW);
let g_a_dir_fixed = g_jet.directional_family(g_jet.a_first, dir, COEFF_SUPPORT_BW);
let g_aa_dir_fixed = g_jet.directional_family(g_jet.aa_first, dir, COEFF_SUPPORT_BW);
let g_dir = eval_coeff4_at(&g_dir_fixed, z_obs);
let g_a_dir = eval_coeff4_at(&g_a_dir_fixed, z_obs);
let g_aa_dir = eval_coeff4_at(&g_aa_dir_fixed, z_obs);
let eta_dir = g_a * a_dir + g_dir;
let eta_u_dir = eta_uv.dot(dir);
let dg_a_dir = g_aa * a_dir + g_a_dir;
let dg_aa_dir = g_aaa * a_dir + g_aa_dir;
let mut dg_au_dir = Array1::<f64>::zeros(r);
for u in 0..r {
let coeff =
g_jet.param_directional_from_b_family(g_jet.ab_first, u, dir, COEFF_SUPPORT_BW);
dg_au_dir[u] = g_aau[u] * a_dir + eval_coeff4_at(&coeff, z_obs);
}
for u in 0..r {
for v in u..r {
let fuvd = f_uv_dir[dir_base + u * r + v];
let n_dir = fuvd
+ f_au_dir[dir_idx * r + u] * a_u[v]
+ f_au[u] * a_u_dir[v]
+ f_au_dir[dir_idx * r + v] * a_u[u]
+ f_au[v] * a_u_dir[u]
+ f_aa_dir[dir_idx] * a_u[u] * a_u[v]
+ f_aa * (a_u_dir[u] * a_u[v] + a_u[u] * a_u_dir[v]);
let a_uv_dir = -(n_dir + f_a_dir[dir_idx] * a_uv[[u, v]]) * inv_f_a;
let third_coeff = g_jet.pair_directional_from_bb_family(
g_jet.bb_first,
u,
v,
dir,
COEFF_SUPPORT_BW,
);
let dg_uv_dir = g_auv[[u, v]] * a_dir + eval_coeff4_at(&third_coeff, z_obs);
let eta_uv_dir = dg_a_dir * a_uv[[u, v]]
+ g_a * a_uv_dir
+ dg_aa_dir * a_u[u] * a_u[v]
+ g_aa * (a_u_dir[u] * a_u[v] + a_u[u] * a_u_dir[v])
+ dg_au_dir[u] * a_u[v]
+ g_au[u] * a_u_dir[v]
+ dg_au_dir[v] * a_u[u]
+ g_au[v] * a_u_dir[u]
+ dg_uv_dir;
let val = u3 * eta_u[u] * eta_u[v] * eta_dir
+ k2 * (eta_uv[[u, v]] * eta_dir
+ eta_u[u] * eta_u_dir[v]
+ eta_u[v] * eta_u_dir[u])
+ u1 * eta_uv_dir;
out[dir_idx][[u, v]] = val;
out[dir_idx][[v, u]] = val;
}
}
}
Ok(out)
}
fn batched_rho_correction_logdet_traces_for_rows(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
factor: &Array2<f64>,
directions: &Array2<f64>,
weighted_rows: &[WeightedOuterRow],
) -> Result<Array1<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let rank = factor.ncols();
let n_dirs = directions.ncols();
if factor.nrows() != slices.total || directions.nrows() != slices.total {
return Err(format!(
"bernoulli marginal-slope batched trace shape mismatch: factor={}x{}, directions={}x{}, p={}",
factor.nrows(),
factor.ncols(),
directions.nrows(),
directions.ncols(),
slices.total
));
}
let traces = weighted_rows
.par_iter()
.try_fold(
|| vec![0.0; n_dirs],
|mut acc, wr| -> Result<_, String> {
let row = wr.index;
let row_ctx = Self::row_ctx(cache, row);
let mut projection = vec![0.0; primary.total * rank];
self.row_factor_primary_projection(
row,
slices,
primary,
factor,
&mut projection,
)?;
let gram =
Self::row_primary_gram_from_projection(primary.total, rank, &projection);
if n_dirs == 1 {
let d_beta = directions.column(0).to_owned();
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, &d_beta)?;
let row_traces = self.row_primary_third_trace_many_with_moments(
row,
block_states,
cache,
row_ctx,
&[row_dir],
&gram,
)?;
acc[0] += wr.weight * row_traces[0];
return Ok(acc);
}
let trace_gradient = self.row_primary_third_trace_gradient_with_moments(
row,
block_states,
cache,
row_ctx,
&gram,
)?;
for dir_idx in 0..n_dirs {
let mut trace = trace_gradient[primary.q]
* self.marginal_design.dot_row_view(
row,
directions.slice(s![slices.marginal.clone(), dir_idx]),
);
trace += trace_gradient[primary.logslope]
* self.logslope_design.dot_row_view(
row,
directions.slice(s![slices.logslope.clone(), dir_idx]),
);
if let (Some(block_range), Some(primary_range)) =
(slices.h.as_ref(), primary.h.as_ref())
{
for (offset, block_idx) in block_range.clone().enumerate() {
trace += trace_gradient[primary_range.start + offset]
* directions[[block_idx, dir_idx]];
}
}
if let (Some(block_range), Some(primary_range)) =
(slices.w.as_ref(), primary.w.as_ref())
{
for (offset, block_idx) in block_range.clone().enumerate() {
trace += trace_gradient[primary_range.start + offset]
* directions[[block_idx, dir_idx]];
}
}
acc[dir_idx] += wr.weight * trace;
}
Ok(acc)
},
)
.try_reduce(
|| vec![0.0; n_dirs],
|mut left, right| -> Result<_, String> {
for (l, r) in left.iter_mut().zip(right.iter()) {
*l += *r;
}
Ok(left)
},
)?;
Ok(Array1::from_vec(traces))
}
fn batched_rho_correction_logdet_traces_full_rows(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
factor: &Array2<f64>,
directions: &Array2<f64>,
) -> Result<Array1<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let rank = factor.ncols();
let n_dirs = directions.ncols();
if n == 0 || rank == 0 || n_dirs == 0 {
return Ok(Array1::zeros(n_dirs));
}
let rows_per_chunk = {
const TARGET_BYTES: usize = 8 * 1024 * 1024;
let panels = 4usize;
let width = rank + n_dirs;
(TARGET_BYTES / (panels * width.max(1) * 8)).max(512).min(n)
};
let factor_m = factor.slice(s![slices.marginal.clone(), ..]);
let factor_g = factor.slice(s![slices.logslope.clone(), ..]);
let dir_m = directions.slice(s![slices.marginal.clone(), ..]);
let dir_g = directions.slice(s![slices.logslope.clone(), ..]);
let tail_pairs = Self::primary_tail_block_pairs(slices, primary);
let tail_tail_gram = Self::primary_tail_tail_gram(primary.total, rank, factor, &tail_pairs);
let n_chunks = n.div_ceil(rows_per_chunk);
let traces = (0..n_chunks)
.into_par_iter()
.map(|chunk_idx| -> Result<Vec<f64>, String> {
let start = chunk_idx * rows_per_chunk;
let end = (start + rows_per_chunk).min(n);
let rows = start..end;
let x_chunk = self
.marginal_design
.try_row_chunk(rows.clone())
.map_err(|err| format!("marginal trace row chunk failed: {err}"))?;
let g_chunk = self
.logslope_design
.try_row_chunk(rows.clone())
.map_err(|err| format!("logslope trace row chunk failed: {err}"))?;
let proj_m = crate::faer_ndarray::fast_ab(&x_chunk, &factor_m);
let proj_g = crate::faer_ndarray::fast_ab(&g_chunk, &factor_g);
let dir_proj_m = crate::faer_ndarray::fast_ab(&x_chunk, &dir_m);
let dir_proj_g = crate::faer_ndarray::fast_ab(&g_chunk, &dir_g);
let mut acc = vec![0.0; n_dirs];
let mut gram = vec![0.0; primary.total * primary.total];
let mut row_dir = Array1::<f64>::zeros(primary.total);
for local in 0..(end - start) {
let row = start + local;
gram.copy_from_slice(&tail_tail_gram);
let mut qq = 0.0;
let mut qg = 0.0;
let mut gg = 0.0;
for col in 0..rank {
let qv = proj_m[[local, col]];
let gv = proj_g[[local, col]];
qq += qv * qv;
qg += qv * gv;
gg += gv * gv;
}
gram[primary.q * primary.total + primary.q] = qq;
gram[primary.q * primary.total + primary.logslope] = qg;
gram[primary.logslope * primary.total + primary.q] = qg;
gram[primary.logslope * primary.total + primary.logslope] = gg;
for &(primary_idx, block_idx) in &tail_pairs {
let mut q_tail = 0.0;
let mut g_tail = 0.0;
for col in 0..rank {
let tail = factor[[block_idx, col]];
q_tail += proj_m[[local, col]] * tail;
g_tail += proj_g[[local, col]] * tail;
}
gram[primary.q * primary.total + primary_idx] = q_tail;
gram[primary_idx * primary.total + primary.q] = q_tail;
gram[primary.logslope * primary.total + primary_idx] = g_tail;
gram[primary_idx * primary.total + primary.logslope] = g_tail;
}
let row_ctx = Self::row_ctx(cache, row);
if n_dirs == 1 {
row_dir.fill(0.0);
row_dir[primary.q] = dir_proj_m[[local, 0]];
row_dir[primary.logslope] = dir_proj_g[[local, 0]];
if let (Some(block_range), Some(primary_range)) =
(slices.h.as_ref(), primary.h.as_ref())
{
for (offset, block_idx) in block_range.clone().enumerate() {
row_dir[primary_range.start + offset] = directions[[block_idx, 0]];
}
}
if let (Some(block_range), Some(primary_range)) =
(slices.w.as_ref(), primary.w.as_ref())
{
for (offset, block_idx) in block_range.clone().enumerate() {
row_dir[primary_range.start + offset] = directions[[block_idx, 0]];
}
}
let row_traces = self.row_primary_third_trace_many_with_moments(
row,
block_states,
cache,
row_ctx,
std::slice::from_ref(&row_dir),
&gram,
)?;
acc[0] += row_traces[0];
continue;
}
let trace_gradient = self.row_primary_third_trace_gradient_with_moments(
row,
block_states,
cache,
row_ctx,
&gram,
)?;
for dir_idx in 0..n_dirs {
let mut trace = trace_gradient[primary.q] * dir_proj_m[[local, dir_idx]]
+ trace_gradient[primary.logslope] * dir_proj_g[[local, dir_idx]];
if let (Some(block_range), Some(primary_range)) =
(slices.h.as_ref(), primary.h.as_ref())
{
for (offset, block_idx) in block_range.clone().enumerate() {
trace += trace_gradient[primary_range.start + offset]
* directions[[block_idx, dir_idx]];
}
}
if let (Some(block_range), Some(primary_range)) =
(slices.w.as_ref(), primary.w.as_ref())
{
for (offset, block_idx) in block_range.clone().enumerate() {
trace += trace_gradient[primary_range.start + offset]
* directions[[block_idx, dir_idx]];
}
}
acc[dir_idx] += trace;
}
}
Ok(acc)
})
.try_reduce(
|| vec![0.0; n_dirs],
|mut left, right| -> Result<_, String> {
for (l, r) in left.iter_mut().zip(right.iter()) {
*l += *r;
}
Ok(left)
},
)?;
Ok(Array1::from_vec(traces))
}
}
impl BernoulliMarginalSlopeExactNewtonJointPsiWorkspace {
fn new(
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
specs: Vec<ParameterBlockSpec>,
derivative_blocks: Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
options: BlockwiseFitOptions,
) -> Result<Self, String> {
let cache = family.build_shared_eval_cache_with_options(&block_states, &options)?;
Ok(Self {
family,
block_states,
specs,
derivative_blocks,
cache,
options,
})
}
}
impl ExactNewtonJointPsiWorkspace for BernoulliMarginalSlopeExactNewtonJointPsiWorkspace {
fn first_order_terms(
&self,
psi_index: usize,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_index)
{
return self.family.sigma_exact_joint_psi_terms_with_options(
&self.block_states,
&self.specs,
&self.options,
);
}
self.family
.exact_newton_joint_psi_terms_from_cache_with_options(
&self.block_states,
&self.derivative_blocks,
psi_index,
&self.cache,
&self.options,
)
}
fn first_order_terms_all(&self) -> Result<Option<Vec<ExactNewtonJointPsiTerms>>, String> {
let total: usize = self.derivative_blocks.iter().map(Vec::len).sum();
if total == 0 {
return Ok(Some(Vec::new()));
}
for psi_index in 0..total {
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_index)
{
return Ok(None);
}
}
let mut axes: Vec<PsiAxisSpec> = Vec::with_capacity(total);
for psi_index in 0..total {
let Some((block_idx, local_idx)) =
psi_derivative_location(&self.derivative_blocks, psi_index)
else {
return Ok(None);
};
axes.push(self.family.resolve_psi_axis_spec(
&self.derivative_blocks,
block_idx,
local_idx,
)?);
}
let results = self.family.run_psi_row_pass_for_axes(
&self.block_states,
&self.cache,
&self.options,
&axes,
)?;
Ok(Some(results))
}
fn second_order_terms(
&self,
psi_i: usize,
psi_j: usize,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_i)
|| self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_j)
{
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_i)
&& self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_j)
{
return self
.family
.sigma_exact_joint_psisecond_order_terms_with_options(
&self.block_states,
&self.options,
);
}
return Err(
"bernoulli marginal-slope mixed log-sigma/spatial psi second derivatives require cross auxiliary terms; only pure log-sigma second derivatives are supported"
.to_string(),
);
}
self.family
.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&self.block_states,
&self.derivative_blocks,
psi_i,
psi_j,
&self.cache,
&self.options,
)
}
fn hessian_directional_derivative(
&self,
psi_index: usize,
d_beta_flat: &Array1<f64>,
) -> Result<Option<crate::solver::estimate::reml::unified::DriftDerivResult>, String> {
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_index)
{
return self
.family
.sigma_exact_joint_psihessian_directional_derivative_with_options(
&self.block_states,
d_beta_flat,
&self.options,
)
.map(|result| {
result.map(crate::solver::estimate::reml::unified::DriftDerivResult::Dense)
});
}
self.family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&self.block_states,
&self.derivative_blocks,
psi_index,
d_beta_flat,
&self.cache,
&self.options,
)
.map(|result| {
result.map(crate::solver::estimate::reml::unified::DriftDerivResult::Operator)
})
}
}
fn build_blockspec(
name: &str,
design: &TermCollectionDesign,
baseline: f64,
offset: &Array1<f64>,
rho: Array1<f64>,
beta_hint: Option<Array1<f64>>,
) -> ParameterBlockSpec {
ParameterBlockSpec {
name: name.to_string(),
design: design.design.clone(),
offset: offset + baseline,
penalties: design.penalties_as_penalty_matrix(),
nullspace_dims: design.nullspace_dims.clone(),
initial_log_lambdas: rho,
initial_beta: beta_hint,
}
}
pub(crate) fn build_deviation_aux_blockspec(
name: &str,
prepared: &DeviationPrepared,
rho: Array1<f64>,
beta_hint: Option<Array1<f64>>,
) -> Result<ParameterBlockSpec, String> {
let mut block = prepared.block.clone();
block.initial_log_lambdas = Some(rho);
let candidate_beta = beta_hint.or_else(|| Some(Array1::<f64>::zeros(block.design.ncols())));
block.initial_beta = candidate_beta
.map(|beta| {
let zero = Array1::<f64>::zeros(beta.len());
project_monotone_feasible_beta(&prepared.runtime, &zero, &beta, name)
})
.transpose()?;
block.intospec(name)
}
pub(crate) fn push_deviation_aux_blockspecs(
blocks: &mut Vec<ParameterBlockSpec>,
rho: &Array1<f64>,
cursor: &mut usize,
score_warp_prepared: Option<&DeviationPrepared>,
link_dev_prepared: Option<&DeviationPrepared>,
score_warp_beta_hint: Option<Array1<f64>>,
link_dev_beta_hint: Option<Array1<f64>>,
) -> Result<(), String> {
if let Some(prepared) = score_warp_prepared {
let rho_h = rho
.slice(s![*cursor..*cursor + prepared.block.penalties.len()])
.to_owned();
*cursor += prepared.block.penalties.len();
blocks.push(build_deviation_aux_blockspec(
"score_warp_dev",
prepared,
rho_h,
score_warp_beta_hint,
)?);
}
if let Some(prepared) = link_dev_prepared {
let rho_w = rho
.slice(s![*cursor..*cursor + prepared.block.penalties.len()])
.to_owned();
blocks.push(build_deviation_aux_blockspec(
"link_dev",
prepared,
rho_w,
link_dev_beta_hint,
)?);
}
Ok(())
}
fn inner_fit(
family: &BernoulliMarginalSlopeFamily,
blocks: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<UnifiedFitResult, String> {
fit_custom_family(family, blocks, options).map_err(|e| e.to_string())
}
pub fn fit_bernoulli_marginal_slope_terms(
data: ArrayView2<'_, f64>,
spec: BernoulliMarginalSlopeTermSpec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
policy: &crate::resource::ResourcePolicy,
) -> Result<BernoulliMarginalSlopeFitResult, String> {
let mut spec = spec;
let data_view = data;
validate_spec(data_view, &spec)?;
let mut effective_kappa_options = kappa_options.clone();
let flex_spatial_scale_path = (spec.score_warp.is_some() || spec.link_dev.is_some())
&& effective_kappa_options.pilot_subsample_threshold > 0
&& spec.y.len()
>= effective_kappa_options
.pilot_subsample_threshold
.saturating_mul(2)
&& effective_kappa_options.enabled;
if flex_spatial_scale_path {
let marginal_terms = spatial_length_scale_term_indices(&spec.marginalspec);
let logslope_terms = spatial_length_scale_term_indices(&spec.logslopespec);
let marginal_updates = apply_spatial_anisotropy_pilot_initializer(
data_view,
&mut spec.marginalspec,
&marginal_terms,
effective_kappa_options.pilot_subsample_threshold,
&effective_kappa_options,
);
let logslope_updates = apply_spatial_anisotropy_pilot_initializer(
data_view,
&mut spec.logslopespec,
&logslope_terms,
effective_kappa_options.pilot_subsample_threshold,
&effective_kappa_options,
);
effective_kappa_options.enabled = false;
log::info!(
"[BMS spatial] n={} flex=true pilot_geometry_updates={} iterative_spatial_outer=false",
spec.y.len(),
marginal_updates + logslope_updates,
);
}
let (z_standardized, z_normalization) = standardize_latent_z_with_policy(
&spec.z,
&spec.weights,
"bernoulli-marginal-slope",
&spec.latent_z_policy,
)?;
spec.z = z_standardized;
let pilot_baseline = pooled_probit_baseline(&spec.y, &spec.z, &spec.weights)?;
let sigma_learnable = matches!(
&spec.frailty,
FrailtySpec::GaussianShift { sigma_fixed: None }
);
let initial_sigma = match &spec.frailty {
FrailtySpec::GaussianShift {
sigma_fixed: Some(s),
} => Some(*s),
FrailtySpec::GaussianShift { sigma_fixed: None } => Some(0.5),
FrailtySpec::None => None,
FrailtySpec::HazardMultiplier { .. } => {
unreachable!("validate_spec rejects unsupported marginal-slope frailty")
}
};
let probit_scale = probit_frailty_scale(initial_sigma);
let baseline = (
bernoulli_marginal_slope_eta_from_probability(
&spec.base_link,
normal_cdf(pilot_baseline.0),
"bernoulli marginal-slope baseline link inversion",
)?,
pilot_baseline.1 / probit_scale,
);
let (mut joint_designs, mut joint_specs) = build_term_collection_designs_and_freeze_joint(
data_view,
&[spec.marginalspec.clone(), spec.logslopespec.clone()],
)
.map_err(|e| e.to_string())?;
let marginal_design = joint_designs.remove(0);
let logslope_design = joint_designs.remove(0);
let marginalspec_boot = joint_specs.remove(0);
let logslopespec_boot = joint_specs.remove(0);
let (latent_measure, latent_z_calibration) = build_latent_measure_with_geometry(
&spec.z,
&spec.weights,
&spec.latent_z_policy,
Some(data_view),
&[&marginalspec_boot, &logslopespec_boot],
)?;
if latent_measure.is_empirical() && sigma_learnable {
return Err(
"empirical latent-measure marginal-slope calibration requires fixed GaussianShift sigma; learnable sigma derivatives must be fit under the standard-normal latent measure"
.to_string(),
);
}
let y = Arc::new(spec.y.clone());
let weights = Arc::new(spec.weights.clone());
let z = match &latent_z_calibration {
LatentMeasureCalibration::None => Arc::new(spec.z.clone()),
LatentMeasureCalibration::RankInverseNormal(cal) => {
Arc::new(cal.apply_to_training(&spec.z)?)
}
};
let mut cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning> = Vec::new();
let score_warp_prepared = if let Some(cfg) = spec.score_warp.as_ref() {
use crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock;
let mut prepared = build_score_warp_deviation_block_from_seed(&spec.z, cfg)?;
let outcome = enforce_cross_block_identifiability_for_flex_block(
&mut prepared,
&spec.z,
cfg,
&[
CrossBlockAnchor::Parametric(&marginal_design.design),
CrossBlockAnchor::Parametric(&logslope_design.design),
],
&[
Some(ParametricAnchorBlock::Marginal),
Some(ParametricAnchorBlock::Logslope),
],
&spec.weights,
)?;
match outcome {
CrossBlockIdentifiabilityOutcome::Reparameterised { .. } => Some(prepared),
CrossBlockIdentifiabilityOutcome::FullyAliased { reason } => {
log::warn!(
"[BMS cross-block identifiability] score-warp block fully aliased \
by marginal+logslope anchors; dropping the block. {reason}"
);
cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
candidate_label: "score_warp",
anchor_summary: "marginal+logslope".to_string(),
reason,
});
None
}
}
} else {
None
};
let link_dev_prepared = if let Some(cfg) = spec.link_dev.as_ref() {
let eta_pilot = pilot_eta_for_link_dev_orthogonalisation(
&spec.base_link,
&spec.y,
&spec.z,
&spec.weights,
&marginal_design.design,
&spec.marginal_offset,
&spec.logslope_offset,
baseline.0,
baseline.1,
probit_scale,
)?;
let link_dev_seed = padded_deviation_seed(&eta_pilot, 1.0, 0.5);
let mut prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_dev_seed,
&eta_pilot,
&spec.weights,
cfg,
)?;
let score_warp_anchor_design = score_warp_prepared
.as_ref()
.map(|sw| sw.runtime.design(&spec.z))
.transpose()?;
use crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock;
let mut anchors = vec![
CrossBlockAnchor::Parametric(&marginal_design.design),
CrossBlockAnchor::Parametric(&logslope_design.design),
];
let mut anchor_tags: Vec<Option<ParametricAnchorBlock>> = vec![
Some(ParametricAnchorBlock::Marginal),
Some(ParametricAnchorBlock::Logslope),
];
if let Some(ref a) = score_warp_anchor_design {
anchors.push(CrossBlockAnchor::FlexEvaluation(a));
anchor_tags.push(None);
}
let _ = link_dev_seed; let outcome = enforce_cross_block_identifiability_for_flex_block(
&mut prepared,
&eta_pilot,
cfg,
&anchors,
&anchor_tags,
&spec.weights,
)?;
match outcome {
CrossBlockIdentifiabilityOutcome::Reparameterised { .. } => Some(prepared),
CrossBlockIdentifiabilityOutcome::FullyAliased { reason } => {
log::warn!(
"[BMS cross-block identifiability] link-deviation block fully aliased \
by parametric + score-warp anchors; dropping the block. {reason}"
);
cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
candidate_label: "link_deviation",
anchor_summary: "marginal+logslope+score_warp".to_string(),
reason,
});
None
}
}
} else {
None
};
let extra_rho0 = {
let mut out = Vec::new();
if let Some(ref prepared) = score_warp_prepared {
out.extend(std::iter::repeat(0.0).take(prepared.block.penalties.len()));
}
if let Some(ref prepared) = link_dev_prepared {
out.extend(std::iter::repeat(0.0).take(prepared.block.penalties.len()));
}
out
};
let setup = joint_setup(
data_view,
&marginalspec_boot,
&logslopespec_boot,
marginal_design.penalties.len(),
logslope_design.penalties.len(),
&extra_rho0,
&effective_kappa_options,
);
let setup = if sigma_learnable {
setup.with_auxiliary(
Array1::from_vec(vec![initial_sigma.expect("learnable sigma seed").ln()]),
Array1::from_vec(vec![0.01_f64.ln()]),
Array1::from_vec(vec![5.0_f64.ln()]),
)
} else {
setup
};
let final_sigma_cell = std::cell::Cell::new(initial_sigma);
let exact_warm_start = RefCell::new(None::<CustomFamilyWarmStart>);
let hints = RefCell::new(ThetaHints::default());
let score_warp_runtime = score_warp_prepared.as_ref().map(|p| p.runtime.clone());
let link_dev_runtime = link_dev_prepared.as_ref().map(|p| p.runtime.clone());
let build_blocks = |rho: &Array1<f64>,
marginal_design: &TermCollectionDesign,
logslope_design: &TermCollectionDesign|
-> Result<Vec<ParameterBlockSpec>, String> {
let hints = hints.borrow();
let mut cursor = 0usize;
let rho_marginal = rho
.slice(s![cursor..cursor + marginal_design.penalties.len()])
.to_owned();
cursor += marginal_design.penalties.len();
let rho_logslope = rho
.slice(s![cursor..cursor + logslope_design.penalties.len()])
.to_owned();
cursor += logslope_design.penalties.len();
let mut blocks = vec![
build_blockspec(
"marginal_surface",
marginal_design,
baseline.0,
&spec.marginal_offset,
rho_marginal,
hints.marginal_beta.clone(),
),
build_blockspec(
"logslope_surface",
logslope_design,
baseline.1,
&spec.logslope_offset,
rho_logslope,
hints.logslope_beta.clone(),
),
];
push_deviation_aux_blockspecs(
&mut blocks,
rho,
&mut cursor,
score_warp_prepared.as_ref(),
link_dev_prepared.as_ref(),
hints.score_warp_beta.clone(),
hints.link_dev_beta.clone(),
)?;
Ok(blocks)
};
let intercept_warm_starts = new_intercept_warm_start_cache(y.len());
let cell_moment_lru = new_cell_moment_lru_cache(policy);
let cell_moment_cache_stats = new_cell_moment_cache_stats();
let make_family = |marginal_design: &TermCollectionDesign,
logslope_design: &TermCollectionDesign,
sigma: Option<f64>|
-> BernoulliMarginalSlopeFamily {
BernoulliMarginalSlopeFamily {
y: Arc::clone(&y),
weights: Arc::clone(&weights),
z: Arc::clone(&z),
latent_measure: latent_measure.clone(),
gaussian_frailty_sd: sigma,
base_link: spec.base_link.clone(),
marginal_design: marginal_design.design.clone(),
logslope_design: logslope_design.design.clone(),
score_warp: score_warp_runtime.clone(),
link_dev: link_dev_runtime.clone(),
policy: policy.clone(),
cell_moment_lru: Arc::clone(&cell_moment_lru),
cell_moment_cache_stats: Arc::clone(&cell_moment_cache_stats),
intercept_warm_starts: Some(Arc::clone(&intercept_warm_starts)),
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
}
};
let marginal_terms = spatial_length_scale_term_indices(&marginalspec_boot);
let logslope_terms = spatial_length_scale_term_indices(&logslopespec_boot);
let marginal_has_spatial = !marginal_terms.is_empty();
let logslope_has_spatial = !logslope_terms.is_empty();
let analytic_joint_derivatives_available =
marginal_has_spatial || logslope_has_spatial || setup.log_kappa_dim() == 0;
if setup.log_kappa_dim() > 0 && !analytic_joint_derivatives_available {
return Err(
"exact bernoulli marginal-slope spatial optimization requires analytic joint psi derivatives"
.to_string(),
);
}
let initial_rho = setup.theta0().slice(s![..setup.rho_dim()]).to_owned();
let initial_blocks = build_blocks(&initial_rho, &marginal_design, &logslope_design)?;
let initial_family = make_family(&marginal_design, &logslope_design, initial_sigma);
let (joint_gradient, joint_hessian) =
custom_family_outer_derivatives(&initial_family, &initial_blocks, options);
let analytic_joint_gradient_available = analytic_joint_derivatives_available
&& matches!(
joint_gradient,
crate::solver::outer_strategy::Derivative::Analytic
);
let analytic_joint_hessian_available =
analytic_joint_derivatives_available && joint_hessian.is_analytic();
let kappa_options_ref: &SpatialLengthScaleOptimizationOptions = &effective_kappa_options;
let sigma_from_theta = |theta: &Array1<f64>| -> Option<f64> {
if sigma_learnable {
Some(theta[setup.rho_dim() + setup.log_kappa_dim()].exp())
} else {
initial_sigma
}
};
let derivative_block_cache = RefCell::new(
None::<(
Array1<f64>,
Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
)>,
);
let theta_matches = |left: &Array1<f64>, right: &Array1<f64>| -> bool {
left.len() == right.len()
&& left
.iter()
.zip(right.iter())
.all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12 * (1.0 + lhs.abs().max(rhs.abs())))
};
let get_derivative_blocks = |theta: &Array1<f64>,
specs: &[TermCollectionSpec],
designs: &[TermCollectionDesign]|
-> Result<
Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
String,
> {
if let Some((cached_theta, cached_blocks)) = derivative_block_cache.borrow().as_ref()
&& theta_matches(cached_theta, theta)
{
return Ok(Arc::clone(cached_blocks));
}
let built = |specs: &[TermCollectionSpec],
designs: &[TermCollectionDesign]|
-> Result<
Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
String,
> {
let marginal_psi_derivs = if marginal_has_spatial {
build_block_spatial_psi_derivatives(data_view, &specs[0], &designs[0])?.ok_or_else(
|| {
"bernoulli marginal-slope: marginal block has spatial terms \
but spatial psi derivatives are unavailable"
.to_string()
},
)?
} else {
Vec::new()
};
let logslope_psi_derivs = if logslope_has_spatial {
build_block_spatial_psi_derivatives(data_view, &specs[1], &designs[1])?.ok_or_else(
|| {
"bernoulli marginal-slope: logslope block has spatial terms \
but spatial psi derivatives are unavailable"
.to_string()
},
)?
} else {
Vec::new()
};
let mut derivative_blocks = vec![marginal_psi_derivs, logslope_psi_derivs];
if score_warp_runtime.is_some() {
derivative_blocks.push(Vec::new());
}
if link_dev_runtime.is_some() {
derivative_blocks.push(Vec::new());
}
if sigma_learnable {
derivative_blocks
.last_mut()
.expect("bernoulli derivative block list is non-empty")
.push(crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
Array2::zeros((0, 0)),
Array2::zeros((0, 0)),
None,
None,
None,
None,
));
}
Ok(derivative_blocks)
}(specs, designs)?;
let built = Arc::new(built);
derivative_block_cache.replace(Some((theta.clone(), Arc::clone(&built))));
Ok(built)
};
let outer_policy = {
let psi_dim = setup.theta0().len() - setup.rho_dim();
initial_family.outer_derivative_policy(&initial_blocks, psi_dim, options)
};
let solved = optimize_spatial_length_scale_exact_joint(
data_view,
&[marginalspec_boot.clone(), logslopespec_boot.clone()],
&[marginal_terms.clone(), logslope_terms.clone()],
kappa_options_ref,
&setup,
crate::seeding::SeedRiskProfile::GeneralizedLinear,
analytic_joint_gradient_available,
analytic_joint_hessian_available,
true,
None,
outer_policy,
|theta, _: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let fit = inner_fit(&family, &blocks, options)?;
let mut hints_mut = hints.borrow_mut();
let mut bidx = 0usize;
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.marginal_beta = Some(block.beta.clone());
}
bidx += 1;
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.logslope_beta = Some(block.beta.clone());
}
bidx += 1;
if score_warp_prepared.is_some() {
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.score_warp_beta = Some(block.beta.clone());
}
bidx += 1;
}
if link_dev_prepared.is_some() {
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.link_dev_beta = Some(block.beta.clone());
}
}
Ok(fit)
},
|theta,
specs: &[TermCollectionSpec],
designs: &[TermCollectionDesign],
eval_mode,
_row_set: &crate::families::row_kernel::RowSet| {
use crate::solver::estimate::reml::unified::EvalMode;
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
let effective_mode = match eval_mode {
EvalMode::ValueGradientHessian if !analytic_joint_hessian_available => {
EvalMode::ValueAndGradient
}
other => other,
};
let eval = evaluate_custom_family_joint_hyper_shared(
&family,
&blocks,
options,
&rho,
derivative_blocks,
exact_warm_start.borrow().as_ref(),
effective_mode,
)?;
exact_warm_start.replace(Some(eval.warm_start.clone()));
if !eval.inner_converged {
return Err(
"exact bernoulli marginal-slope inner solve did not converge".to_string(),
);
}
if matches!(eval_mode, EvalMode::ValueGradientHessian)
&& analytic_joint_hessian_available
&& !eval.outer_hessian.is_analytic()
{
return Err(
"exact bernoulli marginal-slope joint [rho, psi] objective did not return an outer Hessian"
.to_string(),
);
}
Ok((eval.objective, eval.gradient, eval.outer_hessian))
},
|theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
let eval = evaluate_custom_family_joint_hyper_efs_shared(
&family,
&blocks,
options,
&rho,
derivative_blocks,
exact_warm_start.borrow().as_ref(),
)?;
exact_warm_start.replace(Some(eval.warm_start.clone()));
if !eval.inner_converged {
return Err(
"exact bernoulli marginal-slope EFS inner solve did not converge".to_string(),
);
}
Ok(eval.efs_eval)
},
)?;
let mut resolved_specs = solved.resolved_specs;
let mut designs = solved.designs;
let latent_z_rank_int_calibration = match latent_z_calibration {
LatentMeasureCalibration::None => None,
LatentMeasureCalibration::RankInverseNormal(cal) => Some(cal),
};
Ok(BernoulliMarginalSlopeFitResult {
fit: solved.fit,
marginalspec_resolved: resolved_specs.remove(0),
logslopespec_resolved: resolved_specs.remove(0),
marginal_design: designs.remove(0),
logslope_design: designs.remove(0),
baseline_marginal: baseline.0,
baseline_logslope: baseline.1,
z_normalization,
latent_measure,
score_warp_runtime,
link_dev_runtime,
gaussian_frailty_sd: final_sigma_cell.get(),
cross_block_warnings,
latent_z_rank_int_calibration,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::custom_family::{CustomFamily, ExactOuterDerivativeOrder};
use crate::families::bernoulli_marginal_slope::exact_kernel::{
DenestedCubicCell as ExactDenestedCubicCell, ExactCellBranch as ExactCellBranchShared,
LocalSpanCubic, branch_cell as branch_exact_cell, build_denested_partition_cells,
denested_cell_coefficient_partials as exact_denested_cell_coefficient_partials,
global_cubic_from_local as exact_global_cubic_from_local,
transformed_link_cubic as exact_transformed_link_cubic,
};
use ndarray::array;
#[inline]
fn bernoulli_marginal_slope_probit_link() -> InverseLink {
InverseLink::Standard(LinkFunction::Probit)
}
fn empty_termspec() -> TermCollectionSpec {
TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![],
}
}
fn dummy_blockspec(p: usize, n_rows: usize) -> ParameterBlockSpec {
ParameterBlockSpec {
name: "dummy".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::<f64>::zeros((n_rows, p)),
)),
offset: Array1::zeros(n_rows),
penalties: vec![],
nullspace_dims: vec![],
initial_log_lambdas: Array1::zeros(0),
initial_beta: Some(Array1::zeros(p)),
}
}
fn dummy_block_state(beta: Array1<f64>, n_rows: usize) -> ParameterBlockState {
ParameterBlockState {
beta,
eta: Array1::zeros(n_rows),
}
}
fn flex_hessian_matvec_fixture(
n: usize,
) -> Result<
(
BernoulliMarginalSlopeFamily,
Vec<ParameterBlockState>,
BernoulliMarginalSlopeExactEvalCache,
Array1<f64>,
),
String,
> {
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64 + 0.5) / n as f64;
(10.0 * t).sin() + 0.3 * (31.0 * t).cos()
}));
let y = Array1::from_iter((0..n).map(|i| if (i * 37 + 11) % 101 < 43 { 1.0 } else { 0.0 }));
let weights = Array1::from_iter((0..n).map(|i| 0.75 + 0.5 * ((i % 7) as f64) / 6.0));
let design = Array2::from_shape_fn((n, 2), |(row, col)| match col {
0 => 1.0,
1 => z[row],
_ => unreachable!(),
});
let cfg = DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
};
let score_prepared = build_score_warp_deviation_block_from_seed(&z, &cfg)?;
let q_seed = Array1::from_iter(z.iter().map(|zi| 0.05 + 0.2 * zi));
let link_seed = padded_deviation_seed(&q_seed, 1.0, 0.5);
let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q_seed, &weights, &cfg,
)?;
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
design.clone(),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(design)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let marginal_beta = array![0.05, 0.08];
let logslope_beta = array![-0.15, 0.04];
let marginal_eta =
Array1::from_iter(z.iter().map(|zi| marginal_beta[0] + marginal_beta[1] * zi));
let logslope_eta =
Array1::from_iter(z.iter().map(|zi| logslope_beta[0] + logslope_beta[1] * zi));
let states = vec![
ParameterBlockState {
beta: marginal_beta,
eta: marginal_eta,
},
ParameterBlockState {
beta: logslope_beta,
eta: logslope_eta,
},
ParameterBlockState {
beta: Array1::zeros(score_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
ParameterBlockState {
beta: Array1::zeros(link_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
];
let mut cache = family.build_exact_eval_cache(&states)?;
cache.row_primary_hessians = family.build_row_primary_hessian_cache(&states, &cache)?;
let direction = Array1::from_iter((0..cache.slices.total).map(|j| {
let x = j as f64 + 1.0;
0.03 * x.sin() + 0.01 * (0.37 * x).cos()
}));
Ok((family, states, cache, direction))
}
fn assert_allclose_relative(actual: &Array1<f64>, expected: &Array1<f64>, tol: f64) {
assert_eq!(actual.len(), expected.len());
for (idx, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
let denom = a.abs().max(e.abs()).max(1.0);
let rel = (a - e).abs() / denom;
assert!(
rel <= tol,
"entry {idx}: actual={a:.17e}, expected={e:.17e}, rel={rel:.3e}, tol={tol:.3e}"
);
}
}
#[test]
fn cross_block_identifiability_when_candidate_wider_than_flex_anchor_keeps_kept_dim_positive() {
let n = 64usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64) / (n as f64 - 1.0);
-1.5 + 3.0 * t
}));
let weights = Array1::from_elem(n, 1.0);
let score_cfg = DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
};
let link_cfg = DeviationBlockConfig {
num_internal_knots: 5,
..DeviationBlockConfig::default()
};
let score_prepared =
build_score_warp_deviation_block_from_seed(&z, &score_cfg).expect("score-warp fixture");
let q0_seed = Array1::from_iter(z.iter().map(|zi| 0.1 + 0.45 * zi));
let link_seed = padded_deviation_seed(&q0_seed, 1.0, 0.5);
let mut link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q0_seed, &weights, &link_cfg,
)
.expect("link-deviation fixture");
let p_link_before = link_prepared.runtime.basis_dim();
let _ = link_seed; let anchor_design_for_test = score_prepared
.runtime
.design(&z)
.expect("score-warp anchor design");
enforce_cross_block_identifiability_for_flex_block(
&mut link_prepared,
&q0_seed,
&link_cfg,
&[CrossBlockAnchor::FlexEvaluation(&anchor_design_for_test)],
&[None],
&weights,
)
.expect("cross-block reparameterisation should succeed for non-degenerate overlap");
let p_link_after = link_prepared.runtime.basis_dim();
assert!(
p_link_after > 0 && p_link_after == p_link_before,
"FlexEvaluation-only anchors should leave basis_dim unchanged, got {} -> {}",
p_link_before,
p_link_after,
);
for (idx, penalty) in link_prepared.block.penalties.iter().enumerate() {
let dense = penalty.to_dense();
assert_eq!(
dense.ncols(),
p_link_after,
"penalty {idx} ncols {} does not match new basis_dim {}",
dense.ncols(),
p_link_after,
);
assert_eq!(
dense.nrows(),
p_link_after,
"penalty {idx} nrows {} does not match new basis_dim {}",
dense.nrows(),
p_link_after,
);
}
assert_eq!(
link_prepared.block.design.ncols(),
p_link_after,
"rebuilt design column count must match new basis_dim",
);
}
#[test]
fn cross_block_identifiability_anchor_wider_than_candidate_no_false_alias() {
use crate::linalg::matrix::{DenseDesignMatrix, DesignMatrix};
let n = 64usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64) / (n as f64 - 1.0);
-1.5 + 3.0 * t
}));
let weights = Array1::from_elem(n, 1.0);
let mut anchor_dense = ndarray::Array2::<f64>::zeros((n, 20));
for i in 0..n {
for j in 0..20 {
let x = (i as f64 * 0.13 + j as f64 * 0.91 + 1.0).sin();
let y = (i as f64 * 0.07 - j as f64 * 0.31 + 2.0).cos();
anchor_dense[[i, j]] = 0.5 * x + 0.5 * y;
}
}
let anchor_design = DesignMatrix::Dense(DenseDesignMatrix::from(anchor_dense.clone()));
let link_cfg = DeviationBlockConfig {
num_internal_knots: 2,
..DeviationBlockConfig::default()
};
let q0_seed = Array1::from_iter(z.iter().map(|zi| 0.05 + 0.4 * zi));
let link_seed = padded_deviation_seed(&q0_seed, 1.0, 0.5);
let mut link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q0_seed, &weights, &link_cfg,
)
.expect("link-dev fixture");
let _ = link_seed;
let p_before = link_prepared.runtime.basis_dim();
assert!(p_before > 0, "fixture must have positive basis_dim");
use crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock;
enforce_cross_block_identifiability_for_flex_block(
&mut link_prepared,
&q0_seed,
&link_cfg,
&[CrossBlockAnchor::Parametric(&anchor_design)],
&[Some(ParametricAnchorBlock::Marginal)],
&weights,
)
.expect("wider anchor should not false-positive full alias");
let p_after = link_prepared.runtime.basis_dim();
assert!(
p_after > 0,
"new algorithm must keep strictly positive basis_dim when (I-P_A)C has rank > 0"
);
let new_design = link_prepared
.runtime
.design_at_training_with_residual(&q0_seed)
.expect("design after orthogonalisation");
assert_eq!(new_design.nrows(), n);
assert_eq!(new_design.ncols(), p_after);
let cross = anchor_dense.t().dot(&new_design);
let max_abs = cross.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
let anchor_norm = anchor_dense.iter().map(|v| v * v).sum::<f64>().sqrt();
let cand_norm = new_design.iter().map(|v| v * v).sum::<f64>().sqrt();
let scale = (anchor_norm * cand_norm).max(1.0);
assert!(
max_abs <= 1.0e-9 * scale,
"Aᵀ C̃ should be at noise floor; max|.|={max_abs:.3e}, scale={scale:.3e}",
);
}
#[test]
fn cross_block_identifiability_minimal_counterexample_keeps_orthogonal_complement() {
use crate::linalg::matrix::{DenseDesignMatrix, DesignMatrix};
let n = 64usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64) / (n as f64 - 1.0);
-1.5 + 3.0 * t
}));
let weights = Array1::from_elem(n, 1.0);
let link_cfg = DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
};
let q0_seed = Array1::from_iter(z.iter().map(|zi| 0.05 + 0.4 * zi));
let link_seed = padded_deviation_seed(&q0_seed, 1.0, 0.5);
let mut link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q0_seed, &weights, &link_cfg,
)
.expect("link-dev fixture");
let _ = link_seed;
let p_before = link_prepared.runtime.basis_dim();
assert!(
p_before >= 2,
"fixture must have at least two basis columns"
);
let candidate_design = link_prepared
.runtime
.design(&q0_seed)
.expect("candidate training-row design");
let mut anchor_dense = ndarray::Array2::<f64>::zeros((n, 1));
anchor_dense
.column_mut(0)
.assign(&candidate_design.column(0));
let anchor_design = DesignMatrix::Dense(DenseDesignMatrix::from(anchor_dense.clone()));
use crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock;
enforce_cross_block_identifiability_for_flex_block(
&mut link_prepared,
&q0_seed,
&link_cfg,
&[CrossBlockAnchor::Parametric(&anchor_design)],
&[Some(ParametricAnchorBlock::Marginal)],
&weights,
)
.expect("minimal counterexample must keep p_before - 1 directions, not collapse to 0");
let p_after = link_prepared.runtime.basis_dim();
assert_eq!(
p_after,
p_before - 1,
"(I−P_A)C should keep exactly {} directions when one column of C is reproduced by A; got {}",
p_before - 1,
p_after,
);
let new_design = link_prepared
.runtime
.design_at_training_with_residual(&q0_seed)
.expect("design after orthogonalisation");
let cross = anchor_dense.t().dot(&new_design);
let max_abs = cross.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
let anchor_norm = anchor_dense.iter().map(|v| v * v).sum::<f64>().sqrt();
let cand_norm = new_design.iter().map(|v| v * v).sum::<f64>().sqrt();
let scale = (anchor_norm * cand_norm).max(1.0);
assert!(
max_abs <= 1.0e-9 * scale,
"AᵀC̃ should be at noise floor after residualisation; max|.|={max_abs:.3e}, scale={scale:.3e}",
);
}
#[test]
fn cross_block_identifiability_true_alias_returns_fully_aliased_outcome() {
use crate::linalg::matrix::{DenseDesignMatrix, DesignMatrix};
let n = 64usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64) / (n as f64 - 1.0);
-1.5 + 3.0 * t
}));
let weights = Array1::from_elem(n, 1.0);
let link_cfg = DeviationBlockConfig {
num_internal_knots: 2,
..DeviationBlockConfig::default()
};
let q0_seed = Array1::from_iter(z.iter().map(|zi| 0.05 + 0.4 * zi));
let link_seed = padded_deviation_seed(&q0_seed, 1.0, 0.5);
let mut link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q0_seed, &weights, &link_cfg,
)
.expect("link-dev fixture");
let _ = link_seed;
let candidate_design = link_prepared
.runtime
.design(&q0_seed)
.expect("candidate training-row design");
let anchor_design = DesignMatrix::Dense(DenseDesignMatrix::from(candidate_design));
use crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock;
let outcome = enforce_cross_block_identifiability_for_flex_block(
&mut link_prepared,
&q0_seed,
&link_cfg,
&[CrossBlockAnchor::Parametric(&anchor_design)],
&[Some(ParametricAnchorBlock::Marginal)],
&weights,
)
.expect("true alias must produce a structured FullyAliased outcome");
match outcome {
CrossBlockIdentifiabilityOutcome::FullyAliased { reason } => {
assert!(
reason.contains("zero directions remaining"),
"expected FullyAliased reason mentioning 'zero directions remaining', got: {reason}",
);
}
CrossBlockIdentifiabilityOutcome::Reparameterised { kept, dropped } => {
panic!(
"expected FullyAliased outcome but got Reparameterised(kept={kept}, dropped={dropped})"
);
}
}
}
#[test]
fn cross_block_identifiability_outcome_fully_aliased_extracts_reason() {
let outcome = CrossBlockIdentifiabilityOutcome::FullyAliased {
reason: "candidate has zero directions remaining after residualisation".to_string(),
};
match outcome {
CrossBlockIdentifiabilityOutcome::FullyAliased { reason } => {
assert!(reason.contains("zero directions remaining"));
}
CrossBlockIdentifiabilityOutcome::Reparameterised { .. } => {
panic!("constructed FullyAliased; cannot pattern-match as Reparameterised")
}
}
}
#[test]
fn cross_block_identifiability_partial_alias_keeps_residual_rank() {
use crate::faer_ndarray::FaerQr;
use crate::linalg::matrix::{DenseDesignMatrix, DesignMatrix};
let n = 96usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64) / (n as f64 - 1.0);
-1.5 + 3.0 * t
}));
let weights = Array1::from_elem(n, 1.0);
let link_cfg = DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
};
let q0_seed = Array1::from_iter(z.iter().map(|zi| 0.03 + 0.42 * zi));
let link_seed = padded_deviation_seed(&q0_seed, 1.0, 0.5);
let mut link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q0_seed, &weights, &link_cfg,
)
.expect("link-dev fixture");
let _ = link_seed;
let candidate_design = link_prepared
.runtime
.design(&q0_seed)
.expect("candidate training-row design");
let (q, _r) = candidate_design.qr().expect("thin QR of candidate");
let p_c = candidate_design.ncols();
let k_alias = (p_c / 2).max(1);
assert!(
k_alias + 1 <= p_c,
"partial-alias test needs p_c > k_alias + 1, got p_c={p_c}, k_alias={k_alias}",
);
let p_a = k_alias + 1;
let mut anchor_dense = ndarray::Array2::<f64>::zeros((n, p_a));
for j in 0..k_alias {
anchor_dense.column_mut(j).assign(&q.column(j));
}
let mut extra = Array1::<f64>::zeros(n);
for i in 0..n {
extra[i] = ((i as f64 * 0.17).sin() + (i as f64 * 0.41).cos()) * 0.5;
}
let q_t_extra = q.t().dot(&extra);
let extra_orth = &extra - &q.dot(&q_t_extra);
anchor_dense.column_mut(k_alias).assign(&extra_orth);
let anchor_design = DesignMatrix::Dense(DenseDesignMatrix::from(anchor_dense.clone()));
use crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock;
let p_before = link_prepared.runtime.basis_dim();
enforce_cross_block_identifiability_for_flex_block(
&mut link_prepared,
&q0_seed,
&link_cfg,
&[CrossBlockAnchor::Parametric(&anchor_design)],
&[Some(ParametricAnchorBlock::Marginal)],
&weights,
)
.expect("partial alias must keep the surviving rank");
let p_after = link_prepared.runtime.basis_dim();
assert_eq!(
p_after,
p_before - k_alias,
"partial alias should drop exactly the {} aliased directions; got {} -> {}",
k_alias,
p_before,
p_after,
);
let new_design = link_prepared
.runtime
.design_at_training_with_residual(&q0_seed)
.expect("design after orthogonalisation");
let cross = anchor_dense.t().dot(&new_design);
let max_abs = cross.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
let anchor_norm = anchor_dense.iter().map(|v| v * v).sum::<f64>().sqrt();
let cand_norm = new_design.iter().map(|v| v * v).sum::<f64>().sqrt();
let scale = (anchor_norm * cand_norm).max(1.0);
assert!(
max_abs <= 1.0e-9 * scale,
"AᵀC̃ should be at noise floor; max|.|={max_abs:.3e}, scale={scale:.3e}",
);
}
#[test]
fn flex_hessian_matvec_parallel_chunks_match_serial_reference() {
let (family, states, cache, direction) =
flex_hessian_matvec_fixture(96).expect("flex Hv fixture");
let serial = family
.exact_newton_joint_hessian_matvec_from_cache_serial_reference(
&direction, &states, &cache,
)
.expect("serial reference Hv");
let parallel = family
.exact_newton_joint_hessian_matvec_from_cache(&direction, &states, &cache)
.expect("parallel chunked Hv");
assert_allclose_relative(¶llel, &serial, 1.0e-13);
}
#[test]
fn row_primary_third_trace_many_matches_single_direction_contracts() {
let (family, states, cache, direction_a) =
flex_hessian_matvec_fixture(24).expect("flex trace fixture");
let direction_b = Array1::from_iter((0..cache.slices.total).map(|j| {
let x = j as f64 + 0.5;
0.02 * (0.7 * x).sin() - 0.015 * (0.31 * x).cos()
}));
let r = cache.primary.total;
let gram = (0..r * r)
.map(|idx| {
let x = idx as f64 + 1.0;
0.03 * x.sin() + 0.01 * (0.17 * x).cos()
})
.collect::<Vec<_>>();
for row in 0..6 {
let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
let row_dirs = vec![
family
.row_primary_direction_from_flat(
row,
&cache.slices,
&cache.primary,
&direction_a,
)
.expect("row direction a"),
family
.row_primary_direction_from_flat(
row,
&cache.slices,
&cache.primary,
&direction_b,
)
.expect("row direction b"),
];
let many = family
.row_primary_third_trace_many_with_moments(
row, &states, &cache, row_ctx, &row_dirs, &gram,
)
.expect("many-direction row trace");
for (dir_idx, row_dir) in row_dirs.iter().enumerate() {
let third = family
.row_primary_third_contracted_recompute(row, &states, &cache, row_ctx, row_dir)
.expect("single-direction third contraction");
let single =
BernoulliMarginalSlopeFamily::row_primary_trace_contract(&third, &gram);
let denom = single.abs().max(many[dir_idx].abs()).max(1.0);
let rel = (single - many[dir_idx]).abs() / denom;
assert!(
rel <= 1.0e-12,
"row {row} dir {dir_idx}: many={:.17e} single={:.17e} rel={:.3e}",
many[dir_idx],
single,
rel
);
}
}
}
#[test]
fn bernoulli_margslope_warm_start_cache_persists_across_eval_cache_builds() {
let n = 256usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64 + 0.5) / n as f64;
(12.0 * t).sin() + 0.25 * (37.0 * t).cos()
}));
let y = Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1.0 } else { 0.0 }));
let weights = Array1::ones(n);
let cfg = DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
};
let score_prepared = build_score_warp_deviation_block_from_seed(&z, &cfg)
.expect("score-warp deviation block");
let q_seed = Array1::from_iter(z.iter().map(|zi| 0.1 + 0.25 * zi));
let link_seed = padded_deviation_seed(&q_seed, 1.0, 0.5);
let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q_seed, &weights, &cfg,
)
.expect("link-wiggle deviation block");
let cache = new_intercept_warm_start_cache(n);
let make_family =
|cache: Option<Arc<BernoulliInterceptWarmStartCache>>| BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::ones((n, 1)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::ones((n, 1)),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: cache,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let marginal_eta = Array1::from_iter((0..n).map(|i| 0.15 * ((i as f64) * 0.001).sin()));
let slope_eta = Array1::from_iter((0..n).map(|i| 0.35 + 0.02 * ((i as f64) * 0.003).cos()));
let states = vec![
ParameterBlockState {
beta: array![0.0],
eta: marginal_eta,
},
ParameterBlockState {
beta: array![0.0],
eta: slope_eta,
},
ParameterBlockState {
beta: Array1::zeros(score_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
ParameterBlockState {
beta: Array1::zeros(link_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
];
let warm_family = make_family(Some(Arc::clone(&cache)));
let first = warm_family
.build_exact_eval_cache(&states)
.expect("first warm eval cache");
let nan_bits = f64::NAN.to_bits();
for slot in cache.iter() {
let bits = slot.load(Ordering::Relaxed);
let v = f64::from_bits(bits);
assert!(
v.is_finite(),
"cache slot should be populated with converged intercept after first build"
);
assert_ne!(bits, nan_bits);
}
let second = warm_family
.build_exact_eval_cache(&states)
.expect("second warm eval cache");
let cold_family = make_family(None);
let cold = cold_family
.build_exact_eval_cache(&states)
.expect("cold reference eval cache");
for ((warm_a, warm_b), cold_ctx) in first
.row_contexts
.iter()
.zip(second.row_contexts.iter())
.zip(cold.row_contexts.iter())
{
assert!((warm_a.intercept - cold_ctx.intercept).abs() < 1e-9);
assert!((warm_b.intercept - cold_ctx.intercept).abs() < 1e-9);
}
}
#[test]
fn bernoulli_margslope_flex_ll_early_exit_is_exact_or_provably_rejected() {
let n = 24_000usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64 + 0.5) / n as f64;
(10.0 * t).sin() + 0.2 * (29.0 * t).cos()
}));
let y = Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
let weights = Array1::ones(n);
let cfg = DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
};
let score_prepared = build_score_warp_deviation_block_from_seed(&z, &cfg)
.expect("score-warp deviation block");
let q_seed = Array1::from_iter(z.iter().map(|zi| 0.1 + 0.2 * zi));
let link_seed = padded_deviation_seed(&q_seed, 1.0, 0.5);
let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q_seed, &weights, &cfg,
)
.expect("link-wiggle deviation block");
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::ones((n, 1)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::ones((n, 1)),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let states = vec![
ParameterBlockState {
beta: array![0.0],
eta: Array1::from_iter((0..n).map(|i| 0.08 * ((i as f64) * 0.002).sin())),
},
ParameterBlockState {
beta: array![0.0],
eta: Array1::from_iter((0..n).map(|i| 0.30 + 0.03 * ((i as f64) * 0.003).cos())),
},
ParameterBlockState {
beta: Array1::zeros(score_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
ParameterBlockState {
beta: Array1::zeros(link_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
];
let exact = family
.log_likelihood_only_with_options(&states, &BlockwiseFitOptions::default())
.expect("exact full FLEX ll");
let mut permissive = BlockwiseFitOptions::default();
permissive.early_exit_threshold = Some((-exact) + 1.0);
let accepted = family
.log_likelihood_only_with_options(&states, &permissive)
.expect("permissive threshold should compute full FLEX ll");
let rel = ((accepted - exact) / exact.abs().max(1.0)).abs();
assert!(
rel < 1e-10,
"accepted early-exit-enabled FLEX LL {accepted} differs from exact {exact} by rel {rel}"
);
let mut rejecting = BlockwiseFitOptions::default();
rejecting.early_exit_threshold = Some(1e-6);
let err = family
.log_likelihood_only_with_options(&states, &rejecting)
.expect_err("tight threshold should reject before the full FLEX row sweep");
assert!(
err.contains("line-search rejected early"),
"unexpected early-exit error: {err}"
);
}
#[test]
fn row_primary_fourth_contracted_rejects_bad_direction_lengths() {
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![1.0]),
weights: Arc::new(array![1.0]),
z: Arc::new(array![0.25]),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((1, 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((1, 0)),
)),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: Array1::zeros(0),
eta: array![0.15],
},
ParameterBlockState {
beta: Array1::zeros(0),
eta: array![0.2],
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let row_ctx = family
.build_row_exact_context(0, &block_states)
.expect("row context");
let bad_dir = array![1.0];
let good_dir = array![0.0, 1.0];
let err = family
.row_primary_fourth_contracted_recompute_ordered(
0,
&block_states,
&cache,
&row_ctx,
&bad_dir,
&good_dir,
)
.expect_err("bad direction length should be rejected before indexing");
assert!(
err.contains("direction lengths (1,2) != 2"),
"unexpected error: {err}"
);
}
fn base_spec(
y: Array1<f64>,
weights: Array1<f64>,
z: Array1<f64>,
) -> BernoulliMarginalSlopeTermSpec {
let n = y.len();
BernoulliMarginalSlopeTermSpec {
y,
weights,
z,
base_link: bernoulli_marginal_slope_probit_link(),
marginalspec: empty_termspec(),
logslopespec: empty_termspec(),
marginal_offset: Array1::zeros(n),
logslope_offset: Array1::zeros(n),
frailty: FrailtySpec::None,
score_warp: None,
link_dev: None,
latent_z_policy: LatentZPolicy::default(),
}
}
#[test]
fn bernoulli_marginal_link_map_zeroes_derivatives_on_clamped_tails() {
let link = bernoulli_marginal_slope_probit_link();
let lower = bernoulli_marginal_link_map(&link, -8.0).expect("lower tail map");
let upper = bernoulli_marginal_link_map(&link, 8.0).expect("upper tail map");
let lower_q = standard_normal_quantile(BERNOULLI_LINK_PROBABILITY_EPS).unwrap();
let upper_q = standard_normal_quantile(1.0 - BERNOULLI_LINK_PROBABILITY_EPS).unwrap();
assert_eq!(lower.mu, BERNOULLI_LINK_PROBABILITY_EPS);
assert_eq!(upper.mu, 1.0 - BERNOULLI_LINK_PROBABILITY_EPS);
assert!((lower.q - lower_q).abs() < 1e-12);
assert!((upper.q - upper_q).abs() < 1e-12);
assert_eq!([lower.mu1, lower.mu2, lower.mu3, lower.mu4], [0.0; 4]);
assert_eq!([upper.mu1, upper.mu2, upper.mu3, upper.mu4], [0.0; 4]);
assert_eq!([lower.q1, lower.q2, lower.q3, lower.q4], [0.0; 4]);
assert_eq!([upper.q1, upper.q2, upper.q3, upper.q4], [0.0; 4]);
}
#[test]
fn rigid_transformed_gradient_matches_negative_log_likelihood_derivative() {
let link = bernoulli_marginal_slope_probit_link();
let eta = 0.25;
let g = -0.15;
let z = 0.7;
let y = 1.0;
let weight = 1.3;
let probit_scale = 1.0;
let objective = |eta_value: f64, g_value: f64| {
let marginal = bernoulli_marginal_link_map(&link, eta_value).unwrap();
let kernel =
RigidProbitKernel::new(marginal.q, g_value, z, y, weight, probit_scale).unwrap();
-weight * kernel.logcdf
};
let marginal = bernoulli_marginal_link_map(&link, eta).expect("marginal map");
let kernel =
RigidProbitKernel::new(marginal.q, g, z, y, weight, probit_scale).expect("kernel");
let grad = rigid_transformed_gradient(marginal, &kernel);
let step = 1e-6;
let finite_eta = (objective(eta + step, g) - objective(eta - step, g)) / (2.0 * step);
let finite_g = (objective(eta, g + step) - objective(eta, g - step)) / (2.0 * step);
assert!(
(grad[0] - finite_eta).abs() < 1e-7,
"eta gradient {} != finite difference {}",
grad[0],
finite_eta
);
assert!(
(grad[1] - finite_g).abs() < 1e-7,
"g gradient {} != finite difference {}",
grad[1],
finite_g
);
assert!(
grad[0] < 0.0,
"y=1 probit nll should decrease as eta increases"
);
}
fn pair_distance(lhs: (f64, f64), rhs: (f64, f64)) -> f64 {
(lhs.0 - rhs.0).abs() + (lhs.1 - rhs.1).abs()
}
fn make_rigid_test_family(n: usize) -> BernoulliMarginalSlopeFamily {
let y: Array1<f64> =
Array1::from_iter((0..n).map(|i| if (i * 31 + 7) % 5 >= 3 { 1.0 } else { 0.0 }));
let weights: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.5 + ((i * 13 + 4) % 7) as f64 * 0.1));
let z: Array1<f64> = Array1::from_iter(
(0..n).map(|i| -1.5 + 3.0 * (((i * 17 + 5) % n) as f64 + 0.5) / (n as f64)),
);
let ones_col = Array2::from_shape_fn((n, 1), |_| 1.0);
BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
ones_col.clone(),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(ones_col)),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
}
}
fn rigid_block_states(
family: &BernoulliMarginalSlopeFamily,
q: f64,
b: f64,
) -> Vec<ParameterBlockState> {
let n = family.y.len();
vec![
ParameterBlockState {
beta: array![q],
eta: Array1::from_elem(n, q),
},
ParameterBlockState {
beta: array![b],
eta: Array1::from_elem(n, b),
},
]
}
#[test]
fn bernoulli_log_likelihood_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_rigid_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let baseline = family
.log_likelihood_only(&states)
.expect("baseline ll (no subsample)");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full_mask = family
.log_likelihood_only_with_options(&states, &opts_full)
.expect("ll with mask=full");
let rel = ((with_full_mask - baseline) / baseline.abs().max(1.0)).abs();
assert!(
rel < 1e-12,
"subsample(mask=full) {} differs from baseline {} by rel {}",
with_full_mask,
baseline,
rel
);
}
#[test]
fn bernoulli_log_likelihood_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_rigid_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.log_likelihood_only_with_options(&states, &opts_half)
.expect("ll with mask=even");
let mut opts_even_unscaled = BlockwiseFitOptions::default();
opts_even_unscaled.outer_score_subsample = Some(Arc::new(
OuterScoreSubsample::with_uniform_weight(even_mask, m, 0, 1.0),
));
let raw_even_sum = family
.log_likelihood_only_with_options(&states, &opts_even_unscaled)
.expect("raw even-row ll sum");
let expected_scaled = (n as f64 / m as f64) * raw_even_sum;
let rel = ((scaled - expected_scaled) / expected_scaled.abs().max(1.0)).abs();
assert!(
rel < 1e-12,
"scaled {} != 2*even_sum {} (rel {})",
scaled,
expected_scaled,
rel
);
let baseline = family.log_likelihood_only(&states).expect("baseline ll");
let ht_rel = ((scaled - baseline) / baseline.abs().max(1.0)).abs();
assert!(
ht_rel < 0.05,
"Horvitz-Thompson scaled {} not near baseline {} (rel {})",
scaled,
baseline,
ht_rel
);
}
fn make_sigma_aware_test_family(n: usize) -> BernoulliMarginalSlopeFamily {
let mut family = make_rigid_test_family(n);
family.gaussian_frailty_sd = Some(0.7);
family
}
fn rel_diff_array1(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let mut max = 0.0f64;
for i in 0..a.len() {
let d = (a[i] - b[i]).abs() / b[i].abs().max(1.0);
if d > max {
max = d;
}
}
max
}
fn rel_diff_array2(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let mut max = 0.0f64;
for ((i, j), &av) in a.indexed_iter() {
let bv = b[[i, j]];
let d = (av - bv).abs() / bv.abs().max(1.0);
if d > max {
max = d;
}
}
max
}
#[test]
fn bernoulli_sigma_psi_terms_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let baseline = family
.sigma_exact_joint_psi_terms(&states, &specs)
.expect("baseline psi terms")
.expect("baseline some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_full)
.expect("psi terms with full mask")
.expect("some");
let obj_rel = ((with_full.objective_psi - baseline.objective_psi)
/ baseline.objective_psi.abs().max(1.0))
.abs();
assert!(obj_rel < 1e-12, "objective_psi rel {}", obj_rel);
let score_rel = rel_diff_array1(&with_full.score_psi, &baseline.score_psi);
assert!(score_rel < 1e-12, "score_psi rel {}", score_rel);
let h_full = with_full
.hessian_psi_operator
.as_ref()
.expect("op")
.to_dense();
let h_baseline = baseline
.hessian_psi_operator
.as_ref()
.expect("op")
.to_dense();
let h_rel = rel_diff_array2(&h_full, &h_baseline);
assert!(h_rel < 1e-12, "hessian rel {}", h_rel);
}
#[test]
fn bernoulli_sigma_psi_terms_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_half)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::with_uniform_weight(
even_mask, m, 0, 1.0,
)));
let raw = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_raw)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp_obj = factor * raw.objective_psi;
let obj_rel = ((scaled.objective_psi - exp_obj) / exp_obj.abs().max(1.0)).abs();
assert!(obj_rel < 1e-12, "objective_psi rel {}", obj_rel);
let exp_score = &raw.score_psi * factor;
let score_rel = rel_diff_array1(&scaled.score_psi, &exp_score);
assert!(score_rel < 1e-12, "score_psi rel {}", score_rel);
let h_scaled = scaled.hessian_psi_operator.as_ref().expect("op").to_dense();
let h_raw = raw.hessian_psi_operator.as_ref().expect("op").to_dense();
let h_exp = &h_raw * factor;
let h_rel = rel_diff_array2(&h_scaled, &h_exp);
assert!(h_rel < 1e-12, "hessian rel {}", h_rel);
}
#[test]
fn bernoulli_sigma_psi_second_order_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let baseline = family
.sigma_exact_joint_psisecond_order_terms(&states)
.expect("baseline")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.sigma_exact_joint_psisecond_order_terms_with_options(&states, &opts_full)
.expect("with full mask")
.expect("some");
let obj_rel = ((with_full.objective_psi_psi - baseline.objective_psi_psi)
/ baseline.objective_psi_psi.abs().max(1.0))
.abs();
assert!(obj_rel < 1e-12, "objective rel {}", obj_rel);
let score_rel = rel_diff_array1(&with_full.score_psi_psi, &baseline.score_psi_psi);
assert!(score_rel < 1e-12, "score rel {}", score_rel);
}
#[test]
fn bernoulli_sigma_psi_second_order_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.sigma_exact_joint_psisecond_order_terms_with_options(&states, &opts_half)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::with_uniform_weight(
even_mask, m, 0, 1.0,
)));
let raw = family
.sigma_exact_joint_psisecond_order_terms_with_options(&states, &opts_raw)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp_obj = factor * raw.objective_psi_psi;
let obj_rel = ((scaled.objective_psi_psi - exp_obj) / exp_obj.abs().max(1.0)).abs();
assert!(obj_rel < 1e-12, "objective rel {}", obj_rel);
let exp_score = &raw.score_psi_psi * factor;
let score_rel = rel_diff_array1(&scaled.score_psi_psi, &exp_score);
assert!(score_rel < 1e-12, "score rel {}", score_rel);
}
#[test]
fn bernoulli_sigma_psihessian_directional_derivative_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let dir = array![0.1, -0.2];
let baseline = family
.sigma_exact_joint_psihessian_directional_derivative(&states, &dir)
.expect("baseline")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.sigma_exact_joint_psihessian_directional_derivative_with_options(
&states, &dir, &opts_full,
)
.expect("with full")
.expect("some");
let rel = rel_diff_array2(&with_full, &baseline);
assert!(rel < 1e-12, "drift rel {}", rel);
}
#[test]
fn bernoulli_sigma_psihessian_directional_derivative_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let dir = array![0.1, -0.2];
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.sigma_exact_joint_psihessian_directional_derivative_with_options(
&states, &dir, &opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::with_uniform_weight(
even_mask, m, 0, 1.0,
)));
let raw = family
.sigma_exact_joint_psihessian_directional_derivative_with_options(
&states, &dir, &opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp = &raw * factor;
let rel = rel_diff_array2(&scaled, &exp);
assert!(rel < 1e-12, "drift rel {}", rel);
}
#[test]
fn bernoulli_psi_workspace_with_options_threads_subsample_to_first_order() {
use crate::custom_family::CustomFamilyBlockPsiDerivative;
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let derivative_blocks: Vec<Vec<CustomFamilyBlockPsiDerivative>> = vec![
Vec::new(),
vec![CustomFamilyBlockPsiDerivative::new(
None,
Array2::zeros((0, 0)),
Array2::zeros((0, 0)),
None,
None,
None,
None,
)],
];
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xBEEF_CAFE,
)));
let direct = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_half)
.expect("direct sigma terms with options")
.expect("direct some");
let ws = family
.exact_newton_joint_psi_workspace_with_options(
&states,
&specs,
&derivative_blocks,
&opts_half,
)
.expect("workspace with options")
.expect("workspace some");
let psi_total: usize = derivative_blocks.iter().map(Vec::len).sum();
let sigma_psi = psi_total - 1;
let via_ws = ws
.first_order_terms(sigma_psi)
.expect("ws first_order_terms")
.expect("some");
assert_eq!(via_ws.objective_psi, direct.objective_psi);
let score_rel = rel_diff_array1(&via_ws.score_psi, &direct.score_psi);
assert!(score_rel == 0.0, "score_psi diverged: rel {}", score_rel);
let h_ws = via_ws
.hessian_psi_operator
.as_ref()
.expect("ws hessian op")
.to_dense();
let h_direct = direct
.hessian_psi_operator
.as_ref()
.expect("direct hessian op")
.to_dense();
let h_rel = rel_diff_array2(&h_ws, &h_direct);
assert!(h_rel == 0.0, "hessian diverged: rel {}", h_rel);
let ws_full = family
.exact_newton_joint_psi_workspace_with_options(
&states,
&specs,
&derivative_blocks,
&BlockwiseFitOptions::default(),
)
.expect("workspace full")
.expect("some");
let via_ws_full = ws_full
.first_order_terms(sigma_psi)
.expect("ws_full first_order_terms")
.expect("some");
assert!(
(via_ws.objective_psi - via_ws_full.objective_psi).abs() > 1e-9,
"subsample objective {} too close to full-data {} (subsample not threaded?)",
via_ws.objective_psi,
via_ws_full.objective_psi
);
let m_f = m as f64;
let n_f = n as f64;
let ratio_bound = (n_f / m_f) * 2.0 + 1.0;
let ratio = (via_ws.objective_psi.abs() + 1.0) / (via_ws_full.objective_psi.abs() + 1.0);
assert!(
ratio < ratio_bound && (1.0 / ratio) < ratio_bound,
"subsample/full ratio {} outside coarse bound {}",
ratio,
ratio_bound
);
}
fn build_test_link_deviation_block_from_seed(
seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_link_deviation_block_from_knots_design_seed_and_weights(
seed,
seed,
&Array1::ones(seed.len()),
cfg,
)
}
#[test]
fn score_warp_basis_smoothness_penalty_is_full_rank() {
let seed = array![-2.0, -1.0, 0.0, 1.0, 2.0];
let cfg = DeviationBlockConfig {
num_internal_knots: 5,
..DeviationBlockConfig::default()
};
let prepared = build_score_warp_deviation_block_from_seed(&seed, &cfg)
.expect("build smoothness-null-space-drop score-warp");
let max_order = cfg
.penalty_orders
.iter()
.copied()
.max()
.or(Some(cfg.penalty_order))
.filter(|o| *o > 0)
.expect("test config has a positive penalty order");
let (penalty, nullity) = prepared
.runtime
.integrated_derivative_penalty_with_nullity(max_order)
.expect("integrated penalty on transformed basis");
assert_eq!(
nullity, 0,
"smoothness-null-space-drop basis must have zero nullity at the max configured \
derivative order; got {nullity}"
);
use crate::faer_ndarray::FaerEigh;
let (evals, _) = penalty
.eigh(faer::Side::Lower)
.expect("eigendecomposition of transformed-basis penalty");
let evals_slice = evals
.as_slice()
.expect("contiguous transformed-basis penalty eigenvalues");
let threshold = crate::estimate::reml::unified::positive_eigenvalue_threshold(evals_slice);
let smallest = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
assert!(
smallest > threshold,
"smallest eigenvalue {smallest} of transformed-basis penalty must exceed positive \
threshold {threshold} — null space was not fully dropped"
);
}
#[test]
fn link_deviation_basis_smoothness_penalty_is_full_rank() {
let q = array![-2.0, -0.8, -0.1, 0.4, 1.3, 2.1];
let weights = array![0.2, 1.7, 0.5, 2.3, 0.8, 1.1];
let cfg = DeviationBlockConfig {
num_internal_knots: 5,
..DeviationBlockConfig::default()
};
let prepared =
build_link_deviation_block_from_knots_design_seed_and_weights(&q, &q, &weights, &cfg)
.expect("build smoothness-null-space-drop link-deviation");
let max_order = cfg
.penalty_orders
.iter()
.copied()
.max()
.or(Some(cfg.penalty_order))
.filter(|o| *o > 0)
.expect("test config has a positive penalty order");
let (penalty, nullity) = prepared
.runtime
.integrated_derivative_penalty_with_nullity(max_order)
.expect("integrated penalty on transformed basis");
assert_eq!(
nullity, 0,
"smoothness-null-space-drop basis must have zero nullity at the max configured \
derivative order; got {nullity}"
);
use crate::faer_ndarray::FaerEigh;
let (evals, _) = penalty
.eigh(faer::Side::Lower)
.expect("eigendecomposition of transformed-basis penalty");
let evals_slice = evals
.as_slice()
.expect("contiguous transformed-basis penalty eigenvalues");
let threshold = crate::estimate::reml::unified::positive_eigenvalue_threshold(evals_slice);
let smallest = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
assert!(
smallest > threshold,
"smallest eigenvalue {smallest} of transformed-basis penalty must exceed positive \
threshold {threshold} — null space was not fully dropped"
);
}
#[test]
fn bernoulli_marginal_slope_rejects_nonprobit_base_link() {
let y = array![0.0, 1.0];
let weights = array![1.0, 1.0];
let z = array![-0.4, 0.9];
let design =
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![[1.0], [1.0]]));
let spec = BernoulliMarginalSlopeTermSpec {
y,
weights,
z,
base_link: InverseLink::Standard(LinkFunction::Logit),
marginalspec: empty_termspec(),
logslopespec: empty_termspec(),
marginal_offset: Array1::zeros(2),
logslope_offset: Array1::zeros(2),
frailty: FrailtySpec::None,
score_warp: None,
link_dev: None,
latent_z_policy: LatentZPolicy::default(),
};
let err = validate_spec(design.to_dense().view(), &spec)
.expect_err("non-probit marginal-slope link should be rejected");
assert!(err.contains("requires link(type=probit)"));
let err = bernoulli_marginal_slope_eta_from_probability(
&InverseLink::Standard(LinkFunction::Logit),
0.5,
"test logit inverse",
)
.expect_err("non-probit marginal-slope inverse should be rejected");
assert!(err.contains("requires link(type=probit)"));
}
fn expand_integer_weight_rows(
y: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
) -> (Array1<f64>, Array1<f64>) {
let mut y_expanded = Vec::new();
let mut z_expanded = Vec::new();
for i in 0..y.len() {
let reps = weights[i] as usize;
assert!(
(weights[i] - reps as f64).abs() < 1e-12,
"test helper expects integer weights, got {}",
weights[i]
);
for _ in 0..reps {
y_expanded.push(y[i]);
z_expanded.push(z[i]);
}
}
(Array1::from_vec(y_expanded), Array1::from_vec(z_expanded))
}
#[test]
fn link_dev_without_score_warp_exposes_structural_derivative_lower_bounds() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link block initial beta")
.len();
let beta_link = Array1::from_iter((0..link_dim).map(|idx| 0.1 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(seed.len())),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_link.clone(), seed.len()),
];
let slices = block_slices(&family);
assert!(slices.h.is_none(), "score-warp slice should be absent");
let link_slice = slices.w.as_ref().expect("link slice");
assert_eq!(
slices.marginal.len(),
0,
"zero-column marginal design should not contribute coefficient coordinates"
);
assert_eq!(
slices.logslope.len(),
0,
"zero-column logslope design should not contribute coefficient coordinates"
);
assert_eq!(
link_slice.start, 0,
"link-only coefficients should start at 0"
);
assert_eq!(link_slice.len(), link_dim);
let primary = primary_slices(&slices);
assert!(primary.h.is_none(), "primary h slice should be absent");
let primary_w = primary.w.as_ref().expect("primary link slice");
assert_eq!(primary_w.start, 2, "primary link slice should start at 2");
assert_eq!(primary.total, 2 + link_dim);
family
.build_exact_eval_cache(&block_states)
.expect("eval cache");
let row_ctx = family
.build_row_exact_context(0, &block_states)
.expect("row context");
let (nll, grad, hess) = family
.compute_row_primary_gradient_hessian(0, &block_states, &primary, &row_ctx)
.expect("analytic flex eval");
assert!(nll.is_finite(), "neglog should be finite for link-dev-only");
assert!(
grad.iter().all(|v| v.is_finite()),
"gradient should be finite"
);
assert!(
hess.iter().all(|v| v.is_finite()),
"Hessian should be finite"
);
let dummy_spec = dummy_blockspec(link_dim, seed.len());
assert!(
family
.block_linear_constraints(&block_states, 1, &dummy_spec)
.expect("non-link constraint lookup")
.is_none(),
"non-link block should not expose auxiliary monotonicity constraints"
);
let constraints = family
.block_linear_constraints(&block_states, 2, &dummy_spec)
.expect("link constraint lookup")
.expect("link constraints");
assert_eq!(constraints.a.ncols(), link_dim);
assert_eq!(constraints.b.len(), constraints.a.nrows());
assert!(
constraints.a.nrows() >= link_dim,
"anchored link constraints should be expressed in raw derivative-control rows"
);
assert_eq!(
constraints.b,
Array1::<f64>::from_elem(
constraints.a.nrows(),
prepared.runtime.monotonicity_eps() - 1.0
)
);
}
#[test]
fn zero_deviation_intercept_fast_path_matches_denested_calibration() {
let seed = Array1::from_iter((0..25).map(|i| -2.4 + 4.8 * i as f64 / 24.0));
let score_prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 8,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 8,
..DeviationBlockConfig::default()
},
)
.expect("build link-deviation block");
let n = seed.len();
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(n)),
weights: Arc::new(Array1::ones(n)),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: Some(0.65),
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((n, 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((n, 0)),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: Some(new_intercept_warm_start_cache(n)),
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let marginal_eta = 0.35;
let slope = -0.8;
let beta_h = Array1::zeros(score_prepared.runtime.basis_dim());
let beta_w = Array1::zeros(link_prepared.runtime.basis_dim());
let marginal = family
.marginal_link_map(marginal_eta)
.expect("marginal map");
let scale = family.probit_frailty_scale();
let rigid_a = rigid_prescale_intercept_from_marginal(marginal.q, slope, scale);
let (f_rigid, f_a_rigid, _) = family
.evaluate_denested_calibration_newton(
rigid_a,
marginal_eta,
slope,
Some(&beta_h),
Some(&beta_w),
)
.expect("denested zero-deviation calibration");
assert!(
f_rigid.abs() <= 5e-13,
"closed-form rigid intercept residual should be at machine epsilon, got {f_rigid}"
);
let analytic_deriv = rigid_prescale_intercept_derivative_abs(marginal.q, slope, scale);
assert!(
(f_a_rigid - analytic_deriv).abs() <= 5e-13,
"denested derivative {f_a_rigid} != analytic {analytic_deriv}"
);
let (a_fast, deriv_fast, fast_path) = family
.solve_row_intercept_base(0, marginal_eta, slope, Some(&beta_h), Some(&beta_w), None)
.expect("zero-deviation solve");
assert!(
fast_path,
"zero coefficients should take analytic fast path"
);
assert_eq!(a_fast, rigid_a);
assert_eq!(deriv_fast, analytic_deriv);
}
#[test]
fn exact_layout_ignores_dummy_beta_widths_for_empty_design_blocks() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let score_prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(seed.len())),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::zeros(score_prepared.runtime.basis_dim()),
seed.len(),
),
dummy_block_state(Array1::zeros(link_prepared.runtime.basis_dim()), seed.len()),
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
assert_eq!(cache.slices.marginal.len(), 0);
assert_eq!(cache.slices.logslope.len(), 0);
assert_eq!(cache.slices.h.as_ref().expect("h slice").start, 0);
assert_eq!(
cache.slices.w.as_ref().expect("w slice").start,
score_prepared.runtime.basis_dim()
);
assert_eq!(
cache.slices.total,
score_prepared.runtime.basis_dim() + link_prepared.runtime.basis_dim()
);
assert_eq!(cache.primary.q, 0);
assert_eq!(cache.primary.logslope, 1);
assert_eq!(cache.primary.h.as_ref().expect("primary h").start, 2);
assert_eq!(
cache.primary.w.as_ref().expect("primary w").start,
2 + score_prepared.runtime.basis_dim()
);
}
#[test]
fn score_warp_block_exposes_structural_derivative_lower_bounds() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(seed.len())),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(Array1::zeros(score_dim), seed.len()),
];
let dummy_spec = dummy_blockspec(score_dim, seed.len());
let constraints = family
.block_linear_constraints(&block_states, 2, &dummy_spec)
.expect("constraint lookup")
.expect("score-warp constraints");
assert_eq!(constraints.a.ncols(), score_dim);
assert_eq!(constraints.b.len(), constraints.a.nrows());
assert!(
constraints.a.nrows() >= score_dim,
"anchored score-warp constraints should be expressed in raw derivative-control rows"
);
assert_eq!(
constraints.b,
Array1::<f64>::from_elem(
constraints.a.nrows(),
prepared.runtime.monotonicity_eps() - 1.0
)
);
}
#[test]
fn post_update_block_beta_projects_score_warp_to_feasible_step() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared.block.design.ncols();
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(seed.len())),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let current = Array1::<f64>::zeros(score_dim);
let mut proposed = current.clone();
proposed[0] = -128.0;
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(current.clone(), seed.len()),
];
let spec = dummy_blockspec(score_dim, seed.len());
let updated = family
.post_update_block_beta(&block_states, 2, &spec, proposed.clone())
.expect("projected beta");
prepared
.runtime
.monotonicity_feasible(&updated, "projected score-warp")
.expect("post-update beta should remain feasible");
assert_ne!(updated, proposed);
}
#[test]
fn structural_deviation_runtime_is_piecewise_cubic() {
let seed = array![-1.0, 0.0, 1.0];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
..DeviationBlockConfig::default()
},
)
.expect("structural deviation basis");
assert_eq!(prepared.runtime.degree(), 3);
assert_eq!(prepared.runtime.value_span_degree(), 3);
let has_cubic_curvature = prepared
.runtime
.span_c3()
.iter()
.any(|value| value.abs() > 1e-12);
assert!(
has_cubic_curvature,
"structural deviation basis must expose true cubic span coefficients"
);
}
#[test]
fn structural_deviation_runtime_is_c2_at_internal_breakpoints() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("structural deviation basis");
let dim = prepared.block.design.ncols();
let beta = Array1::from_iter((0..dim).map(|idx| 0.015 * (idx as f64 + 1.0)));
let n_spans = prepared.runtime.breakpoints().len().saturating_sub(1);
for span_idx in 1..n_spans {
let left_cubic = prepared
.runtime
.local_cubic_on_span(&beta, span_idx - 1)
.expect("left span cubic");
let right_cubic = prepared
.runtime
.local_cubic_on_span(&beta, span_idx)
.expect("right span cubic");
let knot = prepared.runtime.breakpoints()[span_idx];
assert!(
(left_cubic.evaluate(knot) - right_cubic.evaluate(knot)).abs() <= 1e-10,
"deviation value should be continuous at breakpoint {span_idx}"
);
assert!(
(left_cubic.first_derivative(knot) - right_cubic.first_derivative(knot)).abs()
<= 1e-10,
"deviation first derivative should be continuous at breakpoint {span_idx}"
);
assert!(
(left_cubic.second_derivative(knot) - right_cubic.second_derivative(knot)).abs()
<= 1e-10,
"deviation second derivative should be continuous at breakpoint {span_idx}"
);
}
}
#[test]
fn structural_deviation_rejects_noncubic_degree() {
let seed = array![-1.0, 0.0, 1.0];
let err = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
degree: 4,
..DeviationBlockConfig::default()
},
)
.expect_err("structural deviation block should reject non-cubic degree");
assert!(err.contains("degree must be 3"));
}
#[test]
fn deviation_runtime_replays_exact_training_design() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let replayed = prepared.runtime.design(&seed).expect("replayed design");
let trained = prepared.block.design.to_dense();
assert_eq!(replayed.dim(), trained.dim());
for i in 0..replayed.nrows() {
for j in 0..replayed.ncols() {
assert!(
(replayed[[i, j]] - trained[[i, j]]).abs() <= 1e-10,
"training-basis replay mismatch at ({i},{j})"
);
}
}
}
#[test]
fn structural_constraints_match_exact_monotonicity_guard() {
let seed = array![-1.0, 0.0, 1.0, 2.0];
let prepared =
build_score_warp_deviation_block_from_seed(&seed, &DeviationBlockConfig::default())
.expect("build deviation block");
let constraints = prepared.runtime.structural_monotonicity_constraints();
let dim = constraints.a.ncols();
assert_eq!(dim, prepared.runtime.basis_dim());
assert_eq!(
constraints.a.nrows(),
3 * prepared.runtime.breakpoints().len().saturating_sub(1)
);
assert_eq!(
constraints.b,
Array1::<f64>::from_elem(
constraints.a.nrows(),
prepared.runtime.monotonicity_eps() - 1.0
)
);
let feasible = Array1::<f64>::zeros(dim);
prepared
.runtime
.monotonicity_feasible(&feasible, "feasible structural beta")
.expect("zero deviation should be feasible");
let d1_design = prepared
.runtime
.first_derivative_design(&seed)
.expect("derivative design");
let row_idx = (0..d1_design.nrows())
.find(|&idx| d1_design.row(idx).dot(&d1_design.row(idx)) > 0.0)
.expect("derivative design should include a nonzero row");
let derivative_row = d1_design.row(row_idx).to_owned();
let row_norm_sq = derivative_row.dot(&derivative_row);
let infeasible = derivative_row.mapv(|value| -2.0 * value / row_norm_sq);
assert!(
prepared
.runtime
.monotonicity_feasible(&infeasible, "infeasible structural beta")
.is_err()
);
}
#[test]
fn structural_constraints_are_quadratic_derivative_bernstein_controls() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let constraints = prepared.runtime.structural_monotonicity_constraints();
let beta = Array1::from_iter((0..prepared.runtime.basis_dim()).map(|idx| {
let centered = idx as f64 - 0.5 * (prepared.runtime.basis_dim() as f64 - 1.0);
0.025 * centered
}));
let controls = constraints.a.dot(&beta);
let n_spans = prepared.runtime.breakpoints().len().saturating_sub(1);
for span_idx in 0..n_spans {
let cubic = prepared
.runtime
.local_cubic_on_span(&beta, span_idx)
.expect("local cubic");
let left = cubic.left;
let right = cubic.right;
let mid = 0.5 * (left + right);
let b0 = controls[3 * span_idx];
let b1 = controls[3 * span_idx + 1];
let b2 = controls[3 * span_idx + 2];
assert!(
(b0 - cubic.first_derivative(left)).abs() <= 1e-10,
"left Bernstein control should equal derivative at span start"
);
assert!(
(b2 - cubic.first_derivative(right)).abs() <= 1e-10,
"right Bernstein control should equal derivative at span end"
);
let midpoint_from_bernstein = 0.25 * b0 + 0.5 * b1 + 0.25 * b2;
assert!(
(midpoint_from_bernstein - cubic.first_derivative(mid)).abs() <= 1e-10,
"quadratic Bernstein controls should reconstruct derivative at span midpoint"
);
}
}
#[test]
fn deviation_penalties_are_integrated_function_penalties() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
penalty_order: 2,
penalty_orders: vec![1, 2, 3],
double_penalty: true,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let expected_orders = [1, 2, 3, 0];
assert_eq!(prepared.block.penalties.len(), expected_orders.len());
for ((penalty, &nullity), &order) in prepared
.block
.penalties
.iter()
.zip(prepared.block.nullspace_dims.iter())
.zip(expected_orders.iter())
{
let crate::solver::estimate::PenaltySpec::Dense(actual) = penalty else {
panic!("deviation penalties should be dense local Gram matrices");
};
let (expected, expected_nullity) = prepared
.runtime
.integrated_derivative_penalty_with_nullity(order)
.expect("integrated function penalty");
assert_eq!(nullity, expected_nullity);
assert_eq!(actual.dim(), expected.dim());
for i in 0..actual.nrows() {
for j in 0..actual.ncols() {
assert!(
(actual[[i, j]] - expected[[i, j]]).abs() <= 1e-10,
"penalty order {order} mismatch at ({i},{j}): got {}, expected {}",
actual[[i, j]],
expected[[i, j]]
);
}
}
}
let crate::solver::estimate::PenaltySpec::Dense(l2_penalty) = &prepared.block.penalties[1]
else {
panic!("deviation double penalty should be dense");
};
let mut max_identity_diff = 0.0_f64;
for i in 0..l2_penalty.nrows() {
for j in 0..l2_penalty.ncols() {
let identity = if i == j { 1.0 } else { 0.0 };
max_identity_diff = max_identity_diff.max((l2_penalty[[i, j]] - identity).abs());
}
}
assert!(
max_identity_diff > 1e-6,
"deviation double penalty must be integrated L2, not coefficient identity"
);
}
#[test]
fn local_cubic_span_reconstructs_deviation_exactly() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let dim = prepared.block.design.ncols();
let beta = Array1::from_iter((0..dim).map(|idx| 0.025 * (idx as f64 + 1.0)));
let n_spans = prepared.runtime.breakpoints().len().saturating_sub(1);
let support_left = prepared.runtime.breakpoints()[0];
let support_right =
prepared.runtime.breakpoints()[prepared.runtime.breakpoints().len() - 1];
for span_idx in 0..n_spans {
let cubic = prepared
.runtime
.local_cubic_on_span(&beta, span_idx)
.expect("local cubic coefficients");
let left = cubic.left;
let right = cubic.right;
let x_eval = array![left, 0.5 * (left + right), right];
let value_design = prepared.runtime.design(&x_eval).expect("value design");
let d1_design = prepared
.runtime
.first_derivative_design(&x_eval)
.expect("first derivative design");
let d2_design = prepared
.runtime
.second_derivative_design(&x_eval)
.expect("second derivative design");
let expected = value_design.dot(&beta);
let expected_d1 = d1_design.dot(&beta);
let expected_d2 = d2_design.dot(&beta);
for i in 0..x_eval.len() {
let x = x_eval[i];
assert!(
(cubic.evaluate(x) - expected[i]).abs() < 1e-10,
"span {span_idx}, x={x:.6}: cubic value mismatch"
);
assert!(
(cubic.first_derivative(x) - expected_d1[i]).abs() < 1e-10,
"span {span_idx}, x={x:.6}: cubic first-derivative mismatch"
);
assert!(
(cubic.second_derivative(x) - expected_d2[i]).abs() < 1e-10,
"span {span_idx}, x={x:.6}: cubic second-derivative mismatch"
);
let selected = prepared
.runtime
.local_cubic_at(&beta, x)
.expect("lookup cubic at x");
if x < support_left || x > support_right {
assert!(selected.c1.abs() < 1e-12);
assert!(selected.c2.abs() < 1e-12);
assert!(selected.c3.abs() < 1e-12);
assert!((selected.evaluate(x) - expected[i]).abs() < 1e-10);
} else {
let expected_span_idx = if i == 0 && span_idx > 0 {
span_idx - 1
} else {
span_idx
};
let expected_cubic = prepared
.runtime
.local_cubic_on_span(&beta, expected_span_idx)
.expect("expected lookup cubic on span");
assert_eq!(selected, expected_cubic);
}
}
}
}
#[test]
fn basis_span_cubic_reconstructs_basis_column_exactly() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let basis_idx = 0usize;
let support_left = prepared.runtime.breakpoints()[0];
let support_right =
prepared.runtime.breakpoints()[prepared.runtime.breakpoints().len() - 1];
let cubic = prepared
.runtime
.basis_span_cubic(0, basis_idx)
.expect("basis span cubic");
let x_eval = array![cubic.left, 0.5 * (cubic.left + cubic.right), cubic.right];
let design = prepared.runtime.design(&x_eval).expect("basis design");
let d1 = prepared
.runtime
.first_derivative_design(&x_eval)
.expect("basis d1 design");
for i in 0..x_eval.len() {
let x = x_eval[i];
assert!((cubic.evaluate(x) - design[[i, basis_idx]]).abs() < 1e-10);
assert!((cubic.first_derivative(x) - d1[[i, basis_idx]]).abs() < 1e-10);
let selected = prepared
.runtime
.basis_cubic_at(basis_idx, x)
.expect("basis cubic at x");
if x < support_left || x > support_right {
assert!(selected.c1.abs() < 1e-12);
assert!(selected.c2.abs() < 1e-12);
assert!(selected.c3.abs() < 1e-12);
assert!((selected.evaluate(x) - design[[i, basis_idx]]).abs() < 1e-10);
} else {
let expected_span_idx = 0;
let expected_cubic = prepared
.runtime
.basis_span_cubic(expected_span_idx, basis_idx)
.expect("expected basis span cubic");
assert_eq!(selected, expected_cubic);
}
}
}
#[test]
fn deviation_runtime_saturates_outside_support() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let dim = prepared.block.design.ncols();
let beta = Array1::from_iter((0..dim).map(|idx| 0.02 * (idx as f64 + 1.0)));
let left = prepared.runtime.breakpoints()[0];
let right = prepared.runtime.breakpoints()[prepared.runtime.breakpoints().len() - 1];
let left_tail_near = prepared
.runtime
.local_cubic_at(&beta, left - 0.25)
.expect("left tail");
let left_tail_far = prepared
.runtime
.local_cubic_at(&beta, left - 3.0)
.expect("left far tail");
let right_tail_near = prepared
.runtime
.local_cubic_at(&beta, right + 0.25)
.expect("right tail");
let right_tail_far = prepared
.runtime
.local_cubic_at(&beta, right + 3.0)
.expect("right far tail");
for cubic in [
left_tail_near,
left_tail_far,
right_tail_near,
right_tail_far,
] {
assert!(cubic.c1.abs() < 1e-12);
assert!(cubic.c2.abs() < 1e-12);
assert!(cubic.c3.abs() < 1e-12);
}
assert!((left_tail_near.c0 - left_tail_far.c0).abs() < 1e-12);
assert!((right_tail_near.c0 - right_tail_far.c0).abs() < 1e-12);
}
#[test]
fn deviation_runtime_replays_the_exact_training_basis() {
let seed = array![-2.0, -1.0, -0.25, 0.25, 1.0, 2.0];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let replayed = prepared
.runtime
.design(&seed)
.expect("replay anchored deviation design");
let trained = prepared.block.design.to_dense();
assert_eq!(replayed.dim(), trained.dim());
for i in 0..replayed.nrows() {
for j in 0..replayed.ncols() {
assert!(
(replayed[[i, j]] - trained[[i, j]]).abs() <= 1e-10,
"replayed anchored deviation design mismatch at ({i},{j})"
);
}
}
}
#[test]
fn denested_microcells_follow_score_and_link_breaks() {
let score_seed = array![-2.0, -1.0, 0.0, 1.0, 2.0];
let link_seed = array![-1.5, -0.5, 0.5, 1.5];
let score_prepared = build_score_warp_deviation_block_from_seed(
&score_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build score warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.02 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
);
let exact_cells_a0 = build_denested_partition_cells(
0.25,
0.9,
score_prepared
.runtime
.breakpoints()
.as_slice()
.expect("score breaks"),
link_prepared
.runtime
.breakpoints()
.as_slice()
.expect("link breaks"),
|z| score_prepared.runtime.local_cubic_at(&beta_h, z),
|u| link_prepared.runtime.local_cubic_at(&beta_w, u),
)
.expect("exact module microcells for a=0.25");
let exact_cells_a1 = build_denested_partition_cells(
0.55,
0.9,
score_prepared
.runtime
.breakpoints()
.as_slice()
.expect("score breaks"),
link_prepared
.runtime
.breakpoints()
.as_slice()
.expect("link breaks"),
|z| score_prepared.runtime.local_cubic_at(&beta_h, z),
|u| link_prepared.runtime.local_cubic_at(&beta_w, u),
)
.expect("exact module microcells for a=0.55");
assert!(
exact_cells_a0.len() >= score_prepared.runtime.breakpoints().len().saturating_sub(1),
"microcell partition should refine the score spans"
);
assert!(
exact_cells_a0
.windows(2)
.all(|w| (w[0].cell.right - w[1].cell.left).abs() <= 1e-12),
"microcells should tile the partition contiguously"
);
assert!(exact_cells_a0.first().unwrap().cell.left.is_infinite());
assert!(exact_cells_a0.last().unwrap().cell.right.is_infinite());
assert!(
exact_cells_a0
.iter()
.zip(exact_cells_a1.iter())
.any(|(lhs, rhs)| (lhs.cell.left - rhs.cell.left).abs() > 1e-10),
"changing the intercept should move at least one link-induced breakpoint"
);
}
#[test]
fn denested_microcell_eta_matches_direct_denested_formula() {
let score_seed = array![-2.0, -1.0, 0.0, 1.0, 2.0];
let link_seed = array![-1.5, -0.5, 0.5, 1.5];
let score_prepared = build_score_warp_deviation_block_from_seed(
&score_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build score warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.02 * (idx as f64 + 1.0)),
);
let a = 0.35;
let b = -0.7;
let cells = build_denested_partition_cells(
a,
b,
score_prepared
.runtime
.breakpoints()
.as_slice()
.expect("score breaks"),
link_prepared
.runtime
.breakpoints()
.as_slice()
.expect("link breaks"),
|z| score_prepared.runtime.local_cubic_at(&beta_h, z),
|u| link_prepared.runtime.local_cubic_at(&beta_w, u),
)
.expect("microcells");
for cell in &cells {
let z = exact_kernel::interval_probe_point(cell.cell.left, cell.cell.right)
.expect("finite microcell probe");
let h = score_prepared
.runtime
.design(&array![z])
.expect("score design")
.row(0)
.dot(&beta_h);
let link = link_prepared
.runtime
.design(&array![a + b * z])
.expect("link design")
.row(0)
.dot(&beta_w);
let expected = a + b * z + b * h + link;
assert!(
(cell.cell.eta(z) - expected).abs() < 1e-10,
"microcell eta should equal the direct de-nested predictor at z={z:.6}"
);
}
}
#[test]
fn local_cubic_global_transform_reconstructs_same_function() {
let cubic = exact_kernel::LocalSpanCubic {
left: -1.3,
right: 0.7,
c0: 0.4,
c1: -0.2,
c2: 0.15,
c3: -0.05,
};
let (g0, g1, g2, g3) = exact_global_cubic_from_local(LocalSpanCubic {
left: cubic.left,
right: cubic.right,
c0: cubic.c0,
c1: cubic.c1,
c2: cubic.c2,
c3: cubic.c3,
});
for &x in &[-1.3, -0.8, -0.1, 0.5, 0.7] {
let direct = cubic.evaluate(x);
let global = g0 + g1 * x + g2 * x * x + g3 * x * x * x;
assert!(
(direct - global).abs() < 1e-12,
"global cubic transform should preserve the span polynomial at x={x}"
);
}
}
#[test]
fn denested_branch_selection_uses_normalized_cell_coefficients() {
let affine = ExactDenestedCubicCell {
left: -1.0,
right: 1.0,
c0: 0.2,
c1: -0.4,
c2: 1e-13,
c3: -1e-13,
};
let quartic = ExactDenestedCubicCell {
c2: 2e-4,
c3: 1e-13,
..affine
};
let sextic = ExactDenestedCubicCell {
c2: 2e-4,
c3: 5e-3,
..affine
};
assert_eq!(
branch_exact_cell(affine).expect("affine branch"),
ExactCellBranchShared::Affine
);
assert_eq!(
branch_exact_cell(quartic).expect("quartic branch"),
ExactCellBranchShared::Quartic
);
assert_eq!(
branch_exact_cell(sextic).expect("sextic branch"),
ExactCellBranchShared::Sextic
);
}
#[test]
fn denested_cell_coefficient_partials_match_finite_differences() {
let score_span = exact_kernel::LocalSpanCubic {
left: -0.75,
right: 0.25,
c0: 0.08,
c1: -0.03,
c2: 0.02,
c3: -0.01,
};
let link_span = exact_kernel::LocalSpanCubic {
left: -0.6,
right: 0.9,
c0: -0.05,
c1: 0.04,
c2: -0.02,
c3: 0.015,
};
let a = 0.3;
let b = -0.7;
let eps = 1e-6;
let coeffs = |aa: f64, bb: f64| {
let (h0, h1, h2, h3) = exact_global_cubic_from_local(LocalSpanCubic {
left: score_span.left,
right: score_span.right,
c0: score_span.c0,
c1: score_span.c1,
c2: score_span.c2,
c3: score_span.c3,
});
let (d0, d1, d2, d3) = exact_transformed_link_cubic(
LocalSpanCubic {
left: link_span.left,
right: link_span.right,
c0: link_span.c0,
c1: link_span.c1,
c2: link_span.c2,
c3: link_span.c3,
},
aa,
bb,
);
[
aa + bb * h0 + d0,
bb + bb * h1 + d1,
bb * h2 + d2,
bb * h3 + d3,
]
};
let (dc_da, dc_db) = exact_denested_cell_coefficient_partials(
LocalSpanCubic {
left: score_span.left,
right: score_span.right,
c0: score_span.c0,
c1: score_span.c1,
c2: score_span.c2,
c3: score_span.c3,
},
LocalSpanCubic {
left: link_span.left,
right: link_span.right,
c0: link_span.c0,
c1: link_span.c1,
c2: link_span.c2,
c3: link_span.c3,
},
a,
b,
);
let plus_a = coeffs(a + eps, b);
let minus_a = coeffs(a - eps, b);
let plus_b = coeffs(a, b + eps);
let minus_b = coeffs(a, b - eps);
for j in 0..4 {
let fd_a = (plus_a[j] - minus_a[j]) / (2.0 * eps);
let fd_b = (plus_b[j] - minus_b[j]) / (2.0 * eps);
assert!(
(dc_da[j] - fd_a).abs() < 1e-6,
"dc/da mismatch at coefficient {j}: analytic={}, fd={fd_a}",
dc_da[j]
);
assert!(
(dc_db[j] - fd_b).abs() < 1e-6,
"dc/db mismatch at coefficient {j}: analytic={}, fd={fd_b}",
dc_db[j]
);
}
}
#[test]
fn observed_denested_partials_include_third_a_derivative_for_piecewise_cubic_link() {
let z = array![-0.8, 0.2, 1.1];
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 1.0]),
weights: Arc::new(array![1.0, 0.7, 1.3]),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: None,
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let a = 0.35;
let b = 0.6;
let row = 1usize;
let obs = family
.observed_denested_cell_partials(row, a, b, None, Some(&beta_w))
.expect("observed denested partials");
let u_obs = a + b * z[row];
let link_span = link_prepared
.runtime
.local_cubic_at(&beta_w, u_obs)
.expect("local cubic at observed point");
let expected_daaa = exact_kernel::denested_cell_third_partials(link_span).0;
assert_eq!(obs.dc_daaa, expected_daaa);
assert!(
eval_coeff4_at(&obs.dc_daaa, z[row]).abs() > 1e-12,
"piecewise-cubic link spans should contribute a third a-derivative"
);
}
#[test]
fn pooled_probit_baseline_matches_expanded_integer_weight_fit() {
let y = array![0.0, 1.0, 0.0, 1.0];
let z = array![-1.5, -0.2, 0.4, 1.4];
let weights = array![25.0, 2.0, 1.0, 20.0];
let weighted = pooled_probit_baseline(&y, &z, &weights).expect("weighted baseline");
let unweighted =
pooled_probit_baseline(&y, &z, &Array1::ones(y.len())).expect("unweighted baseline");
let (y_expanded, z_expanded) = expand_integer_weight_rows(&y, &z, &weights);
let expanded =
pooled_probit_baseline(&y_expanded, &z_expanded, &Array1::ones(y_expanded.len()))
.expect("expanded baseline");
assert!(
pair_distance(expanded, unweighted) > 1e-2,
"test data should distinguish weighted from unweighted seeding"
);
assert!(
pair_distance(weighted, expanded) < 1e-8,
"weighted pilot baseline should match the expanded integer-weight fit"
);
}
#[test]
fn validate_spec_rejects_nonfinite_or_negative_weights() {
let data = Array2::<f64>::zeros((3, 0));
let y = array![0.0, 1.0, 0.0];
let z = array![-1.0, 0.0, 1.0];
let err = validate_spec(
data.view(),
&base_spec(y.clone(), array![1.0, f64::NAN, 1.0], z.clone()),
)
.expect_err("non-finite weights should be rejected");
assert!(err.contains("finite non-negative weights"));
let err = validate_spec(data.view(), &base_spec(y, array![1.0, -0.5, 1.0], z))
.expect_err("negative weights should be rejected");
assert!(err.contains("finite non-negative weights"));
}
#[test]
fn validate_spec_rejects_nonfinite_z_values() {
let data = Array2::<f64>::zeros((3, 0));
let err = validate_spec(
data.view(),
&base_spec(
array![0.0, 1.0, 0.0],
array![1.0, 1.0, 1.0],
array![-1.0, f64::INFINITY, 1.0],
),
)
.expect_err("non-finite z should be rejected");
assert!(err.contains("finite z values"));
}
#[test]
fn validate_spec_accepts_learnable_gaussian_shift_sigma() {
let data = Array2::<f64>::zeros((3, 0));
let mut spec = base_spec(
array![0.0, 1.0, 0.0],
array![1.0, 1.0, 1.0],
array![-1.0, 0.0, 1.0],
);
spec.frailty = FrailtySpec::GaussianShift { sigma_fixed: None };
validate_spec(data.view(), &spec).expect("learnable GaussianShift sigma should validate");
}
#[test]
fn signed_probit_helpers_handle_nonfinite_boundaries_explicitly() {
let (logcdf_pos, lambda_pos) = signed_probit_logcdf_and_mills_ratio(f64::INFINITY);
assert_eq!(logcdf_pos, 0.0);
assert_eq!(lambda_pos, 0.0);
let (logcdf_neg, lambda_neg) = signed_probit_logcdf_and_mills_ratio(f64::NEG_INFINITY);
assert_eq!(logcdf_neg, f64::NEG_INFINITY);
assert_eq!(lambda_neg, f64::INFINITY);
let (logcdf_nan, lambda_nan) = signed_probit_logcdf_and_mills_ratio(f64::NAN);
assert!(logcdf_nan.is_nan());
assert!(lambda_nan.is_nan());
}
#[test]
fn signed_probit_exact_derivative_helper_rejects_invalid_nonfinite_margins() {
assert_eq!(
signed_probit_neglog_derivatives_up_to_fourth(f64::INFINITY, 2.5)
.expect("+inf should use the zero tail"),
(0.0, 0.0, 0.0, 0.0)
);
let neg_inf_err = signed_probit_neglog_derivatives_up_to_fourth(f64::NEG_INFINITY, 2.5)
.expect_err("-inf should be rejected in the exact derivative path");
assert!(neg_inf_err.contains("non-finite signed margin"));
let nan_err = signed_probit_neglog_derivatives_up_to_fourth(f64::NAN, 2.5)
.expect_err("NaN should be rejected in the exact derivative path");
assert!(nan_err.contains("non-finite signed margin"));
}
#[test]
fn unary_neglog_phi_preserves_negative_infinity_and_nan_boundaries() {
assert_eq!(
unary_derivatives_neglog_phi(f64::INFINITY, 1.75),
[0.0, 0.0, 0.0, 0.0, 0.0]
);
assert_eq!(
unary_derivatives_neglog_phi(f64::NEG_INFINITY, 1.75),
[f64::INFINITY, f64::NEG_INFINITY, 1.75, 0.0, 0.0]
);
let nan_terms = unary_derivatives_neglog_phi(f64::NAN, 1.75);
assert!(nan_terms.iter().all(|value| value.is_nan()));
}
#[test]
fn flexible_family_uses_second_order_only_for_bounded_row_third_work() {
let seed = array![-1.0, 0.0, 1.0];
let score_prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(3)),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let specs = vec![
dummy_blockspec(1, 3),
dummy_blockspec(1, 3),
dummy_blockspec(2, 3),
];
assert_eq!(
family.exact_outer_derivative_order(&specs, &BlockwiseFitOptions::default()),
ExactOuterDerivativeOrder::Second
);
assert!(family.exact_newton_joint_psi_workspace_for_first_order_terms());
let n_large = 50_001usize;
let mut large_flex_family = family.clone();
large_flex_family.y = Arc::new(Array1::zeros(n_large));
large_flex_family.weights = Arc::new(Array1::ones(n_large));
large_flex_family.z = Arc::new(Array1::zeros(n_large));
let large_flex_specs = vec![
dummy_blockspec(1, n_large),
dummy_blockspec(1, n_large),
dummy_blockspec(2, n_large),
];
assert_eq!(
large_flex_family
.exact_outer_derivative_order(&large_flex_specs, &BlockwiseFitOptions::default()),
ExactOuterDerivativeOrder::Second
);
let penalized_spec = |p: usize, n_rows: usize, n_penalties: usize| ParameterBlockSpec {
name: "penalized".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::<f64>::zeros((n_rows, p)),
)),
offset: Array1::zeros(n_rows),
penalties: (0..n_penalties)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((p, p))))
.collect(),
nullspace_dims: vec![0; n_penalties],
initial_log_lambdas: Array1::zeros(n_penalties),
initial_beta: Some(Array1::zeros(p)),
};
let mut high_work_flex_family = large_flex_family.clone();
high_work_flex_family.link_dev = Some(score_prepared.runtime.clone());
let flex_p = score_prepared.runtime.basis_dim();
let high_work_specs = vec![
penalized_spec(1, n_large, 12),
penalized_spec(1, n_large, 12),
penalized_spec(flex_p, n_large, 12),
penalized_spec(flex_p, n_large, 12),
];
assert_eq!(
high_work_flex_family
.exact_outer_derivative_order(&high_work_specs, &BlockwiseFitOptions::default()),
ExactOuterDerivativeOrder::Second
);
let high_work_policy = high_work_flex_family.outer_derivative_policy(
&high_work_specs,
0,
&BlockwiseFitOptions::default(),
);
assert!(
high_work_policy.predicted_hessian_work
> crate::custom_family::OuterDerivativePolicy::OUTER_HESSIAN_WORK_BUDGET,
"high-work flex configuration must exceed the outer-Hessian work budget; \
got predicted={} budget={}",
high_work_policy.predicted_hessian_work,
crate::custom_family::OuterDerivativePolicy::OUTER_HESSIAN_WORK_BUDGET,
);
assert!(
matches!(
high_work_policy.declared_hessian_form(),
crate::solver::outer_strategy::DeclaredHessianForm::Unavailable
),
"policy must declare the outer Hessian Unavailable when predicted work \
exceeds the budget; got {:?}",
high_work_policy.declared_hessian_form(),
);
assert_eq!(
high_work_policy.order_for_evaluation(
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian,
),
crate::solver::outer_strategy::OuterEvalOrder::ValueAndGradient,
"policy must clamp ValueGradientHessian to ValueAndGradient in the high-work regime",
);
let mut large_rigid_family = large_flex_family.clone();
large_rigid_family.score_warp = None;
let large_rigid_specs = vec![dummy_blockspec(1, n_large), dummy_blockspec(1, n_large)];
assert_eq!(
large_rigid_family
.exact_outer_derivative_order(&large_rigid_specs, &BlockwiseFitOptions::default()),
ExactOuterDerivativeOrder::Second
);
}
#[test]
fn exact_outer_order_stays_second_order_at_biobank_work_scale() {
use crate::custom_family::{
default_coefficient_hessian_cost, exact_outer_order_from_capability,
};
use crate::matrix::DesignMatrix;
use ndarray::Array2;
let small_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((10, 8)),
));
let small_specs: Vec<ParameterBlockSpec> = (0..2)
.map(|i| ParameterBlockSpec {
name: format!("block_{i}"),
design: small_design.clone(),
offset: ndarray::Array1::zeros(10),
penalties: (0..2)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((8, 8))))
.collect(),
nullspace_dims: vec![0; 2],
initial_log_lambdas: ndarray::Array1::zeros(2),
initial_beta: None,
})
.collect();
let small_cost = default_coefficient_hessian_cost(&small_specs);
assert_eq!(
exact_outer_order_from_capability(&small_specs, small_cost),
ExactOuterDerivativeOrder::Second
);
let big_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((5_000, 500)),
));
let big_specs: Vec<ParameterBlockSpec> = (0..2)
.map(|i| ParameterBlockSpec {
name: format!("block_{i}"),
design: big_design.clone(),
offset: ndarray::Array1::zeros(5_000),
penalties: (0..10)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((500, 500))))
.collect(),
nullspace_dims: vec![0; 10],
initial_log_lambdas: ndarray::Array1::zeros(10),
initial_beta: None,
})
.collect();
let big_cost = default_coefficient_hessian_cost(&big_specs);
assert_eq!(
exact_outer_order_from_capability(&big_specs, big_cost),
ExactOuterDerivativeOrder::Second
);
let ctn_specs = vec![ParameterBlockSpec {
name: "ctn".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Array2::zeros((
1, 1,
)))),
offset: ndarray::Array1::zeros(1),
penalties: (0..22)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((1, 1))))
.collect(),
nullspace_dims: vec![0; 22],
initial_log_lambdas: ndarray::Array1::zeros(22),
initial_beta: None,
}];
let ctn_cost: u64 = 400_000u64 * 300 * 300;
assert_eq!(
exact_outer_order_from_capability(&ctn_specs, ctn_cost),
ExactOuterDerivativeOrder::Second
);
use crate::custom_family::exact_outer_order_with_outer_hvp;
assert_eq!(
exact_outer_order_with_outer_hvp(&ctn_specs, ctn_cost, false),
ExactOuterDerivativeOrder::Second,
"exact outer order must not be cost-demoted"
);
let huge_k_specs = vec![ParameterBlockSpec {
name: "k".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Array2::zeros((
1, 1,
)))),
offset: ndarray::Array1::zeros(1),
penalties: (0..5_000)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((1, 1))))
.collect(),
nullspace_dims: vec![0; 5_000],
initial_log_lambdas: ndarray::Array1::zeros(5_000),
initial_beta: None,
}];
assert_eq!(
exact_outer_order_with_outer_hvp(&huge_k_specs, 0, true),
ExactOuterDerivativeOrder::Second,
"outer HVP support keeps the exact second-order declaration regardless of K"
);
}
#[test]
fn bernoulli_marginal_slope_coefficient_cost_uses_joint_coupled_formula() {
use crate::custom_family::default_coefficient_hessian_cost;
use crate::matrix::DesignMatrix;
let n = 1000usize;
let p_marg = 20usize;
let p_log = 8usize;
let marg_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((n, p_marg)),
));
let log_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((n, p_log)),
));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(n)),
weights: Arc::new(Array1::from_elem(n, 1.0)),
z: Arc::new(Array1::zeros(n)),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: marg_design.clone(),
logslope_design: log_design.clone(),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let specs = vec![
ParameterBlockSpec {
name: "marginal".to_string(),
design: marg_design,
offset: Array1::zeros(n),
penalties: vec![crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((
p_marg, p_marg,
)))],
nullspace_dims: vec![0],
initial_log_lambdas: Array1::zeros(1),
initial_beta: None,
},
ParameterBlockSpec {
name: "logslope".to_string(),
design: log_design,
offset: Array1::zeros(n),
penalties: vec![crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((
p_log, p_log,
)))],
nullspace_dims: vec![0],
initial_log_lambdas: Array1::zeros(1),
initial_beta: None,
},
];
let p_total = (p_marg + p_log) as u64;
let expected_joint = (n as u64) * p_total * p_total;
let expected_block_diag = (n as u64) * ((p_marg * p_marg + p_log * p_log) as u64);
assert_eq!(family.coefficient_hessian_cost(&specs), expected_joint);
assert_eq!(
default_coefficient_hessian_cost(&specs),
expected_block_diag
);
assert!(family.coefficient_hessian_cost(&specs) > default_coefficient_hessian_cost(&specs));
}
#[test]
fn rigid_fast_path_matches_loglik_finite_differences() {
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![1.0]),
weights: Arc::new(array![1.2]),
z: Arc::new(array![0.3]),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![[
1.0
]])),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![[
1.0
]])),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let states_at = |q: f64, g: f64| {
vec![
ParameterBlockState {
beta: array![q],
eta: array![q],
},
ParameterBlockState {
beta: array![g],
eta: array![g],
},
]
};
let q = 0.4;
let g = 0.7;
let block_states = states_at(q, g);
let eval = family
.evaluate(&block_states)
.expect("rigid family evaluation");
let grad_q = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
BlockWorkingSet::Diagonal { .. } => {
panic!("expected exact-newton marginal block")
}
};
let grad_g = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
BlockWorkingSet::Diagonal { .. } => {
panic!("expected exact-newton log-slope block")
}
};
let hess_qq = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { hessian, .. } => match hessian {
SymmetricMatrix::Dense(h) => h[[0, 0]],
_ => panic!("expected dense marginal Hessian"),
},
BlockWorkingSet::Diagonal { .. } => {
panic!("expected exact-newton marginal block")
}
};
let hess_gg = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { hessian, .. } => match hessian {
SymmetricMatrix::Dense(h) => h[[0, 0]],
_ => panic!("expected dense log-slope Hessian"),
},
BlockWorkingSet::Diagonal { .. } => {
panic!("expected exact-newton log-slope block")
}
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("rigid exact eval cache");
let row_ctx = family
.build_row_exact_context(0, &block_states)
.expect("rigid row context");
let (_, primary_grad, primary_hess) = family
.compute_row_primary_gradient_hessian(0, &block_states, &cache.primary, &row_ctx)
.expect("rigid exact row derivatives");
let expected_score_q =
BernoulliMarginalSlopeFamily::exact_newton_score_component_from_objective_gradient(
primary_grad[0],
);
let expected_score_g =
BernoulliMarginalSlopeFamily::exact_newton_score_component_from_objective_gradient(
primary_grad[1],
);
assert!(
(grad_q - expected_score_q).abs() < 1e-10,
"marginal gradient mismatch: fast={grad_q:.12e}, exact={expected_score_q:.12e}"
);
assert!(
(grad_g - expected_score_g).abs() < 1e-10,
"logslope gradient mismatch: fast={grad_g:.12e}, exact={expected_score_g:.12e}"
);
assert!(
(hess_qq - primary_hess[[0, 0]]).abs() < 1e-10,
"marginal Hessian mismatch: fast={hess_qq:.12e}, exact={:.12e}",
primary_hess[[0, 0]]
);
assert!(
(hess_gg - primary_hess[[1, 1]]).abs() < 1e-10,
"logslope Hessian mismatch: fast={hess_gg:.12e}, exact={:.12e}",
primary_hess[[1, 1]]
);
}
#[test]
fn w_only_gradient_hessian_finite_and_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let beta_link = Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_link.clone(), seed.len()),
];
let slices = block_slices(&family);
assert!(slices.h.is_none(), "score-warp absent → no h slice");
let primary = primary_slices(&slices);
assert!(primary.h.is_none(), "primary h absent");
assert_eq!(primary.total, 2 + link_dim);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let (_, grad, hess) = family
.compute_row_primary_gradient_hessian(row, &block_states, &primary, &row_ctx)
.unwrap_or_else(|e| {
panic!("row {row}: compute_row_primary_gradient_hessian failed: {e}")
});
assert_eq!(
grad.len(),
primary.total,
"row {row}: gradient length mismatch"
);
assert_eq!(
hess.dim(),
(primary.total, primary.total),
"row {row}: hessian shape mismatch"
);
assert!(
grad.iter().all(|v| v.is_finite()),
"row {row}: non-finite gradient entry: {grad:?}"
);
assert!(
hess.iter().all(|v| v.is_finite()),
"row {row}: non-finite hessian entry"
);
for a in 0..primary.total {
for b in 0..a {
let diff = (hess[[a, b]] - hess[[b, a]]).abs();
assert!(
diff < 1e-10,
"row {row}: hessian asymmetry at ({a},{b}): \
H[{a},{b}]={:.6e} vs H[{b},{a}]={:.6e}, diff={diff:.3e}",
hess[[a, b]],
hess[[b, a]]
);
}
}
}
}
#[test]
fn h_only_gradient_hessian_finite_and_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let beta_score = Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_score, seed.len()),
];
let slices = block_slices(&family);
assert!(slices.w.is_none(), "link-dev absent → no w slice");
let primary = primary_slices(&slices);
assert!(primary.w.is_none(), "primary w absent");
assert_eq!(primary.total, 2 + score_dim);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let (_, grad, hess) = family
.compute_row_primary_gradient_hessian(row, &block_states, &primary, &row_ctx)
.unwrap_or_else(|e| {
panic!("row {row}: compute_row_primary_gradient_hessian failed: {e}")
});
assert_eq!(
grad.len(),
primary.total,
"row {row}: gradient length mismatch"
);
assert_eq!(
hess.dim(),
(primary.total, primary.total),
"row {row}: hessian shape mismatch"
);
assert!(
grad.iter().all(|v| v.is_finite()),
"row {row}: non-finite gradient entry: {grad:?}"
);
assert!(
hess.iter().all(|v| v.is_finite()),
"row {row}: non-finite hessian entry"
);
for a in 0..primary.total {
for b in 0..a {
let diff = (hess[[a, b]] - hess[[b, a]]).abs();
assert!(
diff < 1e-10,
"row {row}: hessian asymmetry at ({a},{b}): \
H[{a},{b}]={:.6e} vs H[{b},{a}]={:.6e}, diff={diff:.3e}",
hess[[a, b]],
hess[[b, a]]
);
}
}
}
}
#[test]
fn w_only_exact_outer_directional_derivatives_are_present_and_finite() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let beta_link = Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_link, seed.len()),
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("w-only third directional derivative")
.expect("w-only third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("w-only fourth directional derivative")
.expect("w-only fourth directional derivative matrix");
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
let max_abs_third = third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
let max_abs_fourth = fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
assert!(
max_abs_third > 1e-10,
"expected nonzero w-only third directional derivative"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero w-only fourth directional derivative"
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
#[test]
fn h_only_exact_outer_directional_derivatives_are_present_and_finite() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let beta_score = Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0)));
let scalar_design = || {
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Array2::from_elem(
(seed.len(), 1),
1.0,
)))
};
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: scalar_design(),
logslope_design: scalar_design(),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(seed.len(), 0.25),
},
ParameterBlockState {
beta: array![0.15],
eta: Array1::from_elem(seed.len(), 0.15),
},
ParameterBlockState {
beta: beta_score,
eta: Array1::zeros(seed.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = -0.35;
dir_u[slices.logslope.start] = 0.28;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[slices.marginal.start] = 0.18;
dir_v[slices.logslope.start] = -0.22;
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("h-only third directional derivative")
.expect("h-only third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("h-only fourth directional derivative")
.expect("h-only fourth directional derivative matrix");
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
let max_abs_third = third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
let max_abs_fourth = fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
assert!(
max_abs_third > 1e-10,
"expected nonzero h-only third directional derivative"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero h-only fourth directional derivative"
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
#[test]
fn h_only_row_primary_higher_order_contractions_are_finite_and_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let beta_score = Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_score, seed.len()),
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = -0.35;
dir_u[cache.primary.logslope] = 0.28;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[cache.primary.q] = 0.18;
dir_v[cache.primary.logslope] = -0.22;
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let mut max_abs_third = 0.0_f64;
let mut max_abs_fourth = 0.0_f64;
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
max_abs_third = max_abs_third.max(
third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
max_abs_fourth = max_abs_fourth.max(
fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
assert!(
max_abs_third > 1e-10,
"expected nonzero h-only third contraction"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero h-only fourth contraction"
);
}
#[test]
fn w_only_row_primary_higher_order_contractions_are_finite_and_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let beta_link = Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_link, seed.len()),
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.4;
dir_u[cache.primary.logslope] = -0.3;
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[cache.primary.q] = -0.2;
dir_v[cache.primary.logslope] = 0.25;
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let mut max_abs_third = 0.0_f64;
let mut max_abs_fourth = 0.0_f64;
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
max_abs_third = max_abs_third.max(
third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
max_abs_fourth = max_abs_fourth.max(
fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
assert!(
max_abs_third > 1e-10,
"expected nonzero w-only third contraction"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero w-only fourth contraction"
);
}
#[test]
fn dual_flex_row_primary_higher_order_contractions_are_finite_and_symmetric() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.7;
dir_u[cache.primary.logslope] = -0.2;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[cache.primary.q] = -0.4;
dir_v[cache.primary.logslope] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let mut max_abs_third = 0.0_f64;
let mut max_abs_fourth = 0.0_f64;
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
max_abs_third = max_abs_third.max(
third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
max_abs_fourth = max_abs_fourth.max(
fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
assert!(
max_abs_third > 1e-10,
"expected nonzero dual-flex third contraction"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero dual-flex fourth contraction"
);
}
#[test]
fn dual_flex_row_primary_higher_order_zero_direction_returns_zero() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let zero = Array1::<f64>::zeros(cache.primary.total);
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &zero)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&zero,
&zero,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero third contraction for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero fourth contraction for zero directions"
);
}
}
#[test]
fn h_only_row_primary_higher_order_zero_direction_returns_zero() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let zero = Array1::<f64>::zeros(cache.primary.total);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &zero)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&zero,
&zero,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero h-only third contraction for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero h-only fourth contraction for zero directions"
);
}
}
#[test]
fn w_only_row_primary_higher_order_zero_direction_returns_zero() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let zero = Array1::<f64>::zeros(cache.primary.total);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &zero)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&zero,
&zero,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero w-only third contraction for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero w-only fourth contraction for zero directions"
);
}
}
#[test]
fn dual_flex_exact_outer_zero_direction_returns_zero() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let zero = Array1::<f64>::zeros(slices.total);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &zero)
.expect("dual-flex third directional derivative")
.expect("dual-flex third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &zero, &zero)
.expect("dual-flex fourth directional derivative")
.expect("dual-flex fourth directional derivative matrix");
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"expected zero dual-flex third directional derivative for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"expected zero dual-flex fourth directional derivative for zero directions"
);
}
#[test]
fn dual_flex_exact_outer_fourth_direction_swap_is_symmetric() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let forward = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("dual-flex fourth directional derivative")
.expect("dual-flex fourth directional derivative matrix");
let swapped = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_v, &dir_u)
.expect("dual-flex swapped fourth directional derivative")
.expect("dual-flex swapped fourth directional derivative matrix");
assert_eq!(forward.dim(), (total, total));
assert_eq!(swapped.dim(), (total, total));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"fourth directional derivative should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
#[test]
fn dual_flex_row_primary_fourth_direction_swap_is_symmetric() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.7;
dir_u[cache.primary.logslope] = -0.2;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[cache.primary.q] = -0.4;
dir_v[cache.primary.logslope] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let forward = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: forward fourth contraction failed: {e}"));
let swapped = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: swapped fourth contraction failed: {e}"));
assert_eq!(forward.dim(), (total, total));
assert_eq!(swapped.dim(), (total, total));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"row {row}: fourth contraction should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
}
#[test]
fn dual_flex_row_primary_higher_order_direction_sign_rules_hold() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.7;
dir_u[cache.primary.logslope] = -0.2;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[cache.primary.q] = -0.4;
dir_v[cache.primary.logslope] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let neg_dir_u = dir_u.mapv(|value| -value);
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction failed: {e}"));
let third_neg = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: negated third contraction failed: {e}"));
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: fourth contraction failed: {e}"));
let fourth_neg_u = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: negated-u fourth contraction failed: {e}"));
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"row {row}: third contraction should be odd in its direction at ({i},{j})"
);
assert!(
(fourth_neg_u[[i, j]] + fourth[[i, j]]).abs() < 1e-8,
"row {row}: fourth contraction should be linear in dir_u sign at ({i},{j})"
);
}
}
}
}
#[test]
fn h_only_row_primary_fourth_direction_swap_is_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = -0.35;
dir_u[cache.primary.logslope] = 0.28;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[cache.primary.q] = 0.18;
dir_v[cache.primary.logslope] = -0.22;
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let forward = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: forward fourth contraction failed: {e}"));
let swapped = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: swapped fourth contraction failed: {e}"));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"row {row}: h-only fourth contraction should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
}
#[test]
fn w_only_row_primary_fourth_direction_swap_is_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.4;
dir_u[cache.primary.logslope] = -0.3;
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[cache.primary.q] = -0.2;
dir_v[cache.primary.logslope] = 0.25;
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let forward = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: forward fourth contraction failed: {e}"));
let swapped = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: swapped fourth contraction failed: {e}"));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"row {row}: w-only fourth contraction should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
}
#[test]
fn h_only_row_primary_higher_order_direction_sign_rules_hold() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir = Array1::<f64>::zeros(total);
dir[cache.primary.q] = -0.35;
dir[cache.primary.logslope] = 0.28;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir[h_range.start] = 0.12;
if h_range.len() > 1 {
dir[h_range.start + 1] = -0.06;
}
let neg_dir = dir.mapv(|value| -value);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &dir)
.unwrap_or_else(|e| panic!("row {row}: third contraction failed: {e}"));
let third_neg = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir,
)
.unwrap_or_else(|e| panic!("row {row}: negated third contraction failed: {e}"));
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir,
&dir,
)
.unwrap_or_else(|e| panic!("row {row}: fourth contraction failed: {e}"));
let fourth_neg = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir,
&neg_dir,
)
.unwrap_or_else(|e| {
panic!("row {row}: doubly-negated fourth contraction failed: {e}")
});
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"row {row}: h-only third contraction should be odd at ({i},{j})"
);
assert!(
(fourth_neg[[i, j]] - fourth[[i, j]]).abs() < 1e-8,
"row {row}: h-only fourth contraction should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
}
#[test]
fn w_only_row_primary_higher_order_direction_sign_rules_hold() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir = Array1::<f64>::zeros(total);
dir[cache.primary.q] = 0.4;
dir[cache.primary.logslope] = -0.3;
let w_range = cache.primary.w.as_ref().expect("w slice");
dir[w_range.start] = 0.15;
if w_range.len() > 1 {
dir[w_range.start + 1] = -0.07;
}
let neg_dir = dir.mapv(|value| -value);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &dir)
.unwrap_or_else(|e| panic!("row {row}: third contraction failed: {e}"));
let third_neg = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir,
)
.unwrap_or_else(|e| panic!("row {row}: negated third contraction failed: {e}"));
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir,
&dir,
)
.unwrap_or_else(|e| panic!("row {row}: fourth contraction failed: {e}"));
let fourth_neg = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir,
&neg_dir,
)
.unwrap_or_else(|e| {
panic!("row {row}: doubly-negated fourth contraction failed: {e}")
});
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"row {row}: w-only third contraction should be odd at ({i},{j})"
);
assert!(
(fourth_neg[[i, j]] - fourth[[i, j]]).abs() < 1e-8,
"row {row}: w-only fourth contraction should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
}
#[test]
fn dual_flex_exact_outer_direction_sign_rules_hold() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let neg_dir_u = dir_u.mapv(|value| -value);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("dual-flex third directional derivative")
.expect("dual-flex third directional derivative matrix");
let third_neg = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &neg_dir_u)
.expect("dual-flex negated third directional derivative")
.expect("dual-flex negated third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("dual-flex fourth directional derivative")
.expect("dual-flex fourth directional derivative matrix");
let fourth_neg_u = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&neg_dir_u,
&dir_v,
)
.expect("dual-flex negated-u fourth directional derivative")
.expect("dual-flex negated-u fourth directional derivative matrix");
assert_eq!(third.dim(), (total, total));
assert_eq!(third_neg.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert_eq!(fourth_neg_u.dim(), (total, total));
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"third directional derivative should be odd in its direction at ({i},{j})"
);
assert!(
(fourth_neg_u[[i, j]] + fourth[[i, j]]).abs() < 1e-8,
"fourth directional derivative should be linear in dir_u sign at ({i},{j})"
);
}
}
}
#[test]
fn dual_flex_exact_outer_fourth_double_sign_flip_is_invariant() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let neg_dir_u = dir_u.mapv(|value| -value);
let neg_dir_v = dir_v.mapv(|value| -value);
let forward = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("dual-flex fourth directional derivative")
.expect("dual-flex fourth directional derivative matrix");
let flipped = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&neg_dir_u,
&neg_dir_v,
)
.expect("dual-flex doubly-negated fourth directional derivative")
.expect("dual-flex doubly-negated fourth directional derivative matrix");
assert_eq!(forward.dim(), (total, total));
assert_eq!(flipped.dim(), (total, total));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - flipped[[i, j]]).abs() < 1e-8,
"fourth directional derivative should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
#[test]
fn dual_flex_exact_outer_third_direction_is_linear() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let dir_sum = &dir_u + &dir_v;
let third_u = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("dual-flex third directional derivative u")
.expect("dual-flex third directional derivative u matrix");
let third_v = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_v)
.expect("dual-flex third directional derivative v")
.expect("dual-flex third directional derivative v matrix");
let third_sum = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_sum)
.expect("dual-flex third directional derivative sum")
.expect("dual-flex third directional derivative sum matrix");
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"third directional derivative should be linear in its direction at ({i},{j})"
);
}
}
}
#[test]
fn dual_flex_row_primary_third_direction_is_linear() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.7;
dir_u[cache.primary.logslope] = -0.2;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[cache.primary.q] = -0.4;
dir_v[cache.primary.logslope] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let dir_sum = &dir_u + &dir_v;
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third_u = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction u failed: {e}"));
let third_v = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction v failed: {e}"));
let third_sum = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_sum,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction sum failed: {e}"));
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"row {row}: third contraction should be linear in its direction at ({i},{j})"
);
}
}
}
}
#[test]
fn h_only_row_primary_third_direction_is_linear() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = -0.35;
dir_u[cache.primary.logslope] = 0.28;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[cache.primary.q] = 0.18;
dir_v[cache.primary.logslope] = -0.22;
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let dir_sum = &dir_u + &dir_v;
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third_u = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction u failed: {e}"));
let third_v = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction v failed: {e}"));
let third_sum = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_sum,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction sum failed: {e}"));
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"row {row}: h-only third contraction should be linear at ({i},{j})"
);
}
}
}
}
#[test]
fn w_only_row_primary_third_direction_is_linear() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.4;
dir_u[cache.primary.logslope] = -0.3;
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[cache.primary.q] = -0.2;
dir_v[cache.primary.logslope] = 0.25;
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let dir_sum = &dir_u + &dir_v;
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third_u = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction u failed: {e}"));
let third_v = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction v failed: {e}"));
let third_sum = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_sum,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction sum failed: {e}"));
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"row {row}: w-only third contraction should be linear at ({i},{j})"
);
}
}
}
}
#[test]
fn dual_flex_exact_outer_fourth_first_direction_is_linear() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let mut dir_w = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
dir_w[slices.marginal.start] = 0.11;
dir_w[slices.logslope.start] = -0.09;
dir_w[h_range.start] = 0.04;
dir_w[w_range.start] = -0.05;
let dir_sum = &dir_u + &dir_v;
let fourth_u = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_w)
.expect("dual-flex fourth directional derivative u,w")
.expect("dual-flex fourth directional derivative u,w matrix");
let fourth_v = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_v, &dir_w)
.expect("dual-flex fourth directional derivative v,w")
.expect("dual-flex fourth directional derivative v,w matrix");
let fourth_sum = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&dir_sum,
&dir_w,
)
.expect("dual-flex fourth directional derivative (u+v),w")
.expect("dual-flex fourth directional derivative (u+v),w matrix");
for i in 0..total {
for j in 0..total {
let expected = fourth_u[[i, j]] + fourth_v[[i, j]];
assert!(
(fourth_sum[[i, j]] - expected).abs() < 1e-8,
"fourth directional derivative should be linear in its first direction at ({i},{j})"
);
}
}
}
#[test]
fn h_only_exact_outer_third_direction_is_linear() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let dir_sum = &dir_u + &dir_v;
let third_u = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("h-only third directional derivative u")
.expect("h-only third directional derivative u matrix");
let third_v = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_v)
.expect("h-only third directional derivative v")
.expect("h-only third directional derivative v matrix");
let third_sum = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_sum)
.expect("h-only third directional derivative sum")
.expect("h-only third directional derivative sum matrix");
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"h-only third directional derivative should be linear at ({i},{j})"
);
}
}
}
#[test]
fn w_only_exact_outer_third_direction_is_linear() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let dir_sum = &dir_u + &dir_v;
let third_u = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("w-only third directional derivative u")
.expect("w-only third directional derivative u matrix");
let third_v = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_v)
.expect("w-only third directional derivative v")
.expect("w-only third directional derivative v matrix");
let third_sum = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_sum)
.expect("w-only third directional derivative sum")
.expect("w-only third directional derivative sum matrix");
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"w-only third directional derivative should be linear at ({i},{j})"
);
}
}
}
#[test]
fn h_only_exact_outer_direction_sign_rules_hold() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir = Array1::<f64>::zeros(total);
let h_range = slices.h.as_ref().expect("h slice");
dir[h_range.start] = 0.12;
if h_range.len() > 1 {
dir[h_range.start + 1] = -0.06;
}
let neg_dir = dir.mapv(|value| -value);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir)
.expect("h-only third directional derivative")
.expect("h-only third directional derivative matrix");
let third_neg = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &neg_dir)
.expect("h-only negated third directional derivative")
.expect("h-only negated third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir, &dir)
.expect("h-only fourth directional derivative")
.expect("h-only fourth directional derivative matrix");
let fourth_neg = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&neg_dir,
&neg_dir,
)
.expect("h-only doubly-negated fourth directional derivative")
.expect("h-only doubly-negated fourth directional derivative matrix");
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"h-only third directional derivative should be odd at ({i},{j})"
);
assert!(
(fourth_neg[[i, j]] - fourth[[i, j]]).abs() < 1e-8,
"h-only fourth directional derivative should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
#[test]
fn w_only_exact_outer_direction_sign_rules_hold() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir = Array1::<f64>::zeros(total);
let w_range = slices.w.as_ref().expect("w slice");
dir[w_range.start] = 0.15;
if w_range.len() > 1 {
dir[w_range.start + 1] = -0.07;
}
let neg_dir = dir.mapv(|value| -value);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir)
.expect("w-only third directional derivative")
.expect("w-only third directional derivative matrix");
let third_neg = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &neg_dir)
.expect("w-only negated third directional derivative")
.expect("w-only negated third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir, &dir)
.expect("w-only fourth directional derivative")
.expect("w-only fourth directional derivative matrix");
let fourth_neg = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&neg_dir,
&neg_dir,
)
.expect("w-only doubly-negated fourth directional derivative")
.expect("w-only doubly-negated fourth directional derivative matrix");
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"w-only third directional derivative should be odd at ({i},{j})"
);
assert!(
(fourth_neg[[i, j]] - fourth[[i, j]]).abs() < 1e-8,
"w-only fourth directional derivative should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
#[test]
fn h_only_exact_outer_fourth_direction_swap_is_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let forward = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("h-only fourth directional derivative")
.expect("h-only fourth directional derivative matrix");
let swapped = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_v, &dir_u)
.expect("h-only swapped fourth directional derivative")
.expect("h-only swapped fourth directional derivative matrix");
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"h-only fourth directional derivative should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
#[test]
fn w_only_exact_outer_fourth_direction_swap_is_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let forward = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("w-only fourth directional derivative")
.expect("w-only fourth directional derivative matrix");
let swapped = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_v, &dir_u)
.expect("w-only swapped fourth directional derivative")
.expect("w-only swapped fourth directional derivative matrix");
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"w-only fourth directional derivative should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
#[test]
fn h_only_exact_outer_zero_direction_returns_zero() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let slices = block_slices(&family);
let zero = Array1::<f64>::zeros(slices.total);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &zero)
.expect("h-only third directional derivative")
.expect("h-only third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &zero, &zero)
.expect("h-only fourth directional derivative")
.expect("h-only fourth directional derivative matrix");
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"expected zero h-only third directional derivative for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"expected zero h-only fourth directional derivative for zero directions"
);
}
#[test]
fn w_only_exact_outer_zero_direction_returns_zero() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let slices = block_slices(&family);
let zero = Array1::<f64>::zeros(slices.total);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &zero)
.expect("w-only third directional derivative")
.expect("w-only third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &zero, &zero)
.expect("w-only fourth directional derivative")
.expect("w-only fourth directional derivative matrix");
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"expected zero w-only third directional derivative for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"expected zero w-only fourth directional derivative for zero directions"
);
}
#[test]
fn w_only_gradient_matches_loglik_finite_differences() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: None,
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let states_at = |q: f64, b: f64, bw: Array1<f64>| {
vec![
ParameterBlockState {
beta: array![q],
eta: Array1::from_elem(z.len(), q),
},
ParameterBlockState {
beta: array![b],
eta: Array1::from_elem(z.len(), b),
},
ParameterBlockState {
beta: bw,
eta: Array1::zeros(z.len()),
},
]
};
let q0 = 0.25;
let b0 = 0.6;
let block_states = states_at(q0, b0, beta_w.clone());
let eval = family.evaluate(&block_states).expect("family evaluation");
let grad_q = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton q block"),
};
let grad_b = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton b block"),
};
let grad_w0 = match &eval.blockworking_sets[2] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton w block"),
};
let fd = |which: &str, eps: f64| match which {
"q" => {
let plus = family
.log_likelihood_only(&states_at(q0 + eps, b0, beta_w.clone()))
.expect("ll plus q");
let minus = family
.log_likelihood_only(&states_at(q0 - eps, b0, beta_w.clone()))
.expect("ll minus q");
(plus - minus) / (2.0 * eps)
}
"b" => {
let plus = family
.log_likelihood_only(&states_at(q0, b0 + eps, beta_w.clone()))
.expect("ll plus b");
let minus = family
.log_likelihood_only(&states_at(q0, b0 - eps, beta_w.clone()))
.expect("ll minus b");
(plus - minus) / (2.0 * eps)
}
"w0" => {
let mut plus_w = beta_w.clone();
plus_w[0] += eps;
let mut minus_w = beta_w.clone();
minus_w[0] -= eps;
let plus = family
.log_likelihood_only(&states_at(q0, b0, plus_w))
.expect("ll plus w");
let minus = family
.log_likelihood_only(&states_at(q0, b0, minus_w))
.expect("ll minus w");
(plus - minus) / (2.0 * eps)
}
_ => panic!("unknown derivative target"),
};
let eps = 1e-5;
assert!((grad_q - fd("q", eps)).abs() < 2e-4);
assert!((grad_b - fd("b", eps)).abs() < 2e-4);
assert!((grad_w0 - fd("w0", eps)).abs() < 2e-4);
}
#[test]
fn h_only_gradient_matches_loglik_finite_differences() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
);
let states_at = |q: f64, b: f64, bh: Array1<f64>| {
vec![
ParameterBlockState {
beta: array![q],
eta: Array1::from_elem(z.len(), q),
},
ParameterBlockState {
beta: array![b],
eta: Array1::from_elem(z.len(), b),
},
ParameterBlockState {
beta: bh,
eta: Array1::zeros(z.len()),
},
]
};
let q0 = 0.25;
let b0 = 0.6;
let block_states = states_at(q0, b0, beta_h.clone());
let eval = family.evaluate(&block_states).expect("family evaluation");
let grad_q = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton q block"),
};
let grad_b = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton b block"),
};
let grad_h0 = match &eval.blockworking_sets[2] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton h block"),
};
let fd = |which: &str, eps: f64| match which {
"q" => {
let plus = family
.log_likelihood_only(&states_at(q0 + eps, b0, beta_h.clone()))
.expect("ll plus q");
let minus = family
.log_likelihood_only(&states_at(q0 - eps, b0, beta_h.clone()))
.expect("ll minus q");
(plus - minus) / (2.0 * eps)
}
"b" => {
let plus = family
.log_likelihood_only(&states_at(q0, b0 + eps, beta_h.clone()))
.expect("ll plus b");
let minus = family
.log_likelihood_only(&states_at(q0, b0 - eps, beta_h.clone()))
.expect("ll minus b");
(plus - minus) / (2.0 * eps)
}
"h0" => {
let mut plus_h = beta_h.clone();
plus_h[0] += eps;
let mut minus_h = beta_h.clone();
minus_h[0] -= eps;
let plus = family
.log_likelihood_only(&states_at(q0, b0, plus_h))
.expect("ll plus h");
let minus = family
.log_likelihood_only(&states_at(q0, b0, minus_h))
.expect("ll minus h");
(plus - minus) / (2.0 * eps)
}
_ => panic!("unknown derivative target"),
};
let eps = 1e-5;
assert!((grad_q - fd("q", eps)).abs() < 2e-4);
assert!((grad_b - fd("b", eps)).abs() < 2e-4);
assert!((grad_h0 - fd("h0", eps)).abs() < 2e-4);
}
#[test]
fn flexible_denested_gradient_matches_loglik_finite_differences() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let states_at = |q: f64, b: f64, bh: Array1<f64>, bw: Array1<f64>| {
vec![
ParameterBlockState {
beta: array![q],
eta: Array1::from_elem(z.len(), q),
},
ParameterBlockState {
beta: array![b],
eta: Array1::from_elem(z.len(), b),
},
ParameterBlockState {
beta: bh,
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: bw,
eta: Array1::zeros(z.len()),
},
]
};
let q0 = 0.25;
let b0 = 0.6;
let block_states = states_at(q0, b0, beta_h.clone(), beta_w.clone());
let eval = family.evaluate(&block_states).expect("family evaluation");
let grad_q = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton q block"),
};
let grad_b = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton b block"),
};
let grad_h0 = match &eval.blockworking_sets[2] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton h block"),
};
let grad_w0 = match &eval.blockworking_sets[3] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton w block"),
};
let fd = |which: &str, eps: f64| match which {
"q" => {
let plus = family
.log_likelihood_only(&states_at(q0 + eps, b0, beta_h.clone(), beta_w.clone()))
.expect("ll plus q");
let minus = family
.log_likelihood_only(&states_at(q0 - eps, b0, beta_h.clone(), beta_w.clone()))
.expect("ll minus q");
(plus - minus) / (2.0 * eps)
}
"b" => {
let plus = family
.log_likelihood_only(&states_at(q0, b0 + eps, beta_h.clone(), beta_w.clone()))
.expect("ll plus b");
let minus = family
.log_likelihood_only(&states_at(q0, b0 - eps, beta_h.clone(), beta_w.clone()))
.expect("ll minus b");
(plus - minus) / (2.0 * eps)
}
"h0" => {
let mut plus_h = beta_h.clone();
plus_h[0] += eps;
let mut minus_h = beta_h.clone();
minus_h[0] -= eps;
let plus = family
.log_likelihood_only(&states_at(q0, b0, plus_h, beta_w.clone()))
.expect("ll plus h");
let minus = family
.log_likelihood_only(&states_at(q0, b0, minus_h, beta_w.clone()))
.expect("ll minus h");
(plus - minus) / (2.0 * eps)
}
"w0" => {
let mut plus_w = beta_w.clone();
plus_w[0] += eps;
let mut minus_w = beta_w.clone();
minus_w[0] -= eps;
let plus = family
.log_likelihood_only(&states_at(q0, b0, beta_h.clone(), plus_w))
.expect("ll plus w");
let minus = family
.log_likelihood_only(&states_at(q0, b0, beta_h.clone(), minus_w))
.expect("ll minus w");
(plus - minus) / (2.0 * eps)
}
_ => panic!("unknown derivative target"),
};
let eps = 1e-5;
assert!((grad_q - fd("q", eps)).abs() < 2e-4);
assert!((grad_b - fd("b", eps)).abs() < 2e-4);
assert!((grad_h0 - fd("h0", eps)).abs() < 2e-4);
assert!((grad_w0 - fd("w0", eps)).abs() < 2e-4);
}
#[test]
fn flexible_exact_outer_directional_derivatives_are_present_and_finite() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: beta_h.clone(),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: beta_w.clone(),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
if let Some(h_range) = slices.h.as_ref() {
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
}
if let Some(w_range) = slices.w.as_ref() {
dir_u[w_range.start] = 0.08;
}
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
if let Some(h_range) = slices.h.as_ref() {
dir_v[h_range.start] = -0.03;
}
if let Some(w_range) = slices.w.as_ref() {
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
}
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("flex third directional derivative")
.expect("flex third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("flex fourth directional derivative")
.expect("flex fourth directional derivative matrix");
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
let max_abs_third = third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
let max_abs_fourth = fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
assert!(
max_abs_third > 1e-10,
"expected nonzero dual-flex third directional derivative"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero dual-flex fourth directional derivative"
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
#[test]
fn flexible_evaluate_block_diagonals_match_joint_exact_oracle() {
let z = array![-1.1, -0.25, 0.35, 1.2];
let y = array![0.0, 1.0, 0.0, 1.0];
let weights = array![1.0, 0.8, 1.3, 0.7];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-1.8, -0.7, 0.1, 0.9, 1.7];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let marginal_x = array![[1.0, -0.4], [1.0, 0.2], [1.0, 0.7], [1.0, 1.1]];
let logslope_x = array![[1.0, 0.3], [1.0, -0.6], [1.0, 0.5], [1.0, -1.0]];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
marginal_x.clone(),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
logslope_x.clone(),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let beta_m = array![0.18, -0.07];
let beta_g = array![0.42, 0.05];
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.006 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| -0.004 * (idx as f64 + 1.0)),
);
let block_states = vec![
ParameterBlockState {
beta: beta_m.clone(),
eta: marginal_x.dot(&beta_m),
},
ParameterBlockState {
beta: beta_g.clone(),
eta: logslope_x.dot(&beta_g),
},
ParameterBlockState {
beta: beta_h,
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: beta_w,
eta: Array1::zeros(z.len()),
},
];
let eval = family.evaluate(&block_states).expect("block evaluation");
let joint_hessian = family
.exact_newton_joint_hessian(&block_states)
.expect("joint hessian result")
.expect("dense joint hessian");
let joint_gradient = family
.exact_newton_joint_gradient_evaluation(&block_states, &[])
.expect("joint gradient result")
.expect("joint gradient");
let slices = block_slices(&family);
let ranges = vec![
slices.marginal.clone(),
slices.logslope.clone(),
slices.h.clone().expect("score-warp block"),
slices.w.clone().expect("link-deviation block"),
];
assert!((eval.log_likelihood - joint_gradient.log_likelihood).abs() < 2e-12);
assert!(
(eval.log_likelihood
- family
.log_likelihood_only(&block_states)
.expect("log likelihood only"))
.abs()
< 2e-12
);
for (block_idx, range) in ranges.iter().enumerate() {
let (gradient, hessian) = match &eval.blockworking_sets[block_idx] {
BlockWorkingSet::ExactNewton { gradient, hessian } => (gradient, hessian),
_ => panic!("expected exact-newton block {block_idx}"),
};
let expected_gradient = joint_gradient.gradient.slice(s![range.clone()]);
for idx in 0..gradient.len() {
assert!(
(gradient[idx] - expected_gradient[idx]).abs() < 2e-10,
"gradient mismatch block {block_idx} idx {idx}: got {} expected {}",
gradient[idx],
expected_gradient[idx]
);
}
let dense_hessian = match hessian {
SymmetricMatrix::Dense(h) => h,
_ => panic!("expected dense hessian block {block_idx}"),
};
let expected_hessian = joint_hessian.slice(s![range.clone(), range.clone()]);
for i in 0..dense_hessian.nrows() {
for j in 0..dense_hessian.ncols() {
assert!(
(dense_hessian[[i, j]] - expected_hessian[[i, j]]).abs() < 2e-9,
"hessian mismatch block {block_idx} ({i},{j}): got {} expected {}",
dense_hessian[[i, j]],
expected_hessian[[i, j]]
);
}
}
}
}
#[test]
fn latent_z_normalization_accepts_finite_sample_gaussian_scores() {
let z = array![
-0.85, -0.12, 0.31, 1.04, -1.21, 0.56, 0.77, -0.44, 1.33, -0.09, 0.28, -0.67
];
let weights = Array1::from_elem(12, 1.0);
let (standardized, normalization) = standardize_latent_z_with_policy(
&z,
&weights,
"bernoulli-marginal-slope",
&LatentZPolicy::exploratory_fit_weighted(),
)
.expect("normalize z");
let replayed = normalization
.apply(&z, "bernoulli-marginal-slope replay")
.expect("replay normalized z");
let mean = standardized.sum() / standardized.len() as f64;
let var = standardized.iter().map(|v| v * v).sum::<f64>() / standardized.len() as f64;
assert_eq!(replayed, standardized);
assert!(mean.abs() < 1e-12);
assert!((var.sqrt() - 1.0).abs() < 1e-12);
}
#[test]
fn latent_z_normalization_rejects_extreme_non_gaussian_scores() {
let z = array![0.0, 0.0, 0.0, 0.0, 10.0, -10.0];
let weights = Array1::from_elem(6, 1.0);
let strict_policy = LatentZPolicy {
check_mode: LatentZCheckMode::Strict,
..LatentZPolicy::default()
};
let err = standardize_latent_z_with_policy(
&z,
&weights,
"bernoulli-marginal-slope",
&strict_policy,
)
.expect_err("expected non-gaussian rejection");
assert!(err.contains("approximately latent N(0,1)"));
}
#[test]
fn auto_latent_measure_uses_rank_int_calibration_for_bad_normal_diagnostics() {
let z = array![0.0, 0.0, 0.0, 0.0, 10.0, -10.0];
let weights = Array1::from_elem(6, 1.0);
let policy = LatentZPolicy {
check_mode: LatentZCheckMode::Off,
normalization: LatentZNormalizationMode::None,
latent_measure: LatentMeasureSpec::Auto { grid_size: 5 },
..LatentZPolicy::default()
};
let (measure, calibration) =
build_latent_measure_with_geometry(&z, &weights, &policy, None, &[])
.expect("auto latent measure");
assert!(
matches!(measure, LatentMeasureKind::StandardNormal),
"bad-normal latent z must route through rank-INT to the standard-normal kernel"
);
match calibration {
LatentMeasureCalibration::RankInverseNormal(cal) => {
assert!(
!cal.sorted_z.is_empty(),
"rank-INT knot table must be non-empty"
);
assert_eq!(cal.sorted_z.len(), cal.weighted_cdf.len());
for w in cal.sorted_z.windows(2) {
assert!(w[0] < w[1], "sorted_z must be strictly increasing");
}
for w in cal.weighted_cdf.windows(2) {
assert!(
w[0] <= w[1],
"weighted_cdf must be non-decreasing (got {} -> {})",
w[0],
w[1]
);
}
assert!(
cal.post_mean.abs() < 0.5,
"rank-INT post-mean too far from 0: {}",
cal.post_mean
);
assert!(
cal.post_sd > 0.0 && cal.post_sd.is_finite(),
"rank-INT post-sd must be positive finite, got {}",
cal.post_sd
);
}
LatentMeasureCalibration::None => {
panic!("bad-normal latent z must produce a RankInverseNormal calibration")
}
}
}
#[test]
fn empirical_intercept_calibrates_marginal_probability() {
let nodes = vec![-2.0, -0.25, 0.5, 3.0];
let weights = vec![0.2, 0.3, 0.1, 0.4];
let target_q = -0.35;
let target_mu = normal_cdf(target_q);
let slope = 0.8;
let scale = 0.9;
let intercept = empirical_intercept_from_marginal(
target_mu, target_q, slope, scale, &nodes, &weights, None,
)
.expect("empirical intercept");
let calibrated = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(intercept + scale * slope * node))
.sum::<f64>();
assert!((calibrated - target_mu).abs() <= 1e-10);
}
#[test]
fn skewed_rigid_empirical_grid_calibrates_marginal_probability() {
let z = array![-1.15, -1.05, -0.95, -0.8, -0.65, -0.45, 0.1, 0.9, 2.4, 4.7];
let weights = array![1.0, 1.0, 1.0, 0.9, 0.9, 0.8, 0.5, 0.35, 0.2, 0.1];
let grid = build_empirical_z_grid(&z, &weights, 7, "test skewed grid").expect("grid");
let target_q = 0.25;
let target_mu = normal_cdf(target_q);
let slope = 1.35;
let scale = 0.82;
let intercept = empirical_intercept_from_marginal(
target_mu,
target_q,
slope,
scale,
&grid.nodes,
&grid.weights,
None,
)
.expect("empirical intercept");
let calibrated = grid
.nodes
.iter()
.zip(grid.weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(intercept + scale * slope * node))
.sum::<f64>();
assert!((calibrated - target_mu).abs() <= 1e-10);
}
#[test]
fn gaussian_rigid_intercept_miscalibrates_skewed_empirical_law() {
let nodes = vec![-0.95, -0.7, -0.45, -0.2, 0.4, 1.3, 3.1];
let weights = vec![0.28, 0.22, 0.17, 0.13, 0.1, 0.07, 0.03];
let target_q = -0.15;
let target_mu = normal_cdf(target_q);
let slope = 1.4;
let scale = 0.9;
let gaussian_intercept = rigid_intercept_from_marginal(target_q, slope, scale);
let gaussian_mu = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(gaussian_intercept + scale * slope * node))
.sum::<f64>();
assert!(
(gaussian_mu - target_mu).abs() > 1e-3,
"skewed empirical law should not be calibrated by Gaussian identity"
);
let empirical_intercept = empirical_intercept_from_marginal(
target_mu, target_q, slope, scale, &nodes, &weights, None,
)
.expect("empirical intercept");
let empirical_mu = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(empirical_intercept + scale * slope * node))
.sum::<f64>();
assert!((empirical_mu - target_mu).abs() <= 1e-10);
}
#[test]
fn empirical_intercept_recovers_from_deep_tail_warm_start() {
let nodes = vec![-2.0, -0.25, 0.5, 3.0];
let weights = vec![0.2, 0.3, 0.1, 0.4];
let target_q = -0.35;
let target_mu = normal_cdf(target_q);
let slope = 0.8;
let scale = 0.9;
let stale_warm_start = Some(-100.0_f64);
let intercept = empirical_intercept_from_marginal(
target_mu,
target_q,
slope,
scale,
&nodes,
&weights,
stale_warm_start,
)
.expect("empirical intercept must recover from deep-tail warm start");
let calibrated = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(intercept + scale * slope * node))
.sum::<f64>();
assert!(
(calibrated - target_mu).abs() <= 1e-10,
"calibrated mu={calibrated} should match target_mu={target_mu} from any seed"
);
}
#[test]
fn empirical_intercept_recovers_from_far_right_warm_start() {
let nodes = vec![-2.0, -0.25, 0.5, 3.0];
let weights = vec![0.2, 0.3, 0.1, 0.4];
let target_q = -0.35;
let target_mu = normal_cdf(target_q);
let slope = 0.8;
let scale = 0.9;
let stale_warm_start = Some(50.0_f64);
let intercept = empirical_intercept_from_marginal(
target_mu,
target_q,
slope,
scale,
&nodes,
&weights,
stale_warm_start,
)
.expect("empirical intercept must recover from far-right warm start");
let calibrated = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(intercept + scale * slope * node))
.sum::<f64>();
assert!(
(calibrated - target_mu).abs() <= 1e-10,
"calibrated mu={calibrated} should match target_mu={target_mu} from any seed"
);
}
#[test]
fn auto_latent_measure_preserves_standard_normal_fast_path() {
let n = 2001usize;
let z = Array1::from_iter((0..n).map(|idx| {
let p = (idx as f64 + 0.5) / n as f64;
standard_normal_quantile(p).expect("normal quantile")
}));
let weights = Array1::ones(n);
let policy = LatentZPolicy {
latent_measure: LatentMeasureSpec::Auto { grid_size: 17 },
..LatentZPolicy::default()
};
let (measure, calibration) =
build_latent_measure_with_geometry(&z, &weights, &policy, None, &[]).expect("measure");
assert!(matches!(measure, LatentMeasureKind::StandardNormal));
assert!(
matches!(calibration, LatentMeasureCalibration::None),
"well-conditioned standard-normal z must skip rank-INT calibration"
);
let slope = 0.7;
let scale = 0.85;
let target_q = 0.4;
assert_eq!(
rigid_intercept_from_marginal(target_q, slope, scale),
target_q * (1.0 + (scale * slope).powi(2)).sqrt()
);
}
#[test]
fn flexible_family_exposes_exact_newton_workspaces() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let specs = vec![
dummy_blockspec(1, z.len()),
dummy_blockspec(1, z.len()),
dummy_blockspec(score_prepared.block.design.ncols(), z.len()),
dummy_blockspec(link_prepared.block.design.ncols(), z.len()),
];
let derivative_blocks = vec![Vec::new(), Vec::new(), Vec::new(), Vec::new()];
assert!(
family
.exact_newton_joint_hessian_workspace(&block_states, &specs)
.expect("flex hessian workspace")
.is_some()
);
assert!(
family
.exact_newton_joint_psi_workspace(&block_states, &specs, &derivative_blocks)
.expect("flex psi workspace")
.is_some()
);
}
#[test]
fn sigma_exact_joint_psi_terms_returns_analytic_terms() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let sigma = 0.7;
let make_family =
|sigma: f64| BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: Some(sigma),
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let family = make_family(sigma);
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
];
let specs = vec![dummy_blockspec(1, z.len()), dummy_blockspec(1, z.len())];
let terms = family
.sigma_exact_joint_psi_terms(&block_states, &specs)
.expect("analytic sigma psi terms")
.expect("sigma terms present");
assert!(terms.objective_psi.is_finite());
assert_eq!(terms.score_psi.len(), 2);
assert!(terms.score_psi.iter().all(|value| value.is_finite()));
assert_eq!(
terms
.hessian_psi_operator
.as_ref()
.expect("sigma Hessian operator")
.to_dense()
.dim(),
(2, 2)
);
let second = family
.sigma_exact_joint_psisecond_order_terms(&block_states)
.expect("analytic second sigma terms")
.expect("second sigma terms present");
assert!(second.objective_psi_psi.is_finite());
assert_eq!(second.score_psi_psi.len(), 2);
let drift = family
.sigma_exact_joint_psihessian_directional_derivative(&block_states, &array![0.1, -0.2])
.expect("analytic sigma Hessian directional derivative")
.expect("sigma drift present");
assert_eq!(drift.dim(), (2, 2));
assert!(drift.iter().all(|value| value.is_finite()));
let tau = sigma.ln();
let eps = 1e-5;
let ll_plus = make_family((tau + eps).exp())
.log_likelihood_only(&block_states)
.expect("ll plus sigma");
let ll_minus = make_family((tau - eps).exp())
.log_likelihood_only(&block_states)
.expect("ll minus sigma");
let objective_fd = -(ll_plus - ll_minus) / (2.0 * eps);
assert!((terms.objective_psi - objective_fd).abs() < 1e-5);
}
fn make_block_psi_test_family(n: usize) -> BernoulliMarginalSlopeFamily {
let y: Array1<f64> =
Array1::from_iter((0..n).map(|i| if (i * 31 + 7) % 5 >= 3 { 1.0 } else { 0.0 }));
let weights: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.5 + ((i * 13 + 4) % 7) as f64 * 0.1));
let z: Array1<f64> = Array1::from_iter(
(0..n).map(|i| -1.5 + 3.0 * (((i * 17 + 5) % n) as f64 + 0.5) / (n as f64)),
);
let marginal_design = Array2::from_shape_fn((n, 1), |(i, _)| {
0.3 + 0.4 * (((i * 29 + 11) % n) as f64) / (n as f64)
});
let logslope_design = Array2::from_shape_fn((n, 1), |(i, _)| {
0.2 + 0.5 * (((i * 37 + 9) % n) as f64) / (n as f64)
});
BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
marginal_design,
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
logslope_design,
)),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
}
}
fn block_psi_test_block_states(
family: &BernoulliMarginalSlopeFamily,
m_beta: f64,
g_beta: f64,
) -> Vec<ParameterBlockState> {
let m_design = family.marginal_design.to_dense().to_owned();
let g_design = family.logslope_design.to_dense().to_owned();
let m_eta = m_design.dot(&array![m_beta]);
let g_eta = g_design.dot(&array![g_beta]);
vec![
ParameterBlockState {
beta: array![m_beta],
eta: m_eta,
},
ParameterBlockState {
beta: array![g_beta],
eta: g_eta,
},
]
}
fn block_psi_test_marginal_derivative_blocks(
n: usize,
) -> Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>> {
let x_psi = Array2::from_shape_fn((n, 1), |(i, _)| {
0.4 + 0.3 * (((i * 41 + 13) % n) as f64) / (n as f64)
});
vec![
vec![crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
x_psi,
Array2::zeros((1, 1)),
None,
None,
None,
None,
)],
Vec::new(),
]
}
fn block_psi_test_dual_derivative_blocks(
n: usize,
) -> Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>> {
let x_psi_m = Array2::from_shape_fn((n, 1), |(i, _)| {
0.4 + 0.3 * (((i * 41 + 13) % n) as f64) / (n as f64)
});
let x_psi_g = Array2::from_shape_fn((n, 1), |(i, _)| {
0.2 + 0.5 * (((i * 43 + 17) % n) as f64) / (n as f64)
});
vec![
vec![crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
x_psi_m,
Array2::zeros((1, 1)),
None,
None,
None,
None,
)],
vec![crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
x_psi_g,
Array2::zeros((1, 1)),
None,
None,
None,
None,
)],
]
}
#[test]
fn bernoulli_psi_terms_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let baseline = family
.exact_newton_joint_psi_terms_from_cache(&states, &derivative_blocks, 0, &cache)
.expect("baseline psi terms")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_psi_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let obj_rel = ((with_full.objective_psi - baseline.objective_psi)
/ baseline.objective_psi.abs().max(1.0))
.abs();
assert!(obj_rel < 1e-12, "objective_psi rel {}", obj_rel);
let score_rel = rel_diff_array1(&with_full.score_psi, &baseline.score_psi);
assert!(score_rel < 1e-12, "score_psi rel {}", score_rel);
}
#[test]
fn bernoulli_psi_terms_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_psi_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::with_uniform_weight(
even_mask, m, 0, 1.0,
)));
let raw = family
.exact_newton_joint_psi_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp_obj = factor * raw.objective_psi;
let obj_rel = ((scaled.objective_psi - exp_obj) / exp_obj.abs().max(1.0)).abs();
assert!(obj_rel < 1e-12, "objective_psi rel {}", obj_rel);
let exp_score = &raw.score_psi * factor;
let score_rel = rel_diff_array1(&scaled.score_psi, &exp_score);
assert!(score_rel < 1e-12, "score_psi rel {}", score_rel);
}
#[test]
fn bernoulli_psi_second_order_terms_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_dual_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let baseline = family
.exact_newton_joint_psisecond_order_terms_from_cache(
&states,
&derivative_blocks,
0,
1,
&cache,
)
.expect("baseline psi second-order")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
1,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let obj_rel = ((with_full.objective_psi_psi - baseline.objective_psi_psi)
/ baseline.objective_psi_psi.abs().max(1.0))
.abs();
assert!(obj_rel < 1e-12, "objective rel {}", obj_rel);
let score_rel = rel_diff_array1(&with_full.score_psi_psi, &baseline.score_psi_psi);
assert!(score_rel < 1e-12, "score rel {}", score_rel);
}
#[test]
fn bernoulli_psi_second_order_terms_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_dual_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
1,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::with_uniform_weight(
even_mask, m, 0, 1.0,
)));
let raw = family
.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
1,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp_obj = factor * raw.objective_psi_psi;
let obj_rel = ((scaled.objective_psi_psi - exp_obj) / exp_obj.abs().max(1.0)).abs();
assert!(obj_rel < 1e-12, "objective rel {}", obj_rel);
let exp_score = &raw.score_psi_psi * factor;
let score_rel = rel_diff_array1(&scaled.score_psi_psi, &exp_score);
assert!(score_rel < 1e-12, "score rel {}", score_rel);
}
#[test]
fn bernoulli_psihessian_directional_derivative_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let baseline = family
.exact_newton_joint_psihessian_directional_derivative_from_cache(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
)
.expect("baseline")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let rel = rel_diff_array2(&with_full, &baseline);
assert!(rel < 1e-12, "drift rel {}", rel);
}
#[test]
fn bernoulli_psihessian_directional_derivative_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::with_uniform_weight(
even_mask, m, 0, 1.0,
)));
let raw = family
.exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp = &raw * factor;
let rel = rel_diff_array2(&scaled, &exp);
assert!(rel < 1e-12, "drift rel {}", rel);
}
#[test]
fn bernoulli_psihessian_operator_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let baseline = family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("baseline operator")
.expect("some");
let baseline_dense = baseline.to_dense();
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let with_full_dense = with_full.to_dense();
let rel = rel_diff_array2(&with_full_dense, &baseline_dense);
assert!(rel < 1e-12, "operator drift rel {}", rel);
}
#[test]
fn bernoulli_psihessian_operator_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let scaled_dense = scaled.to_dense();
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::with_uniform_weight(
even_mask, m, 0, 1.0,
)));
let raw = family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let raw_dense = raw.to_dense();
let factor = n as f64 / m as f64;
let exp = &raw_dense * factor;
let rel = rel_diff_array2(&scaled_dense, &exp);
assert!(rel < 1e-12, "operator drift rel {}", rel);
}
#[test]
fn bernoulli_jointhessian_directional_derivative_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let baseline = family
.exact_newton_joint_hessian_directional_derivative_from_cache(
&states,
&d_beta_flat,
&cache,
)
.expect("baseline")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let rel = rel_diff_array2(&with_full, &baseline);
assert!(rel < 1e-12, "joint Hessian dH drift rel {}", rel);
}
#[test]
fn bernoulli_jointhessian_batched_directional_operators_match_single_direction_path() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let directions: Vec<Array1<f64>> = (0..3)
.map(|rep| {
let mut d = Array1::<f64>::zeros(slices.total);
d[slices.marginal.start] = 0.03 * (rep as f64 + 1.0);
d[slices.logslope.start] = -0.02 * (rep as f64 + 1.0);
d
})
.collect();
let mut opts = BlockwiseFitOptions::default();
opts.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).step_by(2).collect(),
n,
0xB47C,
)));
let batched = family
.exact_newton_joint_hessian_directional_derivative_operators_from_cache_with_options(
&states,
&directions,
&cache,
&opts,
)
.expect("batched operators");
assert_eq!(batched.len(), directions.len());
for (idx, direction) in directions.iter().enumerate() {
let single = family
.exact_newton_joint_hessian_directional_derivative_operator_from_cache_with_options(
&states, direction, &cache, &opts,
)
.expect("single operator")
.expect("single operator some")
.to_dense();
let batched_dense = batched[idx]
.as_ref()
.expect("batched operator some")
.to_dense();
let rel = rel_diff_array2(&batched_dense, &single);
assert!(rel < 1e-12, "batched operator {idx} drift rel {rel}");
}
}
fn make_flex_hvp_cache_test_family(
n: usize,
) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
let score_prepared = build_score_warp_deviation_block_from_seed(
&score_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build score warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let y: Array1<f64> =
Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
let weights: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
let z: Array1<f64> =
Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
1.0
} else {
-0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
}
});
let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
1.0
} else {
0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
}
});
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: Some(0.15),
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
marginal_x.clone(),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
logslope_x.clone(),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
shared_eval_cache: Arc::new(Mutex::new(None)),
};
let beta_m = array![0.12, -0.04];
let beta_g = array![0.35, 0.03];
let beta_h = Array1::from_iter(
(0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
);
let states = vec![
ParameterBlockState {
beta: beta_m.clone(),
eta: marginal_x.dot(&beta_m),
},
ParameterBlockState {
beta: beta_g.clone(),
eta: logslope_x.dot(&beta_g),
},
ParameterBlockState {
beta: beta_h,
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: beta_w,
eta: Array1::zeros(z.len()),
},
];
(family, states)
}
#[test]
fn bernoulli_flex_paired_subsample_ll_delta_sign_matches_full_ll() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let (family, old_states) = make_flex_hvp_cache_test_family(96);
let mut trial_states = old_states.clone();
trial_states[0].beta[0] += 0.015;
trial_states[0].eta += 0.015;
trial_states[1].beta[1] -= 0.01;
let logslope_col =
Array1::from_iter((0..96).map(|i| 0.3 - 0.6 * ((i * 23 + 11) % 96) as f64 / 96.0));
trial_states[1].eta.scaled_add(-0.01, &logslope_col);
let full_old = family
.log_likelihood_only(&old_states)
.expect("full old ll");
let full_trial = family
.log_likelihood_only(&trial_states)
.expect("full trial ll");
let mut opts = BlockwiseFitOptions::default();
opts.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..96).step_by(2).collect(),
96,
0x5EED5EED,
)));
let sub_old = family
.log_likelihood_only_with_options(&old_states, &opts)
.expect("subsample old ll");
let sub_trial = family
.log_likelihood_only_with_options(&trial_states, &opts)
.expect("subsample trial ll");
let full_delta = full_trial - full_old;
let sub_delta = sub_trial - sub_old;
assert!(
full_delta.abs() > 1e-8,
"synthetic beta-pair should produce a non-degenerate full-LL delta: {full_delta}"
);
assert_eq!(
full_delta.is_sign_positive(),
sub_delta.is_sign_positive(),
"paired subsample LL delta sign ({sub_delta}) should match full LL delta sign ({full_delta})"
);
}
#[test]
fn bernoulli_flex_hvp_cache_matches_uncached_path_small_case() {
let (family, states) = make_flex_hvp_cache_test_family(12);
let mut cached = family
.build_exact_eval_cache(&states)
.expect("cached exact eval cache");
cached.row_primary_hessians = family
.build_row_primary_hessian_cache(&states, &cached)
.expect("row Hessian cache");
let uncached = BernoulliMarginalSlopeExactEvalCache {
slices: cached.slices.clone(),
primary: cached.primary.clone(),
row_contexts: cached.row_contexts.clone(),
row_cell_moments: None,
row_primary_hessians: None,
rigid_third_full: crate::resource::RayonSafeOnce::new(),
fingerprint: cached.fingerprint.clone(),
rigid_fourth_full: crate::resource::RayonSafeOnce::new(),
};
let direction =
Array1::from_iter((0..cached.slices.total).map(|idx| 0.02 * ((idx % 5) as f64 - 2.0)));
let hv_cached = family
.exact_newton_joint_hessian_matvec_from_cache(&direction, &states, &cached)
.expect("cached Hv");
let hv_uncached = family
.exact_newton_joint_hessian_matvec_from_cache(&direction, &states, &uncached)
.expect("uncached Hv");
let rel = rel_diff_array1(&hv_cached, &hv_uncached);
assert!(rel < 5e-11, "cached Hv drift rel {rel}");
let diag_cached = family
.exact_newton_joint_hessian_diagonal_from_cache(&states, &cached)
.expect("cached diag");
let diag_uncached = family
.exact_newton_joint_hessian_diagonal_from_cache(&states, &uncached)
.expect("uncached diag");
let rel_diag = rel_diff_array1(&diag_cached, &diag_uncached);
assert!(rel_diag < 5e-11, "cached diag drift rel {rel_diag}");
}
#[test]
#[ignore]
fn bernoulli_flex_hvp_cache_timing_biobank_shape_pattern() {
let (family, states) = make_flex_hvp_cache_test_family(96);
let mut cached = family
.build_exact_eval_cache(&states)
.expect("cached exact eval cache");
cached.row_primary_hessians = family
.build_row_primary_hessian_cache(&states, &cached)
.expect("row Hessian cache");
let uncached = BernoulliMarginalSlopeExactEvalCache {
slices: cached.slices.clone(),
primary: cached.primary.clone(),
row_contexts: cached.row_contexts.clone(),
row_cell_moments: None,
row_primary_hessians: None,
rigid_third_full: crate::resource::RayonSafeOnce::new(),
fingerprint: cached.fingerprint.clone(),
rigid_fourth_full: crate::resource::RayonSafeOnce::new(),
};
let directions: Vec<_> = (0..4)
.map(|rep| {
Array1::from_iter(
(0..cached.slices.total)
.map(|idx| 0.01 * (((idx * 13 + rep * 7) % 11) as f64 - 5.0)),
)
})
.collect();
let start_uncached = std::time::Instant::now();
for direction in &directions {
let _ = family
.exact_newton_joint_hessian_matvec_from_cache(direction, &states, &uncached)
.expect("uncached Hv");
}
let uncached_elapsed = start_uncached.elapsed();
let start_cached = std::time::Instant::now();
for direction in &directions {
let _ = family
.exact_newton_joint_hessian_matvec_from_cache(direction, &states, &cached)
.expect("cached Hv");
}
let cached_elapsed = start_cached.elapsed();
eprintln!("flex Hv cache timing: uncached={uncached_elapsed:?} cached={cached_elapsed:?}");
assert!(
cached_elapsed < uncached_elapsed,
"expected cached Hv loop to beat uncached: cached={cached_elapsed:?} uncached={uncached_elapsed:?}"
);
}
#[test]
fn bernoulli_jointhessian_directional_operator_matches_dense_small_case() {
let n = 17usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let dense = family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("dense dH")
.expect("dH present");
let operator = family
.exact_newton_joint_hessian_directional_derivative_operator_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("operator dH")
.expect("dH operator present");
let rel = rel_diff_array2(&operator.to_dense(), &dense);
assert!(rel < 1e-12, "operator dH rel {}", rel);
}
#[test]
fn bernoulli_jointhessian_second_directional_operator_matches_dense_small_case() {
let n = 17usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_u = Array1::<f64>::zeros(slices.total);
d_beta_u[slices.marginal.start] = 0.05;
d_beta_u[slices.logslope.start] = -0.04;
let mut d_beta_v = Array1::<f64>::zeros(slices.total);
d_beta_v[slices.marginal.start] = -0.03;
d_beta_v[slices.logslope.start] = 0.02;
let dense = family
.exact_newton_joint_hessiansecond_directional_derivative_from_cache_with_options(
&states,
&d_beta_u,
&d_beta_v,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("dense d2H")
.expect("d2H present");
let operator = family
.exact_newton_joint_hessiansecond_directional_derivative_operator_from_cache_with_options(
&states,
&d_beta_u,
&d_beta_v,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("operator d2H")
.expect("d2H operator present");
let rel = rel_diff_array2(&operator.to_dense(), &dense);
assert!(rel < 1e-12, "operator d2H rel {}", rel);
}
#[test]
fn bernoulli_large_scale_outer_derivatives_keep_analytic_hessian_route() {
let n = 50_001usize;
let family = make_block_psi_test_family(n);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let options = BlockwiseFitOptions::default();
let (gradient, hessian) =
crate::custom_family::custom_family_outer_derivatives(&family, &specs, &options);
assert_eq!(
gradient,
crate::solver::outer_strategy::Derivative::Analytic
);
assert_eq!(
hessian,
crate::solver::outer_strategy::DeclaredHessianForm::Either
);
}
#[test]
fn bernoulli_jointhessian_directional_derivative_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::with_uniform_weight(
even_mask, m, 0, 1.0,
)));
let raw = family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp = &raw * factor;
let rel = rel_diff_array2(&scaled, &exp);
assert!(rel < 1e-12, "joint Hessian dH HT rel {}", rel);
}
#[test]
fn auto_outer_subsample_two_phase_converges_to_full_data_optimum() {
let n = 35_000usize;
let family = make_block_psi_test_family(n);
let states: Vec<ParameterBlockState> = Vec::new();
let specs: Vec<ParameterBlockSpec> = Vec::new();
let deriv_blocks: Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>> =
Vec::new();
let rho_dim = 3usize;
let mut opts = BlockwiseFitOptions::default();
opts.auto_outer_subsample = true;
let counter = || {
family
.auto_subsample_phase_counter
.load(std::sync::atomic::Ordering::SeqCst)
};
let mut distinct_calls = 0usize;
for step in 0..15usize {
let rho_step = Array1::<f64>::from_elem(rho_dim, step as f64 * 0.1);
family
.batched_outer_gradient_terms(
&states,
&specs,
&deriv_blocks,
&rho_step,
&opts,
None,
)
.expect("guard ok");
distinct_calls += 1;
assert_eq!(
counter(),
distinct_calls,
"distinct-ρ call {step}: counter should equal number of distinct ρ"
);
if step == 3 || step == 7 {
let rho_retry = Array1::<f64>::from_elem(rho_dim, step as f64 * 0.1);
family
.batched_outer_gradient_terms(
&states,
&specs,
&deriv_blocks,
&rho_retry,
&opts,
None,
)
.expect("guard ok on retry");
assert_eq!(
counter(),
distinct_calls,
"line-search retry at step {step} must NOT bump counter"
);
}
}
assert_eq!(counter(), 15, "final counter should be 15 distinct ρ");
let family_off = make_block_psi_test_family(n);
let opts_off = BlockwiseFitOptions::default();
for step in 0..5 {
let rho_step = Array1::<f64>::from_elem(rho_dim, step as f64 * 0.1);
family_off
.batched_outer_gradient_terms(
&states,
&specs,
&deriv_blocks,
&rho_step,
&opts_off,
None,
)
.expect("guard ok off");
}
assert_eq!(
family_off
.auto_subsample_phase_counter
.load(std::sync::atomic::Ordering::SeqCst),
0,
"with auto_outer_subsample=false the counter must stay at 0"
);
}
}