use crate::custom_family::{
BlockWorkingSet, BlockwiseFitOptions, CustomFamily, CustomFamilyWarmStart,
ExactNewtonJointGradientEvaluation, ExactNewtonJointHessianWorkspace,
ExactNewtonJointPsiSecondOrderTerms, ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace,
FamilyEvaluation, ParameterBlockSpec, ParameterBlockState, build_block_spatial_psi_derivatives,
custom_family_outer_derivatives, evaluate_custom_family_joint_hyper_efs_shared,
evaluate_custom_family_joint_hyper_shared, fit_custom_family,
};
use crate::estimate::UnifiedFitResult;
use crate::estimate::reml::unified::HyperOperator;
use crate::families::gamlss::{ParameterBlockInput, initialize_monotone_wiggle_knots_from_seed};
use crate::families::lognormal_kernel::FrailtySpec;
use crate::families::marginal_slope_shared::{
CoeffSupport, ObservedDenestedCellPartials, SparsePrimaryCoeffJetView, add_optional_matrix,
add_optional_vector, add_two_surface_psi_outer,
build_denested_partition_cells as shared_denested_partition_cells, chunked_row_reduction,
eval_coeff4_at, is_sigma_aux_index as shared_is_sigma_aux_index,
observed_denested_cell_partials as shared_observed_denested_cell_partials, outer_row_indices,
outer_score_scale, probit_frailty_scale, probit_frailty_scale_multi_dir_jet,
psi_derivative_location, scale_coeff4,
};
use crate::families::row_kernel::{
RowKernel, RowKernelHessianWorkspace, build_row_kernel_cache, row_kernel_gradient,
row_kernel_hessian_dense, row_kernel_log_likelihood,
};
use crate::matrix::{DesignMatrix, SymmetricMatrix};
use crate::pirls::LinearInequalityConstraints;
use crate::probability::{
normal_cdf, normal_pdf, signed_probit_logcdf_and_mills_ratio, standard_normal_quantile,
};
use crate::smooth::{
ExactJointHyperSetup, SmoothBasisSpec, SpatialLengthScaleOptimizationOptions,
SpatialLogKappaCoords, TermCollectionDesign, TermCollectionSpec,
build_term_collection_designs_and_freeze_joint, optimize_spatial_length_scale_exact_joint,
spatial_length_scale_term_indices,
};
use crate::types::{InverseLink, LinkFunction, WigglePenaltyConfig};
use ndarray::{Array1, Array2, ArrayView2, s};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
mod deviation_runtime;
pub(crate) mod exact_kernel;
pub use deviation_runtime::DeviationRuntime;
const BMS_HV_ROW_SKIP_TAU: f64 = 1e-14;
#[derive(Clone, Debug)]
pub struct DeviationBlockConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_order: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
pub monotonicity_eps: f64,
}
impl Default for DeviationBlockConfig {
fn default() -> Self {
WigglePenaltyConfig::cubic_triple_operator_default().into()
}
}
impl DeviationBlockConfig {
pub fn triple_penalty_default() -> Self {
Self::default()
}
}
impl From<WigglePenaltyConfig> for DeviationBlockConfig {
fn from(cfg: WigglePenaltyConfig) -> Self {
let penalty_order = *cfg.penalty_orders.iter().max().unwrap_or(&2);
Self {
degree: cfg.degree,
num_internal_knots: cfg.num_internal_knots,
penalty_order,
penalty_orders: cfg.penalty_orders,
double_penalty: cfg.double_penalty,
monotonicity_eps: cfg.monotonicity_eps,
}
}
}
#[derive(Clone)]
pub(crate) struct DeviationPrepared {
pub(crate) block: ParameterBlockInput,
pub(crate) runtime: DeviationRuntime,
}
impl std::fmt::Debug for DeviationPrepared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeviationPrepared").finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct BernoulliMarginalSlopeTermSpec {
pub y: Array1<f64>,
pub weights: Array1<f64>,
pub z: Array1<f64>,
pub base_link: InverseLink,
pub marginalspec: TermCollectionSpec,
pub logslopespec: TermCollectionSpec,
pub marginal_offset: Array1<f64>,
pub logslope_offset: Array1<f64>,
pub frailty: FrailtySpec,
pub score_warp: Option<DeviationBlockConfig>,
pub link_dev: Option<DeviationBlockConfig>,
pub latent_z_policy: LatentZPolicy,
}
impl BernoulliMarginalSlopeTermSpec {
pub fn calibrated_probit(
y: Array1<f64>,
weights: Array1<f64>,
z: Array1<f64>,
marginalspec: TermCollectionSpec,
logslopespec: TermCollectionSpec,
marginal_offset: Array1<f64>,
logslope_offset: Array1<f64>,
frailty: FrailtySpec,
protocol: crate::solver::protocol::MarginalSlopeCalibrationProtocol,
) -> Self {
Self {
y,
weights,
z,
base_link: protocol.base_link,
marginalspec,
logslopespec,
marginal_offset,
logslope_offset,
frailty,
score_warp: Some(protocol.score_warp),
link_dev: Some(protocol.link_deviation),
latent_z_policy: protocol.latent_score.into_policy(),
}
}
}
pub struct BernoulliMarginalSlopeFitResult {
pub fit: UnifiedFitResult,
pub marginalspec_resolved: TermCollectionSpec,
pub logslopespec_resolved: TermCollectionSpec,
pub marginal_design: TermCollectionDesign,
pub logslope_design: TermCollectionDesign,
pub baseline_marginal: f64,
pub baseline_logslope: f64,
pub z_normalization: LatentZNormalization,
pub latent_measure: LatentMeasureKind,
pub score_warp_runtime: Option<DeviationRuntime>,
pub link_dev_runtime: Option<DeviationRuntime>,
pub gaussian_frailty_sd: Option<f64>,
}
#[derive(Clone, Debug)]
pub enum LatentZCheckMode {
Strict,
WarnOnly,
Off,
}
#[derive(Clone, Debug)]
pub enum LatentZNormalizationMode {
None,
FitWeighted,
Frozen { mean: f64, sd: f64 },
}
pub const DEFAULT_EMPIRICAL_LATENT_GRID_SIZE: usize = 65;
const DEFAULT_LOCAL_EMPIRICAL_LATENT_TOP_K: usize = 4;
const DEFAULT_LOCAL_EMPIRICAL_LATENT_BANDWIDTH: f64 = 1.0;
const AUTO_Z_NORMAL_SKEW_TOL: f64 = 0.10;
const AUTO_Z_NORMAL_KURT_TOL: f64 = 0.25;
const AUTO_Z_NORMAL_KS_TOL: f64 = 0.025;
const AUTO_Z_NORMAL_MAX_ABS: f64 = 8.0;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LatentMeasureSpec {
Auto { grid_size: usize },
StandardNormal,
GlobalEmpirical { grid_size: usize },
}
impl LatentMeasureSpec {
pub fn global_empirical_default() -> Self {
Self::GlobalEmpirical {
grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
}
}
pub fn auto_default() -> Self {
Self::Auto {
grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
}
}
}
impl Default for LatentMeasureSpec {
fn default() -> Self {
Self::auto_default()
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EmpiricalZGrid {
pub nodes: Vec<f64>,
pub weights: Vec<f64>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "kebab-case")]
pub enum LatentMeasureKind {
StandardNormal,
GlobalEmpirical {
nodes: Vec<f64>,
weights: Vec<f64>,
},
LocalEmpirical {
feature_cols: Vec<usize>,
#[serde(default)]
input_scales: Option<Vec<f64>>,
centers: Vec<Vec<f64>>,
grids: Vec<EmpiricalZGrid>,
top_k: usize,
bandwidth: f64,
#[serde(skip)]
train_row_mixtures: Arc<Vec<Vec<(usize, f64)>>>,
},
}
impl Default for LatentMeasureKind {
fn default() -> Self {
Self::StandardNormal
}
}
impl LatentMeasureKind {
pub fn validate(&self, context: &str) -> Result<(), String> {
match self {
Self::StandardNormal => Ok(()),
Self::GlobalEmpirical { nodes, weights } => {
validate_empirical_z_grid(nodes, weights, context)
}
Self::LocalEmpirical {
feature_cols,
input_scales,
centers,
grids,
top_k,
bandwidth,
..
} => {
if feature_cols.is_empty() {
return Err(format!(
"{context} local empirical latent measure needs feature columns"
));
}
if centers.is_empty() {
return Err(format!(
"{context} local empirical latent measure needs centers"
));
}
if centers.len() != grids.len() {
return Err(format!(
"{context} local empirical latent measure center/grid length mismatch: centers={}, grids={}",
centers.len(),
grids.len()
));
}
if *top_k == 0 || *top_k > centers.len() {
return Err(format!(
"{context} local empirical latent measure top_k must be in 1..={}, got {top_k}",
centers.len()
));
}
if !(*bandwidth).is_finite() || *bandwidth <= 0.0 {
return Err(format!(
"{context} local empirical latent measure bandwidth must be finite and positive, got {bandwidth}"
));
}
if let Some(scales) = input_scales.as_ref() {
if scales.len() != feature_cols.len() {
return Err(format!(
"{context} local empirical latent measure input scale dimension mismatch: scales={}, features={}",
scales.len(),
feature_cols.len()
));
}
for (scale_idx, scale) in scales.iter().enumerate() {
if !(scale.is_finite() && *scale > 0.0) {
return Err(format!(
"{context} local empirical latent measure input scale {scale_idx} must be finite and positive, got {scale}"
));
}
}
}
for (center_idx, center) in centers.iter().enumerate() {
if center.len() != feature_cols.len() {
return Err(format!(
"{context} local empirical latent center {center_idx} dimension mismatch: got {}, expected {}",
center.len(),
feature_cols.len()
));
}
if center.iter().any(|value| !value.is_finite()) {
return Err(format!(
"{context} local empirical latent center {center_idx} has non-finite coordinates"
));
}
}
for (grid_idx, grid) in grids.iter().enumerate() {
validate_empirical_z_grid(
&grid.nodes,
&grid.weights,
&format!("{context} local empirical grid {grid_idx}"),
)?;
}
Ok(())
}
}
}
fn is_empirical(&self) -> bool {
matches!(
self,
Self::GlobalEmpirical { .. } | Self::LocalEmpirical { .. }
)
}
fn empirical_grid_for_training_row(
&self,
row: usize,
) -> Result<Option<EmpiricalZGrid>, String> {
match self {
Self::StandardNormal => Ok(None),
Self::GlobalEmpirical { nodes, weights } => Ok(Some(EmpiricalZGrid {
nodes: nodes.clone(),
weights: weights.clone(),
})),
Self::LocalEmpirical {
grids,
train_row_mixtures,
..
} => {
let mixture = train_row_mixtures.get(row).ok_or_else(|| {
format!(
"local empirical latent measure is missing training mixture for row {row}"
)
})?;
Ok(Some(combine_empirical_grids(grids, mixture)?))
}
}
}
}
fn validate_empirical_z_grid(nodes: &[f64], weights: &[f64], context: &str) -> Result<(), String> {
if nodes.len() != weights.len() {
return Err(format!(
"{context} empirical latent measure node/weight length mismatch: nodes={}, weights={}",
nodes.len(),
weights.len()
));
}
if nodes.len() < 2 {
return Err(format!(
"{context} empirical latent measure requires at least two nodes"
));
}
let mut total = 0.0;
for (idx, (&node, &weight)) in nodes.iter().zip(weights.iter()).enumerate() {
if !node.is_finite() {
return Err(format!(
"{context} empirical latent measure node {idx} is non-finite ({node})"
));
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"{context} empirical latent measure weight {idx} must be finite and positive, got {weight}"
));
}
total += weight;
}
if !(total.is_finite() && (total - 1.0).abs() <= 1e-8) {
return Err(format!(
"{context} empirical latent measure weights must sum to 1, got {total}"
));
}
Ok(())
}
fn combine_empirical_grids(
grids: &[EmpiricalZGrid],
mixture: &[(usize, f64)],
) -> Result<EmpiricalZGrid, String> {
if mixture.is_empty() {
return Err("local empirical latent measure row mixture is empty".to_string());
}
let mut nodes = Vec::new();
let mut weights = Vec::new();
for &(grid_idx, grid_weight) in mixture {
if !grid_weight.is_finite() || grid_weight <= 0.0 {
return Err(format!(
"local empirical latent mixture weight must be finite and positive, got {grid_weight}"
));
}
let grid = grids.get(grid_idx).ok_or_else(|| {
format!("local empirical latent mixture references missing grid {grid_idx}")
})?;
for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
nodes.push(node);
weights.push(grid_weight * weight);
}
}
let total = weights.iter().copied().sum::<f64>();
if !(total.is_finite() && total > 0.0) {
return Err(
"local empirical latent combined grid has non-positive total weight".to_string(),
);
}
for weight in &mut weights {
*weight /= total;
}
validate_empirical_z_grid(&nodes, &weights, "local empirical latent combined grid")?;
Ok(EmpiricalZGrid { nodes, weights })
}
#[derive(Clone, Debug)]
pub struct LatentZPolicy {
pub check_mode: LatentZCheckMode,
pub normalization: LatentZNormalizationMode,
pub latent_measure: LatentMeasureSpec,
pub mean_tol_multiplier: f64,
pub sd_tol_multiplier: f64,
pub max_abs_skew: f64,
pub max_abs_excess_kurtosis: f64,
}
impl LatentZPolicy {
pub fn frozen_transformation_normal() -> Self {
Self {
check_mode: LatentZCheckMode::WarnOnly,
normalization: LatentZNormalizationMode::Frozen { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureSpec::auto_default(),
mean_tol_multiplier: 4.0,
sd_tol_multiplier: 4.0,
max_abs_skew: 4.0,
max_abs_excess_kurtosis: 20.0,
}
}
pub fn exploratory_fit_weighted() -> Self {
Self {
check_mode: LatentZCheckMode::WarnOnly,
normalization: LatentZNormalizationMode::FitWeighted,
latent_measure: LatentMeasureSpec::auto_default(),
mean_tol_multiplier: 8.0,
sd_tol_multiplier: 8.0,
max_abs_skew: 4.0,
max_abs_excess_kurtosis: 20.0,
}
}
pub fn empirical_fit_weighted() -> Self {
Self {
latent_measure: LatentMeasureSpec::global_empirical_default(),
..Self::exploratory_fit_weighted()
}
}
}
impl Default for LatentZPolicy {
fn default() -> Self {
Self::frozen_transformation_normal()
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct LatentZNormalization {
pub mean: f64,
pub sd: f64,
}
impl LatentZNormalization {
pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, String> {
if !(self.mean.is_finite() && self.sd.is_finite() && self.sd > 1e-12) {
return Err(format!(
"{context} requires finite latent z normalization with sd > 1e-12; got mean={} sd={}",
self.mean, self.sd
));
}
if z.iter().any(|value| !value.is_finite()) {
return Err(format!("{context} requires finite z values"));
}
Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
}
}
fn build_latent_measure_with_geometry(
z: &Array1<f64>,
weights: &Array1<f64>,
policy: &LatentZPolicy,
data: Option<ArrayView2<'_, f64>>,
specs: &[&TermCollectionSpec],
) -> Result<LatentMeasureKind, String> {
match policy.latent_measure {
LatentMeasureSpec::Auto { grid_size } => {
if latent_z_is_standard_normal_enough(z, weights, policy)? {
Ok(LatentMeasureKind::StandardNormal)
} else if let (Some(data), Some(geometry)) =
(data, best_latent_score_law_geometry(specs))
{
log::warn!(
"latent z is not close to standard normal; using local empirical latent-z calibration internally"
);
build_local_empirical_latent_measure(
z,
weights,
data,
&geometry,
grid_size,
DEFAULT_LOCAL_EMPIRICAL_LATENT_TOP_K,
)
} else {
log::warn!(
"latent z is not close to standard normal; using global empirical latent-z calibration internally"
);
build_global_empirical_latent_measure(z, weights, grid_size)
}
}
LatentMeasureSpec::StandardNormal => Ok(LatentMeasureKind::StandardNormal),
LatentMeasureSpec::GlobalEmpirical { grid_size } => {
build_global_empirical_latent_measure(z, weights, grid_size)
}
}
}
fn latent_z_is_standard_normal_enough(
z: &Array1<f64>,
weights: &Array1<f64>,
policy: &LatentZPolicy,
) -> Result<bool, String> {
if z.len() != weights.len() {
return Err(format!(
"latent-measure auto-detection length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
if !(weight_sum.is_finite()
&& weight_sum > 0.0
&& weight_sq_sum.is_finite()
&& weight_sq_sum > 0.0)
{
return Err("latent-measure auto-detection requires positive finite weights".to_string());
}
let effective_n = weight_sum * weight_sum / weight_sq_sum;
if !(effective_n.is_finite() && effective_n > 1.0) {
return Err(
"latent-measure auto-detection requires at least two effective observations"
.to_string(),
);
}
let mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ weight_sum;
let var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
.sum::<f64>()
/ weight_sum;
let sd = var.sqrt();
if !(mean.is_finite() && sd.is_finite() && sd > 0.0) {
return Ok(false);
}
let skew = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| {
let centered = (zi - mean) / sd;
wi * centered.powi(3)
})
.sum::<f64>()
/ weight_sum;
let excess_kurtosis = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| {
let centered = (zi - mean) / sd;
wi * centered.powi(4)
})
.sum::<f64>()
/ weight_sum
- 3.0;
let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
let ks_to_normal = weighted_ks_to_standard_normal(z, weights, weight_sum)?;
let tail_mass_4 = weighted_tail_mass(z, weights, weight_sum, 4.0);
let tail_mass_6 = weighted_tail_mass(z, weights, weight_sum, 6.0);
let max_abs_z = z.iter().fold(0.0_f64, |acc, &zi| acc.max(zi.abs()));
let normal_tail_4 = 2.0 * (1.0 - normal_cdf(4.0));
let normal_tail_6 = 2.0 * (1.0 - normal_cdf(6.0));
Ok(mean.abs() <= mean_tol
&& (sd - 1.0).abs() <= sd_tol
&& skew.is_finite()
&& skew.abs() <= policy.max_abs_skew.min(AUTO_Z_NORMAL_SKEW_TOL)
&& excess_kurtosis.is_finite()
&& excess_kurtosis.abs() <= policy.max_abs_excess_kurtosis.min(AUTO_Z_NORMAL_KURT_TOL)
&& ks_to_normal.is_finite()
&& ks_to_normal <= AUTO_Z_NORMAL_KS_TOL
&& tail_mass_4 <= 2.0 * normal_tail_4 + 1e-5
&& tail_mass_6 <= 2.0 * normal_tail_6 + 1e-8
&& max_abs_z < AUTO_Z_NORMAL_MAX_ABS)
}
fn build_global_empirical_latent_measure(
z: &Array1<f64>,
weights: &Array1<f64>,
grid_size: usize,
) -> Result<LatentMeasureKind, String> {
let grid = build_empirical_z_grid(z, weights, grid_size, "empirical latent measure")?;
let measure = LatentMeasureKind::GlobalEmpirical {
nodes: grid.nodes,
weights: grid.weights,
};
measure.validate("empirical latent measure")?;
Ok(measure)
}
fn weighted_ks_to_standard_normal(
z: &Array1<f64>,
weights: &Array1<f64>,
total_weight: f64,
) -> Result<f64, String> {
let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
for (&zi, &wi) in z.iter().zip(weights.iter()) {
if !zi.is_finite() || !wi.is_finite() || wi < 0.0 {
return Err(
"latent-measure KS diagnostic requires finite z and non-negative finite weights"
.to_string(),
);
}
if wi > 0.0 {
pairs.push((zi, wi));
}
}
pairs.sort_by(|left, right| {
left.0
.partial_cmp(&right.0)
.expect("validated latent z values are finite")
});
let mut prev = 0.0;
let mut ks = 0.0_f64;
for (zi, wi) in pairs {
let cdf = normal_cdf(zi);
let next = prev + wi / total_weight;
ks = ks.max((cdf - prev).abs()).max((cdf - next).abs());
prev = next;
}
Ok(ks)
}
fn weighted_tail_mass(
z: &Array1<f64>,
weights: &Array1<f64>,
total_weight: f64,
cutoff: f64,
) -> f64 {
z.iter()
.zip(weights.iter())
.filter(|&(&zi, _)| zi.abs() > cutoff)
.map(|(_, &wi)| wi)
.sum::<f64>()
/ total_weight
}
fn build_empirical_z_grid(
z: &Array1<f64>,
weights: &Array1<f64>,
grid_size: usize,
context: &str,
) -> Result<EmpiricalZGrid, String> {
if grid_size < 3 {
return Err(format!(
"empirical latent measure grid_size must be at least 3, got {grid_size}"
));
}
if z.len() != weights.len() {
return Err(format!(
"{context} length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
for (idx, (&zi, &wi)) in z.iter().zip(weights.iter()).enumerate() {
if !zi.is_finite() {
return Err(format!(
"{context} z value at row {idx} is non-finite ({zi})"
));
}
if !wi.is_finite() || wi < 0.0 {
return Err(format!(
"{context} weight at row {idx} must be finite and non-negative, got {wi}"
));
}
if wi > 0.0 {
pairs.push((zi, wi));
}
}
if pairs.len() < 2 {
return Err(format!(
"{context} requires at least two positive-weight rows"
));
}
pairs.sort_by(|left, right| {
left.0
.partial_cmp(&right.0)
.expect("validated empirical latent z values are finite")
});
let total_weight = pairs.iter().map(|(_, weight)| *weight).sum::<f64>();
if !(total_weight.is_finite() && total_weight > 0.0) {
return Err(format!("{context} requires positive finite total weight"));
}
let m = grid_size.min(pairs.len());
let mut nodes = Vec::with_capacity(m);
let mut out_weights = Vec::with_capacity(m);
let bin_weight_target = total_weight / (m as f64);
let mut cursor = 0usize;
let mut remaining = pairs[0].1;
for _ in 0..m {
let mut need = bin_weight_target;
let mut bin_weight = 0.0;
let mut bin_sum = 0.0;
while need > 1e-14 * bin_weight_target && cursor < pairs.len() {
let take = remaining.min(need);
bin_sum += take * pairs[cursor].0;
bin_weight += take;
need -= take;
remaining -= take;
if remaining <= 1e-14 * pairs[cursor].1 {
cursor += 1;
if cursor < pairs.len() {
remaining = pairs[cursor].1;
}
}
}
if bin_weight > 0.0 {
nodes.push(bin_sum / bin_weight);
out_weights.push(bin_weight / total_weight);
}
}
if nodes.len() < 2 {
return Err(format!(
"{context} compression produced fewer than two nodes"
));
}
recenter_rescale_empirical_grid(&mut nodes, &out_weights);
let total = out_weights.iter().sum::<f64>();
if total.is_finite() && total > 0.0 {
for weight in &mut out_weights {
*weight /= total;
}
}
validate_empirical_z_grid(&nodes, &out_weights, context)?;
Ok(EmpiricalZGrid {
nodes,
weights: out_weights,
})
}
fn recenter_rescale_empirical_grid(nodes: &mut [f64], weights: &[f64]) {
let total = weights.iter().sum::<f64>();
if !(total.is_finite() && total > 0.0) {
return;
}
let mean = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * node)
.sum::<f64>()
/ total;
let var = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * (node - mean).powi(2))
.sum::<f64>()
/ total;
let sd = var.sqrt();
if sd.is_finite() && sd > 1e-12 {
for node in nodes {
*node = (*node - mean) / sd;
}
}
}
#[derive(Clone)]
struct LatentScoreLawGeometry {
feature_cols: Vec<usize>,
centers: Array2<f64>,
input_scales: Option<Vec<f64>>,
bandwidth: f64,
}
fn best_latent_score_law_geometry(specs: &[&TermCollectionSpec]) -> Option<LatentScoreLawGeometry> {
let mut best: Option<LatentScoreLawGeometry> = None;
for spec in specs {
for term in &spec.smooth_terms {
let SmoothBasisSpec::Duchon {
feature_cols,
spec,
input_scales,
} = &term.basis
else {
continue;
};
if feature_cols.len() < 2 {
continue;
}
let crate::basis::CenterStrategy::UserProvided(centers) = &spec.center_strategy else {
continue;
};
if centers.nrows() < 2 || centers.ncols() != feature_cols.len() {
continue;
}
let bandwidth = spec
.length_scale
.filter(|value| value.is_finite() && *value > 0.0)
.unwrap_or(DEFAULT_LOCAL_EMPIRICAL_LATENT_BANDWIDTH);
let candidate = LatentScoreLawGeometry {
feature_cols: feature_cols.clone(),
centers: centers.clone(),
input_scales: input_scales.clone(),
bandwidth,
};
if best
.as_ref()
.is_none_or(|current| candidate.feature_cols.len() > current.feature_cols.len())
{
best = Some(candidate);
}
}
}
best
}
fn conditioning_matrix_from_geometry(
data: ArrayView2<'_, f64>,
geometry: &LatentScoreLawGeometry,
) -> Result<Array2<f64>, String> {
let n = data.nrows();
let d = geometry.feature_cols.len();
let mut out = Array2::<f64>::zeros((n, d));
for (local_col, &feature_col) in geometry.feature_cols.iter().enumerate() {
if feature_col >= data.ncols() {
return Err(format!(
"local empirical latent measure feature column {feature_col} is out of bounds for {} columns",
data.ncols()
));
}
out.column_mut(local_col).assign(&data.column(feature_col));
}
if let Some(scales) = geometry.input_scales.as_ref() {
if scales.len() != d {
return Err(format!(
"local empirical latent measure input scale dimension mismatch: scales={}, features={d}",
scales.len()
));
}
for (col, &scale) in scales.iter().enumerate() {
if !(scale.is_finite() && scale > 0.0) {
return Err(format!(
"local empirical latent measure input scale {col} must be finite and positive, got {scale}"
));
}
out.column_mut(col).mapv_inplace(|value| value / scale);
}
}
if out.iter().any(|value| !value.is_finite()) {
return Err(
"local empirical latent measure conditioning values must be finite".to_string(),
);
}
Ok(out)
}
fn local_empirical_mixture_for_point(
point: &[f64],
centers: &[Vec<f64>],
top_k: usize,
bandwidth: f64,
) -> Result<Vec<(usize, f64)>, String> {
if centers.is_empty() {
return Err("local empirical latent measure has no centers".to_string());
}
if top_k == 0 {
return Err("local empirical latent measure top_k must be positive".to_string());
}
if !(bandwidth.is_finite() && bandwidth > 0.0) {
return Err(format!(
"local empirical latent measure bandwidth must be finite and positive, got {bandwidth}"
));
}
let bw2 = bandwidth * bandwidth;
let mut distances = Vec::<(usize, f64)>::with_capacity(centers.len());
for (idx, center) in centers.iter().enumerate() {
if center.len() != point.len() {
return Err(format!(
"local empirical latent center {idx} dimension mismatch: center={}, point={}",
center.len(),
point.len()
));
}
let d2 = center
.iter()
.zip(point.iter())
.map(|(&c, &x)| {
let delta = x - c;
delta * delta
})
.sum::<f64>();
distances.push((idx, d2));
}
distances.sort_by(|left, right| {
left.1
.partial_cmp(&right.1)
.expect("validated local empirical distances are finite")
});
let k = top_k.min(distances.len());
let mut mixture = Vec::with_capacity(k);
let mut total = 0.0;
for &(idx, d2) in distances.iter().take(k) {
let weight = (-0.5 * d2 / bw2).exp().max(1e-300);
mixture.push((idx, weight));
total += weight;
}
if !(total.is_finite() && total > 0.0) {
return Err("local empirical latent mixture has non-positive total weight".to_string());
}
for (_, weight) in &mut mixture {
*weight /= total;
}
Ok(mixture)
}
fn build_local_empirical_latent_measure(
z: &Array1<f64>,
weights: &Array1<f64>,
data: ArrayView2<'_, f64>,
geometry: &LatentScoreLawGeometry,
grid_size: usize,
top_k: usize,
) -> Result<LatentMeasureKind, String> {
let conditioning = conditioning_matrix_from_geometry(data, geometry)?;
if conditioning.nrows() != z.len() {
return Err(format!(
"local empirical latent measure row mismatch: conditioning={}, z={}",
conditioning.nrows(),
z.len()
));
}
let centers = geometry
.centers
.outer_iter()
.map(|row| row.to_vec())
.collect::<Vec<_>>();
let top_k = top_k.min(centers.len()).max(1);
let mut grids = Vec::with_capacity(centers.len());
for center in ¢ers {
let local_weights = Array1::from_iter((0..conditioning.nrows()).map(|row| {
let d2 = conditioning
.row(row)
.iter()
.zip(center.iter())
.map(|(&x, &c)| {
let delta = x - c;
delta * delta
})
.sum::<f64>();
weights[row] * (-0.5 * d2 / (geometry.bandwidth * geometry.bandwidth)).exp()
}));
grids.push(build_empirical_z_grid(
z,
&local_weights,
grid_size,
"local empirical latent measure",
)?);
}
let train_row_mixtures = Arc::new(
conditioning
.outer_iter()
.map(|row| {
local_empirical_mixture_for_point(
row.as_slice().ok_or_else(|| {
"local empirical conditioning row is not contiguous".to_string()
})?,
¢ers,
top_k,
geometry.bandwidth,
)
})
.collect::<Result<Vec<_>, String>>()?,
);
let measure = LatentMeasureKind::LocalEmpirical {
feature_cols: geometry.feature_cols.clone(),
input_scales: geometry.input_scales.clone(),
centers,
grids,
top_k,
bandwidth: geometry.bandwidth,
train_row_mixtures,
};
measure.validate("local empirical latent measure")?;
Ok(measure)
}
#[derive(Clone)]
struct BernoulliMarginalSlopeFamily {
y: Arc<Array1<f64>>,
weights: Arc<Array1<f64>>,
z: Arc<Array1<f64>>,
latent_measure: LatentMeasureKind,
gaussian_frailty_sd: Option<f64>,
base_link: InverseLink,
marginal_design: DesignMatrix,
logslope_design: DesignMatrix,
score_warp: Option<DeviationRuntime>,
link_dev: Option<DeviationRuntime>,
policy: crate::resource::ResourcePolicy,
cell_moment_lru: Arc<exact_kernel::CellMomentLruCache>,
cell_moment_cache_stats: Arc<exact_kernel::CellMomentCacheStats>,
intercept_warm_starts: Option<Arc<BernoulliInterceptWarmStartCache>>,
}
#[derive(Clone)]
struct BernoulliInterceptPredictorWarmStart {
intercept: f64,
primary_point: Vec<f64>,
intercept_primary_deriv: Vec<f64>,
}
struct BernoulliInterceptWarmStartCache {
intercepts: Vec<AtomicU64>,
predictors: Vec<Mutex<Option<BernoulliInterceptPredictorWarmStart>>>,
}
impl std::ops::Deref for BernoulliInterceptWarmStartCache {
type Target = [AtomicU64];
fn deref(&self) -> &Self::Target {
&self.intercepts
}
}
impl BernoulliInterceptWarmStartCache {
fn predictor_seed(&self, row: usize, current_point: &[f64]) -> Option<f64> {
let warm = self.predictors.get(row)?.lock().ok()?.as_ref().cloned()?;
if warm.primary_point.len() != current_point.len()
|| warm.intercept_primary_deriv.len() != current_point.len()
|| !warm.intercept.is_finite()
{
return None;
}
let correction = warm
.intercept_primary_deriv
.iter()
.zip(current_point.iter().zip(warm.primary_point.iter()))
.map(|(a_u, (new, old))| a_u * (new - old))
.sum::<f64>();
let seed = warm.intercept + correction;
seed.is_finite().then_some(seed)
}
fn store_predictor(
&self,
row: usize,
intercept: f64,
primary_point: Vec<f64>,
intercept_primary_deriv: Vec<f64>,
) {
if !intercept.is_finite()
|| primary_point.iter().any(|value| !value.is_finite())
|| intercept_primary_deriv
.iter()
.any(|value| !value.is_finite())
{
return;
}
let Some(slot) = self.predictors.get(row) else {
return;
};
if let Ok(mut guard) = slot.lock() {
*guard = Some(BernoulliInterceptPredictorWarmStart {
intercept,
primary_point,
intercept_primary_deriv,
});
}
}
}
fn new_intercept_warm_start_cache(n: usize) -> Arc<BernoulliInterceptWarmStartCache> {
Arc::new(BernoulliInterceptWarmStartCache {
intercepts: (0..n).map(|_| AtomicU64::new(f64::NAN.to_bits())).collect(),
predictors: (0..n).map(|_| Mutex::new(None)).collect(),
})
}
#[derive(Clone, Default)]
struct ThetaHints {
marginal_beta: Option<Array1<f64>>,
logslope_beta: Option<Array1<f64>>,
score_warp_beta: Option<Array1<f64>>,
link_dev_beta: Option<Array1<f64>>,
}
pub(crate) fn build_score_warp_deviation_block_from_seed(
seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_deviation_block_from_knots_and_design_seed_with_anchor(
seed,
seed,
cfg,
DeviationAnchorKind::StandardNormal,
)
}
pub(crate) fn build_score_warp_deviation_block_from_seed_empirical_anchor(
seed: &Array1<f64>,
weights: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_deviation_block_from_knots_and_design_seed_with_anchor(
seed,
seed,
cfg,
DeviationAnchorKind::EmpiricalDesign {
anchor_weights: weights,
},
)
}
const BERNOULLI_LINK_PROBABILITY_EPS: f64 = 1e-12;
#[derive(Clone, Copy, Debug)]
pub(crate) struct BernoulliMarginalLinkMap {
pub mu: f64,
pub mu1: f64,
pub mu2: f64,
pub mu3: f64,
pub mu4: f64,
pub q: f64,
pub q1: f64,
pub q2: f64,
pub q3: f64,
pub q4: f64,
}
#[inline]
fn clamp_bernoulli_link_probability(probability: f64) -> f64 {
probability.clamp(
BERNOULLI_LINK_PROBABILITY_EPS,
1.0 - BERNOULLI_LINK_PROBABILITY_EPS,
)
}
pub(crate) fn bernoulli_marginal_slope_eta_from_probability(
base_link: &InverseLink,
probability: f64,
context: &str,
) -> Result<f64, String> {
require_probit_marginal_slope_link(base_link, context)?;
let target = clamp_bernoulli_link_probability(probability);
standard_normal_quantile(target)
.map_err(|e| format!("{context} failed to invert probit probability {target}: {e}"))
}
pub(crate) fn bernoulli_marginal_link_map(
base_link: &InverseLink,
eta: f64,
) -> Result<BernoulliMarginalLinkMap, String> {
require_probit_marginal_slope_link(base_link, "bernoulli marginal-slope")?;
let raw_mu = normal_cdf(eta);
let mu = clamp_bernoulli_link_probability(raw_mu);
let q = standard_normal_quantile(mu).map_err(|e| {
format!("bernoulli marginal-slope probit target inversion failed at mu={mu}: {e}")
})?;
if raw_mu <= BERNOULLI_LINK_PROBABILITY_EPS || raw_mu >= 1.0 - BERNOULLI_LINK_PROBABILITY_EPS {
return Ok(BernoulliMarginalLinkMap {
mu,
mu1: 0.0,
mu2: 0.0,
mu3: 0.0,
mu4: 0.0,
q,
q1: 0.0,
q2: 0.0,
q3: 0.0,
q4: 0.0,
});
}
let phi_eta = normal_pdf(eta);
let phi_q = normal_pdf(q);
if !phi_q.is_finite() || phi_q <= 0.0 {
return Err(format!(
"bernoulli marginal-slope internal probit density must be positive, got phi(q)={phi_q} at eta={eta}, q={q}"
));
}
let mu1 = phi_eta;
let mu2 = -eta * phi_eta;
let mu3 = (eta * eta - 1.0) * phi_eta;
let mu4 = -(eta.powi(3) - 3.0 * eta) * phi_eta;
let q1 = mu1 / phi_q;
let q2 = mu2 / phi_q + q * q1 * q1;
let q3 = mu3 / phi_q + 3.0 * q * q1 * q2 - (q * q - 1.0) * q1.powi(3);
let q4 =
mu4 / phi_q + (q.powi(3) - 3.0 * q) * q1.powi(4) + 4.0 * q * q1 * q3 + 3.0 * q * q2 * q2
- 6.0 * (q * q - 1.0) * q1 * q1 * q2;
Ok(BernoulliMarginalLinkMap {
mu,
mu1,
mu2,
mu3,
mu4,
q,
q1,
q2,
q3,
q4,
})
}
fn require_probit_marginal_slope_link(
base_link: &InverseLink,
context: &str,
) -> Result<(), String> {
if matches!(base_link, InverseLink::Standard(LinkFunction::Probit)) {
Ok(())
} else {
Err(format!(
"{context} requires link(type=probit); non-probit marginal-slope base links are not supported by the calibrated de-nested probit kernel"
))
}
}
pub(crate) fn build_link_deviation_block_from_knots_design_seed_and_weights(
knot_seed: &Array1<f64>,
design_seed: &Array1<f64>,
anchor_weights: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_deviation_block_from_knots_and_design_seed_with_anchor(
knot_seed,
design_seed,
cfg,
DeviationAnchorKind::EmpiricalDesign { anchor_weights },
)
}
enum DeviationAnchorKind<'a> {
StandardNormal,
EmpiricalDesign { anchor_weights: &'a Array1<f64> },
}
fn build_deviation_block_from_knots_and_design_seed_with_anchor(
knot_seed: &Array1<f64>,
design_seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
anchor: DeviationAnchorKind<'_>,
) -> Result<DeviationPrepared, String> {
if cfg.degree != 3 {
return Err(format!(
"structural deviation runtime is cubic; degree must be 3, got {}",
cfg.degree
));
}
let penalty_orders = resolve_deviation_operator_orders(cfg)?;
let knots = initialize_monotone_wiggle_knots_from_seed(
knot_seed.view(),
cfg.degree,
cfg.num_internal_knots,
)?;
let runtime = match anchor {
DeviationAnchorKind::StandardNormal => {
DeviationRuntime::try_new_standard_normal_anchor(knots, cfg.monotonicity_eps)?
}
DeviationAnchorKind::EmpiricalDesign { anchor_weights } => {
DeviationRuntime::try_new_weighted_empirical_anchor(
knots,
cfg.monotonicity_eps,
design_seed,
anchor_weights,
)?
}
};
let design = runtime.design(design_seed)?;
let p = design.ncols();
if p == 0 {
return Err("structural deviation basis has no free derivative controls".to_string());
}
let mut block = ParameterBlockInput {
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(design)),
offset: Array1::zeros(design_seed.len()),
penalties: Vec::new(),
nullspace_dims: Vec::new(),
initial_log_lambdas: None,
initial_beta: Some(Array1::zeros(p)),
};
for order in penalty_orders {
append_deviation_function_penalty(&mut block, &runtime, order)?;
}
if cfg.double_penalty {
append_deviation_function_penalty(&mut block, &runtime, 0)?;
}
Ok(DeviationPrepared { block, runtime })
}
fn resolve_deviation_operator_orders(cfg: &DeviationBlockConfig) -> Result<Vec<usize>, String> {
let mut orders = Vec::new();
let requested = if cfg.penalty_orders.is_empty() {
std::slice::from_ref(&cfg.penalty_order)
} else {
cfg.penalty_orders.as_slice()
};
for &order in requested {
if order == 0 {
continue;
}
if order > cfg.degree {
return Err(format!(
"deviation function penalty derivative order {order} exceeds basis degree {}",
cfg.degree
));
}
if !orders.contains(&order) {
orders.push(order);
}
}
if orders.is_empty() {
return Err(
"deviation block requires at least one positive function-penalty derivative order"
.to_string(),
);
}
Ok(orders)
}
fn append_deviation_function_penalty(
block: &mut ParameterBlockInput,
runtime: &DeviationRuntime,
derivative_order: usize,
) -> Result<(), String> {
let (penalty, nullity) =
runtime.integrated_derivative_penalty_with_nullity(derivative_order)?;
block
.penalties
.push(crate::solver::estimate::PenaltySpec::Dense(penalty));
block.nullspace_dims.push(nullity);
Ok(())
}
pub(crate) fn project_monotone_feasible_beta(
runtime: &DeviationRuntime,
current: &Array1<f64>,
proposed: &Array1<f64>,
label: &str,
) -> Result<Array1<f64>, String> {
if current.len() != runtime.basis_dim() {
return Err(format!(
"{label} monotone projection current length mismatch: current={}, expected={}",
current.len(),
runtime.basis_dim()
));
}
if proposed.len() != runtime.basis_dim() {
return Err(format!(
"{label} monotone projection length mismatch: proposed={}, expected={}",
proposed.len(),
runtime.basis_dim()
));
}
for (idx, value) in current.iter().enumerate() {
if !value.is_finite() {
return Err(format!("{label} current coefficient {idx} is non-finite"));
}
}
for (idx, value) in proposed.iter().enumerate() {
if !value.is_finite() {
return Err(format!("{label} coefficient {idx} is non-finite"));
}
}
runtime.monotonicity_feasible(current, &format!("{label} current beta"))?;
if runtime
.monotonicity_feasible(proposed, &format!("{label} proposed beta"))
.is_ok()
{
return Ok(proposed.clone());
}
let direction = proposed - current;
let mut lo = 0.0_f64;
let mut hi = 1.0_f64;
for _ in 0..64 {
let mid = 0.5 * (lo + hi);
let candidate = current + &direction.mapv(|value| value * mid);
if runtime
.monotonicity_feasible(&candidate, &format!("{label} projected beta"))
.is_ok()
{
lo = mid;
} else {
hi = mid;
}
}
Ok(current + &direction.mapv(|value| value * lo))
}
fn validate_spec(
data: ArrayView2<'_, f64>,
spec: &BernoulliMarginalSlopeTermSpec,
) -> Result<(), String> {
let n = data.nrows();
if spec.y.len() != n
|| spec.weights.len() != n
|| spec.z.len() != n
|| spec.marginal_offset.len() != n
|| spec.logslope_offset.len() != n
{
return Err(format!(
"bernoulli-marginal-slope row mismatch: data={}, y={}, weights={}, z={}, marginal_offset={}, logslope_offset={}",
n,
spec.y.len(),
spec.weights.len(),
spec.z.len(),
spec.marginal_offset.len(),
spec.logslope_offset.len()
));
}
if spec
.y
.iter()
.any(|&yi| !yi.is_finite() || ((yi - 0.0).abs() > 1e-9 && (yi - 1.0).abs() > 1e-9))
{
return Err("bernoulli-marginal-slope requires binary y in {0,1}".to_string());
}
if spec.weights.iter().any(|&w| !w.is_finite() || w < 0.0) {
return Err("bernoulli-marginal-slope requires finite non-negative weights".to_string());
}
if spec.z.iter().any(|&zi| !zi.is_finite()) {
return Err("bernoulli-marginal-slope requires finite z values".to_string());
}
if spec.marginal_offset.iter().any(|&value| !value.is_finite()) {
return Err("bernoulli-marginal-slope requires finite marginal offsets".to_string());
}
if spec.logslope_offset.iter().any(|&value| !value.is_finite()) {
return Err("bernoulli-marginal-slope requires finite logslope offsets".to_string());
}
require_probit_marginal_slope_link(&spec.base_link, "bernoulli-marginal-slope")?;
spec.frailty.validate_for_marginal_slope()?;
match &spec.frailty {
FrailtySpec::None => {}
FrailtySpec::GaussianShift { sigma_fixed } => {
if let Some(sigma) = sigma_fixed
&& (!sigma.is_finite() || *sigma < 0.0)
{
return Err(format!(
"bernoulli-marginal-slope requires GaussianShift sigma >= 0, got {sigma}"
));
}
}
FrailtySpec::HazardMultiplier { .. } => unreachable!(),
}
Ok(())
}
pub(crate) fn standardize_latent_z_with_policy(
z: &Array1<f64>,
weights: &Array1<f64>,
context: &str,
policy: &LatentZPolicy,
) -> Result<(Array1<f64>, LatentZNormalization), String> {
if z.len() != weights.len() {
return Err(format!(
"{context} latent-score normalization length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
if !(weight_sum.is_finite()
&& weight_sum > 0.0
&& weight_sq_sum.is_finite()
&& weight_sq_sum > 0.0)
{
return Err(format!("{context} requires positive finite total weight"));
}
let effective_n = weight_sum * weight_sum / weight_sq_sum;
if !(effective_n.is_finite() && effective_n > 1.0) {
return Err(format!(
"{context} requires at least two effective observations for latent-score normalization"
));
}
let mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ weight_sum;
let var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
.sum::<f64>()
/ weight_sum;
let sd = var.sqrt();
if !(sd.is_finite() && sd > 1e-12) {
return Err(format!(
"{context} requires z with positive finite weighted standard deviation"
));
}
let target_norm = match policy.normalization {
LatentZNormalizationMode::None => LatentZNormalization { mean: 0.0, sd: 1.0 },
LatentZNormalizationMode::FitWeighted => LatentZNormalization { mean, sd },
LatentZNormalizationMode::Frozen {
mean: frozen_mean,
sd: frozen_sd,
} => LatentZNormalization {
mean: frozen_mean,
sd: frozen_sd,
},
};
let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
let check_msg = || {
format!(
"{context} requires z to already be approximately latent N(0,1) before identification normalization; got mean={mean:.6e}, sd={sd:.6e}, effective_n={effective_n:.1}, allowed_mean={mean_tol:.3e}, allowed_sd={sd_tol:.3e}"
)
};
if mean.abs() > mean_tol || (sd - 1.0).abs() > sd_tol {
match policy.check_mode {
LatentZCheckMode::Strict => return Err(check_msg()),
LatentZCheckMode::WarnOnly => log::warn!("{}", check_msg()),
LatentZCheckMode::Off => {}
}
}
let normalization = target_norm;
let z_std = normalization.apply(z, context)?;
let skew = z_std
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi.powi(3))
.sum::<f64>()
/ weight_sum;
let kurt = z_std
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi.powi(4))
.sum::<f64>()
/ weight_sum
- 3.0;
if skew.abs() > policy.max_abs_skew || kurt.abs() > policy.max_abs_excess_kurtosis {
let msg = format!(
"{context} requires z to be approximately Gaussian after identification normalization; got skewness={skew:.3}, excess_kurtosis={kurt:.3}"
);
match policy.check_mode {
LatentZCheckMode::Strict => return Err(msg),
LatentZCheckMode::WarnOnly => log::warn!("{}", msg),
LatentZCheckMode::Off => {}
}
}
if skew.abs() > 0.75 || kurt.abs() > 2.0 {
log::warn!(
"{context}: z has skewness={skew:.3} and excess kurtosis={kurt:.3}; latent-measure auto-selection will use empirical calibration unless stricter diagnostics pass"
);
}
Ok((z_std, normalization))
}
pub fn padded_deviation_seed(seed: &Array1<f64>, min_iqr: f64, pad_fraction: f64) -> Array1<f64> {
let mut sorted = seed.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if sorted.len() < 4 {
return seed.clone();
}
let n = sorted.len();
let q1 = sorted[n / 4];
let q3 = sorted[3 * n / 4];
let iqr = (q3 - q1).max(min_iqr);
let pad = pad_fraction * iqr;
let mut out = seed.to_vec();
out.push(sorted[0] - pad);
out.push(sorted[n - 1] + pad);
Array1::from_vec(out)
}
fn pooled_probit_baseline(
y: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
) -> Result<(f64, f64), String> {
if y.len() != z.len() || y.len() != weights.len() {
return Err(format!(
"pooled bernoulli-marginal-slope pilot length mismatch: y={}, z={}, weights={}",
y.len(),
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
if !weight_sum.is_finite() || weight_sum <= 0.0 {
return Err(
"pooled bernoulli-marginal-slope pilot requires positive finite total weight"
.to_string(),
);
}
let prevalence = y
.iter()
.zip(weights.iter())
.map(|(&yi, &wi)| yi * wi)
.sum::<f64>()
/ weight_sum;
let prevalence = prevalence.clamp(1e-6, 1.0 - 1e-6);
let z_mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| zi * wi)
.sum::<f64>()
/ weight_sum;
let z_var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - z_mean) * (zi - z_mean))
.sum::<f64>()
/ weight_sum;
let yz_cov = y
.iter()
.zip(z.iter())
.zip(weights.iter())
.map(|((&yi, &zi), &wi)| wi * (yi - prevalence) * (zi - z_mean))
.sum::<f64>()
/ weight_sum;
let mut beta0 = standard_normal_quantile(prevalence).map_err(|e| {
format!("failed to initialize pooled bernoulli-marginal-slope pilot intercept: {e}")
})?;
let mut beta1 = if z_var > 1e-12 { yz_cov / z_var } else { 0.0 };
let objective_grad_hess =
|intercept: f64, slope: f64| -> Result<(f64, f64, f64, f64, f64, f64), String> {
let mut obj = 0.0;
let mut g0 = 0.0;
let mut g1 = 0.0;
let mut h00 = 0.0;
let mut h01 = 0.0;
let mut h11 = 0.0;
for ((&yi, &zi), &wi) in y.iter().zip(z.iter()).zip(weights.iter()) {
if wi == 0.0 {
continue;
}
let eta = intercept + slope * zi;
let s = 2.0 * yi - 1.0;
let margin = s * eta;
let (logcdf, lambda) = signed_probit_logcdf_and_mills_ratio(margin);
let g_eta = -wi * s * lambda;
let h_eta = wi * lambda * (margin + lambda);
obj -= wi * logcdf;
g0 += g_eta;
g1 += g_eta * zi;
h00 += h_eta;
h01 += h_eta * zi;
h11 += h_eta * zi * zi;
}
Ok((obj, g0, g1, h00, h01, h11))
};
let mut obj_prev = f64::INFINITY;
for _ in 0..50 {
let (obj, g0, g1, h00, h01, h11) = objective_grad_hess(beta0, beta1)?;
if !obj.is_finite() || !g0.is_finite() || !g1.is_finite() {
return Err(
"pooled bernoulli-marginal-slope pilot produced non-finite objective or gradient"
.to_string(),
);
}
let grad_max = g0.abs().max(g1.abs());
if grad_max < 1e-8 {
break;
}
let mut ridge = 1e-8;
let (step0, step1) = loop {
let h00_r = h00 + ridge;
let h11_r = h11 + ridge;
let det = h00_r * h11_r - h01 * h01;
if det.is_finite() && det.abs() > 1e-18 {
let s0 = (h11_r * g0 - h01 * g1) / det;
let s1 = (-h01 * g0 + h00_r * g1) / det;
if s0.is_finite() && s1.is_finite() {
break (s0, s1);
}
}
ridge *= 10.0;
if ridge > 1e6 {
return Err(
"pooled bernoulli-marginal-slope pilot Hessian solve failed".to_string()
);
}
};
let mut accepted = false;
let mut step_scale = 1.0;
for _ in 0..25 {
let cand0 = beta0 - step_scale * step0;
let cand1 = beta1 - step_scale * step1;
let (cand_obj, _, _, _, _, _) = objective_grad_hess(cand0, cand1)?;
if cand_obj.is_finite() && cand_obj <= obj {
beta0 = cand0;
beta1 = cand1;
obj_prev = cand_obj;
accepted = true;
break;
}
step_scale *= 0.5;
}
if !accepted {
if (obj_prev - obj).abs() < 1e-10 {
break;
}
return Err("pooled bernoulli-marginal-slope pilot line search failed".to_string());
}
}
let a = beta0;
let b = if beta1.abs() < 1e-6 {
if beta1.is_sign_negative() {
-1e-6
} else {
1e-6
}
} else {
beta1
};
Ok((a / (1.0 + b * b).sqrt(), b))
}
fn joint_setup(
data: ArrayView2<'_, f64>,
marginalspec: &TermCollectionSpec,
logslopespec: &TermCollectionSpec,
marginal_penalties: usize,
logslope_penalties: usize,
extra_rho0: &[f64],
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> ExactJointHyperSetup {
let marginal_terms = spatial_length_scale_term_indices(marginalspec);
let logslope_terms = spatial_length_scale_term_indices(logslopespec);
let rho_dim = marginal_penalties + logslope_penalties + extra_rho0.len();
let mut rho0vec = Array1::<f64>::zeros(rho_dim);
for (idx, &value) in extra_rho0.iter().enumerate() {
rho0vec[marginal_penalties + logslope_penalties + idx] = value;
}
let rho_lower = Array1::<f64>::from_elem(rho_dim, -12.0);
let rho_upper = Array1::<f64>::from_elem(rho_dim, 12.0);
let marginal_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
marginalspec,
&marginal_terms,
kappa_options,
)
.reseed_from_data(data, marginalspec, &marginal_terms, kappa_options);
let logslope_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
logslopespec,
&logslope_terms,
kappa_options,
)
.reseed_from_data(data, logslopespec, &logslope_terms, kappa_options);
let mut values = marginal_kappa.as_array().to_vec();
values.extend(logslope_kappa.as_array().iter());
let marginal_dims = marginal_kappa.dims_per_term().to_vec();
let logslope_dims = logslope_kappa.dims_per_term().to_vec();
let mut dims = marginal_dims.clone();
dims.extend(logslope_dims.iter().copied());
let log_kappa0 = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(values), dims.clone());
let marginal_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data,
marginalspec,
&marginal_terms,
&marginal_dims,
kappa_options,
);
let logslope_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data,
logslopespec,
&logslope_terms,
&logslope_dims,
kappa_options,
);
let mut lower_vals = marginal_lower.as_array().to_vec();
lower_vals.extend(logslope_lower.as_array().iter());
let log_kappa_lower =
SpatialLogKappaCoords::new_with_dims(Array1::from_vec(lower_vals), dims.clone());
let marginal_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data,
marginalspec,
&marginal_terms,
&marginal_dims,
kappa_options,
);
let logslope_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data,
logslopespec,
&logslope_terms,
&logslope_dims,
kappa_options,
);
let mut upper_vals = marginal_upper.as_array().to_vec();
upper_vals.extend(logslope_upper.as_array().iter());
let log_kappa_upper = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(upper_vals), dims);
let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
ExactJointHyperSetup::new(
rho0vec,
rho_lower,
rho_upper,
log_kappa0,
log_kappa_lower,
log_kappa_upper,
)
}
#[inline]
fn signed_probit_neglog_derivatives_up_to_fourth_numeric(
signed_margin: f64,
weight: f64,
) -> (f64, f64, f64, f64) {
if weight == 0.0 || signed_margin == f64::INFINITY {
return (0.0, 0.0, 0.0, 0.0);
}
if signed_margin == f64::NEG_INFINITY {
return (f64::NEG_INFINITY, weight, 0.0, 0.0);
}
if signed_margin.is_nan() {
return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
}
let (_, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
let k1 = -lambda;
let k2 = lambda * (signed_margin + lambda);
let k3 = lambda
* (1.0
- signed_margin * signed_margin
- 3.0 * signed_margin * lambda
- 2.0 * lambda * lambda);
let k4 = lambda
* ((signed_margin.powi(3) - 3.0 * signed_margin)
+ (7.0 * signed_margin * signed_margin - 4.0) * lambda
+ 12.0 * signed_margin * lambda * lambda
+ 6.0 * lambda.powi(3));
(weight * k1, weight * k2, weight * k3, weight * k4)
}
pub(crate) fn signed_probit_neglog_derivatives_up_to_fourth(
signed_margin: f64,
weight: f64,
) -> Result<(f64, f64, f64, f64), String> {
if weight == 0.0 || signed_margin == f64::INFINITY {
return Ok((0.0, 0.0, 0.0, 0.0));
}
if !signed_margin.is_finite() {
return Err(format!(
"non-finite signed margin in exact probit derivative helper: {signed_margin}"
));
}
Ok(signed_probit_neglog_derivatives_up_to_fourth_numeric(
signed_margin,
weight,
))
}
#[inline]
fn rigid_observed_logslope(logslope: f64, probit_scale: f64) -> f64 {
probit_scale * logslope
}
#[inline]
fn rigid_observed_scale(logslope: f64, probit_scale: f64) -> f64 {
let observed_logslope = rigid_observed_logslope(logslope, probit_scale);
(1.0 + observed_logslope * observed_logslope).sqrt()
}
#[inline]
fn rigid_intercept_from_marginal(marginal_eta: f64, logslope: f64, probit_scale: f64) -> f64 {
marginal_eta * rigid_observed_scale(logslope, probit_scale)
}
#[inline]
fn rigid_prescale_intercept_from_marginal(
marginal_eta: f64,
logslope: f64,
probit_scale: f64,
) -> f64 {
rigid_intercept_from_marginal(marginal_eta, logslope, probit_scale) / probit_scale
}
#[inline]
fn rigid_prescale_intercept_derivative_abs(
marginal_eta: f64,
logslope: f64,
probit_scale: f64,
) -> f64 {
let c = rigid_observed_scale(logslope, probit_scale);
probit_scale * normal_pdf(marginal_eta) / c
}
#[inline]
fn rigid_observed_eta(marginal_eta: f64, logslope: f64, z: f64, probit_scale: f64) -> f64 {
rigid_intercept_from_marginal(marginal_eta, logslope, probit_scale)
+ rigid_observed_logslope(logslope, probit_scale) * z
}
fn unary_derivatives_normal_cdf(x: f64) -> [f64; 5] {
let pdf = normal_pdf(x);
[
normal_cdf(x),
pdf,
-x * pdf,
(x * x - 1.0) * pdf,
(-x.powi(3) + 3.0 * x) * pdf,
]
}
fn unary_derivatives_normal_pdf(x: f64) -> [f64; 5] {
let pdf = normal_pdf(x);
[
pdf,
-x * pdf,
(x * x - 1.0) * pdf,
(-x.powi(3) + 3.0 * x) * pdf,
(x.powi(4) - 6.0 * x * x + 3.0) * pdf,
]
}
fn unary_derivatives_reciprocal(x: f64) -> [f64; 5] {
let x1 = x.max(1e-300);
let x2 = x1 * x1;
let x3 = x2 * x1;
let x4 = x3 * x1;
let x5 = x4 * x1;
[1.0 / x1, -1.0 / x2, 2.0 / x3, -6.0 / x4, 24.0 / x5]
}
fn empirical_rigid_calibration_eval(
intercept: f64,
target_mu: f64,
slope: f64,
probit_scale: f64,
nodes: &[f64],
weights: &[f64],
) -> Result<(f64, f64, f64), String> {
let observed_slope = rigid_observed_logslope(slope, probit_scale);
let mut f = -target_mu;
let mut f_a = 0.0;
let mut f_aa = 0.0;
for (&node, &weight) in nodes.iter().zip(weights.iter()) {
let eta = intercept + observed_slope * node;
let pdf = normal_pdf(eta);
f += weight * normal_cdf(eta);
f_a += weight * pdf;
f_aa -= weight * eta * pdf;
}
if !(f.is_finite() && f_a.is_finite() && f_aa.is_finite() && f_a > 0.0) {
return Err(format!(
"empirical latent calibration produced invalid root state: f={f}, f_a={f_a}, f_aa={f_aa}"
));
}
Ok((f, f_a, f_aa))
}
pub(crate) fn empirical_intercept_from_marginal(
target_mu: f64,
target_q: f64,
slope: f64,
probit_scale: f64,
nodes: &[f64],
weights: &[f64],
initial: Option<f64>,
) -> Result<f64, String> {
if !(target_mu.is_finite() && target_mu > 0.0 && target_mu < 1.0) {
return Err(format!(
"empirical latent calibration requires target mu in (0,1), got {target_mu}"
));
}
let seed =
initial.unwrap_or_else(|| rigid_intercept_from_marginal(target_q, slope, probit_scale));
let eval = |a: f64| {
empirical_rigid_calibration_eval(a, target_mu, slope, probit_scale, nodes, weights)
};
let abs_tol = 1e-13_f64.max(4.0 * f64::EPSILON * target_mu.abs());
let (root, _, f_best) = super::monotone_root::solve_monotone_root(
eval,
seed,
"empirical latent intercept",
abs_tol,
64,
48,
)?;
if f_best.abs() > abs_tol {
return Err(format!(
"empirical latent intercept solve failed: residual={f_best:.3e} at a={root:.6}, target mu={target_mu:.6}"
));
}
Ok(root)
}
struct RigidProbitKernel {
logcdf: f64,
u1: f64,
u2: f64,
u3: f64,
u4: f64,
c1: f64,
c2: f64,
c3: f64,
c4: f64,
eta_q: f64,
eta_g: f64,
}
impl RigidProbitKernel {
#[inline]
fn new(q: f64, g: f64, z: f64, y: f64, w: f64, probit_scale: f64) -> Result<Self, String> {
let s = 2.0 * y - 1.0;
let observed_logslope = rigid_observed_logslope(g, probit_scale);
let g2 = observed_logslope * observed_logslope;
let c = (1.0 + g2).sqrt();
let c1 = probit_scale * observed_logslope / c;
let c_inv3 = 1.0 / (c * c * c);
let c_inv5 = c_inv3 / (c * c);
let c_inv7 = c_inv5 / (c * c);
let eta = q * c + observed_logslope * z;
let m = s * eta;
let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(m);
let (k1, k2, k3, k4) = signed_probit_neglog_derivatives_up_to_fourth(m, w)?;
Ok(Self {
logcdf,
u1: s * k1,
u2: k2,
u3: s * k3,
u4: k4,
c1,
c2: probit_scale * probit_scale * c_inv3,
c3: -3.0 * probit_scale.powi(3) * observed_logslope * c_inv5,
c4: probit_scale.powi(4) * (12.0 * g2 - 3.0) * c_inv7,
eta_q: c,
eta_g: q * c1 + probit_scale * z,
})
}
#[inline]
fn primary_hessian(&self, q: f64) -> [[f64; 2]; 2] {
let h00 = self.u2 * self.eta_q * self.eta_q;
let h01 = self.u2 * self.eta_q * self.eta_g + self.u1 * self.c1;
let h11 = self.u2 * self.eta_g * self.eta_g + self.u1 * q * self.c2;
[[h00, h01], [h01, h11]]
}
#[inline]
fn third_contracted(&self, q: f64, dq: f64, dg: f64) -> [[f64; 2]; 2] {
let dd = self.eta_q * dq + self.eta_g * dg;
let dd_q = self.c1 * dg;
let dd_g = self.c1 * dq + q * self.c2 * dg;
let dd_qg = self.c2 * dg;
let dd_gg = self.c2 * dq + q * self.c3 * dg;
let t00 = self.u3 * self.eta_q * self.eta_q * dd + self.u2 * 2.0 * self.eta_q * dd_q;
let t01 = self.u3 * self.eta_q * self.eta_g * dd
+ self.u2 * (self.c1 * dd + self.eta_q * dd_g + self.eta_g * dd_q)
+ self.u1 * dd_qg;
let t11 = self.u3 * self.eta_g * self.eta_g * dd
+ self.u2 * (q * self.c2 * dd + 2.0 * self.eta_g * dd_g)
+ self.u1 * dd_gg;
[[t00, t01], [t01, t11]]
}
#[inline]
fn fourth_contracted(&self, q: f64, uq: f64, ug: f64, vq: f64, vg: f64) -> [[f64; 2]; 2] {
let du = self.eta_q * uq + self.eta_g * ug;
let dv = self.eta_q * vq + self.eta_g * vg;
let du_a = [self.c1 * ug, self.c1 * uq + q * self.c2 * ug];
let dv_a = [self.c1 * vg, self.c1 * vq + q * self.c2 * vg];
let du_ab = [
[0.0, self.c2 * ug],
[self.c2 * ug, self.c2 * uq + q * self.c3 * ug],
];
let dv_ab = [
[0.0, self.c2 * vg],
[self.c2 * vg, self.c2 * vq + q * self.c3 * vg],
];
let dduv = self.c1 * (uq * vg + ug * vq) + q * self.c2 * ug * vg;
let dduv_a = [
self.c2 * ug * vg,
self.c2 * (uq * vg + ug * vq) + q * self.c3 * ug * vg,
];
let dduv_ab = [
[0.0, self.c3 * ug * vg],
[
self.c3 * ug * vg,
self.c3 * (uq * vg + ug * vq) + q * self.c4 * ug * vg,
],
];
let eta_a = [self.eta_q, self.eta_g];
let eta_ab = [[0.0, self.c1], [self.c1, q * self.c2]];
let mut f = [[0.0f64; 2]; 2];
for a in 0..2 {
for b in a..2 {
let val = self.u4 * eta_a[a] * eta_a[b] * du * dv
+ self.u3
* (eta_ab[a][b] * du * dv
+ du_a[a] * eta_a[b] * dv
+ dv_a[a] * eta_a[b] * du
+ du_a[b] * eta_a[a] * dv
+ dv_a[b] * eta_a[a] * du
+ dduv * eta_a[a] * eta_a[b])
+ self.u2
* (eta_ab[a][b] * dduv
+ du_a[a] * dv_a[b]
+ dv_a[a] * du_a[b]
+ du_ab[a][b] * dv
+ dv_ab[a][b] * du
+ eta_a[b] * dduv_a[a]
+ eta_a[a] * dduv_a[b])
+ self.u1 * dduv_ab[a][b];
f[a][b] = val;
f[b][a] = val;
}
}
f
}
}
#[inline]
fn rigid_transformed_gradient(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
) -> [f64; 2] {
[
kernel.u1 * kernel.eta_q * marginal.q1,
kernel.u1 * kernel.eta_g,
]
}
#[inline]
fn rigid_transformed_hessian(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
) -> [[f64; 2]; 2] {
let h_q = kernel.primary_hessian(marginal.q);
let grad_q = kernel.u1 * kernel.eta_q;
[
[
h_q[0][0] * marginal.q1 * marginal.q1 + grad_q * marginal.q2,
h_q[0][1] * marginal.q1,
],
[h_q[1][0] * marginal.q1, h_q[1][1]],
]
}
#[inline]
fn rigid_internal_third_components(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
) -> (f64, f64, f64, f64) {
let q_dir = kernel.third_contracted(marginal.q, 1.0, 0.0);
let g_dir = kernel.third_contracted(marginal.q, 0.0, 1.0);
(q_dir[0][0], q_dir[0][1], q_dir[1][1], g_dir[1][1])
}
#[inline]
fn rigid_transformed_third_contracted(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
d_eta: f64,
d_g: f64,
) -> [[f64; 2]; 2] {
let h_q = kernel.primary_hessian(marginal.q);
let grad_q = kernel.u1 * kernel.eta_q;
let (f_qqq, f_qqg, f_qgg, f_ggg) = rigid_internal_third_components(marginal, kernel);
let f_etaetaeta = f_qqq * marginal.q1.powi(3)
+ 3.0 * h_q[0][0] * marginal.q1 * marginal.q2
+ grad_q * marginal.q3;
let f_etaetag = f_qqg * marginal.q1 * marginal.q1 + h_q[0][1] * marginal.q2;
let f_etagg = f_qgg * marginal.q1;
[
[
f_etaetaeta * d_eta + f_etaetag * d_g,
f_etaetag * d_eta + f_etagg * d_g,
],
[
f_etaetag * d_eta + f_etagg * d_g,
f_etagg * d_eta + f_ggg * d_g,
],
]
}
#[inline]
fn rigid_transformed_fourth_contracted(
marginal: BernoulliMarginalLinkMap,
kernel: &RigidProbitKernel,
u_eta: f64,
u_g: f64,
v_eta: f64,
v_g: f64,
) -> [[f64; 2]; 2] {
let h_q = kernel.primary_hessian(marginal.q);
let grad_q = kernel.u1 * kernel.eta_q;
let (f_qqq, f_qqg, f_qgg, _) = rigid_internal_third_components(marginal, kernel);
let qq = kernel.fourth_contracted(marginal.q, 1.0, 0.0, 1.0, 0.0);
let qg = kernel.fourth_contracted(marginal.q, 1.0, 0.0, 0.0, 1.0);
let gg = kernel.fourth_contracted(marginal.q, 0.0, 1.0, 0.0, 1.0);
let f_qqqq = qq[0][0];
let f_qqqg = qq[0][1];
let f_qqgg = qq[1][1];
let f_qggg = qg[1][1];
let f_gggg = gg[1][1];
let f_eta4 = f_qqqq * marginal.q1.powi(4)
+ 6.0 * f_qqq * marginal.q1 * marginal.q1 * marginal.q2
+ 3.0 * h_q[0][0] * marginal.q2 * marginal.q2
+ 4.0 * h_q[0][0] * marginal.q1 * marginal.q3
+ grad_q * marginal.q4;
let f_eta3g = f_qqqg * marginal.q1.powi(3)
+ 3.0 * f_qqg * marginal.q1 * marginal.q2
+ h_q[0][1] * marginal.q3;
let f_eta2g2 = f_qqgg * marginal.q1 * marginal.q1 + f_qgg * marginal.q2;
let f_etag3 = f_qggg * marginal.q1;
[
[
f_eta4 * u_eta * v_eta + f_eta3g * (u_eta * v_g + u_g * v_eta) + f_eta2g2 * u_g * v_g,
f_eta3g * u_eta * v_eta + f_eta2g2 * (u_eta * v_g + u_g * v_eta) + f_etag3 * u_g * v_g,
],
[
f_eta3g * u_eta * v_eta + f_eta2g2 * (u_eta * v_g + u_g * v_eta) + f_etag3 * u_g * v_g,
f_eta2g2 * u_eta * v_eta + f_etag3 * (u_eta * v_g + u_g * v_eta) + f_gggg * u_g * v_g,
],
]
}
pub(crate) fn unary_derivatives_sqrt(x: f64) -> [f64; 5] {
let s = x.max(1e-300).sqrt();
let x1 = x.max(1e-300);
let x2 = x1 * x1;
let x3 = x2 * x1;
[
s,
0.5 / s,
-0.25 / (x1 * s),
3.0 / (8.0 * x2 * s),
-15.0 / (16.0 * x3 * s),
]
}
pub(crate) fn unary_derivatives_neglog_phi(x: f64, weight: f64) -> [f64; 5] {
if weight == 0.0 || x == f64::INFINITY {
return [0.0, 0.0, 0.0, 0.0, 0.0];
}
if x == f64::NEG_INFINITY {
return [f64::INFINITY, f64::NEG_INFINITY, weight, 0.0, 0.0];
}
if x.is_nan() {
return [f64::NAN; 5];
}
let (d1, d2, d3, d4) = signed_probit_neglog_derivatives_up_to_fourth_numeric(x, weight);
let (log_cdf, _) = signed_probit_logcdf_and_mills_ratio(x);
[-weight * log_cdf, d1, d2, d3, d4]
}
pub(crate) fn unary_derivatives_log(x: f64) -> [f64; 5] {
let x1 = x.max(1e-300);
let x2 = x1 * x1;
let x3 = x2 * x1;
let x4 = x3 * x1;
[x1.ln(), 1.0 / x1, -1.0 / x2, 2.0 / x3, -6.0 / x4]
}
pub(crate) fn unary_derivatives_log_normal_pdf(x: f64) -> [f64; 5] {
let c = 0.5 * (2.0 * std::f64::consts::PI).ln();
[-0.5 * x * x - c, -x, -1.0, 0.0, 0.0]
}
struct BlockPsiRow {
block_idx: usize,
range: std::ops::Range<usize>,
local_vec: Array1<f64>,
}
#[derive(Clone)]
struct BlockSlices {
marginal: std::ops::Range<usize>,
logslope: std::ops::Range<usize>,
h: Option<std::ops::Range<usize>>,
w: Option<std::ops::Range<usize>>,
total: usize,
}
fn block_slices(family: &BernoulliMarginalSlopeFamily) -> BlockSlices {
let mut cursor = 0usize;
let marginal = cursor..cursor + family.marginal_design.ncols();
cursor = marginal.end;
let logslope = cursor..cursor + family.logslope_design.ncols();
cursor = logslope.end;
let h = family.score_warp.as_ref().map(|runtime| {
let range = cursor..cursor + runtime.basis_dim();
cursor = range.end;
range
});
let w = family.link_dev.as_ref().map(|runtime| {
let range = cursor..cursor + runtime.basis_dim();
cursor = range.end;
range
});
BlockSlices {
marginal,
logslope,
h,
w,
total: cursor,
}
}
#[derive(Clone)]
struct PrimarySlices {
q: usize,
logslope: usize,
h: Option<std::ops::Range<usize>>,
w: Option<std::ops::Range<usize>>,
total: usize,
}
fn primary_slices(slices: &BlockSlices) -> PrimarySlices {
let q = 0usize;
let logslope = 1usize;
let mut cursor = 2usize;
let h = slices.h.as_ref().map(|range| {
let out = cursor..cursor + range.len();
cursor = out.end;
out
});
let w = slices.w.as_ref().map(|range| {
let out = cursor..cursor + range.len();
cursor = out.end;
out
});
PrimarySlices {
q,
logslope,
h,
w,
total: cursor,
}
}
struct BernoulliBlockHessianAccumulator {
h_mm: Array2<f64>,
h_gg: Array2<f64>,
h_mg: Array2<f64>,
dense_correction: Option<Array2<f64>>,
}
impl BernoulliBlockHessianAccumulator {
fn new(slices: &BlockSlices) -> Self {
let p_m = slices.marginal.len();
let p_g = slices.logslope.len();
let has_hw = slices.h.is_some() || slices.w.is_some();
Self {
h_mm: Array2::zeros((p_m, p_m)),
h_gg: Array2::zeros((p_g, p_g)),
h_mg: Array2::zeros((p_m, p_g)),
dense_correction: if has_hw {
Some(Array2::zeros((slices.total, slices.total)))
} else {
None
},
}
}
fn add_pullback(
&mut self,
family: &BernoulliMarginalSlopeFamily,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
primary_hessian: &Array2<f64>,
) {
let h = primary_hessian;
family
.marginal_design
.syr_row_into(row, h[[0, 0]], &mut self.h_mm)
.expect("marginal syr_row_into dimension mismatch");
family
.logslope_design
.syr_row_into(row, h[[1, 1]], &mut self.h_gg)
.expect("logslope syr_row_into dimension mismatch");
if h[[0, 1]] != 0.0 {
family
.marginal_design
.row_outer_into_view(
row,
&family.logslope_design,
h[[0, 1]],
self.h_mg.view_mut(),
)
.expect("marginal-logslope row_outer_into dimension mismatch");
}
if let Some(ref mut dc) = self.dense_correction {
family.add_pullback_primary_hessian_hw_only(dc, row, slices, primary, h.view());
}
}
fn add_weighted_design_grams(
&mut self,
family: &BernoulliMarginalSlopeFamily,
rows: std::ops::Range<usize>,
w_mm: &Array1<f64>,
w_mg: &Array1<f64>,
w_gg: &Array1<f64>,
) -> Result<(), String> {
let x = family
.marginal_design
.try_row_chunk(rows.clone())
.map_err(|e| format!("bernoulli marginal_design try_row_chunk: {e}"))?;
let g = family
.logslope_design
.try_row_chunk(rows)
.map_err(|e| format!("bernoulli logslope_design try_row_chunk: {e}"))?;
self.h_mm += &crate::faer_ndarray::fast_xt_diag_x(&x, w_mm);
if w_mg.iter().any(|value| *value != 0.0) {
self.h_mg += &crate::faer_ndarray::fast_xt_diag_y(&x, w_mg, &g);
}
self.h_gg += &crate::faer_ndarray::fast_xt_diag_x(&g, w_gg);
Ok(())
}
fn add_rank1_psi_cross(
&mut self,
family: &BernoulliMarginalSlopeFamily,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
psi_block_idx: usize,
psi_row: &Array1<f64>,
right_primary: &Array1<f64>,
) {
if right_primary[0] != 0.0 {
match psi_block_idx {
0 => {
for (idx, &value) in psi_row.iter().enumerate() {
if value == 0.0 {
continue;
}
let scale = right_primary[0] * value;
{
let mut col = self.h_mm.column_mut(idx);
family
.marginal_design
.axpy_row_into(row, scale, &mut col)
.expect("marginal axpy column mismatch");
}
{
let mut row_view = self.h_mm.row_mut(idx);
family
.marginal_design
.axpy_row_into(row, scale, &mut row_view)
.expect("marginal axpy row mismatch");
}
}
}
1 => {
for (idx, &value) in psi_row.iter().enumerate() {
if value == 0.0 {
continue;
}
let mut col = self.h_mg.column_mut(idx);
family
.marginal_design
.axpy_row_into(row, right_primary[0] * value, &mut col)
.expect("marginal axpy column mismatch");
}
}
_ => {}
}
}
if right_primary[1] != 0.0 {
match psi_block_idx {
0 => {
for (idx, &value) in psi_row.iter().enumerate() {
if value == 0.0 {
continue;
}
let mut row_view = self.h_mg.row_mut(idx);
family
.logslope_design
.axpy_row_into(row, right_primary[1] * value, &mut row_view)
.expect("logslope axpy row mismatch");
}
}
1 => {
for (idx, &value) in psi_row.iter().enumerate() {
if value == 0.0 {
continue;
}
let scale = right_primary[1] * value;
{
let mut col = self.h_gg.column_mut(idx);
family
.logslope_design
.axpy_row_into(row, scale, &mut col)
.expect("logslope axpy column mismatch");
}
{
let mut row_view = self.h_gg.row_mut(idx);
family
.logslope_design
.axpy_row_into(row, scale, &mut row_view)
.expect("logslope axpy row mismatch");
}
}
}
_ => {}
}
}
if let Some(ref mut dc) = self.dense_correction {
let psi_range = if psi_block_idx == 0 {
slices.marginal.clone()
} else {
slices.logslope.clone()
};
if let (Some(ph), Some(bh)) = (primary.h.as_ref(), slices.h.as_ref()) {
let h_part = right_primary.slice(ndarray::s![ph.start..ph.end]);
for (li, gi) in psi_range.clone().enumerate() {
for (lj, gj) in bh.clone().enumerate() {
let val = psi_row[li] * h_part[lj];
dc[[gi, gj]] += val;
dc[[gj, gi]] += val;
}
}
}
if let (Some(pw), Some(bw)) = (primary.w.as_ref(), slices.w.as_ref()) {
let w_part = right_primary.slice(ndarray::s![pw.start..pw.end]);
for (li, gi) in psi_range.enumerate() {
for (lj, gj) in bw.clone().enumerate() {
let val = psi_row[li] * w_part[lj];
dc[[gi, gj]] += val;
dc[[gj, gi]] += val;
}
}
}
}
}
fn add_psi_psi_outer(
&mut self,
block_i: usize,
psi_row_i: &Array1<f64>,
block_j: usize,
psi_row_j: &Array1<f64>,
alpha: f64,
) {
add_two_surface_psi_outer(
block_i,
psi_row_i,
block_j,
psi_row_j,
alpha,
0,
1,
&mut self.h_mm,
&mut self.h_gg,
&mut self.h_mg,
);
}
fn add(&mut self, other: &BernoulliBlockHessianAccumulator) {
self.h_mm += &other.h_mm;
self.h_gg += &other.h_gg;
self.h_mg += &other.h_mg;
if let (Some(ref mut dc), Some(ref odc)) = (
self.dense_correction.as_mut(),
other.dense_correction.as_ref(),
) {
dc.scaled_add(1.0, odc);
}
}
fn scale_assign(&mut self, scale: f64) {
if scale == 1.0 {
return;
}
self.h_mm.mapv_inplace(|v| v * scale);
self.h_gg.mapv_inplace(|v| v * scale);
self.h_mg.mapv_inplace(|v| v * scale);
if let Some(ref mut dc) = self.dense_correction {
dc.mapv_inplace(|v| v * scale);
}
}
fn to_dense(&self, slices: &BlockSlices) -> Array2<f64> {
let mut out = Array2::zeros((slices.total, slices.total));
out.slice_mut(s![slices.marginal.clone(), slices.marginal.clone()])
.assign(&self.h_mm);
out.slice_mut(s![slices.logslope.clone(), slices.logslope.clone()])
.assign(&self.h_gg);
out.slice_mut(s![slices.marginal.clone(), slices.logslope.clone()])
.assign(&self.h_mg);
out.slice_mut(s![slices.logslope.clone(), slices.marginal.clone()])
.assign(&self.h_mg.t());
if let Some(ref dc) = self.dense_correction {
out += dc;
}
out
}
fn into_operator(self, slices: &BlockSlices) -> BernoulliBlockHessianOperator {
BernoulliBlockHessianOperator {
h_mm: self.h_mm,
h_gg: self.h_gg,
h_mg: self.h_mg,
dense_correction: self.dense_correction,
marginal: slices.marginal.clone(),
logslope: slices.logslope.clone(),
total: slices.total,
}
}
}
struct BernoulliBlockHessianOperator {
h_mm: Array2<f64>,
h_gg: Array2<f64>,
h_mg: Array2<f64>,
dense_correction: Option<Array2<f64>>,
marginal: std::ops::Range<usize>,
logslope: std::ops::Range<usize>,
total: usize,
}
impl HyperOperator for BernoulliBlockHessianOperator {
fn dim(&self) -> usize {
self.total
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let v_m = v.slice(s![self.marginal.clone()]);
let v_g = v.slice(s![self.logslope.clone()]);
let mut out = Array1::zeros(self.total);
{
let mut o_m = out.slice_mut(s![self.marginal.clone()]);
o_m += &self.h_mm.dot(&v_m);
o_m += &self.h_mg.dot(&v_g);
}
{
let mut o_g = out.slice_mut(s![self.logslope.clone()]);
o_g += &self.h_mg.t().dot(&v_m);
o_g += &self.h_gg.dot(&v_g);
}
if let Some(ref dc) = self.dense_correction {
out += &dc.dot(v);
}
out
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let v_m = v.slice(s![self.marginal.clone()]);
let v_g = v.slice(s![self.logslope.clone()]);
let u_m = u.slice(s![self.marginal.clone()]);
let u_g = u.slice(s![self.logslope.clone()]);
let mut total = v_m.dot(&self.h_mm.dot(&u_m));
total += v_g.dot(&self.h_gg.dot(&u_g));
total += v_m.dot(&self.h_mg.dot(&u_g));
total += v_g.dot(&self.h_mg.t().dot(&u_m));
if let Some(ref dc) = self.dense_correction {
total += v.dot(&dc.dot(u));
}
total
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::zeros((self.total, self.total));
out.slice_mut(s![self.marginal.clone(), self.marginal.clone()])
.assign(&self.h_mm);
out.slice_mut(s![self.logslope.clone(), self.logslope.clone()])
.assign(&self.h_gg);
out.slice_mut(s![self.marginal.clone(), self.logslope.clone()])
.assign(&self.h_mg);
out.slice_mut(s![self.logslope.clone(), self.marginal.clone()])
.assign(&self.h_mg.t());
if let Some(ref dc) = self.dense_correction {
out += dc;
}
out
}
fn is_implicit(&self) -> bool {
false
}
}
#[cfg(not(test))]
#[inline]
fn cell_moment_per_row_dedup_enabled() -> bool {
true
}
#[cfg(test)]
fn cell_moment_per_row_dedup_enabled() -> bool {
CELL_MOMENT_PER_ROW_DEDUP_ENABLED.load(Ordering::Relaxed)
}
#[cfg(test)]
static CELL_MOMENT_PER_ROW_DEDUP_ENABLED: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(true);
#[cfg(test)]
pub(crate) fn set_cell_moment_per_row_dedup_enabled(enabled: bool) {
CELL_MOMENT_PER_ROW_DEDUP_ENABLED.store(enabled, Ordering::Relaxed);
}
#[derive(Clone)]
struct CachedDenestedCellMoments {
partition_cell: exact_kernel::DenestedPartitionCell,
state: exact_kernel::CellMomentState,
}
#[derive(Clone)]
struct RowCellMomentsBundle {
max_degree: usize,
rows: Vec<Vec<CachedDenestedCellMoments>>,
}
impl RowCellMomentsBundle {
#[inline]
fn row(&self, row: usize, required_degree: usize) -> Option<&[CachedDenestedCellMoments]> {
debug_assert!(
self.max_degree >= required_degree,
"row cell moments bundle max_degree={} required_degree={}",
self.max_degree,
required_degree
);
(self.max_degree >= required_degree)
.then(|| self.rows.get(row).map(Vec::as_slice))
.flatten()
}
fn estimated_resident_bytes(n_rows: usize, n_cells: usize, max_degree: usize) -> usize {
let row_vecs = n_rows.saturating_mul(std::mem::size_of::<Vec<CachedDenestedCellMoments>>());
let cell_records = n_cells.saturating_mul(std::mem::size_of::<CachedDenestedCellMoments>());
let moment_payload = n_cells
.saturating_mul(max_degree.saturating_add(1))
.saturating_mul(std::mem::size_of::<f64>());
row_vecs
.saturating_add(cell_records)
.saturating_add(moment_payload)
}
}
#[derive(Clone)]
struct BernoulliMarginalSlopeRowExactContext {
intercept: f64,
m_a: f64,
intercept_fast_path: bool,
degree9_cells: Option<Vec<CachedDenestedCellMoments>>,
}
struct BernoulliMarginalSlopeFlexRowScratch {
m_u: Array1<f64>,
m_au: Array1<f64>,
m_uv: Array2<f64>,
a_u: Array1<f64>,
a_uv: Array2<f64>,
rho: Array1<f64>,
tau: Array1<f64>,
du: Array1<f64>,
grad: Array1<f64>,
hess: Array2<f64>,
}
impl BernoulliMarginalSlopeFlexRowScratch {
fn new(primary_dim: usize) -> Self {
Self {
m_u: Array1::zeros(primary_dim),
m_au: Array1::zeros(primary_dim),
m_uv: Array2::zeros((primary_dim, primary_dim)),
a_u: Array1::zeros(primary_dim),
a_uv: Array2::zeros((primary_dim, primary_dim)),
rho: Array1::zeros(primary_dim),
tau: Array1::zeros(primary_dim),
du: Array1::zeros(primary_dim),
grad: Array1::zeros(primary_dim),
hess: Array2::zeros((primary_dim, primary_dim)),
}
}
fn reset(&mut self, need_hessian: bool) {
self.m_u.fill(0.0);
self.a_u.fill(0.0);
self.rho.fill(0.0);
self.tau.fill(0.0);
self.du.fill(0.0);
self.grad.fill(0.0);
if need_hessian {
self.m_au.fill(0.0);
self.m_uv.fill(0.0);
self.a_uv.fill(0.0);
self.hess.fill(0.0);
}
}
}
#[inline]
fn accumulate_flex_block_grad_hess(
primary_range: &std::ops::Range<usize>,
scratch: &BernoulliMarginalSlopeFlexRowScratch,
grad: &mut Array1<f64>,
hess: &mut Array2<f64>,
) {
let start = primary_range.start;
let end = primary_range.end;
let src_g = scratch.grad.slice(s![start..end]);
grad.scaled_add(-1.0, &src_g);
let src_h = scratch.hess.slice(s![start..end, start..end]);
*hess += &src_h;
}
use crate::families::jet_partitions::MultiDirJet;
const COEFF_SUPPORT_BHW: CoeffSupport = CoeffSupport {
include_primary: true,
include_h: true,
include_w: true,
};
const COEFF_SUPPORT_BW: CoeffSupport = CoeffSupport {
include_primary: true,
include_h: false,
include_w: true,
};
const COEFF_SUPPORT_W: CoeffSupport = CoeffSupport {
include_primary: false,
include_h: false,
include_w: true,
};
struct BernoulliExactNewtonAccumulator {
ll: f64,
grad_marginal: Array1<f64>,
grad_logslope: Array1<f64>,
hess_marginal: Array2<f64>,
hess_logslope: Array2<f64>,
grad_h: Option<Array1<f64>>,
grad_w: Option<Array1<f64>>,
hess_h: Option<Array2<f64>>,
hess_w: Option<Array2<f64>>,
}
impl BernoulliExactNewtonAccumulator {
fn new(slices: &BlockSlices) -> Self {
Self {
ll: 0.0,
grad_marginal: Array1::zeros(slices.marginal.len()),
grad_logslope: Array1::zeros(slices.logslope.len()),
hess_marginal: Array2::zeros((slices.marginal.len(), slices.marginal.len())),
hess_logslope: Array2::zeros((slices.logslope.len(), slices.logslope.len())),
grad_h: slices.h.as_ref().map(|range| Array1::zeros(range.len())),
grad_w: slices.w.as_ref().map(|range| Array1::zeros(range.len())),
hess_h: slices
.h
.as_ref()
.map(|range| Array2::zeros((range.len(), range.len()))),
hess_w: slices
.w
.as_ref()
.map(|range| Array2::zeros((range.len(), range.len()))),
}
}
fn add_pullback_block_diagonals(
&mut self,
family: &BernoulliMarginalSlopeFamily,
row: usize,
primary: &PrimarySlices,
row_neglog: f64,
scratch: &BernoulliMarginalSlopeFlexRowScratch,
) -> Result<(), String> {
self.ll -= row_neglog;
{
let mut marginal = self.grad_marginal.view_mut();
family.marginal_design.axpy_row_into(
row,
BernoulliMarginalSlopeFamily::exact_newton_score_component_from_objective_gradient(
scratch.grad[0],
),
&mut marginal,
)?;
}
{
let mut logslope = self.grad_logslope.view_mut();
family.logslope_design.axpy_row_into(
row,
BernoulliMarginalSlopeFamily::exact_newton_score_component_from_objective_gradient(
scratch.grad[1],
),
&mut logslope,
)?;
}
family
.marginal_design
.syr_row_into(row, scratch.hess[[0, 0]], &mut self.hess_marginal)?;
family
.logslope_design
.syr_row_into(row, scratch.hess[[1, 1]], &mut self.hess_logslope)?;
if let (Some(primary_h), Some(grad_h), Some(hess_h)) = (
primary.h.as_ref(),
self.grad_h.as_mut(),
self.hess_h.as_mut(),
) {
accumulate_flex_block_grad_hess(primary_h, scratch, grad_h, hess_h);
}
if let (Some(primary_w), Some(grad_w), Some(hess_w)) = (
primary.w.as_ref(),
self.grad_w.as_mut(),
self.hess_w.as_mut(),
) {
accumulate_flex_block_grad_hess(primary_w, scratch, grad_w, hess_w);
}
Ok(())
}
fn add(&mut self, other: &Self) {
self.ll += other.ll;
self.grad_marginal += &other.grad_marginal;
self.grad_logslope += &other.grad_logslope;
self.hess_marginal += &other.hess_marginal;
self.hess_logslope += &other.hess_logslope;
add_optional_vector(&mut self.grad_h, &other.grad_h);
add_optional_vector(&mut self.grad_w, &other.grad_w);
add_optional_matrix(&mut self.hess_h, &other.hess_h);
add_optional_matrix(&mut self.hess_w, &other.hess_w);
}
}
fn add_weighted_chunk_gradient(
chunk: &Array2<f64>,
weights: &Array1<f64>,
target: &mut Array1<f64>,
) {
*target += &crate::faer_ndarray::fast_atv(chunk, weights);
}
fn new_cell_moment_lru_cache(
policy: &crate::resource::ResourcePolicy,
) -> Arc<exact_kernel::CellMomentLruCache> {
let budget = policy.max_single_materialization_bytes;
Arc::new(exact_kernel::CellMomentLruCache::new(budget))
}
fn new_cell_moment_cache_stats() -> Arc<exact_kernel::CellMomentCacheStats> {
Arc::new(exact_kernel::CellMomentCacheStats::default())
}
fn add_weighted_chunk_gram(chunk: &Array2<f64>, weights: &Array1<f64>, target: &mut Array2<f64>) {
*target += &crate::faer_ndarray::fast_xt_diag_x(chunk, weights);
}
const ROW_CHUNK_SIZE: usize = 1024;
const EXACT_WORK_LOG_MIN_ROWS: usize = 50_000;
#[inline]
fn log_exact_work(n: usize) -> bool {
n >= EXACT_WORK_LOG_MIN_ROWS
}
#[derive(Clone)]
struct BernoulliMarginalSlopeExactEvalCache {
slices: BlockSlices,
primary: PrimarySlices,
row_contexts: Vec<BernoulliMarginalSlopeRowExactContext>,
row_cell_moments: Option<RowCellMomentsBundle>,
row_primary_hessians: Option<Array2<f64>>,
}
struct BernoulliRigidRowKernel {
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
slices: BlockSlices,
}
impl BernoulliRigidRowKernel {
fn new(family: BernoulliMarginalSlopeFamily, block_states: Vec<ParameterBlockState>) -> Self {
let slices = block_slices(&family);
Self {
family,
block_states,
slices,
}
}
}
impl RowKernel<2> for BernoulliRigidRowKernel {
fn n_rows(&self) -> usize {
self.family.y.len()
}
fn n_coefficients(&self) -> usize {
self.slices.total
}
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; 2], [[f64; 2]; 2]), String> {
let marginal_eta = self.block_states[0].eta[row];
let marginal = self.family.marginal_link_map(marginal_eta)?;
let g = self.block_states[1].eta[row];
self.family
.rigid_row_kernel_eval(row, marginal_eta, marginal, g)
}
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; 2] {
let d_beta = ndarray::ArrayView1::from(d_beta);
[
self.family
.marginal_design
.dot_row_view(row, d_beta.slice(s![self.slices.marginal.clone()])),
self.family
.logslope_design
.dot_row_view(row, d_beta.slice(s![self.slices.logslope.clone()])),
]
}
fn jacobian_transpose_action(&self, row: usize, v: &[f64; 2], out: &mut [f64]) {
{
let mut m = ndarray::ArrayViewMut1::from(&mut out[self.slices.marginal.clone()]);
self.family
.marginal_design
.axpy_row_into(row, v[0], &mut m)
.expect("marginal axpy dim mismatch");
}
{
let mut g = ndarray::ArrayViewMut1::from(&mut out[self.slices.logslope.clone()]);
self.family
.logslope_design
.axpy_row_into(row, v[1], &mut g)
.expect("logslope axpy dim mismatch");
}
}
fn add_pullback_hessian(&self, row: usize, h: &[[f64; 2]; 2], target: &mut Array2<f64>) {
self.family
.marginal_design
.syr_row_into_view(
row,
h[0][0],
target.slice_mut(s![
self.slices.marginal.clone(),
self.slices.marginal.clone()
]),
)
.expect("marginal syr dim mismatch");
if h[0][1] != 0.0 {
self.family
.marginal_design
.row_outer_into_view(
row,
&self.family.logslope_design,
h[0][1],
target.slice_mut(s![
self.slices.marginal.clone(),
self.slices.logslope.clone()
]),
)
.expect("marginal-logslope outer dim mismatch");
self.family
.logslope_design
.row_outer_into_view(
row,
&self.family.marginal_design,
h[0][1],
target.slice_mut(s![
self.slices.logslope.clone(),
self.slices.marginal.clone()
]),
)
.expect("logslope-marginal outer dim mismatch");
}
self.family
.logslope_design
.syr_row_into_view(
row,
h[1][1],
target.slice_mut(s![
self.slices.logslope.clone(),
self.slices.logslope.clone()
]),
)
.expect("logslope syr dim mismatch");
}
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; 2]; 2], diag: &mut [f64]) {
{
let mut md = ndarray::ArrayViewMut1::from(&mut diag[self.slices.marginal.clone()]);
self.family
.marginal_design
.squared_axpy_row_into(row, h[0][0], &mut md)
.expect("marginal squared_axpy dim mismatch");
}
{
let mut gd = ndarray::ArrayViewMut1::from(&mut diag[self.slices.logslope.clone()]);
self.family
.logslope_design
.squared_axpy_row_into(row, h[1][1], &mut gd)
.expect("logslope squared_axpy dim mismatch");
}
}
fn row_third_contracted(&self, row: usize, dir: &[f64; 2]) -> Result<[[f64; 2]; 2], String> {
let marginal_eta = self.block_states[0].eta[row];
let marginal = self.family.marginal_link_map(marginal_eta)?;
let g = self.block_states[1].eta[row];
self.family
.rigid_row_third_contracted(row, marginal_eta, marginal, g, dir[0], dir[1])
}
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; 2],
dir_v: &[f64; 2],
) -> Result<[[f64; 2]; 2], String> {
let marginal_eta = self.block_states[0].eta[row];
let marginal = self.family.marginal_link_map(marginal_eta)?;
let g = self.block_states[1].eta[row];
self.family.rigid_row_fourth_contracted(
row,
marginal_eta,
marginal,
g,
dir_u[0],
dir_u[1],
dir_v[0],
dir_v[1],
)
}
}
struct BernoulliMarginalSlopeExactNewtonJointHessianWorkspace {
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
cache: BernoulliMarginalSlopeExactEvalCache,
matvec_calls: AtomicUsize,
options: BlockwiseFitOptions,
}
struct BernoulliMarginalSlopeExactNewtonJointPsiWorkspace {
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
specs: Vec<ParameterBlockSpec>,
derivative_blocks: Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
cache: BernoulliMarginalSlopeExactEvalCache,
options: BlockwiseFitOptions,
}
const BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS: usize = 10_000;
fn bernoulli_margslope_line_search_ll_with_early_exit<F>(
rows: &[usize],
scale: f64,
threshold: f64,
row_ll: F,
) -> Result<f64, String>
where
F: Fn(usize) -> Result<f64, String> + Sync,
{
if !threshold.is_finite() {
return Err(format!(
"bernoulli marginal-slope early-exit threshold must be finite, got {threshold}"
));
}
if !(scale.is_finite() && scale > 0.0) {
return Err(format!(
"bernoulli marginal-slope early-exit scale must be positive and finite, got {scale}"
));
}
let mut total_ll = 0.0;
for chunk in rows.chunks(BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS) {
let chunk_ll: f64 = chunk
.into_par_iter()
.try_fold(
|| 0.0,
|mut acc, &row| -> Result<_, String> {
acc += row_ll(row)?;
Ok(acc)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
)?;
total_ll += chunk_ll;
if -scale * total_ll > threshold {
return Err(format!(
"bernoulli marginal-slope line-search rejected early: partial_nll={} threshold={}",
-scale * total_ll,
threshold
));
}
}
Ok(total_ll * scale)
}
impl BernoulliMarginalSlopeFamily {
#[inline]
fn probit_frailty_scale(&self) -> f64 {
probit_frailty_scale(self.gaussian_frailty_sd)
}
fn empirical_rigid_intercept_for_row(
&self,
row: usize,
marginal: BernoulliMarginalLinkMap,
slope: f64,
nodes: &[f64],
measure_weights: &[f64],
) -> Result<f64, String> {
let cached = self.intercept_warm_starts.as_ref().and_then(|cache| {
let value = f64::from_bits(cache.get(row)?.load(Ordering::Relaxed));
value.is_finite().then_some(value)
});
let root = empirical_intercept_from_marginal(
marginal.mu,
marginal.q,
slope,
self.probit_frailty_scale(),
nodes,
measure_weights,
cached,
)?;
if let Some(cache) = self.intercept_warm_starts.as_ref()
&& let Some(slot) = cache.get(row)
{
slot.store(root.to_bits(), Ordering::Relaxed);
}
Ok(root)
}
fn empirical_rigid_calibration_jets(
&self,
intercept: &MultiDirJet,
mu: &MultiDirJet,
slope: &MultiDirJet,
nodes: &[f64],
measure_weights: &[f64],
) -> (MultiDirJet, MultiDirJet) {
let n_dirs = intercept.coeffs.len().trailing_zeros() as usize;
let observed_slope = slope.scale(self.probit_frailty_scale());
let mut f = mu.scale(-1.0);
let mut f_a = MultiDirJet::zero(n_dirs);
for (&node, &weight) in nodes.iter().zip(measure_weights.iter()) {
let eta = intercept.add(&observed_slope.scale(node));
let cdf = eta.compose_unary(unary_derivatives_normal_cdf(eta.coeff(0)));
let pdf = eta.compose_unary(unary_derivatives_normal_pdf(eta.coeff(0)));
f = f.add(&cdf.scale(weight));
f_a = f_a.add(&pdf.scale(weight));
}
(f, f_a)
}
fn empirical_rigid_neglog_jet(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
directions: &[[f64; 2]],
nodes: &[f64],
measure_weights: &[f64],
) -> Result<MultiDirJet, String> {
let n_dirs = directions.len();
let marginal_first = directions.iter().map(|dir| dir[0]).collect::<Vec<_>>();
let slope_first = directions.iter().map(|dir| dir[1]).collect::<Vec<_>>();
let marginal_eta_jet = MultiDirJet::linear(n_dirs, marginal_eta, &marginal_first);
let mu_jet = marginal_eta_jet.compose_unary([
marginal.mu,
marginal.mu1,
marginal.mu2,
marginal.mu3,
marginal.mu4,
]);
let slope_jet = MultiDirJet::linear(n_dirs, slope, &slope_first);
let intercept_root =
self.empirical_rigid_intercept_for_row(row, marginal, slope, nodes, measure_weights)?;
let mut intercept_jet = MultiDirJet::constant(n_dirs, intercept_root);
for _ in 0..6 {
let (f, f_a) = self.empirical_rigid_calibration_jets(
&intercept_jet,
&mu_jet,
&slope_jet,
nodes,
measure_weights,
);
let inv_f_a = f_a.compose_unary(unary_derivatives_reciprocal(f_a.coeff(0)));
intercept_jet = intercept_jet.add(&f.mul(&inv_f_a).scale(-1.0));
intercept_jet.coeffs[0] = intercept_root;
}
let observed_slope = slope_jet.scale(self.probit_frailty_scale());
let observed_eta = intercept_jet.add(&observed_slope.scale(self.z[row]));
let signed = observed_eta.scale(2.0 * self.y[row] - 1.0);
Ok(signed.compose_unary(unary_derivatives_neglog_phi(
signed.coeff(0),
self.weights[row],
)))
}
fn primary_component_jet(
n_dirs: usize,
base: f64,
directions: &[Array1<f64>],
idx: usize,
) -> Result<MultiDirJet, String> {
let first = directions
.iter()
.map(|dir| {
dir.get(idx).copied().ok_or_else(|| {
format!(
"bernoulli empirical flex direction length {} is too short for primary index {idx}",
dir.len()
)
})
})
.collect::<Result<Vec<_>, String>>()?;
Ok(MultiDirJet::linear(n_dirs, base, &first))
}
fn local_cubic_value_jet(cubic: exact_kernel::LocalSpanCubic, x: &MultiDirJet) -> MultiDirJet {
let n_dirs = x.coeffs.len().trailing_zeros() as usize;
let t = x.add(&MultiDirJet::constant(n_dirs, -cubic.left));
let t2 = t.mul(&t);
let t3 = t2.mul(&t);
MultiDirJet::constant(n_dirs, cubic.c0)
.add(&t.scale(cubic.c1))
.add(&t2.scale(cubic.c2))
.add(&t3.scale(cubic.c3))
}
fn local_cubic_first_derivative_jet(
cubic: exact_kernel::LocalSpanCubic,
x: &MultiDirJet,
) -> MultiDirJet {
let n_dirs = x.coeffs.len().trailing_zeros() as usize;
let t = x.add(&MultiDirJet::constant(n_dirs, -cubic.left));
let t2 = t.mul(&t);
MultiDirJet::constant(n_dirs, cubic.c1)
.add(&t.scale(2.0 * cubic.c2))
.add(&t2.scale(3.0 * cubic.c3))
}
fn empirical_flex_eta_and_eta_a_jet_at_z(
&self,
primary: &PrimarySlices,
a_jet: &MultiDirJet,
b_jet: &MultiDirJet,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
directions: &[Array1<f64>],
z: f64,
) -> Result<(MultiDirJet, MultiDirJet), String> {
let n_dirs = directions.len();
let mut inside = a_jet.add(&b_jet.scale(z));
if let Some(h_range) = primary.h.as_ref() {
let runtime = self.score_warp.as_ref().ok_or_else(|| {
"empirical flex score-warp primary range without runtime".to_string()
})?;
let beta_h = beta_h.ok_or_else(|| {
"empirical flex score-warp primary range without beta".to_string()
})?;
let mut h_jet = MultiDirJet::zero(n_dirs);
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z,
"empirical flex score-warp",
|local_idx, idx, basis_span| {
let basis_value = basis_span.evaluate(z);
let beta_jet =
Self::primary_component_jet(n_dirs, beta_h[local_idx], directions, idx)?;
h_jet = h_jet.add(&beta_jet.scale(basis_value));
Ok(())
},
)?;
inside = inside.add(&b_jet.mul(&h_jet));
}
let u_jet = a_jet.add(&b_jet.scale(z));
let mut w_jet = MultiDirJet::zero(n_dirs);
let mut w_prime_jet = MultiDirJet::zero(n_dirs);
if let Some(w_range) = primary.w.as_ref() {
let runtime = self.link_dev.as_ref().ok_or_else(|| {
"empirical flex link-deviation primary range without runtime".to_string()
})?;
let beta_w = beta_w.ok_or_else(|| {
"empirical flex link-deviation primary range without beta".to_string()
})?;
let u0 = u_jet.coeff(0);
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u0,
"empirical flex link-deviation",
|local_idx, idx, basis_span| {
let beta_jet =
Self::primary_component_jet(n_dirs, beta_w[local_idx], directions, idx)?;
let basis_value = Self::local_cubic_value_jet(basis_span, &u_jet);
let basis_derivative =
Self::local_cubic_first_derivative_jet(basis_span, &u_jet);
w_jet = w_jet.add(&beta_jet.mul(&basis_value));
w_prime_jet = w_prime_jet.add(&beta_jet.mul(&basis_derivative));
Ok(())
},
)?;
}
let scale = self.probit_frailty_scale();
let eta = inside.add(&w_jet).scale(scale);
let eta_a = MultiDirJet::constant(n_dirs, 1.0)
.add(&w_prime_jet)
.scale(scale);
Ok((eta, eta_a))
}
fn empirical_flex_calibration_jets(
&self,
primary: &PrimarySlices,
a_jet: &MultiDirJet,
mu_jet: &MultiDirJet,
b_jet: &MultiDirJet,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
directions: &[Array1<f64>],
grid: &EmpiricalZGrid,
) -> Result<(MultiDirJet, MultiDirJet), String> {
let n_dirs = directions.len();
let mut f = mu_jet.scale(-1.0);
let mut f_a = MultiDirJet::zero(n_dirs);
for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
let (eta, eta_a) = self.empirical_flex_eta_and_eta_a_jet_at_z(
primary, a_jet, b_jet, beta_h, beta_w, directions, node,
)?;
let cdf = eta.compose_unary(unary_derivatives_normal_cdf(eta.coeff(0)));
let pdf = eta.compose_unary(unary_derivatives_normal_pdf(eta.coeff(0)));
f = f.add(&cdf.scale(weight));
f_a = f_a.add(&pdf.mul(&eta_a).scale(weight));
}
Ok((f, f_a))
}
fn empirical_flex_neglog_jet(
&self,
row: usize,
primary: &PrimarySlices,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
directions: &[Array1<f64>],
grid: &EmpiricalZGrid,
) -> Result<MultiDirJet, String> {
let n_dirs = directions.len();
if n_dirs > 6 {
return Err(format!(
"bernoulli empirical flex jet supports at most 6 directions, got {n_dirs}"
));
}
for dir in directions {
if dir.len() != primary.total {
return Err(format!(
"bernoulli empirical flex direction length {} != primary dimension {}",
dir.len(),
primary.total
));
}
}
if !(row_ctx.intercept.is_finite() && row_ctx.m_a.is_finite() && row_ctx.m_a > 0.0) {
return Err("non-finite empirical flexible row context in jet contraction".to_string());
}
let marginal = self.marginal_link_map(q)?;
let q_jet = Self::primary_component_jet(n_dirs, q, directions, primary.q)?;
let mu_jet = q_jet.compose_unary([
marginal.mu,
marginal.mu1,
marginal.mu2,
marginal.mu3,
marginal.mu4,
]);
let b_jet = Self::primary_component_jet(n_dirs, b, directions, primary.logslope)?;
let intercept_root = row_ctx.intercept;
let mut a_jet = MultiDirJet::constant(n_dirs, intercept_root);
for _ in 0..6 {
let (f, f_a) = self.empirical_flex_calibration_jets(
primary, &a_jet, &mu_jet, &b_jet, beta_h, beta_w, directions, grid,
)?;
if !(f_a.coeff(0).is_finite() && f_a.coeff(0) > 0.0) {
return Err(format!(
"empirical flex calibration jet has invalid F_a={}",
f_a.coeff(0)
));
}
let inv_f_a = f_a.compose_unary(unary_derivatives_reciprocal(f_a.coeff(0)));
a_jet = a_jet.add(&f.mul(&inv_f_a).scale(-1.0));
a_jet.coeffs[0] = intercept_root;
}
let (eta_observed, _) = self.empirical_flex_eta_and_eta_a_jet_at_z(
primary,
&a_jet,
&b_jet,
beta_h,
beta_w,
directions,
self.z[row],
)?;
let signed = eta_observed.scale(2.0 * self.y[row] - 1.0);
Ok(signed.compose_unary(unary_derivatives_neglog_phi(
signed.coeff(0),
self.weights[row],
)))
}
fn unit_primary_direction(dim: usize, idx: usize) -> Array1<f64> {
let mut direction = Array1::<f64>::zeros(dim);
direction[idx] = 1.0;
direction
}
fn empirical_flex_row_third_contracted_recompute(
&self,
row: usize,
primary: &PrimarySlices,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir: &Array1<f64>,
grid: &EmpiricalZGrid,
) -> Result<Array2<f64>, String> {
let r = primary.total;
if dir.len() != r {
return Err(format!(
"bernoulli empirical flex third contraction direction length {} != primary dimension {r}",
dir.len()
));
}
if dir.iter().all(|value| *value == 0.0) {
return Ok(Array2::<f64>::zeros((r, r)));
}
let basis_dirs = (0..r)
.map(|idx| Self::unit_primary_direction(r, idx))
.collect::<Vec<_>>();
let dir_owned = dir.to_owned();
let mut out = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let directions = vec![
basis_dirs[u].clone(),
basis_dirs[v].clone(),
dir_owned.clone(),
];
let jet = self.empirical_flex_neglog_jet(
row,
primary,
q,
b,
beta_h,
beta_w,
row_ctx,
&directions,
grid,
)?;
let val = jet.coeff(1 | 2 | 4);
out[[u, v]] = val;
out[[v, u]] = val;
}
}
Ok(out)
}
fn empirical_flex_row_fourth_contracted_recompute(
&self,
row: usize,
primary: &PrimarySlices,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir_u: &Array1<f64>,
dir_v: &Array1<f64>,
grid: &EmpiricalZGrid,
) -> Result<Array2<f64>, String> {
let r = primary.total;
if dir_u.len() != r || dir_v.len() != r {
return Err(format!(
"bernoulli empirical flex fourth contraction direction lengths ({},{}) != primary dimension {r}",
dir_u.len(),
dir_v.len()
));
}
if dir_u.iter().all(|value| *value == 0.0) || dir_v.iter().all(|value| *value == 0.0) {
return Ok(Array2::<f64>::zeros((r, r)));
}
let basis_dirs = (0..r)
.map(|idx| Self::unit_primary_direction(r, idx))
.collect::<Vec<_>>();
let dir_u_owned = dir_u.to_owned();
let dir_v_owned = dir_v.to_owned();
let mut out = Array2::<f64>::zeros((r, r));
for p in 0..r {
for q_idx in p..r {
let directions = vec![
basis_dirs[p].clone(),
basis_dirs[q_idx].clone(),
dir_u_owned.clone(),
dir_v_owned.clone(),
];
let jet = self.empirical_flex_neglog_jet(
row,
primary,
q,
b,
beta_h,
beta_w,
row_ctx,
&directions,
grid,
)?;
let val = jet.coeff(1 | 2 | 4 | 8);
out[[p, q_idx]] = val;
out[[q_idx, p]] = val;
}
}
Ok(out)
}
fn rigid_row_kernel_eval(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
) -> Result<(f64, [f64; 2], [[f64; 2]; 2]), String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
let kernel = RigidProbitKernel::new(
marginal.q,
slope,
self.z[row],
self.y[row],
self.weights[row],
self.probit_frailty_scale(),
)?;
Ok((
-self.weights[row] * kernel.logcdf,
rigid_transformed_gradient(marginal, &kernel),
rigid_transformed_hessian(marginal, &kernel),
))
}
Some(grid) => {
let jet = self.empirical_rigid_neglog_jet(
row,
marginal_eta,
marginal,
slope,
&[[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]],
&grid.nodes,
&grid.weights,
)?;
Ok((
jet.coeff(0),
[jet.coeff(1), jet.coeff(2)],
[
[jet.coeff(1 | 4), jet.coeff(1 | 2)],
[jet.coeff(1 | 2), jet.coeff(2 | 8)],
],
))
}
}
}
fn rigid_row_third_contracted(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
dir_q: f64,
dir_g: f64,
) -> Result<[[f64; 2]; 2], String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
let kernel = RigidProbitKernel::new(
marginal.q,
slope,
self.z[row],
self.y[row],
self.weights[row],
self.probit_frailty_scale(),
)?;
Ok(rigid_transformed_third_contracted(
marginal, &kernel, dir_q, dir_g,
))
}
Some(grid) => {
let jet = self.empirical_rigid_neglog_jet(
row,
marginal_eta,
marginal,
slope,
&[
[1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
[dir_q, dir_g],
],
&grid.nodes,
&grid.weights,
)?;
Ok([
[jet.coeff(1 | 4 | 16), jet.coeff(1 | 2 | 16)],
[jet.coeff(1 | 2 | 16), jet.coeff(2 | 8 | 16)],
])
}
}
}
fn rigid_row_fourth_contracted(
&self,
row: usize,
marginal_eta: f64,
marginal: BernoulliMarginalLinkMap,
slope: f64,
u_q: f64,
u_g: f64,
v_q: f64,
v_g: f64,
) -> Result<[[f64; 2]; 2], String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
let kernel = RigidProbitKernel::new(
marginal.q,
slope,
self.z[row],
self.y[row],
self.weights[row],
self.probit_frailty_scale(),
)?;
Ok(rigid_transformed_fourth_contracted(
marginal, &kernel, u_q, u_g, v_q, v_g,
))
}
Some(grid) => {
let jet = self.empirical_rigid_neglog_jet(
row,
marginal_eta,
marginal,
slope,
&[
[1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
[u_q, u_g],
[v_q, v_g],
],
&grid.nodes,
&grid.weights,
)?;
Ok([
[jet.coeff(1 | 4 | 16 | 32), jet.coeff(1 | 2 | 16 | 32)],
[jet.coeff(1 | 2 | 16 | 32), jet.coeff(2 | 8 | 16 | 32)],
])
}
}
}
pub(crate) fn log_likelihood_only_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<f64, String> {
self.validate_exact_monotonicity(block_states)?;
let flex_active = self.effective_flex_active(block_states)?;
let n = self.y.len();
let row_iter = outer_row_indices(options, n).to_vec();
let scale = outer_score_scale(options, n);
if !flex_active {
let b = &block_states[1].eta;
let row_ll = |i: usize| -> Result<f64, String> {
let marginal_eta = block_states[0].eta[i];
let marginal = self.marginal_link_map(marginal_eta)?;
let (neglog, _, _) = self.rigid_row_kernel_eval(i, marginal_eta, marginal, b[i])?;
Ok(-neglog)
};
if let Some(threshold) = options.early_exit_threshold {
return bernoulli_margslope_line_search_ll_with_early_exit(
&row_iter, scale, threshold, row_ll,
);
}
let total: Result<f64, String> = row_iter
.into_par_iter()
.try_fold(
|| 0.0,
|mut ll, i| -> Result<_, String> {
ll += row_ll(i)?;
Ok(ll)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
);
return total.map(|v| v * scale);
}
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let row_ll = |row: usize| -> Result<f64, String> {
let intercept = self
.solve_row_intercept_base(
row,
block_states[0].eta[row],
block_states[1].eta[row],
beta_h,
beta_w,
None,
)?
.0;
let slope = block_states[1].eta[row];
let obs =
self.observed_denested_cell_partials(row, intercept, slope, beta_h, beta_w)?;
let s_i = eval_coeff4_at(&obs.coeff, self.z[row]);
let signed = (2.0 * self.y[row] - 1.0) * s_i;
let (log_cdf, _) = signed_probit_logcdf_and_mills_ratio(signed);
Ok(self.weights[row] * log_cdf)
};
if let Some(threshold) = options.early_exit_threshold {
return bernoulli_margslope_line_search_ll_with_early_exit(
&row_iter, scale, threshold, row_ll,
);
}
let total: Result<f64, String> = row_iter
.into_par_iter()
.try_fold(
|| 0.0,
|mut ll, row| -> Result<_, String> {
ll += row_ll(row)?;
Ok(ll)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
);
total.map(|v| v * scale)
}
fn is_sigma_aux_index(
&self,
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
) -> bool {
shared_is_sigma_aux_index(self.gaussian_frailty_sd, derivative_blocks, psi_index)
}
fn sigma_scale_jet(
&self,
n_dirs: usize,
first_masks: &[usize],
second_masks: &[usize],
) -> Result<MultiDirJet, String> {
probit_frailty_scale_multi_dir_jet(
self.gaussian_frailty_sd,
"bernoulli marginal-slope log-sigma auxiliary requested without GaussianShift sigma",
n_dirs,
first_masks,
second_masks,
)
}
fn row_neglog_directional_with_scale_jet(
&self,
row: usize,
block_states: &[ParameterBlockState],
dirs: &[Array1<f64>],
scale_jet: &MultiDirJet,
) -> Result<f64, String> {
let k = dirs.len();
if k > 4 {
return Err(format!(
"bernoulli marginal-slope sigma row directional expects 0..=4 directions, got {k}"
));
}
if scale_jet.coeffs.len() != (1usize << k) {
return Err(format!(
"bernoulli marginal-slope sigma scale jet dimension mismatch: coeffs={}, dirs={k}",
scale_jet.coeffs.len()
));
}
let first = |idx: usize| -> Vec<f64> { dirs.iter().map(|dir| dir[idx]).collect() };
let marginal = self.marginal_link_map(block_states[0].eta[row])?;
let eta_jet = MultiDirJet::linear(k, block_states[0].eta[row], &first(0));
let q_jet = eta_jet.compose_unary([
marginal.q,
marginal.q1,
marginal.q2,
marginal.q3,
marginal.q4,
]);
let g_jet = MultiDirJet::linear(k, block_states[1].eta[row], &first(1));
let observed_g_jet = g_jet.mul(scale_jet);
let one_plus_b2 = MultiDirJet::constant(k, 1.0).add(&observed_g_jet.mul(&observed_g_jet));
let c_jet = one_plus_b2.compose_unary(unary_derivatives_sqrt(one_plus_b2.coeff(0)));
let z_jet = MultiDirJet::constant(k, self.z[row]);
let eta_observed_jet = q_jet.mul(&c_jet).add(&observed_g_jet.mul(&z_jet));
let signed_jet = eta_observed_jet.scale(2.0 * self.y[row] - 1.0);
Ok(signed_jet
.compose_unary(unary_derivatives_neglog_phi(
signed_jet.coeff(0),
self.weights[row],
))
.coeff((1usize << k) - 1))
}
fn row_sigma_primary_terms(
&self,
row: usize,
block_states: &[ParameterBlockState],
second_sigma: bool,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let primary_dim = 2usize;
let zero = Array1::<f64>::zeros(primary_dim);
let objective = if second_sigma {
let scale = self.sigma_scale_jet(2, &[1, 2], &[3])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), zero.clone()],
&scale,
)?
} else {
let scale = self.sigma_scale_jet(1, &[1], &[])?;
self.row_neglog_directional_with_scale_jet(row, block_states, &[zero.clone()], &scale)?
};
let mut grad = Array1::<f64>::zeros(primary_dim);
for a in 0..primary_dim {
let mut da = Array1::<f64>::zeros(primary_dim);
da[a] = 1.0;
grad[a] = if second_sigma {
let scale = self.sigma_scale_jet(3, &[1, 2], &[3])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), zero.clone(), da],
&scale,
)?
} else {
let scale = self.sigma_scale_jet(2, &[1], &[])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), da],
&scale,
)?
};
}
let mut hess = Array2::<f64>::zeros((primary_dim, primary_dim));
for a in 0..primary_dim {
let mut da = Array1::<f64>::zeros(primary_dim);
da[a] = 1.0;
for b in a..primary_dim {
let mut db = Array1::<f64>::zeros(primary_dim);
db[b] = 1.0;
let value = if second_sigma {
let scale = self.sigma_scale_jet(4, &[1, 2], &[3])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), zero.clone(), da.clone(), db],
&scale,
)?
} else {
let scale = self.sigma_scale_jet(3, &[1], &[])?;
self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), da.clone(), db],
&scale,
)?
};
hess[[a, b]] = value;
hess[[b, a]] = value;
}
}
Ok((objective, grad, hess))
}
fn accumulate_rigid_sigma_pullback(
&self,
row: usize,
slices: &BlockSlices,
primary_grad: &Array1<f64>,
primary_hessian: &Array2<f64>,
score: &mut Array1<f64>,
hessian: &mut BernoulliBlockHessianAccumulator,
) -> Result<(), String> {
{
let mut marginal = score.slice_mut(s![slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, primary_grad[0], &mut marginal)?;
}
{
let mut logslope = score.slice_mut(s![slices.logslope.clone()]);
self.logslope_design
.axpy_row_into(row, primary_grad[1], &mut logslope)?;
}
hessian.add_pullback(self, row, slices, &primary_slices(slices), primary_hessian);
Ok(())
}
fn sigma_exact_joint_psi_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
self.sigma_exact_joint_psi_terms_with_options(
block_states,
specs,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn sigma_exact_joint_psi_terms_with_options(
&self,
block_states: &[ParameterBlockState],
_specs: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if self.effective_flex_active(block_states)? {
return Err(
"bernoulli marginal-slope log-sigma hyperderivatives are implemented for the rigid probit marginal-slope kernel; flexible score/link kernels require the analytic denested cell-tensor sigma path"
.to_string(),
);
}
if self.gaussian_frailty_sd.is_none() {
return Ok(None);
}
let slices = block_slices(self);
let n = self.y.len();
let row_iter = outer_row_indices(options, n).to_vec();
let scale = outer_score_scale(options, n);
let (mut objective_psi, mut score_psi, mut acc) = chunked_row_reduction(
row_iter.as_slice(),
|| {
(
0.0,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(&slices),
)
},
|row, acc| -> Result<(), String> {
let (obj, grad, hess) = self.row_sigma_primary_terms(row, block_states, false)?;
acc.0 += obj;
self.accumulate_rigid_sigma_pullback(
row, &slices, &grad, &hess, &mut acc.1, &mut acc.2,
)?;
Ok(())
},
|total, chunk| {
total.0 += chunk.0;
total.1 += &chunk.1;
total.2.add(&chunk.2);
},
)?;
if scale != 1.0 {
objective_psi *= scale;
score_psi.mapv_inplace(|v| v * scale);
acc.scale_assign(scale);
}
Ok(Some(ExactNewtonJointPsiTerms {
objective_psi,
score_psi,
hessian_psi: Array2::zeros((0, 0)),
hessian_psi_operator: Some(Arc::new(acc.into_operator(&slices))),
}))
}
fn sigma_exact_joint_psisecond_order_terms(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
self.sigma_exact_joint_psisecond_order_terms_with_options(
block_states,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn sigma_exact_joint_psisecond_order_terms_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
if self.effective_flex_active(block_states)? {
return Err(
"bernoulli marginal-slope second log-sigma hyperderivatives are implemented for the rigid probit marginal-slope kernel; flexible score/link kernels require the analytic denested cell-tensor sigma path"
.to_string(),
);
}
if self.gaussian_frailty_sd.is_none() {
return Ok(None);
}
let slices = block_slices(self);
let n = self.y.len();
let row_iter = outer_row_indices(options, n).to_vec();
let scale = outer_score_scale(options, n);
let (mut objective_psi_psi, mut score_psi_psi, mut acc) = chunked_row_reduction(
row_iter.as_slice(),
|| {
(
0.0,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(&slices),
)
},
|row, acc| -> Result<(), String> {
let (obj, grad, hess) = self.row_sigma_primary_terms(row, block_states, true)?;
acc.0 += obj;
self.accumulate_rigid_sigma_pullback(
row, &slices, &grad, &hess, &mut acc.1, &mut acc.2,
)?;
Ok(())
},
|total, chunk| {
total.0 += chunk.0;
total.1 += &chunk.1;
total.2.add(&chunk.2);
},
)?;
if scale != 1.0 {
objective_psi_psi *= scale;
score_psi_psi.mapv_inplace(|v| v * scale);
acc.scale_assign(scale);
}
Ok(Some(ExactNewtonJointPsiSecondOrderTerms {
objective_psi_psi,
score_psi_psi,
hessian_psi_psi: Array2::zeros((0, 0)),
hessian_psi_psi_operator: Some(Box::new(acc.into_operator(&slices))),
}))
}
fn sigma_exact_joint_psihessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.sigma_exact_joint_psihessian_directional_derivative_with_options(
block_states,
d_beta_flat,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn sigma_exact_joint_psihessian_directional_derivative_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
options: &BlockwiseFitOptions,
) -> Result<Option<Array2<f64>>, String> {
if self.effective_flex_active(block_states)? {
return Err(
"bernoulli marginal-slope log-sigma Hessian directional derivatives are implemented for the rigid probit marginal-slope kernel; flexible score/link kernels require the analytic denested cell-tensor sigma path"
.to_string(),
);
}
if self.gaussian_frailty_sd.is_none() {
return Ok(None);
}
let slices = block_slices(self);
if d_beta_flat.len() != slices.total {
return Err(format!(
"bernoulli marginal-slope d_beta length mismatch for sigma Hessian derivative: got {}, expected {}",
d_beta_flat.len(),
slices.total
));
}
let n = self.y.len();
let primary = primary_slices(&slices);
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
let mut acc = chunked_row_reduction(
row_iter.as_slice(),
|| BernoulliBlockHessianAccumulator::new(&slices),
|row, acc| -> Result<(), String> {
let row_dir =
self.row_primary_direction_from_flat(row, &slices, &primary, d_beta_flat)?;
let zero = Array1::<f64>::zeros(primary.total);
let mut grad = Array1::<f64>::zeros(primary.total);
for a in 0..primary.total {
let mut da = Array1::<f64>::zeros(primary.total);
da[a] = 1.0;
let scale = self.sigma_scale_jet(3, &[1], &[])?;
grad[a] = self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), row_dir.clone(), da],
&scale,
)?;
}
let mut hess = Array2::<f64>::zeros((primary.total, primary.total));
for a in 0..primary.total {
let mut da = Array1::<f64>::zeros(primary.total);
da[a] = 1.0;
for b in a..primary.total {
let mut db = Array1::<f64>::zeros(primary.total);
db[b] = 1.0;
let scale = self.sigma_scale_jet(4, &[1], &[])?;
let value = self.row_neglog_directional_with_scale_jet(
row,
block_states,
&[zero.clone(), row_dir.clone(), da.clone(), db],
&scale,
)?;
hess[[a, b]] = value;
hess[[b, a]] = value;
}
}
acc.add_pullback(self, row, &slices, &primary, &hess);
Ok(())
},
|total, chunk| {
total.add(&chunk);
},
)?;
if outer_scale != 1.0 {
acc.scale_assign(outer_scale);
}
Ok(Some(acc.into_operator(&slices).to_dense()))
}
#[inline]
fn marginal_link_map(&self, eta: f64) -> Result<BernoulliMarginalLinkMap, String> {
bernoulli_marginal_link_map(&self.base_link, eta)
}
#[inline]
fn exact_newton_score_component_from_objective_gradient(
objective_gradient_component: f64,
) -> f64 {
-objective_gradient_component
}
#[inline]
fn exact_newton_score_from_objective_gradient(objective_gradient: Array1<f64>) -> Array1<f64> {
-objective_gradient
}
#[inline]
fn exact_newton_observed_information_from_objective_hessian(
objective_hessian: Array2<f64>,
) -> Array2<f64> {
objective_hessian
}
#[inline]
fn score_block_index(&self) -> Option<usize> {
self.score_warp.as_ref().map(|_| 2)
}
#[inline]
fn link_block_index(&self) -> Option<usize> {
self.link_dev
.as_ref()
.map(|_| 2 + usize::from(self.score_warp.is_some()))
}
fn optional_exact_block_state<'a>(
&self,
block_states: &'a [ParameterBlockState],
block_idx: Option<usize>,
label: &str,
) -> Result<Option<&'a ParameterBlockState>, String> {
match block_idx {
Some(idx) => block_states
.get(idx)
.map(Some)
.ok_or_else(|| format!("missing {label} block state")),
None => Ok(None),
}
}
fn score_block_state<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a ParameterBlockState>, String> {
self.optional_exact_block_state(block_states, self.score_block_index(), "score-warp")
}
fn link_block_state<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a ParameterBlockState>, String> {
self.optional_exact_block_state(block_states, self.link_block_index(), "link deviation")
}
fn score_beta<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a Array1<f64>>, String> {
Ok(self
.score_block_state(block_states)?
.map(|state| &state.beta))
}
fn link_beta<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a Array1<f64>>, String> {
Ok(self
.link_block_state(block_states)?
.map(|state| &state.beta))
}
fn validate_exact_block_state_shapes(
&self,
block_states: &[ParameterBlockState],
) -> Result<(), String> {
let expected_blocks =
2usize + usize::from(self.score_warp.is_some()) + usize::from(self.link_dev.is_some());
if block_states.len() != expected_blocks {
return Err(format!(
"bernoulli marginal-slope block count mismatch: got {}, expected {}",
block_states.len(),
expected_blocks
));
}
let n_rows = self.y.len();
let marginal = &block_states[0];
let marginal_ncols = self.marginal_design.ncols();
if marginal_ncols > 0 && marginal.beta.len() != marginal_ncols {
return Err(format!(
"bernoulli marginal-slope marginal beta length mismatch: got {}, expected {}",
marginal.beta.len(),
marginal_ncols
));
}
if marginal.eta.len() != n_rows {
return Err(format!(
"bernoulli marginal-slope marginal eta length mismatch: got {}, expected {}",
marginal.eta.len(),
n_rows
));
}
let logslope = &block_states[1];
let logslope_ncols = self.logslope_design.ncols();
if logslope_ncols > 0 && logslope.beta.len() != logslope_ncols {
return Err(format!(
"bernoulli marginal-slope logslope beta length mismatch: got {}, expected {}",
logslope.beta.len(),
logslope_ncols
));
}
if logslope.eta.len() != n_rows {
return Err(format!(
"bernoulli marginal-slope logslope eta length mismatch: got {}, expected {}",
logslope.eta.len(),
n_rows
));
}
if let Some(runtime) = &self.score_warp {
let score = self
.score_block_state(block_states)?
.expect("score-warp block should exist when runtime is present");
if score.beta.len() != runtime.basis_dim() {
return Err(format!(
"bernoulli marginal-slope score-warp beta length mismatch: got {}, expected {}",
score.beta.len(),
runtime.basis_dim()
));
}
if score.eta.len() != n_rows {
return Err(format!(
"bernoulli marginal-slope score-warp eta length mismatch: got {}, expected {}",
score.eta.len(),
n_rows
));
}
}
if let Some(runtime) = &self.link_dev {
let link = self
.link_block_state(block_states)?
.expect("link-deviation block should exist when runtime is present");
if link.beta.len() != runtime.basis_dim() {
return Err(format!(
"bernoulli marginal-slope link-deviation beta length mismatch: got {}, expected {}",
link.beta.len(),
runtime.basis_dim()
));
}
if link.eta.len() != n_rows {
return Err(format!(
"bernoulli marginal-slope link-deviation eta length mismatch: got {}, expected {}",
link.eta.len(),
n_rows
));
}
}
Ok(())
}
fn denested_partition_cells(
&self,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<Vec<exact_kernel::DenestedPartitionCell>, String> {
shared_denested_partition_cells(
a,
b,
self.score_warp.as_ref(),
beta_h,
self.link_dev.as_ref(),
beta_w,
self.probit_frailty_scale(),
)
}
#[inline]
fn evaluate_cell_moments_lru(
&self,
cell: exact_kernel::DenestedCubicCell,
max_degree: usize,
) -> Result<exact_kernel::CellMomentState, String> {
exact_kernel::evaluate_cell_moments_cached(
cell,
max_degree,
&self.cell_moment_lru,
Some(&self.cell_moment_cache_stats),
)
}
#[inline]
fn for_each_deviation_basis_cubic_at<F>(
runtime: &DeviationRuntime,
primary_range: &std::ops::Range<usize>,
value: f64,
label: &str,
mut visit: F,
) -> Result<(), String>
where
F: FnMut(usize, usize, exact_kernel::LocalSpanCubic) -> Result<(), String>,
{
if primary_range.len() != runtime.basis_dim() {
return Err(format!(
"{label} primary range length {} does not match deviation basis dimension {}",
primary_range.len(),
runtime.basis_dim()
));
}
runtime.for_each_basis_cubic_at(value, |local_idx, basis_span| {
visit(local_idx, primary_range.start + local_idx, basis_span)
})
}
fn evaluate_denested_calibration_newton(
&self,
a: f64,
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<(f64, f64, f64), String> {
let marginal = self.marginal_link_map(marginal_eta)?;
let cells = self.denested_partition_cells(a, slope, beta_h, beta_w)?;
let scale = self.probit_frailty_scale();
let mut f = -marginal.mu;
let mut f_a = 0.0;
for partition_cell in cells {
let cell = partition_cell.cell;
let state = self.evaluate_cell_moments_lru(cell, 4)?;
f += state.value;
let (dc_da_raw, _) = exact_kernel::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
slope,
);
let dc_da = scale_coeff4(dc_da_raw, scale);
f_a += exact_kernel::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
}
Ok((f, f_a, 0.0))
}
fn evaluate_empirical_grid_calibration_newton(
&self,
a: f64,
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
grid: &EmpiricalZGrid,
) -> Result<(f64, f64, f64), String> {
let marginal = self.marginal_link_map(marginal_eta)?;
let mut f = -marginal.mu;
let mut f_a = 0.0;
let mut f_aa = 0.0;
for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
let obs = self.observed_denested_cell_partials_at_z(node, a, slope, beta_h, beta_w)?;
let eta = eval_coeff4_at(&obs.coeff, node);
let eta_a = eval_coeff4_at(&obs.dc_da, node);
let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
let pdf = normal_pdf(eta);
f += weight * normal_cdf(eta);
f_a += weight * pdf * eta_a;
f_aa += weight * pdf * (eta_aa - eta * eta_a * eta_a);
}
if !(f.is_finite() && f_a.is_finite() && f_a > 0.0 && f_aa.is_finite()) {
return Err(format!(
"empirical latent denested calibration produced invalid root state: f={f}, f_a={f_a}, f_aa={f_aa}"
));
}
Ok((f, f_a, f_aa))
}
fn evaluate_calibration_newton(
&self,
row: usize,
a: f64,
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<(f64, f64, f64), String> {
match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
self.evaluate_denested_calibration_newton(a, marginal_eta, slope, beta_h, beta_w)
}
Some(grid) => self.evaluate_empirical_grid_calibration_newton(
a,
marginal_eta,
slope,
beta_h,
beta_w,
&grid,
),
}
}
fn flex_active(&self) -> bool {
self.score_warp.is_some() || self.link_dev.is_some()
}
fn effective_flex_active(&self, block_states: &[ParameterBlockState]) -> Result<bool, String> {
if self.score_warp.is_some() && self.score_beta(block_states)?.is_none() {
return Err("missing bernoulli score-warp block state".to_string());
}
if self.link_dev.is_some() && self.link_beta(block_states)?.is_none() {
return Err("missing bernoulli link-deviation block state".to_string());
}
Ok(self.flex_active())
}
fn validate_exact_monotonicity(
&self,
block_states: &[ParameterBlockState],
) -> Result<(), String> {
self.validate_exact_block_state_shapes(block_states)?;
if let (Some(runtime), Some(score)) =
(&self.score_warp, self.score_block_state(block_states)?)
{
runtime.monotonicity_feasible(
&score.beta,
"bernoulli marginal-slope score-warp deviation",
)?;
}
if let (Some(runtime), Some(beta_w)) = (&self.link_dev, self.link_beta(block_states)?) {
runtime.monotonicity_feasible(beta_w, "bernoulli marginal-slope link deviation")?;
}
Ok(())
}
fn link_terms_value_d1(
&self,
eta0: &Array1<f64>,
beta_w: Option<&Array1<f64>>,
) -> Result<(Array1<f64>, Array1<f64>), String> {
if let (Some(runtime), Some(beta)) = (&self.link_dev, beta_w) {
let basis = runtime.design(eta0)?;
let d1 = runtime.first_derivative_design(eta0)?;
Ok((eta0 + &basis.dot(beta), d1.dot(beta) + 1.0))
} else {
Ok((eta0.clone(), Array1::ones(eta0.len())))
}
}
fn row_intercept_closed_form_seed(
&self,
marginal: BernoulliMarginalLinkMap,
slope: f64,
beta_w: Option<&Array1<f64>>,
) -> Result<f64, String> {
let probit_scale = self.probit_frailty_scale();
let a_rigid_pre_scale =
rigid_intercept_from_marginal(marginal.q, slope, probit_scale) / probit_scale;
if beta_w.is_some() {
let v = Array1::from_vec(vec![a_rigid_pre_scale]);
let (l_val, l_d1) = self.link_terms_value_d1(&v, beta_w)?;
let ell1 = l_d1[0];
if ell1 > 1e-8 {
let ell0 = l_val[0] - ell1 * a_rigid_pre_scale;
let observed_logslope = probit_scale * ell1 * slope;
return Ok(
(marginal.q * (1.0 + observed_logslope * observed_logslope).sqrt()
/ probit_scale
- ell0)
/ ell1,
);
}
}
Ok(a_rigid_pre_scale)
}
fn preseed_intercept_warm_starts(
&self,
block_states: &[ParameterBlockState],
) -> Result<(), String> {
if !self.effective_flex_active(block_states)? {
return Ok(());
}
let Some(cache) = self.intercept_warm_starts.as_ref() else {
return Ok(());
};
let beta_w = self.link_beta(block_states)?;
let n = self.y.len();
if cache.len() != n {
return Ok(());
}
let marginal_eta = &block_states[0].eta;
let slope_eta = &block_states[1].eta;
let seeds: Result<Vec<f64>, String> = (0..n)
.into_par_iter()
.map(|row| {
let marginal = self.marginal_link_map(marginal_eta[row])?;
self.row_intercept_closed_form_seed(marginal, slope_eta[row], beta_w)
})
.collect();
let seeds = seeds?;
let nan_bits = f64::NAN.to_bits();
let mut preseeded = 0usize;
let mut kept_warm = 0usize;
for (row, seed) in seeds.iter().enumerate() {
let slot = &cache[row];
if !seed.is_finite() {
continue;
}
match slot.compare_exchange(
nan_bits,
seed.to_bits(),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => preseeded += 1,
Err(prev) => {
if f64::from_bits(prev).is_finite() {
kept_warm += 1;
}
}
}
}
log::info!(
"[bernoulli intercept warm-start] preseeded={} (cold), kept_warm={} (carried over from previous PIRLS)",
preseeded,
kept_warm,
);
Ok(())
}
#[inline]
fn row_intercept_newton_is_converged(a: f64, f: f64, f_a: f64) -> bool {
if !a.is_finite() || !f.is_finite() || !f_a.is_finite() || f_a == 0.0 {
return false;
}
let correction = (f / f_a).abs();
f.abs() <= 1e-10 || correction <= 1e-10 * (1.0 + a.abs())
}
}
#[derive(Default)]
struct BernoulliInterceptSolveStats {
cached_short_circuit: AtomicUsize,
closed_form_short_circuit: AtomicUsize,
full_solver: AtomicUsize,
seed_residual_le_1e12: AtomicUsize,
seed_residual_le_1e10: AtomicUsize,
seed_residual_le_1e8: AtomicUsize,
seed_residual_le_abs_tol: AtomicUsize,
seed_residual_gt_abs_tol: AtomicUsize,
max_full_solver_iters: AtomicUsize,
}
impl BernoulliInterceptSolveStats {
fn record_seed_residual(&self, residual: f64, abs_tol: f64) {
let abs = residual.abs();
if abs <= 1e-12 {
self.seed_residual_le_1e12.fetch_add(1, Ordering::Relaxed);
} else if abs <= 1e-10 {
self.seed_residual_le_1e10.fetch_add(1, Ordering::Relaxed);
} else if abs <= 1e-8 {
self.seed_residual_le_1e8.fetch_add(1, Ordering::Relaxed);
} else if abs <= abs_tol {
self.seed_residual_le_abs_tol
.fetch_add(1, Ordering::Relaxed);
} else {
self.seed_residual_gt_abs_tol
.fetch_add(1, Ordering::Relaxed);
}
}
fn record_full_solver(&self, refine_iters: usize) {
self.full_solver.fetch_add(1, Ordering::Relaxed);
let mut current = self.max_full_solver_iters.load(Ordering::Relaxed);
while refine_iters > current {
match self.max_full_solver_iters.compare_exchange_weak(
current,
refine_iters,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(next) => current = next,
}
}
}
}
impl BernoulliMarginalSlopeFamily {
fn intercept_primary_point(
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Vec<f64> {
let mut point = Vec::with_capacity(
2 + beta_h.map(|beta| beta.len()).unwrap_or(0)
+ beta_w.map(|beta| beta.len()).unwrap_or(0),
);
point.push(q);
point.push(b);
if let Some(beta) = beta_h {
point.extend(beta.iter().copied());
}
if let Some(beta) = beta_w {
point.extend(beta.iter().copied());
}
point
}
#[inline]
fn cache_row_intercept(&self, row: usize, a: f64) {
if let Some(cache) = self.intercept_warm_starts.as_ref()
&& let Some(slot) = cache.get(row)
{
slot.store(a.to_bits(), Ordering::Relaxed);
}
}
fn cache_row_intercept_predictor(
&self,
row: usize,
a: f64,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
a_u: &Array1<f64>,
) {
let Some(cache) = self.intercept_warm_starts.as_ref() else {
return;
};
let primary_point = Self::intercept_primary_point(q, b, beta_h, beta_w);
if primary_point.len() != a_u.len() {
return;
}
cache.store_predictor(row, a, primary_point, a_u.iter().copied().collect());
}
#[inline]
fn beta_linf(beta: Option<&Array1<f64>>) -> f64 {
beta.map(|b| b.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs())))
.unwrap_or(0.0)
}
fn near_zero_deviation_residual_bound(
&self,
slope: f64,
beta_h_linf: f64,
beta_w_linf: f64,
) -> f64 {
let score_basis_sup = self
.score_warp
.as_ref()
.map(|runtime| runtime.value_basis_l1_sup_norm())
.unwrap_or(0.0);
let link_basis_sup = self
.link_dev
.as_ref()
.map(|runtime| runtime.value_basis_l1_sup_norm())
.unwrap_or(0.0);
normal_pdf(0.0)
* self.probit_frailty_scale()
* (slope.abs() * score_basis_sup * beta_h_linf + link_basis_sup * beta_w_linf)
}
fn solve_row_intercept_base(
&self,
row: usize,
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
stats: Option<&BernoulliInterceptSolveStats>,
) -> Result<(f64, f64, bool), String> {
let marginal = self.marginal_link_map(marginal_eta)?;
let probit_scale = self.probit_frailty_scale();
let target = marginal.mu;
let abs_tol = 1e-8_f64.max(1e-4 * target.abs());
let rigid_a = rigid_prescale_intercept_from_marginal(marginal.q, slope, probit_scale);
let rigid_abs_deriv =
rigid_prescale_intercept_derivative_abs(marginal.q, slope, probit_scale);
let beta_h_linf = Self::beta_linf(beta_h);
let beta_w_linf = Self::beta_linf(beta_w);
let exact_zero_deviation = beta_h_linf == 0.0 && beta_w_linf == 0.0;
let standard_normal_law = matches!(self.latent_measure, LatentMeasureKind::StandardNormal);
if exact_zero_deviation && standard_normal_law {
self.cache_row_intercept(row, rigid_a);
return Ok((rigid_a, rigid_abs_deriv, true));
}
let near_zero_bound =
self.near_zero_deviation_residual_bound(slope, beta_h_linf, beta_w_linf);
let beta_linf_max = beta_h_linf.max(beta_w_linf);
if standard_normal_law && near_zero_bound <= abs_tol && beta_linf_max <= f64::EPSILON.sqrt()
{
let (f_rigid, _, _) = self.evaluate_calibration_newton(
row,
rigid_a,
marginal_eta,
slope,
beta_h,
beta_w,
)?;
if f_rigid.abs() <= abs_tol {
self.cache_row_intercept(row, rigid_a);
return Ok((rigid_a, rigid_abs_deriv, true));
}
}
let eval = |a: f64| -> Result<(f64, f64, f64), String> {
self.evaluate_calibration_newton(row, a, marginal_eta, slope, beta_h, beta_w)
};
let a_closed_form = self.row_intercept_closed_form_seed(marginal, slope, beta_w)?;
let current_primary_point =
Self::intercept_primary_point(marginal_eta, slope, beta_h, beta_w);
let predictor_a = self
.intercept_warm_starts
.as_ref()
.and_then(|cache| cache.predictor_seed(row, ¤t_primary_point));
let cached_a = self.intercept_warm_starts.as_ref().and_then(|cache| {
let value = f64::from_bits(cache.get(row)?.load(Ordering::Relaxed));
value.is_finite().then_some(value)
});
let a_init = predictor_a.or(cached_a).unwrap_or(a_closed_form);
let probe_result = (|| -> Result<(Option<(f64, f64, f64)>, f64), String> {
let (f0, f_a0, _) = eval(a_init)?;
let seed_residual = f0;
if Self::row_intercept_newton_is_converged(a_init, f0, f_a0) {
return Ok((Some((a_init, f_a0.abs(), f0)), seed_residual));
}
if f_a0.is_finite() && f_a0 != 0.0 {
let a1 = a_init - f0 / f_a0;
if a1.is_finite() {
let (f1, f_a1, _) = eval(a1)?;
if Self::row_intercept_newton_is_converged(a1, f1, f_a1) {
return Ok((Some((a1, f_a1.abs(), f1)), seed_residual));
}
}
}
Ok((None, seed_residual))
})();
if let Ok((accepted, seed_residual)) = &probe_result {
if let Some(stats) = stats {
stats.record_seed_residual(*seed_residual, abs_tol);
}
if let Some((a, abs_deriv, _)) = accepted {
if let Some(stats) = stats {
if predictor_a.is_some() || cached_a.is_some() {
stats.cached_short_circuit.fetch_add(1, Ordering::Relaxed);
} else {
stats
.closed_form_short_circuit
.fetch_add(1, Ordering::Relaxed);
}
}
self.cache_row_intercept(row, *a);
return Ok((*a, *abs_deriv, false));
}
}
let mut solve_result = super::monotone_root::solve_monotone_root_detailed(
&eval,
a_init,
"bernoulli intercept",
abs_tol,
64,
48,
);
if (predictor_a.is_some() || cached_a.is_some()) && solve_result.is_err() {
solve_result = super::monotone_root::solve_monotone_root_detailed(
&eval,
a_closed_form,
"bernoulli intercept",
abs_tol,
64,
48,
);
}
let solve_solution = solve_result?;
if let Some(stats) = stats {
stats.record_full_solver(solve_solution.refine_iters);
}
let (a, abs_deriv, f_best) = (
solve_solution.root,
solve_solution.abs_deriv,
solve_solution.residual,
);
if f_best.abs() > abs_tol {
return Err(format!(
"bernoulli marginal-slope intercept solve failed: \
residual={f_best:.3e} at a={a:.6}, target mu={target:.6}"
));
}
self.cache_row_intercept(row, a);
Ok((a, abs_deriv, false))
}
#[cfg(test)]
fn build_row_exact_context(
&self,
row: usize,
block_states: &[ParameterBlockState],
) -> Result<BernoulliMarginalSlopeRowExactContext, String> {
self.build_row_exact_context_with_stats(row, block_states, None)
}
fn build_row_exact_context_with_stats(
&self,
row: usize,
block_states: &[ParameterBlockState],
stats: Option<&BernoulliInterceptSolveStats>,
) -> Result<BernoulliMarginalSlopeRowExactContext, String> {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let slope = block_states[1].eta[row];
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let (intercept, m_a, intercept_fast_path) = if self.effective_flex_active(block_states)? {
self.solve_row_intercept_base(row, marginal_eta, slope, beta_h, beta_w, stats)?
} else {
let intercept = match self.latent_measure.empirical_grid_for_training_row(row)? {
None => {
rigid_intercept_from_marginal(marginal.q, slope, self.probit_frailty_scale())
}
Some(grid) => self.empirical_rigid_intercept_for_row(
row,
marginal,
slope,
&grid.nodes,
&grid.weights,
)?,
};
(intercept, f64::NAN, false)
};
let degree9_cells = if self.effective_flex_active(block_states)?
&& matches!(self.latent_measure, LatentMeasureKind::StandardNormal)
{
let cells = self.denested_partition_cells(intercept, slope, beta_h, beta_w)?;
let dedup_enabled = cell_moment_per_row_dedup_enabled();
let mut dedup: HashMap<exact_kernel::CellFingerprint, exact_kernel::CellMomentState> =
HashMap::new();
let mut out: Vec<CachedDenestedCellMoments> = Vec::with_capacity(cells.len());
for partition_cell in cells.into_iter() {
let state: exact_kernel::CellMomentState = if dedup_enabled {
let key = exact_kernel::CellFingerprint::new(partition_cell.cell);
if let Some(existing) = dedup.get(&key) {
existing.clone()
} else {
let computed = self.evaluate_cell_moments_lru(partition_cell.cell, 9)?;
dedup.insert(key, computed.clone());
computed
}
} else {
self.evaluate_cell_moments_lru(partition_cell.cell, 9)?
};
out.push(CachedDenestedCellMoments {
partition_cell,
state,
});
}
Some(out)
} else {
None
};
Ok(BernoulliMarginalSlopeRowExactContext {
intercept,
m_a,
intercept_fast_path,
degree9_cells,
})
}
#[inline]
fn row_ctx(
cache: &BernoulliMarginalSlopeExactEvalCache,
row: usize,
) -> &BernoulliMarginalSlopeRowExactContext {
&cache.row_contexts[row]
}
fn build_exact_eval_cache_with_order(
&self,
block_states: &[ParameterBlockState],
) -> Result<BernoulliMarginalSlopeExactEvalCache, String> {
self.validate_exact_block_state_shapes(block_states)?;
let slices = block_slices(self);
let primary = primary_slices(&slices);
let n = self.y.len();
let flex_active = self.effective_flex_active(block_states)?;
let started = std::time::Instant::now();
if log_exact_work(n) {
log::info!(
"[BMS exact-cache] build start n={} p={} flex={}",
n,
slices.total,
flex_active
);
}
self.preseed_intercept_warm_starts(block_states)?;
if flex_active {
exact_kernel::reset_tail_cell_moment_cache();
}
let stats = BernoulliInterceptSolveStats::default();
let cell_cache_before = self.cell_moment_cache_stats.snapshot();
let row_contexts: Result<Vec<_>, String> = (0..n)
.into_par_iter()
.map(|row| self.build_row_exact_context_with_stats(row, block_states, Some(&stats)))
.collect();
let row_contexts = row_contexts?;
let fast_path_rows = row_contexts
.iter()
.filter(|ctx| ctx.intercept_fast_path)
.count();
log::debug!(
"[BMS exact-cache] row-intercept zero-deviation fast path rows={}/{}",
fast_path_rows,
n
);
if flex_active {
log::debug!(
"bernoulli marginal-slope intercept seed short-circuit: cached={}, closed_form={}, full_solver={}, max_full_solver_iters={}, seed_residual_bins={{<=1e-12:{}, <=1e-10:{}, <=1e-8:{}, <=abs_tol:{}, >abs_tol:{}}}",
stats.cached_short_circuit.load(Ordering::Relaxed),
stats.closed_form_short_circuit.load(Ordering::Relaxed),
stats.full_solver.load(Ordering::Relaxed),
stats.max_full_solver_iters.load(Ordering::Relaxed),
stats.seed_residual_le_1e12.load(Ordering::Relaxed),
stats.seed_residual_le_1e10.load(Ordering::Relaxed),
stats.seed_residual_le_1e8.load(Ordering::Relaxed),
stats.seed_residual_le_abs_tol.load(Ordering::Relaxed),
stats.seed_residual_gt_abs_tol.load(Ordering::Relaxed),
);
}
if flex_active {
let (cell_hits, cell_misses, cell_hit_rate) = self
.cell_moment_cache_stats
.hit_rate_delta(cell_cache_before);
log::info!(
"[BMS cell-moment LRU] cycle hits={} misses={} hit_rate={:.1}% entries={} resident_mib={:.1}/{:.1}",
cell_hits,
cell_misses,
100.0 * cell_hit_rate,
self.cell_moment_lru.len(),
self.cell_moment_lru.resident_bytes() as f64 / (1024.0 * 1024.0),
self.cell_moment_lru.max_bytes() as f64 / (1024.0 * 1024.0),
);
let tail_stats = exact_kernel::tail_cell_moment_cache_stats();
log::info!(
"[BMS exact-cache] affine tail-cell memo: hits={} misses={} entries={} hit_rate={:.3}%",
tail_stats.hits,
tail_stats.misses,
tail_stats.entries,
100.0 * tail_stats.hit_rate(),
);
}
if log_exact_work(n) {
log::info!(
"[BMS exact-cache] build done n={} p={} flex={} elapsed={:.3}s",
n,
slices.total,
flex_active,
started.elapsed().as_secs_f64()
);
}
let row_cell_moments =
self.build_row_cell_moments_bundle(block_states, &row_contexts, 21)?;
Ok(BernoulliMarginalSlopeExactEvalCache {
slices,
primary,
row_contexts,
row_cell_moments,
row_primary_hessians: None,
})
}
fn build_row_cell_moments_bundle(
&self,
block_states: &[ParameterBlockState],
row_contexts: &[BernoulliMarginalSlopeRowExactContext],
max_degree: usize,
) -> Result<Option<RowCellMomentsBundle>, String> {
if !self.effective_flex_active(block_states)? {
return Ok(None);
}
if !matches!(self.latent_measure, LatentMeasureKind::StandardNormal) {
return Ok(None);
}
let n = self.y.len();
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let partitions: Vec<Vec<exact_kernel::DenestedPartitionCell>> = (0..n)
.into_par_iter()
.map(|row| {
self.denested_partition_cells(
row_contexts[row].intercept,
block_states[1].eta[row],
beta_h,
beta_w,
)
})
.collect::<Result<Vec<_>, String>>()?;
let n_cells = partitions.iter().map(Vec::len).sum::<usize>();
let estimated_bytes =
RowCellMomentsBundle::estimated_resident_bytes(n, n_cells, max_degree);
let limit_bytes = self.policy.max_operator_cache_bytes;
if estimated_bytes > limit_bytes {
log::warn!(
"[BMS row-cell-moments] skip precompute n={} cells={} degree={} estimated_bytes={} limit_bytes={}",
n,
n_cells,
max_degree,
estimated_bytes,
limit_bytes
);
return Ok(None);
}
let started = std::time::Instant::now();
let rows = partitions
.into_par_iter()
.map(|cells| {
cells
.into_iter()
.map(|partition_cell| {
self.evaluate_cell_moments_lru(partition_cell.cell, max_degree)
.map(|state| CachedDenestedCellMoments {
partition_cell,
state,
})
})
.collect::<Result<Vec<_>, String>>()
})
.collect::<Result<Vec<_>, String>>()?;
if log_exact_work(n) {
log::info!(
"[BMS row-cell-moments] precomputed n={} cells={} degree={} estimated_bytes={} elapsed={:.3}s",
n,
n_cells,
max_degree,
estimated_bytes,
started.elapsed().as_secs_f64()
);
}
Ok(Some(RowCellMomentsBundle { max_degree, rows }))
}
fn row_hessian_importance_bound(
&self,
row: usize,
block_states: &[ParameterBlockState],
row_ctx: &BernoulliMarginalSlopeRowExactContext,
) -> Result<f64, String> {
let q = block_states[0].eta[row];
let b = block_states[1].eta[row];
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let z_obs = self.z[row];
let eta_val = eval_coeff4_at(
&self
.observed_denested_cell_partials(row, row_ctx.intercept, b, beta_h, beta_w)?
.coeff,
z_obs,
);
let marginal = self.marginal_link_map(q)?;
let s_y = 2.0 * self.y[row] - 1.0;
let signed_margin = s_y * eta_val;
let (_, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
let d1_abs = self.weights[row].abs() * lambda.abs();
let d2 = self.weights[row].abs() * lambda.abs() * (signed_margin + lambda).abs();
let m_a_recip = if row_ctx.m_a.is_finite() && row_ctx.m_a.abs() > 0.0 {
row_ctx.m_a.abs().recip()
} else {
f64::INFINITY
};
Ok((d1_abs + d2) * (1.0 + marginal.mu1.abs() + marginal.mu2.abs() + m_a_recip))
}
fn hessian_skip_mask_from_row_bounds(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
tau: f64,
) -> Result<Option<Vec<bool>>, String> {
if tau <= 0.0 || !self.effective_flex_active(block_states)? {
return Ok(None);
}
let n = self.y.len();
let bounds: Result<Vec<_>, String> = (0..n)
.into_par_iter()
.map(|row| {
self.row_hessian_importance_bound(row, block_states, Self::row_ctx(cache, row))
})
.collect();
let bounds = bounds?;
let mean = bounds.iter().copied().sum::<f64>() / n.max(1) as f64;
if !mean.is_finite() || mean <= 0.0 {
return Ok(None);
}
let threshold = tau * mean;
Ok(Some(bounds.into_iter().map(|b| b < threshold).collect()))
}
fn build_row_primary_hessian_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<Array2<f64>>, String> {
if !self.effective_flex_active(block_states)? {
return Ok(None);
}
let n = self.y.len();
let primary = &cache.primary;
let r = primary.total;
let tau = BMS_HV_ROW_SKIP_TAU;
let skip_mask = self.hessian_skip_mask_from_row_bounds(block_states, cache, tau)?;
let skipped = skip_mask
.as_ref()
.map(|mask| mask.iter().filter(|&&skip| skip).count())
.unwrap_or(0);
if tau > 0.0 {
log::info!(
"[BMS Hessian-Hv row-skip] tau={:.3e} skipped={}/{} ({:.3}%)",
tau,
skipped,
n,
100.0 * skipped as f64 / n.max(1) as f64
);
}
let rows: Result<Vec<_>, String> = (0..n)
.into_par_iter()
.map(|row| {
if skip_mask.as_ref().is_some_and(|mask| mask[row]) {
return Ok(vec![0.0; r * r]);
}
let row_ctx = Self::row_ctx(cache, row);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(r);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
Ok(scratch.hess.into_raw_vec_and_offset().0)
})
.collect();
let rows = rows?;
let mut packed = Array2::<f64>::zeros((n, r * r));
for (row, row_hess) in rows.into_iter().enumerate() {
packed.row_mut(row).assign(&Array1::from(row_hess));
}
Ok(Some(packed))
}
#[inline]
fn cached_row_primary_hessian<'a>(
cache: &'a BernoulliMarginalSlopeExactEvalCache,
row: usize,
) -> Option<ArrayView2<'a, f64>> {
let rows = cache.row_primary_hessians.as_ref()?;
let r = cache.primary.total;
if row >= rows.nrows() {
return None;
}
let width = r.checked_mul(r)?;
let start = row.checked_mul(width)?;
let end = start.checked_add(width)?;
ArrayView2::from_shape((r, r), rows.as_slice()?.get(start..end)?).ok()
}
fn build_exact_eval_cache(
&self,
block_states: &[ParameterBlockState],
) -> Result<BernoulliMarginalSlopeExactEvalCache, String> {
self.build_exact_eval_cache_with_order(block_states)
}
fn row_primary_direction_from_flat(
&self,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
d_beta_flat: &Array1<f64>,
) -> Result<Array1<f64>, String> {
if d_beta_flat.len() != slices.total {
return Err(format!(
"bernoulli marginal-slope d_beta length mismatch: got {}, expected {}",
d_beta_flat.len(),
slices.total
));
}
let mut out = Array1::<f64>::zeros(primary.total);
out[primary.q] = self
.marginal_design
.dot_row_view(row, d_beta_flat.slice(s![slices.marginal.clone()]));
out[primary.logslope] = self
.logslope_design
.dot_row_view(row, d_beta_flat.slice(s![slices.logslope.clone()]));
if let (Some(block_range), Some(primary_range)) = (slices.h.as_ref(), primary.h.as_ref()) {
out.slice_mut(s![primary_range.start..primary_range.end])
.assign(&d_beta_flat.slice(s![block_range.clone()]).to_owned());
}
if let (Some(block_range), Some(primary_range)) = (slices.w.as_ref(), primary.w.as_ref()) {
out.slice_mut(s![primary_range.start..primary_range.end])
.assign(&d_beta_flat.slice(s![block_range.clone()]).to_owned());
}
Ok(out)
}
fn row_primary_psi_direction_from_map(
&self,
row: usize,
block_idx: usize,
psi_map: &crate::families::custom_family::PsiDesignMap,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(primary.total);
match block_idx {
0 => {
let x_row = psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.q] = x_row.dot(&block_states[0].beta);
}
1 => {
let x_row = psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.logslope] = x_row.dot(&block_states[1].beta);
}
_ => {
return Err(format!(
"bernoulli marginal-slope psi direction only supports spatial marginal/logslope blocks, got block {block_idx}"
));
}
}
Ok(out)
}
fn row_primary_psi_action_on_direction_from_map(
&self,
row: usize,
block_idx: usize,
psi_map: &crate::families::custom_family::PsiDesignMap,
slices: &BlockSlices,
d_beta_flat: &Array1<f64>,
primary: &PrimarySlices,
) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(primary.total);
match block_idx {
0 => {
let x_row = psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.q] =
x_row.dot(&d_beta_flat.slice(s![slices.marginal.clone()]).to_owned())
}
1 => {
let x_row = psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.logslope] =
x_row.dot(&d_beta_flat.slice(s![slices.logslope.clone()]).to_owned())
}
_ => {
return Err(format!(
"bernoulli marginal-slope psi action only supports marginal/logslope blocks, got block {block_idx}"
));
}
}
Ok(out)
}
fn row_primary_psi_second_direction_from_map(
&self,
row: usize,
block_i: usize,
block_j: usize,
psi_map_ij: Option<&crate::families::custom_family::PsiDesignMap>,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
) -> Result<Array1<f64>, String> {
if block_i != block_j {
return Ok(Array1::<f64>::zeros(primary.total));
}
let psi_map_ij = psi_map_ij.expect("psi_map_ij must be provided when block_i == block_j");
let mut out = Array1::<f64>::zeros(primary.total);
match block_i {
0 => {
let x_row = psi_map_ij
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.q] = x_row.dot(&block_states[0].beta);
}
1 => {
let x_row = psi_map_ij
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?;
out[primary.logslope] = x_row.dot(&block_states[1].beta);
}
_ => {
return Err(format!(
"bernoulli marginal-slope psi second direction only supports marginal/logslope blocks, got block {block_i}"
));
}
}
Ok(out)
}
fn pullback_primary_vector(
&self,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
primary_vec: &Array1<f64>,
) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(slices.total);
{
let mut marginal = out.slice_mut(s![slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, primary_vec[primary.q], &mut marginal)?;
}
{
let mut logslope = out.slice_mut(s![slices.logslope.clone()]);
self.logslope_design.axpy_row_into(
row,
primary_vec[primary.logslope],
&mut logslope,
)?;
}
if let Some(primary_h) = primary.h.as_ref() {
if let Some(block_h) = slices.h.as_ref() {
out.slice_mut(s![block_h.clone()]).assign(
&primary_vec
.slice(s![primary_h.start..primary_h.end])
.to_owned(),
);
}
}
if let Some(primary_w) = primary.w.as_ref() {
if let Some(block_w) = slices.w.as_ref() {
out.slice_mut(s![block_w.clone()]).assign(
&primary_vec
.slice(s![primary_w.start..primary_w.end])
.to_owned(),
);
}
}
Ok(out)
}
fn block_psi_row_from_map(
&self,
row: usize,
block_idx: usize,
psi_map: &crate::families::custom_family::PsiDesignMap,
slices: &BlockSlices,
) -> Result<BlockPsiRow, String> {
let (local_vec, range) = match block_idx {
0 => (
psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?,
slices.marginal.clone(),
),
1 => (
psi_map
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?,
slices.logslope.clone(),
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi embedding only supports marginal/logslope blocks, got block {block_idx}"
));
}
};
Ok(BlockPsiRow {
block_idx,
range,
local_vec,
})
}
fn block_psi_second_row_from_map(
&self,
row: usize,
block_i: usize,
block_j: usize,
psi_map_ij: Option<&crate::families::custom_family::PsiDesignMap>,
slices: &BlockSlices,
) -> Result<Option<BlockPsiRow>, String> {
if block_i != block_j {
return Ok(None);
}
let psi_map_ij = psi_map_ij.expect("psi_map_ij must be provided when block_i == block_j");
let (local_vec, range) = match block_i {
0 => (
psi_map_ij
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?,
slices.marginal.clone(),
),
1 => (
psi_map_ij
.row_vector(row)
.map_err(|e| format!("survival rowwise psi map: {e}"))?,
slices.logslope.clone(),
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi second embedding only supports marginal/logslope blocks, got block {block_i}"
));
}
};
Ok(Some(BlockPsiRow {
block_idx: block_i,
range,
local_vec,
}))
}
fn compute_row_primary_gradient_hessian(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
if self.effective_flex_active(block_states)? {
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
let neglog = self.compute_row_analytic_flex_into(
row,
block_states,
primary,
row_ctx,
true,
&mut scratch,
)?;
return Ok((neglog, scratch.grad, scratch.hess));
}
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (neglog, grad_pair, h) = self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
let mut grad = Array1::<f64>::zeros(2);
grad[0] = grad_pair[0];
grad[1] = grad_pair[1];
let mut hess = Array2::<f64>::zeros((2, 2));
hess[[0, 0]] = h[0][0];
hess[[0, 1]] = h[0][1];
hess[[1, 0]] = h[1][0];
hess[[1, 1]] = h[1][1];
Ok((neglog, grad, hess))
}
fn compute_row_analytic_flex_into(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
need_hessian: bool,
scratch: &mut BernoulliMarginalSlopeFlexRowScratch,
) -> Result<f64, String> {
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
None,
need_hessian,
scratch,
)
}
fn compute_row_analytic_flex_into_with_moments(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
row_cell_moments: Option<&[CachedDenestedCellMoments]>,
need_hessian: bool,
scratch: &mut BernoulliMarginalSlopeFlexRowScratch,
) -> Result<f64, String> {
let q = block_states[0].eta[row];
let b = block_states[1].eta[row];
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
self.compute_row_analytic_flex_from_parts_into(
row,
primary,
q,
b,
beta_h,
beta_w,
row_ctx,
row_cell_moments,
need_hessian,
scratch,
)
}
fn compute_row_analytic_flex_from_parts_into(
&self,
row: usize,
primary: &PrimarySlices,
q: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
row_cell_moments: Option<&[CachedDenestedCellMoments]>,
need_hessian: bool,
scratch: &mut BernoulliMarginalSlopeFlexRowScratch,
) -> Result<f64, String> {
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let r = primary.total;
scratch.reset(need_hessian);
let a = row_ctx.intercept;
let f_a = row_ctx.m_a;
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let marginal = self.marginal_link_map(q)?;
let inv_ma = 1.0 / f_a;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let zero_family = vec![[0.0; 4]; r];
let f_u = &mut scratch.m_u;
let f_au = &mut scratch.m_au;
let f_uv = &mut scratch.m_uv;
let mut f_aa = 0.0f64;
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
if let Some(empirical_grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
for (&node, &weight) in empirical_grid
.nodes
.iter()
.zip(empirical_grid.weights.iter())
{
coeff_u.fill([0.0; 4]);
coeff_au.fill([0.0; 4]);
coeff_bu.fill([0.0; 4]);
let obs = self.observed_denested_cell_partials_at_z(node, a, b, beta_h, beta_w)?;
let eta = eval_coeff4_at(&obs.coeff, node);
let eta_a = eval_coeff4_at(&obs.dc_da, node);
let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
let phi = normal_pdf(eta);
if need_hessian {
f_aa += weight * phi * (eta_aa - eta * eta_a * eta_a);
}
coeff_u[1] = obs.dc_db;
if need_hessian {
coeff_au[1] = obs.dc_dab;
coeff_bu[1] = obs.dc_dbb;
}
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
node,
"score-warp",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
if need_hessian {
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
}
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
let u_node = a + b * node;
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_node,
"link-wiggle",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
if need_hessian {
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
}
Ok(())
},
)?;
}
let eta_u = (0..r)
.map(|idx| eval_coeff4_at(&coeff_u[idx], node))
.collect::<Vec<_>>();
for u in 1..r {
f_u[u] += weight * phi * eta_u[u];
if need_hessian {
let eta_au = eval_coeff4_at(&coeff_au[u], node);
f_au[u] += weight * phi * (eta_au - eta * eta_a * eta_u[u]);
}
}
if need_hessian {
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
for u in 1..r {
for v in u..r {
let second_coeff = coeff_jet.pair_from_b_family(
coeff_jet.b_first,
u,
v,
COEFF_SUPPORT_BHW,
);
let eta_uv = eval_coeff4_at(&second_coeff, node);
let val = weight * phi * (eta_uv - eta * eta_u[u] * eta_u[v]);
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
}
}
}
}
} else {
let owned_cells;
let cached_cells: Vec<(
exact::DenestedPartitionCell,
std::borrow::Cow<'_, exact::CellMomentState>,
)> = if let Some(cached) = row_cell_moments {
debug_assert!(
!cached.is_empty(),
"row cell moments bundle was selected but row {row} has no cells"
);
cached
.iter()
.map(|entry| {
(
entry.partition_cell,
std::borrow::Cow::Borrowed(&entry.state),
)
})
.collect()
} else if need_hessian && let Some(cached) = row_ctx.degree9_cells.as_ref() {
cached
.iter()
.map(|entry| {
(
entry.partition_cell,
std::borrow::Cow::Borrowed(&entry.state),
)
})
.collect()
} else {
owned_cells = self.denested_partition_cells(a, b, beta_h, beta_w)?;
owned_cells
.into_iter()
.map(|partition_cell| {
let degree = if need_hessian { 9 } else { 3 };
self.evaluate_cell_moments_lru(partition_cell.cell, degree)
.map(|state| (partition_cell, std::borrow::Cow::Owned(state)))
})
.collect::<Result<Vec<_>, String>>()?
};
for (partition_cell, state) in cached_cells {
coeff_u.fill([0.0; 4]);
coeff_au.fill([0.0; 4]);
coeff_bu.fill([0.0; 4]);
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state: &exact::CellMomentState = &state;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
coeff_u[1] = dc_db;
coeff_au[1] = [0.0; 4];
coeff_bu[1] = [0.0; 4];
if need_hessian {
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
}
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
if need_hessian {
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
}
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
if need_hessian {
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
}
Ok(())
},
)?;
}
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_u[u], &state.moments)?;
if need_hessian {
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_u[u],
&coeff_au[u],
&state.moments,
)?;
}
}
if need_hessian {
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
for u in 1..r {
for v in u..r {
let second_coeff = coeff_jet.pair_from_b_family(
coeff_jet.b_first,
u,
v,
COEFF_SUPPORT_BHW,
);
let val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
}
}
}
}
}
f_u[0] = -marginal.mu1;
if need_hessian {
f_uv[[0, 0]] = -marginal.mu2;
}
let a_u = &mut scratch.a_u;
for u in 0..r {
a_u[u] = -f_u[u] * inv_ma;
}
self.cache_row_intercept_predictor(row, a, q, b, beta_h, beta_w, a_u);
let a_uv = &mut scratch.a_uv;
if need_hessian {
for u in 0..r {
for v in u..r {
let val = -(f_uv[[u, v]]
+ f_au[u] * a_u[v]
+ f_au[v] * a_u[u]
+ f_aa * a_u[u] * a_u[v])
* inv_ma;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let chi_obs = eval_coeff4_at(&obs.dc_da, z_obs);
let eta_aa_obs = eval_coeff4_at(&obs.dc_daa, z_obs);
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
let rho = &mut scratch.rho;
let tau = &mut scratch.tau;
rho.fill(0.0);
tau.fill(0.0);
for u in 1..r {
rho[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
tau[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
}
let eta_u = &mut scratch.grad;
for u in 0..r {
eta_u[u] = chi_obs * a_u[u] + rho[u];
}
let signed_margin = s_y * eta_val;
let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
let neglog_val = -w_i * log_cdf;
let d1_m = -w_i * lambda;
let d2_m = w_i * lambda * (signed_margin + lambda);
if need_hessian {
let hess = &mut scratch.hess;
hess.fill(0.0);
for u in 0..r {
for v in u..r {
let r_uv = eval_coeff4_at(
&g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW),
z_obs,
);
let eta_uv = chi_obs * a_uv[[u, v]]
+ eta_aa_obs * a_u[u] * a_u[v]
+ tau[u] * a_u[v]
+ a_u[u] * tau[v]
+ r_uv;
let val = d2_m * eta_u[u] * eta_u[v] + d1_m * s_y * eta_uv;
hess[[u, v]] = val;
hess[[v, u]] = val;
}
}
}
eta_u.mapv_inplace(|eu| d1_m * s_y * eu);
Ok(neglog_val)
}
fn primary_point_from_block_states(
&self,
row: usize,
block_states: &[ParameterBlockState],
primary: &PrimarySlices,
) -> Result<Array1<f64>, String> {
let mut point = Array1::<f64>::zeros(primary.total);
point[primary.q] = block_states[0].eta[row];
point[primary.logslope] = block_states[1].eta[row];
if let Some(h_range) = primary.h.as_ref() {
let score = self
.score_block_state(block_states)?
.ok_or_else(|| "missing score-warp beta".to_string())?;
point
.slice_mut(s![h_range.start..h_range.end])
.assign(&score.beta);
}
if let Some(w_range) = primary.w.as_ref() {
let beta_w = self
.link_block_state(block_states)?
.ok_or_else(|| "missing link deviation beta".to_string())?;
point
.slice_mut(s![w_range.start..w_range.end])
.assign(&beta_w.beta);
}
Ok(point)
}
fn primary_point_components(
&self,
point: &Array1<f64>,
primary: &PrimarySlices,
) -> (f64, f64, Option<Array1<f64>>, Option<Array1<f64>>) {
let beta_h = primary
.h
.as_ref()
.map(|range| point.slice(s![range.start..range.end]).to_owned());
let beta_w = primary
.w
.as_ref()
.map(|range| point.slice(s![range.start..range.end]).to_owned());
(point[primary.q], point[primary.logslope], beta_h, beta_w)
}
fn observed_denested_cell_partials(
&self,
row: usize,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<ObservedDenestedCellPartials, String> {
shared_observed_denested_cell_partials(
self.z[row],
a,
b,
self.score_warp.as_ref(),
beta_h,
self.link_dev.as_ref(),
beta_w,
self.probit_frailty_scale(),
)
}
fn observed_denested_cell_partials_at_z(
&self,
z_value: f64,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<ObservedDenestedCellPartials, String> {
shared_observed_denested_cell_partials(
z_value,
a,
b,
self.score_warp.as_ref(),
beta_h,
self.link_dev.as_ref(),
beta_w,
self.probit_frailty_scale(),
)
}
fn row_primary_third_contracted_recompute(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir: &Array1<f64>,
) -> Result<Array2<f64>, String> {
if !self.effective_flex_active(block_states)? {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let t =
self.rigid_row_third_contracted(row, marginal_eta, marginal, g, dir[0], dir[1])?;
let mut out = Array2::<f64>::zeros((2, 2));
out[[0, 0]] = t[0][0];
out[[0, 1]] = t[0][1];
out[[1, 0]] = t[1][0];
out[[1, 1]] = t[1][1];
return Ok(out);
}
if dir.iter().all(|value| value.abs() <= 0.0) {
return Ok(Array2::<f64>::zeros((
cache.primary.total,
cache.primary.total,
)));
}
if !row_ctx.intercept.is_finite() || !row_ctx.m_a.is_finite() || row_ctx.m_a <= 0.0 {
return Err(
"non-finite flexible row context in third-order directional contraction"
.to_string(),
);
}
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let primary = &cache.primary;
let point = self.primary_point_from_block_states(row, block_states, primary)?;
let (q, b, beta_h_owned, beta_w_owned) = self.primary_point_components(&point, primary);
let beta_h = beta_h_owned.as_ref();
let beta_w = beta_w_owned.as_ref();
if let Some(grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
return self.empirical_flex_row_third_contracted_recompute(
row, primary, q, b, beta_h, beta_w, row_ctx, dir, &grid,
);
}
let a = row_ctx.intercept;
let r = primary.total;
let marginal = self.marginal_link_map(q)?;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let zero_family = vec![[0.0; 4]; r];
let mut f_a = 0.0;
let mut f_aa = 0.0;
let mut f_a_dir = 0.0;
let mut f_aa_dir = 0.0;
let mut f_u = Array1::<f64>::zeros(r);
let mut f_au = Array1::<f64>::zeros(r);
let mut f_au_dir = Array1::<f64>::zeros(r);
let mut f_uv = Array2::<f64>::zeros((r, r));
let mut f_uv_dir = Array2::<f64>::zeros((r, r));
let cells = self.denested_partition_cells(a, b, beta_h, beta_w)?;
for partition_cell in cells {
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = self.evaluate_cell_moments_lru(cell, 15)?;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let denested_third = exact::denested_cell_third_partials(partition_cell.link_span);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
let dc_daab = scale_coeff4(denested_third.1, scale);
let dc_dabb = scale_coeff4(denested_third.2, scale);
let dc_dbbb = scale_coeff4(denested_third.3, scale);
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
coeff_u[1] = dc_db;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
coeff_aau[1] = dc_daab;
coeff_abu[1] = dc_dabb;
coeff_bbu[1] = dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp third-direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle third-direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw_raw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw_raw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&coeff_aau,
&coeff_abu,
&coeff_bbu,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
f_a += exact::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_jet.first[u], &state.moments)?;
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_jet.a_first[u],
&state.moments,
)?;
}
let coeff_dir = coeff_jet.directional_family(coeff_jet.first, dir, COEFF_SUPPORT_BHW);
let coeff_a_dir =
coeff_jet.directional_family(coeff_jet.a_first, dir, COEFF_SUPPORT_BW);
let coeff_aa_dir =
coeff_jet.directional_family(coeff_jet.aa_first, dir, COEFF_SUPPORT_BW);
f_a_dir += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_dir,
&coeff_a_dir,
&state.moments,
)?;
f_aa_dir += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir,
&dc_daa,
&coeff_a_dir,
&coeff_a_dir,
&coeff_aa_dir,
&state.moments,
)?;
let mut coeff_u_dir = vec![[0.0; 4]; r];
let mut coeff_au_dir = vec![[0.0; 4]; r];
for u in 1..r {
coeff_u_dir[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.b_first,
u,
dir,
COEFF_SUPPORT_BHW,
);
coeff_au_dir[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.ab_first,
u,
dir,
COEFF_SUPPORT_BW,
);
}
for u in 1..r {
f_au_dir[u] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_dir,
&coeff_jet.a_first[u],
&coeff_a_dir,
&coeff_u_dir[u],
&coeff_au_dir[u],
&state.moments,
)?;
}
for u in 1..r {
for v in u..r {
let second_coeff =
coeff_jet.pair_from_b_family(coeff_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += val;
if u != v {
f_uv[[v, u]] += val;
}
let third_coeff = coeff_jet.pair_directional_from_bb_family(
coeff_jet.bb_first,
u,
v,
dir,
COEFF_SUPPORT_BW,
);
let dir_val = exact::cell_third_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir,
&second_coeff,
&coeff_u_dir[u],
&coeff_u_dir[v],
&third_coeff,
&state.moments,
)?;
f_uv_dir[[u, v]] += dir_val;
if u != v {
f_uv_dir[[v, u]] += dir_val;
}
}
}
}
f_u[0] = -marginal.mu1;
f_uv[[0, 0]] = -marginal.mu2;
f_uv_dir[[0, 0]] = -dir[0] * marginal.mu3;
let inv_f_a = 1.0 / f_a;
let mut a_u = Array1::<f64>::zeros(r);
for u in 0..r {
a_u[u] = -f_u[u] * inv_f_a;
}
let mut a_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val =
-(f_uv[[u, v]] + f_au[u] * a_u[v] + f_au[v] * a_u[u] + f_aa * a_u[u] * a_u[v])
* inv_f_a;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
let a_dir = a_u.dot(dir);
let a_u_dir = a_uv.dot(dir);
let mut a_uv_dir = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let n_dir = f_uv_dir[[u, v]]
+ f_au_dir[u] * a_u[v]
+ f_au[u] * a_u_dir[v]
+ f_au_dir[v] * a_u[u]
+ f_au[v] * a_u_dir[u]
+ f_aa_dir * a_u[u] * a_u[v]
+ f_aa * (a_u_dir[u] * a_u[v] + a_u[u] * a_u_dir[v]);
let val = -(n_dir + f_a_dir * a_uv[[u, v]]) * inv_f_a;
a_uv_dir[[u, v]] = val;
a_uv_dir[[v, u]] = val;
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
let mut g_aau_fixed = vec![[0.0; 4]; r];
let mut g_abu_fixed = vec![[0.0; 4]; r];
let mut g_bbu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
g_aau_fixed[1] = obs.dc_daab;
g_abu_fixed[1] = obs.dc_dabb;
g_bbu_fixed[1] = obs.dc_dbbb;
let scale = self.probit_frailty_scale();
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp third-direction observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle third-direction observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
g_aau_fixed[idx] = scale_coeff4(dc_aaw_raw, scale);
g_abu_fixed[idx] = scale_coeff4(dc_abw_raw, scale);
g_bbu_fixed[idx] = scale_coeff4(dc_bbw_raw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&g_aau_fixed,
&g_abu_fixed,
&g_bbu_fixed,
&zero_family,
&zero_family,
&zero_family,
&zero_family,
);
let g_a = eval_coeff4_at(&obs.dc_da, z_obs);
let g_aa = eval_coeff4_at(&obs.dc_daa, z_obs);
let g_aaa = eval_coeff4_at(&obs.dc_daaa, z_obs);
let mut g_u = Array1::<f64>::zeros(r);
let mut g_au = Array1::<f64>::zeros(r);
let mut g_aau = Array1::<f64>::zeros(r);
let mut g_uv = Array2::<f64>::zeros((r, r));
let mut g_auv = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
g_au[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
g_aau[u] = eval_coeff4_at(&g_jet.aa_first[u], z_obs);
}
for u in 1..r {
for v in u..r {
let second_coeff = g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = eval_coeff4_at(&second_coeff, z_obs);
g_uv[[u, v]] = val;
g_uv[[v, u]] = val;
let third_coeff = g_jet.pair_from_b_family(g_jet.ab_first, u, v, COEFF_SUPPORT_BW);
let third_val = eval_coeff4_at(&third_coeff, z_obs);
g_auv[[u, v]] = third_val;
g_auv[[v, u]] = third_val;
}
}
let mut g_u_dir_fixed = vec![[0.0; 4]; r];
let mut g_au_dir_fixed = vec![[0.0; 4]; r];
let g_dir_fixed = g_jet.directional_family(g_jet.first, dir, COEFF_SUPPORT_BHW);
let g_a_dir_fixed = g_jet.directional_family(g_jet.a_first, dir, COEFF_SUPPORT_BW);
let g_aa_dir_fixed = g_jet.directional_family(g_jet.aa_first, dir, COEFF_SUPPORT_BW);
let g_dir = eval_coeff4_at(&g_dir_fixed, z_obs);
let g_a_dir = eval_coeff4_at(&g_a_dir_fixed, z_obs);
let g_aa_dir = eval_coeff4_at(&g_aa_dir_fixed, z_obs);
for u in 1..r {
g_u_dir_fixed[u] =
g_jet.param_directional_from_b_family(g_jet.b_first, u, dir, COEFF_SUPPORT_BHW);
g_au_dir_fixed[u] =
g_jet.param_directional_from_b_family(g_jet.ab_first, u, dir, COEFF_SUPPORT_BW);
}
let mut g_u_dir = Array1::<f64>::zeros(r);
let mut g_uv_dir = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u_dir[u] = eval_coeff4_at(&g_u_dir_fixed[u], z_obs);
}
for u in 1..r {
for v in u..r {
let third_coeff = g_jet.pair_directional_from_bb_family(
g_jet.bb_first,
u,
v,
dir,
COEFF_SUPPORT_BW,
);
let val = eval_coeff4_at(&third_coeff, z_obs);
g_uv_dir[[u, v]] = val;
g_uv_dir[[v, u]] = val;
}
}
let eta_u = g_a * &a_u + &g_u;
let mut eta_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = g_a * a_uv[[u, v]]
+ g_aa * a_u[u] * a_u[v]
+ g_au[u] * a_u[v]
+ g_au[v] * a_u[u]
+ g_uv[[u, v]];
eta_uv[[u, v]] = val;
eta_uv[[v, u]] = val;
}
}
let eta_dir = g_a * a_dir + g_dir;
let eta_u_dir = eta_uv.dot(dir);
let dg_a_dir = g_aa * a_dir + g_a_dir;
let dg_aa_dir = g_aaa * a_dir + g_aa_dir;
let mut dg_au_dir = Array1::<f64>::zeros(r);
let mut dg_uv_dir = Array2::<f64>::zeros((r, r));
for u in 0..r {
dg_au_dir[u] = g_aau[u] * a_dir + eval_coeff4_at(&g_au_dir_fixed[u], z_obs);
}
for u in 0..r {
for v in u..r {
let val = g_auv[[u, v]] * a_dir + g_uv_dir[[u, v]];
dg_uv_dir[[u, v]] = val;
dg_uv_dir[[v, u]] = val;
}
}
let mut eta_uv_dir = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = dg_a_dir * a_uv[[u, v]]
+ g_a * a_uv_dir[[u, v]]
+ dg_aa_dir * a_u[u] * a_u[v]
+ g_aa * (a_u_dir[u] * a_u[v] + a_u[u] * a_u_dir[v])
+ dg_au_dir[u] * a_u[v]
+ g_au[u] * a_u_dir[v]
+ dg_au_dir[v] * a_u[u]
+ g_au[v] * a_u_dir[u]
+ dg_uv_dir[[u, v]];
eta_uv_dir[[u, v]] = val;
eta_uv_dir[[v, u]] = val;
}
}
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let m = s_y * eta_val;
let (k1, k2, k3, _) = signed_probit_neglog_derivatives_up_to_fourth(m, w_i)?;
let u1 = s_y * k1;
let u3 = s_y * k3;
let mut out = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = u3 * eta_u[u] * eta_u[v] * eta_dir
+ k2 * (eta_uv[[u, v]] * eta_dir
+ eta_u[u] * eta_u_dir[v]
+ eta_u[v] * eta_u_dir[u])
+ u1 * eta_uv_dir[[u, v]];
out[[u, v]] = val;
out[[v, u]] = val;
}
}
Ok(out)
}
fn row_primary_fourth_contracted_recompute_ordered(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir_u: &Array1<f64>,
dir_v: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let flex_active = self.effective_flex_active(block_states)?;
let expected_dir_len = if flex_active { cache.primary.total } else { 2 };
if dir_u.len() != expected_dir_len || dir_v.len() != expected_dir_len {
return Err(format!(
"bernoulli fourth contracted row {row}: direction lengths ({},{}) != {expected_dir_len}",
dir_u.len(),
dir_v.len()
));
}
if !flex_active {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let f = self.rigid_row_fourth_contracted(
row,
marginal_eta,
marginal,
g,
dir_u[0],
dir_u[1],
dir_v[0],
dir_v[1],
)?;
let mut out = Array2::<f64>::zeros((2, 2));
out[[0, 0]] = f[0][0];
out[[0, 1]] = f[0][1];
out[[1, 0]] = f[1][0];
out[[1, 1]] = f[1][1];
return Ok(out);
}
if dir_u.iter().all(|value| *value == 0.0) || dir_v.iter().all(|value| *value == 0.0) {
return Ok(Array2::<f64>::zeros((
cache.primary.total,
cache.primary.total,
)));
}
if !row_ctx.intercept.is_finite() || !row_ctx.m_a.is_finite() || row_ctx.m_a <= 0.0 {
return Err(
"non-finite flexible row context in fourth-order directional contraction"
.to_string(),
);
}
use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
let primary = &cache.primary;
let point = self.primary_point_from_block_states(row, block_states, primary)?;
let (q, b, beta_h_owned, beta_w_owned) = self.primary_point_components(&point, primary);
let beta_h = beta_h_owned.as_ref();
let beta_w = beta_w_owned.as_ref();
if let Some(grid) = self.latent_measure.empirical_grid_for_training_row(row)? {
return self.empirical_flex_row_fourth_contracted_recompute(
row, primary, q, b, beta_h, beta_w, row_ctx, dir_u, dir_v, &grid,
);
}
let a = row_ctx.intercept;
let r = primary.total;
let marginal = self.marginal_link_map(q)?;
let h_range = primary.h.as_ref();
let w_range = primary.w.as_ref();
let score_runtime = self.score_warp.as_ref();
let link_runtime = self.link_dev.as_ref();
let scale = self.probit_frailty_scale();
let mut f_a = 0.0;
let mut f_aa = 0.0;
let mut f_u = Array1::<f64>::zeros(r);
let mut f_au = Array1::<f64>::zeros(r);
let mut f_uv = Array2::<f64>::zeros((r, r));
let mut f_a_u = 0.0;
let mut f_aa_u = 0.0;
let mut f_au_u = Array1::<f64>::zeros(r);
let mut f_uv_u = Array2::<f64>::zeros((r, r));
let mut f_a_v = 0.0;
let mut f_aa_v = 0.0;
let mut f_au_v = Array1::<f64>::zeros(r);
let mut f_uv_v = Array2::<f64>::zeros((r, r));
let mut f_a_uv = 0.0;
let mut f_aa_uv = 0.0;
let mut f_au_uv = Array1::<f64>::zeros(r);
let mut f_uv_uv = Array2::<f64>::zeros((r, r));
let cells = self.denested_partition_cells(a, b, beta_h, beta_w)?;
for partition_cell in cells {
let cell = partition_cell.cell;
let z_mid = exact::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = self.evaluate_cell_moments_lru(cell, 21)?;
let (dc_da_raw, dc_db_raw) = exact::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) = exact::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
b,
);
let denested_third = exact::denested_cell_third_partials(partition_cell.link_span);
let dc_da = scale_coeff4(dc_da_raw, scale);
let dc_db = scale_coeff4(dc_db_raw, scale);
let dc_daa = scale_coeff4(dc_daa_raw, scale);
let dc_dab = scale_coeff4(dc_dab_raw, scale);
let dc_dbb = scale_coeff4(dc_dbb_raw, scale);
let dc_daab = scale_coeff4(denested_third.1, scale);
let dc_dabb = scale_coeff4(denested_third.2, scale);
let dc_dbbb = scale_coeff4(denested_third.3, scale);
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
let mut coeff_aaau = vec![[0.0; 4]; r];
let mut coeff_aabu = vec![[0.0; 4]; r];
let mut coeff_abbu = vec![[0.0; 4]; r];
let mut coeff_bbbu = vec![[0.0; 4]; r];
coeff_u[1] = dc_db;
coeff_au[1] = dc_dab;
coeff_bu[1] = dc_dbb;
coeff_aau[1] = dc_daab;
coeff_abu[1] = dc_dabb;
coeff_bbu[1] = dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_mid,
"score-warp fourth-direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, b),
scale,
);
coeff_bu[idx] = scale_coeff4(
exact::score_basis_cell_coefficients(basis_span, 1.0),
scale,
);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_mid,
"link-wiggle fourth-direction",
|_, idx, basis_span| {
coeff_u[idx] = scale_coeff4(
exact::link_basis_cell_coefficients(basis_span, a, b),
scale,
);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
let (dc_aaaw, dc_aabw, dc_abbw, dc_bbbw) =
exact::link_basis_cell_third_partials(basis_span);
coeff_au[idx] = scale_coeff4(dc_aw_raw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw_raw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw_raw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw_raw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw_raw, scale);
coeff_aaau[idx] = scale_coeff4(dc_aaaw, scale);
coeff_aabu[idx] = scale_coeff4(dc_aabw, scale);
coeff_abbu[idx] = scale_coeff4(dc_abbw, scale);
coeff_bbbu[idx] = scale_coeff4(dc_bbbw, scale);
Ok(())
},
)?;
}
let coeff_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&coeff_u,
&coeff_au,
&coeff_bu,
&coeff_aau,
&coeff_abu,
&coeff_bbu,
&coeff_aaau,
&coeff_aabu,
&coeff_abbu,
&coeff_bbbu,
);
f_a += exact::cell_first_derivative_from_moments(&dc_da, &state.moments)?;
f_aa += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&dc_daa,
&state.moments,
)?;
for u in 1..r {
f_u[u] +=
exact::cell_first_derivative_from_moments(&coeff_jet.first[u], &state.moments)?;
f_au[u] += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_jet.first[u],
&coeff_jet.a_first[u],
&state.moments,
)?;
}
let coeff_dir_u =
coeff_jet.directional_family(coeff_jet.first, dir_u, COEFF_SUPPORT_BHW);
let coeff_dir_v =
coeff_jet.directional_family(coeff_jet.first, dir_v, COEFF_SUPPORT_BHW);
let coeff_a_dir_u =
coeff_jet.directional_family(coeff_jet.a_first, dir_u, COEFF_SUPPORT_BW);
let coeff_a_dir_v =
coeff_jet.directional_family(coeff_jet.a_first, dir_v, COEFF_SUPPORT_BW);
let coeff_aa_dir_u =
coeff_jet.directional_family(coeff_jet.aa_first, dir_u, COEFF_SUPPORT_BW);
let coeff_aa_dir_v =
coeff_jet.directional_family(coeff_jet.aa_first, dir_v, COEFF_SUPPORT_BW);
f_a_u += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_dir_u,
&coeff_a_dir_u,
&state.moments,
)?;
f_a_v += exact::cell_second_derivative_from_moments(
cell,
&dc_da,
&coeff_dir_v,
&coeff_a_dir_v,
&state.moments,
)?;
f_aa_u += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir_u,
&dc_daa,
&coeff_a_dir_u,
&coeff_a_dir_u,
&coeff_aa_dir_u,
&state.moments,
)?;
f_aa_v += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir_v,
&dc_daa,
&coeff_a_dir_v,
&coeff_a_dir_v,
&coeff_aa_dir_v,
&state.moments,
)?;
let coeff_dir_uv = coeff_jet.mixed_directional_from_b_family(
coeff_jet.b_first,
dir_u,
dir_v,
COEFF_SUPPORT_BHW,
);
let coeff_a_dir_uv = coeff_jet.mixed_directional_from_b_family(
coeff_jet.ab_first,
dir_u,
dir_v,
COEFF_SUPPORT_BW,
);
let coeff_aa_dir_uv = coeff_jet.mixed_directional_from_b_family(
coeff_jet.aab_first,
dir_u,
dir_v,
COEFF_SUPPORT_W,
);
f_a_uv += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_dir_u,
&coeff_dir_v,
&coeff_a_dir_u,
&coeff_a_dir_v,
&coeff_dir_uv,
&coeff_a_dir_uv,
&state.moments,
)?;
f_aa_uv += exact::cell_fourth_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&coeff_dir_u,
&coeff_dir_v,
&dc_daa,
&coeff_a_dir_u,
&coeff_a_dir_v,
&coeff_a_dir_u,
&coeff_a_dir_v,
&coeff_dir_uv,
&coeff_aa_dir_u,
&coeff_aa_dir_v,
&coeff_a_dir_uv,
&coeff_a_dir_uv,
&coeff_aa_dir_uv,
&state.moments,
)?;
let mut coeff_u_dir_u = vec![[0.0; 4]; r];
let mut coeff_u_dir_v = vec![[0.0; 4]; r];
let mut coeff_u_dir_uv = vec![[0.0; 4]; r];
let mut coeff_au_dir_u = vec![[0.0; 4]; r];
let mut coeff_au_dir_v = vec![[0.0; 4]; r];
let mut coeff_au_dir_uv = vec![[0.0; 4]; r];
for u in 1..r {
coeff_u_dir_u[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.b_first,
u,
dir_u,
COEFF_SUPPORT_BHW,
);
coeff_u_dir_v[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.b_first,
u,
dir_v,
COEFF_SUPPORT_BHW,
);
coeff_u_dir_uv[u] = coeff_jet.param_mixed_from_bb_family(
coeff_jet.bb_first,
u,
dir_u,
dir_v,
COEFF_SUPPORT_BW,
);
coeff_au_dir_u[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.ab_first,
u,
dir_u,
COEFF_SUPPORT_BW,
);
coeff_au_dir_v[u] = coeff_jet.param_directional_from_b_family(
coeff_jet.ab_first,
u,
dir_v,
COEFF_SUPPORT_BW,
);
coeff_au_dir_uv[u] = coeff_jet.param_mixed_from_bb_family(
coeff_jet.abb_first,
u,
dir_u,
dir_v,
COEFF_SUPPORT_W,
);
}
for u in 1..r {
f_au_u[u] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_u[u],
&coeff_dir_u,
&coeff_au[u],
&coeff_a_dir_u,
&coeff_u_dir_u[u],
&coeff_au_dir_u[u],
&state.moments,
)?;
f_au_v[u] += exact::cell_third_derivative_from_moments(
cell,
&dc_da,
&coeff_u[u],
&coeff_dir_v,
&coeff_au[u],
&coeff_a_dir_v,
&coeff_u_dir_v[u],
&coeff_au_dir_v[u],
&state.moments,
)?;
f_au_uv[u] += exact::cell_fourth_derivative_from_moments(
cell,
&dc_da,
&coeff_u[u],
&coeff_dir_u,
&coeff_dir_v,
&coeff_au[u],
&coeff_a_dir_u,
&coeff_a_dir_v,
&coeff_u_dir_u[u],
&coeff_u_dir_v[u],
&coeff_dir_uv,
&coeff_au_dir_u[u],
&coeff_au_dir_v[u],
&coeff_a_dir_uv,
&coeff_u_dir_uv[u],
&coeff_au_dir_uv[u],
&state.moments,
)?;
}
for u in 1..r {
for v in u..r {
let second_coeff =
coeff_jet.pair_from_b_family(coeff_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let base_val = exact::cell_second_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&second_coeff,
&state.moments,
)?;
f_uv[[u, v]] += base_val;
if u != v {
f_uv[[v, u]] += base_val;
}
let third_u = coeff_jet.pair_directional_from_bb_family(
coeff_jet.bb_first,
u,
v,
dir_u,
COEFF_SUPPORT_BW,
);
let third_v = coeff_jet.pair_directional_from_bb_family(
coeff_jet.bb_first,
u,
v,
dir_v,
COEFF_SUPPORT_BW,
);
let fourth_uv = coeff_jet.pair_mixed_from_bbb_family(
coeff_jet.bbb_first,
u,
v,
dir_u,
dir_v,
COEFF_SUPPORT_W,
);
let dir_u_val = exact::cell_third_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir_u,
&second_coeff,
&coeff_u_dir_u[u],
&coeff_u_dir_u[v],
&third_u,
&state.moments,
)?;
let dir_v_val = exact::cell_third_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir_v,
&second_coeff,
&coeff_u_dir_v[u],
&coeff_u_dir_v[v],
&third_v,
&state.moments,
)?;
let mix_val = exact::cell_fourth_derivative_from_moments(
cell,
&coeff_jet.first[u],
&coeff_jet.first[v],
&coeff_dir_u,
&coeff_dir_v,
&second_coeff,
&coeff_u_dir_u[u],
&coeff_u_dir_v[u],
&coeff_u_dir_u[v],
&coeff_u_dir_v[v],
&coeff_dir_uv,
&third_u,
&third_v,
&coeff_u_dir_uv[u],
&coeff_u_dir_uv[v],
&fourth_uv,
&state.moments,
)?;
f_uv_u[[u, v]] += dir_u_val;
f_uv_v[[u, v]] += dir_v_val;
f_uv_uv[[u, v]] += mix_val;
if u != v {
f_uv_u[[v, u]] += dir_u_val;
f_uv_v[[v, u]] += dir_v_val;
f_uv_uv[[v, u]] += mix_val;
}
}
}
}
f_u[0] = -marginal.mu1;
f_uv[[0, 0]] = -marginal.mu2;
f_uv_u[[0, 0]] = -dir_u[0] * marginal.mu3;
f_uv_v[[0, 0]] = -dir_v[0] * marginal.mu3;
f_uv_uv[[0, 0]] = -dir_u[0] * dir_v[0] * marginal.mu4;
let inv_f_a = 1.0 / f_a;
let mut a_u = Array1::<f64>::zeros(r);
for u in 0..r {
a_u[u] = -f_u[u] * inv_f_a;
}
let mut a_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val =
-(f_uv[[u, v]] + f_au[u] * a_u[v] + f_au[v] * a_u[u] + f_aa * a_u[u] * a_u[v])
* inv_f_a;
a_uv[[u, v]] = val;
a_uv[[v, u]] = val;
}
}
let a_u_dir_u = a_uv.dot(dir_u);
let a_u_dir_v = a_uv.dot(dir_v);
let mut a_uv_u = Array2::<f64>::zeros((r, r));
let mut a_uv_v = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let n_u = f_uv_u[[u, v]]
+ f_au_u[u] * a_u[v]
+ f_au[u] * a_u_dir_u[v]
+ f_au_u[v] * a_u[u]
+ f_au[v] * a_u_dir_u[u]
+ f_aa_u * a_u[u] * a_u[v]
+ f_aa * (a_u_dir_u[u] * a_u[v] + a_u[u] * a_u_dir_u[v]);
let val_u = -(n_u + f_a_u * a_uv[[u, v]]) * inv_f_a;
a_uv_u[[u, v]] = val_u;
a_uv_u[[v, u]] = val_u;
let n_v = f_uv_v[[u, v]]
+ f_au_v[u] * a_u[v]
+ f_au[u] * a_u_dir_v[v]
+ f_au_v[v] * a_u[u]
+ f_au[v] * a_u_dir_v[u]
+ f_aa_v * a_u[u] * a_u[v]
+ f_aa * (a_u_dir_v[u] * a_u[v] + a_u[u] * a_u_dir_v[v]);
let val_v = -(n_v + f_a_v * a_uv[[u, v]]) * inv_f_a;
a_uv_v[[u, v]] = val_v;
a_uv_v[[v, u]] = val_v;
}
}
let a_u_uv = a_uv_u.dot(dir_v);
let mut a_uv_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let n_uv = f_uv_uv[[u, v]]
+ f_au_uv[u] * a_u[v]
+ f_au_u[u] * a_u_dir_v[v]
+ f_au_v[u] * a_u_dir_u[v]
+ f_au[u] * a_u_uv[v]
+ f_au_uv[v] * a_u[u]
+ f_au_u[v] * a_u_dir_v[u]
+ f_au_v[v] * a_u_dir_u[u]
+ f_au[v] * a_u_uv[u]
+ f_aa_uv * a_u[u] * a_u[v]
+ f_aa_u * (a_u_dir_v[u] * a_u[v] + a_u[u] * a_u_dir_v[v])
+ f_aa_v * (a_u_dir_u[u] * a_u[v] + a_u[u] * a_u_dir_u[v])
+ f_aa
* (a_u_uv[u] * a_u[v]
+ a_u_dir_u[u] * a_u_dir_v[v]
+ a_u_dir_v[u] * a_u_dir_u[v]
+ a_u[u] * a_u_uv[v]);
let val = -(n_uv
+ f_a_v * a_uv_u[[u, v]]
+ f_a_u * a_uv_v[[u, v]]
+ f_a_uv * a_uv[[u, v]])
* inv_f_a;
a_uv_uv[[u, v]] = val;
a_uv_uv[[v, u]] = val;
}
}
let z_obs = self.z[row];
let u_obs = a + b * z_obs;
let obs = self.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let eta_val = eval_coeff4_at(&obs.coeff, z_obs);
let mut g_u_fixed = vec![[0.0; 4]; r];
let mut g_au_fixed = vec![[0.0; 4]; r];
let mut g_bu_fixed = vec![[0.0; 4]; r];
let mut g_aau_fixed = vec![[0.0; 4]; r];
let mut g_abu_fixed = vec![[0.0; 4]; r];
let mut g_bbu_fixed = vec![[0.0; 4]; r];
let mut g_aaau_fixed = vec![[0.0; 4]; r];
let mut g_aabu_fixed = vec![[0.0; 4]; r];
let mut g_abbu_fixed = vec![[0.0; 4]; r];
let mut g_bbbu_fixed = vec![[0.0; 4]; r];
g_u_fixed[1] = obs.dc_db;
g_au_fixed[1] = obs.dc_dab;
g_bu_fixed[1] = obs.dc_dbb;
g_aau_fixed[1] = obs.dc_daab;
g_abu_fixed[1] = obs.dc_dabb;
g_bbu_fixed[1] = obs.dc_dbbb;
if let (Some(h_range), Some(runtime)) = (h_range, score_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
h_range,
z_obs,
"score-warp fourth-direction observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, b), scale);
g_bu_fixed[idx] =
scale_coeff4(exact::score_basis_cell_coefficients(basis_span, 1.0), scale);
Ok(())
},
)?;
}
if let (Some(w_range), Some(runtime)) = (w_range, link_runtime) {
Self::for_each_deviation_basis_cubic_at(
runtime,
w_range,
u_obs,
"link-wiggle fourth-direction observed",
|_, idx, basis_span| {
g_u_fixed[idx] =
scale_coeff4(exact::link_basis_cell_coefficients(basis_span, a, b), scale);
let (dc_aw_raw, dc_bw_raw) =
exact::link_basis_cell_coefficient_partials(basis_span, a, b);
let (dc_aaw_raw, dc_abw_raw, dc_bbw_raw) =
exact::link_basis_cell_second_partials(basis_span, a, b);
let (dc_aaaw, dc_aabw, dc_abbw, dc_bbbw) =
exact::link_basis_cell_third_partials(basis_span);
g_au_fixed[idx] = scale_coeff4(dc_aw_raw, scale);
g_bu_fixed[idx] = scale_coeff4(dc_bw_raw, scale);
g_aau_fixed[idx] = scale_coeff4(dc_aaw_raw, scale);
g_abu_fixed[idx] = scale_coeff4(dc_abw_raw, scale);
g_bbu_fixed[idx] = scale_coeff4(dc_bbw_raw, scale);
g_aaau_fixed[idx] = scale_coeff4(dc_aaaw, scale);
g_aabu_fixed[idx] = scale_coeff4(dc_aabw, scale);
g_abbu_fixed[idx] = scale_coeff4(dc_abbw, scale);
g_bbbu_fixed[idx] = scale_coeff4(dc_bbbw, scale);
Ok(())
},
)?;
}
let g_jet = SparsePrimaryCoeffJetView::new(
1,
h_range,
w_range,
&g_u_fixed,
&g_au_fixed,
&g_bu_fixed,
&g_aau_fixed,
&g_abu_fixed,
&g_bbu_fixed,
&g_aaau_fixed,
&g_aabu_fixed,
&g_abbu_fixed,
&g_bbbu_fixed,
);
let g_a = eval_coeff4_at(&obs.dc_da, z_obs);
let g_aa = eval_coeff4_at(&obs.dc_daa, z_obs);
let g_aaa = eval_coeff4_at(&obs.dc_daaa, z_obs);
let mut g_u = Array1::<f64>::zeros(r);
let mut g_au = Array1::<f64>::zeros(r);
let mut g_aau = Array1::<f64>::zeros(r);
let mut g_aaau = Array1::<f64>::zeros(r);
let mut g_uv = Array2::<f64>::zeros((r, r));
let mut g_auv = Array2::<f64>::zeros((r, r));
let mut g_aauv = Array2::<f64>::zeros((r, r));
for u in 1..r {
g_u[u] = eval_coeff4_at(&g_jet.first[u], z_obs);
g_au[u] = eval_coeff4_at(&g_jet.a_first[u], z_obs);
g_aau[u] = eval_coeff4_at(&g_jet.aa_first[u], z_obs);
g_aaau[u] = eval_coeff4_at(&g_jet.aaa_first[u], z_obs);
}
for u in 1..r {
for v in u..r {
let second_coeff = g_jet.pair_from_b_family(g_jet.b_first, u, v, COEFF_SUPPORT_BHW);
let val = eval_coeff4_at(&second_coeff, z_obs);
g_uv[[u, v]] = val;
g_uv[[v, u]] = val;
let third_coeff = g_jet.pair_from_b_family(g_jet.ab_first, u, v, COEFF_SUPPORT_BW);
let fourth_coeff = g_jet.pair_from_b_family(g_jet.aab_first, u, v, COEFF_SUPPORT_W);
let third_val = eval_coeff4_at(&third_coeff, z_obs);
let fourth_val = eval_coeff4_at(&fourth_coeff, z_obs);
g_auv[[u, v]] = third_val;
g_auv[[v, u]] = third_val;
g_aauv[[u, v]] = fourth_val;
g_aauv[[v, u]] = fourth_val;
}
}
let g_dir_u_fixed = g_jet.directional_family(g_jet.first, dir_u, COEFF_SUPPORT_BHW);
let g_dir_v_fixed = g_jet.directional_family(g_jet.first, dir_v, COEFF_SUPPORT_BHW);
let g_a_dir_u_fixed = g_jet.directional_family(g_jet.a_first, dir_u, COEFF_SUPPORT_BW);
let g_a_dir_v_fixed = g_jet.directional_family(g_jet.a_first, dir_v, COEFF_SUPPORT_BW);
let g_aa_dir_u_fixed = g_jet.directional_family(g_jet.aa_first, dir_u, COEFF_SUPPORT_BW);
let g_aa_dir_v_fixed = g_jet.directional_family(g_jet.aa_first, dir_v, COEFF_SUPPORT_BW);
let g_dir_uv_fixed =
g_jet.mixed_directional_from_b_family(g_jet.b_first, dir_u, dir_v, COEFF_SUPPORT_BHW);
let g_a_dir_uv_fixed =
g_jet.mixed_directional_from_b_family(g_jet.ab_first, dir_u, dir_v, COEFF_SUPPORT_BW);
let g_aa_dir_uv_fixed =
g_jet.mixed_directional_from_b_family(g_jet.aab_first, dir_u, dir_v, COEFF_SUPPORT_W);
let g_dir_u = eval_coeff4_at(&g_dir_u_fixed, z_obs);
let g_dir_v = eval_coeff4_at(&g_dir_v_fixed, z_obs);
let g_dir_uv = eval_coeff4_at(&g_dir_uv_fixed, z_obs);
let g_a_u_fixed = eval_coeff4_at(&g_a_dir_u_fixed, z_obs);
let g_a_v_fixed = eval_coeff4_at(&g_a_dir_v_fixed, z_obs);
let g_aa_u_fixed = eval_coeff4_at(&g_aa_dir_u_fixed, z_obs);
let g_aa_v_fixed = eval_coeff4_at(&g_aa_dir_v_fixed, z_obs);
let g_a_uv_fixed = eval_coeff4_at(&g_a_dir_uv_fixed, z_obs);
let g_aa_uv_fixed = eval_coeff4_at(&g_aa_dir_uv_fixed, z_obs);
let mut g_u_u_fixed = Array1::<f64>::zeros(r);
let mut g_u_v_fixed = Array1::<f64>::zeros(r);
let mut g_u_uv_fixed = Array1::<f64>::zeros(r);
let mut g_au_u_fixed = Array1::<f64>::zeros(r);
let mut g_au_v_fixed = Array1::<f64>::zeros(r);
let mut g_au_uv_fixed = Array1::<f64>::zeros(r);
let mut g_uv_u_fixed = Array2::<f64>::zeros((r, r));
let mut g_uv_v_fixed = Array2::<f64>::zeros((r, r));
let mut g_uv_uv_fixed = Array2::<f64>::zeros((r, r));
let mut g_auv_u_fixed = Array2::<f64>::zeros((r, r));
let mut g_auv_v_fixed = Array2::<f64>::zeros((r, r));
for u in 1..r {
let tmp_u =
g_jet.param_directional_from_b_family(g_jet.b_first, u, dir_u, COEFF_SUPPORT_BHW);
let tmp_v =
g_jet.param_directional_from_b_family(g_jet.b_first, u, dir_v, COEFF_SUPPORT_BHW);
let tmp_uv =
g_jet.param_mixed_from_bb_family(g_jet.bb_first, u, dir_u, dir_v, COEFF_SUPPORT_BW);
let tmp_au_u =
g_jet.param_directional_from_b_family(g_jet.ab_first, u, dir_u, COEFF_SUPPORT_BW);
let tmp_au_v =
g_jet.param_directional_from_b_family(g_jet.ab_first, u, dir_v, COEFF_SUPPORT_BW);
let tmp_au_uv =
g_jet.param_mixed_from_bb_family(g_jet.abb_first, u, dir_u, dir_v, COEFF_SUPPORT_W);
g_u_u_fixed[u] = eval_coeff4_at(&tmp_u, z_obs);
g_u_v_fixed[u] = eval_coeff4_at(&tmp_v, z_obs);
g_u_uv_fixed[u] = eval_coeff4_at(&tmp_uv, z_obs);
g_au_u_fixed[u] = eval_coeff4_at(&tmp_au_u, z_obs);
g_au_v_fixed[u] = eval_coeff4_at(&tmp_au_v, z_obs);
g_au_uv_fixed[u] = eval_coeff4_at(&tmp_au_uv, z_obs);
}
for u in 1..r {
for v in u..r {
let third_u = g_jet.pair_directional_from_bb_family(
g_jet.bb_first,
u,
v,
dir_u,
COEFF_SUPPORT_BW,
);
let third_v = g_jet.pair_directional_from_bb_family(
g_jet.bb_first,
u,
v,
dir_v,
COEFF_SUPPORT_BW,
);
let fourth_uv = g_jet.pair_mixed_from_bbb_family(
g_jet.bbb_first,
u,
v,
dir_u,
dir_v,
COEFF_SUPPORT_W,
);
let a_third_u = g_jet.pair_directional_from_bb_family(
g_jet.abb_first,
u,
v,
dir_u,
COEFF_SUPPORT_W,
);
let a_third_v = g_jet.pair_directional_from_bb_family(
g_jet.abb_first,
u,
v,
dir_v,
COEFF_SUPPORT_W,
);
let vu = eval_coeff4_at(&third_u, z_obs);
let vv = eval_coeff4_at(&third_v, z_obs);
let vuv = eval_coeff4_at(&fourth_uv, z_obs);
g_uv_u_fixed[[u, v]] = vu;
g_uv_v_fixed[[u, v]] = vv;
g_uv_uv_fixed[[u, v]] = vuv;
g_uv_u_fixed[[v, u]] = vu;
g_uv_v_fixed[[v, u]] = vv;
g_uv_uv_fixed[[v, u]] = vuv;
let atu = eval_coeff4_at(&a_third_u, z_obs);
let atv = eval_coeff4_at(&a_third_v, z_obs);
g_auv_u_fixed[[u, v]] = atu;
g_auv_v_fixed[[u, v]] = atv;
g_auv_u_fixed[[v, u]] = atu;
g_auv_v_fixed[[v, u]] = atv;
}
}
let eta_u = g_a * &a_u + &g_u;
let mut eta_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let val = g_a * a_uv[[u, v]]
+ g_aa * a_u[u] * a_u[v]
+ g_au[u] * a_u[v]
+ g_au[v] * a_u[u]
+ g_uv[[u, v]];
eta_uv[[u, v]] = val;
eta_uv[[v, u]] = val;
}
}
let a_dir_u = a_u.dot(dir_u);
let a_dir_v = a_u.dot(dir_v);
let g_a_u = g_aa * a_dir_u + g_a_u_fixed;
let g_a_v = g_aa * a_dir_v + g_a_v_fixed;
let g_aa_u = g_aaa * a_dir_u + g_aa_u_fixed;
let g_aa_v = g_aaa * a_dir_v + g_aa_v_fixed;
let mut g_u_u = Array1::<f64>::zeros(r);
let mut g_u_v = Array1::<f64>::zeros(r);
let mut g_au_u = Array1::<f64>::zeros(r);
let mut g_au_v = Array1::<f64>::zeros(r);
for u in 0..r {
g_u_u[u] = g_au[u] * a_dir_u + g_u_u_fixed[u];
g_u_v[u] = g_au[u] * a_dir_v + g_u_v_fixed[u];
g_au_u[u] = g_aau[u] * a_dir_u + g_au_u_fixed[u];
g_au_v[u] = g_aau[u] * a_dir_v + g_au_v_fixed[u];
}
let mut eta_uv_u = Array2::<f64>::zeros((r, r));
let mut eta_uv_v = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let g_uv_u = g_auv[[u, v]] * a_dir_u + g_uv_u_fixed[[u, v]];
let g_uv_v = g_auv[[u, v]] * a_dir_v + g_uv_v_fixed[[u, v]];
let val_u = g_a_u * a_uv[[u, v]]
+ g_a * a_uv_u[[u, v]]
+ g_aa_u * a_u[u] * a_u[v]
+ g_aa * (a_u_dir_u[u] * a_u[v] + a_u[u] * a_u_dir_u[v])
+ g_au_u[u] * a_u[v]
+ g_au[u] * a_u_dir_u[v]
+ g_au_u[v] * a_u[u]
+ g_au[v] * a_u_dir_u[u]
+ g_uv_u;
eta_uv_u[[u, v]] = val_u;
eta_uv_u[[v, u]] = val_u;
let val_v = g_a_v * a_uv[[u, v]]
+ g_a * a_uv_v[[u, v]]
+ g_aa_v * a_u[u] * a_u[v]
+ g_aa * (a_u_dir_v[u] * a_u[v] + a_u[u] * a_u_dir_v[v])
+ g_au_v[u] * a_u[v]
+ g_au[u] * a_u_dir_v[v]
+ g_au_v[v] * a_u[u]
+ g_au[v] * a_u_dir_v[u]
+ g_uv_v;
eta_uv_v[[u, v]] = val_v;
eta_uv_v[[v, u]] = val_v;
}
}
let a_dir_uv = a_u_dir_u.dot(dir_v);
let g_a_uv = g_aaa * a_dir_u * a_dir_v
+ g_aa * a_dir_uv
+ g_aa_u_fixed * a_dir_v
+ g_aa_v_fixed * a_dir_u
+ g_a_uv_fixed;
let g_aa_uv = g_aaau.dot(dir_u) * a_dir_v
+ g_aaau.dot(dir_v) * a_dir_u
+ g_aaa * a_dir_uv
+ g_aa_uv_fixed;
let mut g_u_uv = Array1::<f64>::zeros(r);
let mut g_au_uv = Array1::<f64>::zeros(r);
for u in 0..r {
g_u_uv[u] = g_aau[u] * a_dir_u * a_dir_v
+ g_au[u] * a_dir_uv
+ g_au_u_fixed[u] * a_dir_v
+ g_au_v_fixed[u] * a_dir_u
+ g_u_uv_fixed[u];
let row_u_u = g_aauv.row(u).dot(dir_u);
let row_u_v = g_aauv.row(u).dot(dir_v);
g_au_uv[u] = g_aaau[u] * a_dir_u * a_dir_v
+ g_aau[u] * a_dir_uv
+ row_u_u * a_dir_v
+ row_u_v * a_dir_u
+ g_au_uv_fixed[u];
}
let mut eta_uv_uv = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let g_uv_uv = g_aauv[[u, v]] * a_dir_u * a_dir_v
+ g_auv[[u, v]] * a_dir_uv
+ g_auv_u_fixed[[u, v]] * a_dir_v
+ g_auv_v_fixed[[u, v]] * a_dir_u
+ g_uv_uv_fixed[[u, v]];
let val = g_a_uv * a_uv[[u, v]]
+ g_a_u * a_uv_v[[u, v]]
+ g_a_v * a_uv_u[[u, v]]
+ g_a * a_uv_uv[[u, v]]
+ g_aa_uv * a_u[u] * a_u[v]
+ g_aa_u * (a_u_dir_v[u] * a_u[v] + a_u[u] * a_u_dir_v[v])
+ g_aa_v * (a_u_dir_u[u] * a_u[v] + a_u[u] * a_u_dir_u[v])
+ g_aa
* (a_u_uv[u] * a_u[v]
+ a_u_dir_u[u] * a_u_dir_v[v]
+ a_u_dir_v[u] * a_u_dir_u[v]
+ a_u[u] * a_u_uv[v])
+ g_au_uv[u] * a_u[v]
+ g_au_u[u] * a_u_dir_v[v]
+ g_au_v[u] * a_u_dir_u[v]
+ g_au[u] * a_u_uv[v]
+ g_au_uv[v] * a_u[u]
+ g_au_u[v] * a_u_dir_v[u]
+ g_au_v[v] * a_u_dir_u[u]
+ g_au[v] * a_u_uv[u]
+ g_uv_uv;
eta_uv_uv[[u, v]] = val;
eta_uv_uv[[v, u]] = val;
}
}
let eta_dir_u = g_a * a_dir_u + g_dir_u;
let eta_dir_v = g_a * a_dir_v + g_dir_v;
let eta_u_dir_u = eta_uv.dot(dir_u);
let eta_u_dir_v = eta_uv.dot(dir_v);
let eta_dir_uv = g_a_v * a_dir_u + g_a_u_fixed * a_dir_v + g_a * a_dir_uv + g_dir_uv;
let eta_u_uv = eta_uv_u.dot(dir_v);
let y_i = self.y[row];
let w_i = self.weights[row];
let s_y = 2.0 * y_i - 1.0;
let m = s_y * eta_val;
let (k1, k2, k3, k4) = signed_probit_neglog_derivatives_up_to_fourth(m, w_i)?;
let u1 = s_y * k1;
let u3 = s_y * k3;
let mut out = Array2::<f64>::zeros((r, r));
for u in 0..r {
for v in u..r {
let a_term = eta_u[u] * eta_u[v] * eta_dir_u;
let a_term_v = eta_u_dir_v[u] * eta_u[v] * eta_dir_u
+ eta_u[u] * eta_u_dir_v[v] * eta_dir_u
+ eta_u[u] * eta_u[v] * eta_dir_uv;
let b_term = eta_uv_u[[u, v]];
let b_term_v = eta_uv_uv[[u, v]];
let c_term = eta_uv[[u, v]] * eta_dir_u
+ eta_u[u] * eta_u_dir_u[v]
+ eta_u[v] * eta_u_dir_u[u];
let c_term_v = eta_uv_v[[u, v]] * eta_dir_u
+ eta_uv[[u, v]] * eta_dir_uv
+ eta_u_dir_v[u] * eta_u_dir_u[v]
+ eta_u[u] * eta_u_uv[v]
+ eta_u_dir_v[v] * eta_u_dir_u[u]
+ eta_u[v] * eta_u_uv[u];
let val = k4 * eta_dir_v * a_term
+ u3 * a_term_v
+ u3 * eta_dir_v * c_term
+ k2 * c_term_v
+ k2 * eta_dir_v * b_term
+ u1 * b_term_v;
out[[u, v]] = val;
out[[v, u]] = val;
}
}
Ok(out)
}
fn row_primary_fourth_contracted_recompute(
&self,
row: usize,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
row_ctx: &BernoulliMarginalSlopeRowExactContext,
dir_u: &Array1<f64>,
dir_v: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let ordered = self.row_primary_fourth_contracted_recompute_ordered(
row,
block_states,
cache,
row_ctx,
dir_u,
dir_v,
)?;
if !self.effective_flex_active(block_states)? {
return Ok(ordered);
}
let swapped = self.row_primary_fourth_contracted_recompute_ordered(
row,
block_states,
cache,
row_ctx,
dir_v,
dir_u,
)?;
let mut sym = ordered;
for i in 0..sym.nrows() {
for j in 0..sym.ncols() {
sym[[i, j]] = 0.5 * (sym[[i, j]] + swapped[[i, j]]);
}
}
Ok(sym)
}
fn add_pullback_primary_hessian_hw_only(
&self,
target: &mut Array2<f64>,
row: usize,
slices: &BlockSlices,
primary: &PrimarySlices,
primary_hessian: ArrayView2<'_, f64>,
) {
let h = primary_hessian;
if let (Some(primary_h), Some(block_h)) = (primary.h.as_ref(), slices.h.as_ref()) {
for (local_idx, global_idx) in block_h.clone().enumerate() {
let h_q = h[[0, primary_h.start + local_idx]];
if h_q != 0.0 {
{
let mut col = target.slice_mut(s![slices.marginal.clone(), global_idx]);
self.marginal_design
.axpy_row_into(row, h_q, &mut col)
.expect("marginal axpy column mismatch");
}
{
let mut row_view =
target.slice_mut(s![global_idx, slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, h_q, &mut row_view)
.expect("marginal axpy row mismatch");
}
}
let h_g = h[[1, primary_h.start + local_idx]];
if h_g != 0.0 {
{
let mut col = target.slice_mut(s![slices.logslope.clone(), global_idx]);
self.logslope_design
.axpy_row_into(row, h_g, &mut col)
.expect("logslope axpy column mismatch");
}
{
let mut row_view =
target.slice_mut(s![global_idx, slices.logslope.clone()]);
self.logslope_design
.axpy_row_into(row, h_g, &mut row_view)
.expect("logslope axpy row mismatch");
}
}
}
target
.slice_mut(s![block_h.clone(), block_h.clone()])
.scaled_add(
1.0,
&h.slice(s![
primary_h.start..primary_h.end,
primary_h.start..primary_h.end
]),
);
}
if let (Some(primary_w), Some(block_w)) = (primary.w.as_ref(), slices.w.as_ref()) {
for (local_idx, global_idx) in block_w.clone().enumerate() {
let w_q = h[[0, primary_w.start + local_idx]];
if w_q != 0.0 {
{
let mut col = target.slice_mut(s![slices.marginal.clone(), global_idx]);
self.marginal_design
.axpy_row_into(row, w_q, &mut col)
.expect("marginal axpy column mismatch");
}
{
let mut row_view =
target.slice_mut(s![global_idx, slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, w_q, &mut row_view)
.expect("marginal axpy row mismatch");
}
}
let w_g = h[[1, primary_w.start + local_idx]];
if w_g != 0.0 {
{
let mut col = target.slice_mut(s![slices.logslope.clone(), global_idx]);
self.logslope_design
.axpy_row_into(row, w_g, &mut col)
.expect("logslope axpy column mismatch");
}
{
let mut row_view =
target.slice_mut(s![global_idx, slices.logslope.clone()]);
self.logslope_design
.axpy_row_into(row, w_g, &mut row_view)
.expect("logslope axpy row mismatch");
}
}
}
if let (Some(primary_h), Some(block_h)) = (primary.h.as_ref(), slices.h.as_ref()) {
target
.slice_mut(s![block_h.clone(), block_w.clone()])
.scaled_add(
1.0,
&h.slice(s![
primary_h.start..primary_h.end,
primary_w.start..primary_w.end
]),
);
target
.slice_mut(s![block_w.clone(), block_h.clone()])
.scaled_add(
1.0,
&h.slice(s![
primary_w.start..primary_w.end,
primary_h.start..primary_h.end
]),
);
}
target
.slice_mut(s![block_w.clone(), block_w.clone()])
.scaled_add(
1.0,
&h.slice(s![
primary_w.start..primary_w.end,
primary_w.start..primary_w.end
]),
);
}
}
fn exact_newton_joint_hessian_dense_from_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Array2<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let started = std::time::Instant::now();
if log_exact_work(n) {
log::info!(
"[BMS dense-H] build start n={} p={} source=cache",
n,
slices.total
);
}
let acc = (0..n.div_ceil(ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let chunk_len = end - start;
let mut w_mm = Array1::<f64>::zeros(chunk_len);
let mut w_mg = Array1::<f64>::zeros(chunk_len);
let mut w_gg = Array1::<f64>::zeros(chunk_len);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for (local, row) in (start..end).enumerate() {
let hess_view =
if let Some(cached) = Self::cached_row_primary_hessian(cache, row) {
cached
} else {
let row_ctx = Self::row_ctx(cache, row);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
scratch.hess.view()
};
w_mm[local] = hess_view[[0, 0]];
w_mg[local] = hess_view[[0, 1]];
w_gg[local] = hess_view[[1, 1]];
if let Some(ref mut dc) = acc.dense_correction {
self.add_pullback_primary_hessian_hw_only(
dc, row, slices, primary, hess_view,
);
}
}
acc.add_weighted_design_grams(self, start..end, &w_mm, &w_mg, &w_gg)?;
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
let dense = acc.to_dense(slices);
if log_exact_work(n) {
log::info!(
"[BMS dense-H] build done n={} p={} source=cache elapsed={:.3}s",
n,
slices.total,
started.elapsed().as_secs_f64()
);
}
Ok(dense)
}
fn log_likelihood_from_exact_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<f64, String> {
if !self.effective_flex_active(block_states)? {
return self
.log_likelihood_only_with_options(block_states, &BlockwiseFitOptions::default());
}
let n = self.y.len();
let started = std::time::Instant::now();
if log_exact_work(n) {
log::info!(
"[BMS exact-loglik] eval start n={} p={} source=cache",
n,
cache.slices.total
);
}
let beta_h = self.score_beta(block_states)?;
let beta_w = self.link_beta(block_states)?;
let total: Result<f64, String> = (0..n)
.into_par_iter()
.try_fold(
|| 0.0,
|mut log_likelihood, row| -> Result<_, String> {
let row_ctx = Self::row_ctx(cache, row);
let slope = block_states[1].eta[row];
let obs = self.observed_denested_cell_partials(
row,
row_ctx.intercept,
slope,
beta_h,
beta_w,
)?;
let s_i = eval_coeff4_at(&obs.coeff, self.z[row]);
let signed = (2.0 * self.y[row] - 1.0) * s_i;
let (log_cdf, _) = signed_probit_logcdf_and_mills_ratio(signed);
log_likelihood += self.weights[row] * log_cdf;
Ok(log_likelihood)
},
)
.try_reduce(
|| 0.0,
|left, right| -> Result<_, String> { Ok(left + right) },
);
let log_likelihood = total?;
if log_exact_work(n) {
log::info!(
"[BMS exact-loglik] eval done n={} p={} source=cache elapsed={:.3}s",
n,
cache.slices.total,
started.elapsed().as_secs_f64()
);
}
Ok(log_likelihood)
}
fn exact_newton_joint_gradient_evaluation_from_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<ExactNewtonJointGradientEvaluation, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let started = std::time::Instant::now();
if log_exact_work(n) {
log::info!(
"[BMS exact-gradient] eval start n={} p={} source=cache",
n,
slices.total
);
}
let make_acc = || {
(
0.0_f64,
Array1::<f64>::zeros(slices.marginal.len()),
Array1::<f64>::zeros(slices.logslope.len()),
slices
.h
.as_ref()
.map(|range| Array1::<f64>::zeros(range.len())),
slices
.w
.as_ref()
.map(|range| Array1::<f64>::zeros(range.len())),
)
};
let (log_likelihood, grad_marginal, grad_logslope, grad_h, grad_w) = (0..n
.div_ceil(ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(make_acc, |mut acc, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in start..end {
let row_ctx = Self::row_ctx(cache, row);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 3));
let neglog = self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
false,
&mut scratch,
)?;
acc.0 -= neglog;
{
let mut marginal = acc.1.view_mut();
self.marginal_design.axpy_row_into(
row,
Self::exact_newton_score_component_from_objective_gradient(
scratch.grad[0],
),
&mut marginal,
)?;
}
{
let mut logslope = acc.2.view_mut();
self.logslope_design.axpy_row_into(
row,
Self::exact_newton_score_component_from_objective_gradient(
scratch.grad[1],
),
&mut logslope,
)?;
}
if let (Some(primary_h), Some(grad_h)) = (primary.h.as_ref(), acc.3.as_mut()) {
for idx in 0..primary_h.len() {
grad_h[idx] +=
Self::exact_newton_score_component_from_objective_gradient(
scratch.grad[primary_h.start + idx],
);
}
}
if let (Some(primary_w), Some(grad_w)) = (primary.w.as_ref(), acc.4.as_mut()) {
for idx in 0..primary_w.len() {
grad_w[idx] +=
Self::exact_newton_score_component_from_objective_gradient(
scratch.grad[primary_w.start + idx],
);
}
}
}
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.0 += right.0;
left.1 += &right.1;
left.2 += &right.2;
if let (Some(lhs), Some(rhs)) = (left.3.as_mut(), right.3.as_ref()) {
*lhs += rhs;
}
if let (Some(lhs), Some(rhs)) = (left.4.as_mut(), right.4.as_ref()) {
*lhs += rhs;
}
Ok(left)
})?;
let mut gradient = Array1::<f64>::zeros(slices.total);
gradient
.slice_mut(s![slices.marginal.clone()])
.assign(&grad_marginal);
gradient
.slice_mut(s![slices.logslope.clone()])
.assign(&grad_logslope);
if let (Some(range), Some(grad_h)) = (slices.h.as_ref(), grad_h.as_ref()) {
gradient.slice_mut(s![range.clone()]).assign(grad_h);
}
if let (Some(range), Some(grad_w)) = (slices.w.as_ref(), grad_w.as_ref()) {
gradient.slice_mut(s![range.clone()]).assign(grad_w);
}
if log_exact_work(n) {
log::info!(
"[BMS exact-gradient] eval done n={} p={} source=cache elapsed={:.3}s",
n,
slices.total,
started.elapsed().as_secs_f64()
);
}
Ok(ExactNewtonJointGradientEvaluation {
log_likelihood,
gradient,
})
}
fn exact_newton_joint_hessian_matvec_from_cache(
&self,
direction: &Array1<f64>,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Array1<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
if !self.effective_flex_active(block_states)? {
let out = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| Array1::<f64>::zeros(slices.total),
|mut chunk_out, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
for row in start..end {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (_, _, h) =
self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
let v_q = self
.marginal_design
.dot_row_view(row, direction.slice(s![slices.marginal.clone()]));
let v_g = self
.logslope_design
.dot_row_view(row, direction.slice(s![slices.logslope.clone()]));
let a_q = h[0][0] * v_q + h[0][1] * v_g;
let a_g = h[1][0] * v_q + h[1][1] * v_g;
{
let mut m = chunk_out.slice_mut(s![slices.marginal.clone()]);
self.marginal_design.axpy_row_into(row, a_q, &mut m)?;
}
{
let mut l = chunk_out.slice_mut(s![slices.logslope.clone()]);
self.logslope_design.axpy_row_into(row, a_g, &mut l)?;
}
}
Ok(chunk_out)
},
)
.try_reduce(
|| Array1::<f64>::zeros(slices.total),
|mut left, right| -> Result<_, String> {
left += &right;
Ok(left)
},
)?;
return Ok(out);
}
let out = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| Array1::<f64>::zeros(slices.total),
|mut chunk_out, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in start..end {
let row_ctx = Self::row_ctx(cache, row);
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, direction)?;
let row_action =
if let Some(row_hess) = Self::cached_row_primary_hessian(cache, row) {
row_hess.dot(&row_dir)
} else {
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
scratch.hess.dot(&row_dir)
};
chunk_out +=
&self.pullback_primary_vector(row, slices, primary, &row_action)?;
}
Ok(chunk_out)
},
)
.try_reduce(
|| Array1::<f64>::zeros(slices.total),
|mut left, right| -> Result<_, String> {
left += &right;
Ok(left)
},
)?;
Ok(out)
}
#[cfg(test)]
fn exact_newton_joint_hessian_matvec_from_cache_serial_reference(
&self,
direction: &Array1<f64>,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Array1<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let mut out = Array1::<f64>::zeros(slices.total);
if !self.effective_flex_active(block_states)? {
for row in 0..self.y.len() {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (_, _, h) = self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
let v_q = self
.marginal_design
.dot_row_view(row, direction.slice(s![slices.marginal.clone()]));
let v_g = self
.logslope_design
.dot_row_view(row, direction.slice(s![slices.logslope.clone()]));
let a_q = h[0][0] * v_q + h[0][1] * v_g;
let a_g = h[1][0] * v_q + h[1][1] * v_g;
{
let mut marginal_out = out.slice_mut(s![slices.marginal.clone()]);
self.marginal_design
.axpy_row_into(row, a_q, &mut marginal_out)?;
}
{
let mut logslope_out = out.slice_mut(s![slices.logslope.clone()]);
self.logslope_design
.axpy_row_into(row, a_g, &mut logslope_out)?;
}
}
return Ok(out);
}
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in 0..self.y.len() {
let row_ctx = Self::row_ctx(cache, row);
let row_dir = self.row_primary_direction_from_flat(row, slices, primary, direction)?;
let row_action = if let Some(row_hess) = Self::cached_row_primary_hessian(cache, row) {
row_hess.dot(&row_dir)
} else {
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
scratch.hess.dot(&row_dir)
};
out += &self.pullback_primary_vector(row, slices, primary, &row_action)?;
}
Ok(out)
}
fn exact_newton_joint_hessian_diagonal_from_cache(
&self,
block_states: &[ParameterBlockState],
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Array1<f64>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
if !self.effective_flex_active(block_states)? {
let diagonal = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| Array1::<f64>::zeros(slices.total),
|mut chunk_diag, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
for row in start..end {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (_, _, h) =
self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
{
let mut m = chunk_diag.slice_mut(s![slices.marginal.clone()]);
self.marginal_design
.squared_axpy_row_into(row, h[0][0], &mut m)?;
}
{
let mut l = chunk_diag.slice_mut(s![slices.logslope.clone()]);
self.logslope_design
.squared_axpy_row_into(row, h[1][1], &mut l)?;
}
}
Ok(chunk_diag)
},
)
.try_reduce(
|| Array1::<f64>::zeros(slices.total),
|mut left, right| -> Result<_, String> {
left += &right;
Ok(left)
},
)?;
return Ok(diagonal);
}
let diagonal = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
|| Array1::<f64>::zeros(slices.total),
|mut chunk_diag, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in start..end {
let row_ctx = Self::row_ctx(cache, row);
let cached_hess = Self::cached_row_primary_hessian(cache, row);
if cached_hess.is_none() {
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
}
let h00 = if let Some(row_hess) = cached_hess {
row_hess[[0, 0]]
} else {
scratch.hess[[0, 0]]
};
let h11 = if let Some(row_hess) = cached_hess {
row_hess[[1, 1]]
} else {
scratch.hess[[1, 1]]
};
{
let mut marginal_diag =
chunk_diag.slice_mut(s![slices.marginal.clone()]);
self.marginal_design.squared_axpy_row_into(
row,
h00,
&mut marginal_diag,
)?;
}
{
let mut logslope_diag =
chunk_diag.slice_mut(s![slices.logslope.clone()]);
self.logslope_design.squared_axpy_row_into(
row,
h11,
&mut logslope_diag,
)?;
}
if let (Some(primary_h), Some(block_h)) =
(primary.h.as_ref(), slices.h.as_ref())
{
for (local_idx, global_idx) in block_h.clone().enumerate() {
let ii = primary_h.start + local_idx;
chunk_diag[global_idx] += if let Some(row_hess) = cached_hess {
row_hess[[ii, ii]]
} else {
scratch.hess[[ii, ii]]
};
}
}
if let (Some(primary_w), Some(block_w)) =
(primary.w.as_ref(), slices.w.as_ref())
{
for (local_idx, global_idx) in block_w.clone().enumerate() {
let ii = primary_w.start + local_idx;
chunk_diag[global_idx] += if let Some(row_hess) = cached_hess {
row_hess[[ii, ii]]
} else {
scratch.hess[[ii, ii]]
};
}
}
}
Ok(chunk_diag)
},
)
.try_reduce(
|| Array1::<f64>::zeros(slices.total),
|mut left, right| -> Result<_, String> {
left += &right;
Ok(left)
},
)?;
Ok(diagonal)
}
fn exact_newton_joint_psi_terms_from_cache(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
self.exact_newton_joint_psi_terms_from_cache_with_options(
block_states,
derivative_blocks,
psi_index,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_psi_terms_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let Some((block_idx, local_idx)) = psi_derivative_location(derivative_blocks, psi_index)
else {
return Ok(None);
};
let idx_primary = if block_idx == 0 { 0 } else { 1 };
let n = self.y.len();
let deriv = &derivative_blocks[block_idx][local_idx];
let (p_psi, psi_label) = match block_idx {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi terms only support marginal/logslope blocks, got block {block_idx}"
));
}
};
let psi_map = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv,
n,
p_psi,
0..n,
psi_label,
&self.policy,
)?;
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
let (mut objective_psi, mut score_psi, mut block_acc) = row_iter
.into_par_iter()
.try_fold(
|| {
(
0.0f64,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(slices),
)
},
|mut acc, row| -> Result<_, String> {
let dir = self.row_primary_psi_direction_from_map(
row,
block_idx,
&psi_map,
block_states,
primary,
)?;
let row_ctx = Self::row_ctx(cache, row);
let (_, f_pi, f_pipi) = self.compute_row_primary_gradient_hessian(
row,
block_states,
primary,
row_ctx,
)?;
let third = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir,
)?;
let psi_row = self.block_psi_row_from_map(row, block_idx, &psi_map, slices)?;
acc.0 += f_pi.dot(&dir);
acc.1
.slice_mut(s![psi_row.range.clone()])
.scaled_add(f_pi[idx_primary], &psi_row.local_vec);
acc.1 +=
&self.pullback_primary_vector(row, slices, primary, &f_pipi.dot(&dir))?;
let right_primary = f_pipi.row(idx_primary).to_owned();
acc.2.add_rank1_psi_cross(
self,
row,
slices,
primary,
block_idx,
&psi_row.local_vec,
&right_primary,
);
acc.2.add_pullback(self, row, slices, primary, &third);
Ok(acc)
},
)
.try_reduce(
|| {
(
0.0f64,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(slices),
)
},
|mut left, right| -> Result<_, String> {
left.0 += right.0;
left.1 += &right.1;
left.2.add(&right.2);
Ok(left)
},
)?;
if outer_scale != 1.0 {
objective_psi *= outer_scale;
score_psi.mapv_inplace(|v| v * outer_scale);
block_acc.scale_assign(outer_scale);
}
Ok(Some(ExactNewtonJointPsiTerms {
objective_psi,
score_psi,
hessian_psi: Array2::zeros((0, 0)),
hessian_psi_operator: Some(std::sync::Arc::new(block_acc.into_operator(slices))),
}))
}
fn exact_newton_joint_psisecond_order_terms_from_cache(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
self.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
block_states,
derivative_blocks,
psi_i,
psi_j,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let Some((block_i, local_i)) = psi_derivative_location(derivative_blocks, psi_i) else {
return Ok(None);
};
let Some((block_j, local_j)) = psi_derivative_location(derivative_blocks, psi_j) else {
return Ok(None);
};
let idx_i = if block_i == 0 { 0 } else { 1 };
let idx_j = if block_j == 0 { 0 } else { 1 };
let n = self.y.len();
let deriv_i = &derivative_blocks[block_i][local_i];
let deriv_j = &derivative_blocks[block_j][local_j];
let (p_psi_i, label_i) = match block_i {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi second-order only supports marginal/logslope blocks, got block {block_i}"
));
}
};
let (p_psi_j, label_j) = match block_j {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi second-order only supports marginal/logslope blocks, got block {block_j}"
));
}
};
let psi_map_i = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv_i,
n,
p_psi_i,
0..n,
label_i,
&self.policy,
)?;
let psi_map_j = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv_j,
n,
p_psi_j,
0..n,
label_j,
&self.policy,
)?;
let psi_map_ij = if block_i == block_j {
Some(
crate::families::custom_family::resolve_custom_family_x_psi_psi_map(
deriv_i,
deriv_j,
local_j,
n,
p_psi_i,
0..n,
label_i,
&self.policy,
)?,
)
} else {
None
};
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
let (mut objective_psi_psi, mut score_psi_psi, mut block_acc) = row_iter
.into_par_iter()
.try_fold(
|| {
(
0.0f64,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(slices),
)
},
|mut acc, row| -> Result<_, String> {
{
let dir_i = self.row_primary_psi_direction_from_map(
row,
block_i,
&psi_map_i,
block_states,
primary,
)?;
let dir_j = self.row_primary_psi_direction_from_map(
row,
block_j,
&psi_map_j,
block_states,
primary,
)?;
let dir_ij = self.row_primary_psi_second_direction_from_map(
row,
block_i,
block_j,
psi_map_ij.as_ref(),
block_states,
primary,
)?;
let row_ctx = Self::row_ctx(cache, row);
let (_, f_pi, f_pipi) = self.compute_row_primary_gradient_hessian(
row,
block_states,
primary,
row_ctx,
)?;
let third_i = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir_i,
)?;
let third_j = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir_j,
)?;
let fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir_i,
&dir_j,
)?;
let br_i = self.block_psi_row_from_map(row, block_i, &psi_map_i, slices)?;
let br_j = self.block_psi_row_from_map(row, block_j, &psi_map_j, slices)?;
let br_ij = self.block_psi_second_row_from_map(
row,
block_i,
block_j,
psi_map_ij.as_ref(),
slices,
)?;
acc.0 += dir_i.dot(&f_pipi.dot(&dir_j)) + f_pi.dot(&dir_ij);
if let Some(ref bij) = br_ij {
let idx_ij = if bij.block_idx == 0 { 0 } else { 1 };
acc.1
.slice_mut(s![bij.range.clone()])
.scaled_add(f_pi[idx_ij], &bij.local_vec);
}
acc.1
.slice_mut(s![br_i.range.clone()])
.scaled_add(f_pipi.row(idx_i).dot(&dir_j), &br_i.local_vec);
acc.1
.slice_mut(s![br_j.range.clone()])
.scaled_add(f_pipi.row(idx_j).dot(&dir_i), &br_j.local_vec);
acc.1 += &self.pullback_primary_vector(
row,
slices,
primary,
&f_pipi.dot(&dir_ij),
)?;
acc.1 += &self.pullback_primary_vector(
row,
slices,
primary,
&third_i.dot(&dir_j),
)?;
if let Some(ref bij) = br_ij {
let idx_ij = if bij.block_idx == 0 { 0 } else { 1 };
let right_primary_ij = f_pipi.row(idx_ij).to_owned();
acc.2.add_rank1_psi_cross(
self,
row,
slices,
primary,
bij.block_idx,
&bij.local_vec,
&right_primary_ij,
);
}
let scalar_ij = f_pipi[[idx_i, idx_j]];
acc.2.add_psi_psi_outer(
block_i,
&br_i.local_vec,
block_j,
&br_j.local_vec,
scalar_ij,
);
let right_primary_i = third_j.row(idx_i).to_owned();
acc.2.add_rank1_psi_cross(
self,
row,
slices,
primary,
block_i,
&br_i.local_vec,
&right_primary_i,
);
let right_primary_j = third_i.row(idx_j).to_owned();
acc.2.add_rank1_psi_cross(
self,
row,
slices,
primary,
block_j,
&br_j.local_vec,
&right_primary_j,
);
acc.2.add_pullback(self, row, slices, primary, &fourth);
let third_ij = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&dir_ij,
)?;
acc.2.add_pullback(self, row, slices, primary, &third_ij);
}
Ok(acc)
},
)
.try_reduce(
|| {
(
0.0f64,
Array1::<f64>::zeros(slices.total),
BernoulliBlockHessianAccumulator::new(slices),
)
},
|mut left, right| -> Result<_, String> {
left.0 += right.0;
left.1 += &right.1;
left.2.add(&right.2);
Ok(left)
},
)?;
if outer_scale != 1.0 {
objective_psi_psi *= outer_scale;
score_psi_psi.mapv_inplace(|v| v * outer_scale);
block_acc.scale_assign(outer_scale);
}
Ok(Some(ExactNewtonJointPsiSecondOrderTerms {
objective_psi_psi,
score_psi_psi,
hessian_psi_psi: Array2::zeros((0, 0)),
hessian_psi_psi_operator: Some(Box::new(block_acc.into_operator(slices))),
}))
}
fn exact_newton_joint_psihessian_directional_derivative_from_cache(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<Array2<f64>>, String> {
self.exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
block_states,
derivative_blocks,
psi_index,
d_beta_flat,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Array2<f64>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let Some((block_idx, local_idx)) = psi_derivative_location(derivative_blocks, psi_index)
else {
return Ok(None);
};
let idx_primary = if block_idx == 0 { 0 } else { 1 };
let n = self.y.len();
let deriv = &derivative_blocks[block_idx][local_idx];
let (p_psi, psi_label) = match block_idx {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi hessian only supports marginal/logslope blocks, got block {block_idx}"
));
}
};
let psi_map = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv,
n,
p_psi,
0..n,
psi_label,
&self.policy,
)?;
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
let mut block_acc = row_iter
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, row| -> Result<_, String> {
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)?;
let psi_dir = self.row_primary_psi_direction_from_map(
row,
block_idx,
&psi_map,
block_states,
primary,
)?;
let psi_action = self.row_primary_psi_action_on_direction_from_map(
row,
block_idx,
&psi_map,
slices,
d_beta_flat,
primary,
)?;
let row_ctx = Self::row_ctx(cache, row);
let third_beta = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
)?;
let fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
&psi_dir,
)?;
let psi_row = self.block_psi_row_from_map(row, block_idx, &psi_map, slices)?;
let right_primary = third_beta.row(idx_primary).to_owned();
acc.add_rank1_psi_cross(
self,
row,
slices,
primary,
psi_row.block_idx,
&psi_row.local_vec,
&right_primary,
);
acc.add_pullback(self, row, slices, primary, &fourth);
let third_action = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&psi_action,
)?;
acc.add_pullback(self, row, slices, primary, &third_action);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
Ok(Some(block_acc.to_dense(slices)))
}
pub(crate) fn exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let Some((block_idx, local_idx)) = psi_derivative_location(derivative_blocks, psi_index)
else {
return Ok(None);
};
let idx_primary = if block_idx == 0 { 0 } else { 1 };
let n = self.y.len();
let deriv = &derivative_blocks[block_idx][local_idx];
let (p_psi, psi_label) = match block_idx {
0 => (
self.marginal_design.ncols(),
"BernoulliMarginalSlopeFamily marginal",
),
1 => (
self.logslope_design.ncols(),
"BernoulliMarginalSlopeFamily log-slope",
),
_ => {
return Err(format!(
"bernoulli marginal-slope psi hessian operator only supports marginal/logslope blocks, got block {block_idx}"
));
}
};
let psi_map = crate::families::custom_family::resolve_custom_family_x_psi_map(
deriv,
n,
p_psi,
0..n,
psi_label,
&self.policy,
)?;
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
let mut block_acc = row_iter
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, row| -> Result<_, String> {
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)?;
let psi_dir = self.row_primary_psi_direction_from_map(
row,
block_idx,
&psi_map,
block_states,
primary,
)?;
let psi_action = self.row_primary_psi_action_on_direction_from_map(
row,
block_idx,
&psi_map,
slices,
d_beta_flat,
primary,
)?;
let row_ctx = Self::row_ctx(cache, row);
let third_beta = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
)?;
let fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
&psi_dir,
)?;
let psi_row = self.block_psi_row_from_map(row, block_idx, &psi_map, slices)?;
let right_primary = third_beta.row(idx_primary).to_owned();
acc.add_rank1_psi_cross(
self,
row,
slices,
primary,
psi_row.block_idx,
&psi_row.local_vec,
&right_primary,
);
acc.add_pullback(self, row, slices, primary, &fourth);
let third_action = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&psi_action,
)?;
acc.add_pullback(self, row, slices, primary, &third_action);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
))
}
fn exact_newton_joint_hessian_directional_derivative_from_cache(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<Array2<f64>>, String> {
self.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
block_states,
d_beta_flat,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Array2<f64>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
if !self.effective_flex_active(block_states)? {
let mut block_acc = row_iter
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, row| -> Result<_, String> {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let dq = self
.marginal_design
.dot_row_view(row, d_beta_flat.slice(s![slices.marginal.clone()]));
let dg = self
.logslope_design
.dot_row_view(row, d_beta_flat.slice(s![slices.logslope.clone()]));
let t = self.rigid_row_third_contracted(
row,
marginal_eta,
marginal,
g,
dq,
dg,
)?;
let t_arr = Array2::from_shape_fn((2, 2), |(a, b)| t[a][b]);
acc.add_pullback(self, row, slices, primary, &t_arr);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| {
left.add(&right);
Ok(left)
},
)?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
return Ok(Some(block_acc.to_dense(slices)));
}
let mut block_acc = row_iter
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, row| -> Result<_, String> {
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)?;
let row_ctx = Self::row_ctx(cache, row);
let third = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
)?;
acc.add_pullback(self, row, slices, primary, &third);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
Ok(Some(block_acc.to_dense(slices)))
}
pub(crate) fn exact_newton_joint_hessian_directional_derivative_operator_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
if !self.effective_flex_active(block_states)? {
let mut block_acc = row_iter
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, row| -> Result<_, String> {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let dq = self
.marginal_design
.dot_row_view(row, d_beta_flat.slice(s![slices.marginal.clone()]));
let dg = self
.logslope_design
.dot_row_view(row, d_beta_flat.slice(s![slices.logslope.clone()]));
let t = self.rigid_row_third_contracted(
row,
marginal_eta,
marginal,
g,
dq,
dg,
)?;
let t_arr = Array2::from_shape_fn((2, 2), |(a, b)| t[a][b]);
acc.add_pullback(self, row, slices, primary, &t_arr);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
return Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
));
}
let mut block_acc = row_iter
.into_par_iter()
.try_fold(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut acc, row| -> Result<_, String> {
let row_dir =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_flat)?;
let row_ctx = Self::row_ctx(cache, row);
let third = self.row_primary_third_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_dir,
)?;
acc.add_pullback(self, row, slices, primary, &third);
Ok(acc)
},
)
.try_reduce(
|| BernoulliBlockHessianAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
))
}
fn exact_newton_joint_hessiansecond_directional_derivative_from_cache(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<Option<Array2<f64>>, String> {
self.exact_newton_joint_hessiansecond_directional_derivative_from_cache_with_options(
block_states,
d_beta_u_flat,
d_beta_v_flat,
cache,
&BlockwiseFitOptions::default(),
)
}
pub(crate) fn exact_newton_joint_hessiansecond_directional_derivative_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Array2<f64>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let make_acc = || BernoulliBlockHessianAccumulator::new(slices);
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
if !self.effective_flex_active(block_states)? {
let mut block_acc = row_iter
.into_par_iter()
.try_fold(make_acc, |mut acc, row| -> Result<_, String> {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let uq = self
.marginal_design
.dot_row_view(row, d_beta_u_flat.slice(s![slices.marginal.clone()]));
let ug = self
.logslope_design
.dot_row_view(row, d_beta_u_flat.slice(s![slices.logslope.clone()]));
let vq = self
.marginal_design
.dot_row_view(row, d_beta_v_flat.slice(s![slices.marginal.clone()]));
let vg = self
.logslope_design
.dot_row_view(row, d_beta_v_flat.slice(s![slices.logslope.clone()]));
let f = self.rigid_row_fourth_contracted(
row,
marginal_eta,
marginal,
g,
uq,
ug,
vq,
vg,
)?;
let f_arr = Array2::from_shape_fn((2, 2), |(a, b)| f[a][b]);
acc.add_pullback(self, row, slices, primary, &f_arr);
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
})?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
return Ok(Some(block_acc.to_dense(slices)));
}
let mut block_acc = row_iter
.into_par_iter()
.try_fold(make_acc, |mut acc, row| -> Result<_, String> {
let row_u =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_u_flat)?;
let row_v =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_v_flat)?;
let row_ctx = Self::row_ctx(cache, row);
let fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_u,
&row_v,
)?;
acc.add_pullback(self, row, slices, primary, &fourth);
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
})?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
Ok(Some(block_acc.to_dense(slices)))
}
pub(crate) fn exact_newton_joint_hessiansecond_directional_derivative_operator_from_cache_with_options(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
cache: &BernoulliMarginalSlopeExactEvalCache,
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let slices = &cache.slices;
let primary = &cache.primary;
let n = self.y.len();
let make_acc = || BernoulliBlockHessianAccumulator::new(slices);
let row_iter = outer_row_indices(options, n).to_vec();
let outer_scale = outer_score_scale(options, n);
if !self.effective_flex_active(block_states)? {
let mut block_acc = row_iter
.into_par_iter()
.try_fold(make_acc, |mut acc, row| -> Result<_, String> {
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let uq = self
.marginal_design
.dot_row_view(row, d_beta_u_flat.slice(s![slices.marginal.clone()]));
let ug = self
.logslope_design
.dot_row_view(row, d_beta_u_flat.slice(s![slices.logslope.clone()]));
let vq = self
.marginal_design
.dot_row_view(row, d_beta_v_flat.slice(s![slices.marginal.clone()]));
let vg = self
.logslope_design
.dot_row_view(row, d_beta_v_flat.slice(s![slices.logslope.clone()]));
let f = self.rigid_row_fourth_contracted(
row,
marginal_eta,
marginal,
g,
uq,
ug,
vq,
vg,
)?;
let f_arr = Array2::from_shape_fn((2, 2), |(a, b)| f[a][b]);
acc.add_pullback(self, row, slices, primary, &f_arr);
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
})?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
return Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
));
}
let mut block_acc = row_iter
.into_par_iter()
.try_fold(make_acc, |mut acc, row| -> Result<_, String> {
let row_u =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_u_flat)?;
let row_v =
self.row_primary_direction_from_flat(row, slices, primary, d_beta_v_flat)?;
let row_ctx = Self::row_ctx(cache, row);
let fourth = self.row_primary_fourth_contracted_recompute(
row,
block_states,
cache,
row_ctx,
&row_u,
&row_v,
)?;
acc.add_pullback(self, row, slices, primary, &fourth);
Ok(acc)
})
.try_reduce(make_acc, |mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
})?;
if outer_scale != 1.0 {
block_acc.scale_assign(outer_scale);
}
Ok(Some(
Arc::new(block_acc.into_operator(slices)) as Arc<dyn HyperOperator>
))
}
fn evaluate_flex_block_diagonals_from_cache(
&self,
block_states: &[ParameterBlockState],
slices: &BlockSlices,
cache: &BernoulliMarginalSlopeExactEvalCache,
) -> Result<FamilyEvaluation, String> {
let primary = cache.primary.clone();
let n = self.y.len();
let n_chunks = n.div_ceil(ROW_CHUNK_SIZE);
let reduced = (0..n_chunks)
.into_par_iter()
.try_fold(
|| BernoulliExactNewtonAccumulator::new(slices),
|mut acc, chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(primary.total);
for row in start..end {
let row_ctx = Self::row_ctx(cache, row);
let row_moments = cache
.row_cell_moments
.as_ref()
.and_then(|bundle| bundle.row(row, 9));
let row_neglog = self.compute_row_analytic_flex_into_with_moments(
row,
block_states,
&primary,
row_ctx,
row_moments,
true,
&mut scratch,
)?;
acc.add_pullback_block_diagonals(
self, row, &primary, row_neglog, &scratch,
)?;
}
Ok(acc)
},
)
.try_reduce(
|| BernoulliExactNewtonAccumulator::new(slices),
|mut left, right| -> Result<_, String> {
left.add(&right);
Ok(left)
},
)?;
let BernoulliExactNewtonAccumulator {
ll,
grad_marginal,
grad_logslope,
hess_marginal,
hess_logslope,
grad_h,
grad_w,
hess_h,
hess_w,
} = reduced;
let mut blockworking_sets = vec![
BlockWorkingSet::ExactNewton {
gradient: grad_marginal,
hessian: SymmetricMatrix::Dense(hess_marginal),
},
BlockWorkingSet::ExactNewton {
gradient: grad_logslope,
hessian: SymmetricMatrix::Dense(hess_logslope),
},
];
if let (Some(gradient), Some(hessian)) = (grad_h, hess_h) {
blockworking_sets.push(BlockWorkingSet::ExactNewton {
gradient,
hessian: SymmetricMatrix::Dense(hessian),
});
}
if let (Some(gradient), Some(hessian)) = (grad_w, hess_w) {
blockworking_sets.push(BlockWorkingSet::ExactNewton {
gradient,
hessian: SymmetricMatrix::Dense(hessian),
});
}
Ok(FamilyEvaluation {
log_likelihood: ll,
blockworking_sets,
})
}
fn evaluate_blockwise_exact_newton(
&self,
block_states: &[ParameterBlockState],
) -> Result<FamilyEvaluation, String> {
let slices = block_slices(self);
let flex_active = self.effective_flex_active(block_states)?;
if !flex_active && slices.total < 512 {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let cache = build_row_kernel_cache(&kern)?;
let ll = row_kernel_log_likelihood(&cache);
let joint_gradient = Self::exact_newton_score_from_objective_gradient(
row_kernel_gradient(&kern, &cache),
);
let n = cache.n;
let p_marginal = slices.marginal.len();
let p_logslope = slices.logslope.len();
let make_pair = || {
(
Array2::<f64>::zeros((p_marginal, p_marginal)),
Array2::<f64>::zeros((p_logslope, p_logslope)),
)
};
let (hess_marginal, hess_logslope) = (0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
make_pair,
|(mut hm, mut hl), chunk_idx| -> Result<(Array2<f64>, Array2<f64>), String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let rows = end - start;
let marginal_chunk = self
.marginal_design
.try_row_chunk(start..end)
.map_err(|e| format!("bernoulli marginal_design try_row_chunk: {e}"))?;
let logslope_chunk = self
.logslope_design
.try_row_chunk(start..end)
.map_err(|e| format!("bernoulli logslope_design try_row_chunk: {e}"))?;
let mut hm_w = Array1::<f64>::zeros(rows);
let mut hl_w = Array1::<f64>::zeros(rows);
for local_row in 0..rows {
let h = &cache.hessians[start + local_row];
hm_w[local_row] = h[0][0];
hl_w[local_row] = h[1][1];
}
add_weighted_chunk_gram(&marginal_chunk, &hm_w, &mut hm);
add_weighted_chunk_gram(&logslope_chunk, &hl_w, &mut hl);
Ok((hm, hl))
},
)
.try_reduce(
make_pair,
|(mut lhm, mut lhl),
(rhm, rhl)|
-> Result<(Array2<f64>, Array2<f64>), String> {
lhm += &rhm;
lhl += &rhl;
Ok((lhm, lhl))
},
)?;
let hess_marginal =
Self::exact_newton_observed_information_from_objective_hessian(hess_marginal);
let hess_logslope =
Self::exact_newton_observed_information_from_objective_hessian(hess_logslope);
let grad_marginal = joint_gradient.slice(s![slices.marginal.clone()]).to_owned();
let grad_logslope = joint_gradient.slice(s![slices.logslope.clone()]).to_owned();
let mut sets = vec![
BlockWorkingSet::ExactNewton {
gradient: grad_marginal,
hessian: SymmetricMatrix::Dense(hess_marginal),
},
BlockWorkingSet::ExactNewton {
gradient: grad_logslope,
hessian: SymmetricMatrix::Dense(hess_logslope),
},
];
if let Some(range) = slices.h.as_ref() {
sets.push(BlockWorkingSet::ExactNewton {
gradient: Array1::zeros(range.len()),
hessian: SymmetricMatrix::Dense(Array2::zeros((range.len(), range.len()))),
});
}
if let Some(range) = slices.w.as_ref() {
sets.push(BlockWorkingSet::ExactNewton {
gradient: Array1::zeros(range.len()),
hessian: SymmetricMatrix::Dense(Array2::zeros((range.len(), range.len()))),
});
}
return Ok(FamilyEvaluation {
log_likelihood: ll,
blockworking_sets: sets,
});
}
if flex_active {
let cache = self.build_exact_eval_cache_with_order(block_states)?;
return self.evaluate_flex_block_diagonals_from_cache(block_states, &slices, &cache);
}
let n = self.y.len();
let p_marginal = slices.marginal.len();
let p_logslope = slices.logslope.len();
let make_acc = || {
(
0.0_f64,
Array1::<f64>::zeros(p_marginal),
Array1::<f64>::zeros(p_logslope),
Array2::<f64>::zeros((p_marginal, p_marginal)),
Array2::<f64>::zeros((p_logslope, p_logslope)),
)
};
let (ll, grad_marginal, grad_logslope, hess_marginal, hess_logslope) =
(0..((n + ROW_CHUNK_SIZE - 1) / ROW_CHUNK_SIZE))
.into_par_iter()
.try_fold(
make_acc,
|(mut ll, mut gm, mut gl, mut hm, mut hl), chunk_idx| -> Result<_, String> {
let start = chunk_idx * ROW_CHUNK_SIZE;
let end = (start + ROW_CHUNK_SIZE).min(n);
let rows = end - start;
let marginal_chunk = self
.marginal_design
.try_row_chunk(start..end)
.map_err(|e| format!("bernoulli marginal_design try_row_chunk: {e}"))?;
let logslope_chunk = self
.logslope_design
.try_row_chunk(start..end)
.map_err(|e| format!("bernoulli logslope_design try_row_chunk: {e}"))?;
let mut gm_w = Array1::<f64>::zeros(rows);
let mut gl_w = Array1::<f64>::zeros(rows);
let mut hm_w = Array1::<f64>::zeros(rows);
let mut hl_w = Array1::<f64>::zeros(rows);
for local_row in 0..rows {
let row = start + local_row;
let marginal_eta = block_states[0].eta[row];
let marginal = self.marginal_link_map(marginal_eta)?;
let g = block_states[1].eta[row];
let (neglog, grad, h) =
self.rigid_row_kernel_eval(row, marginal_eta, marginal, g)?;
ll -= neglog;
gm_w[local_row] =
Self::exact_newton_score_component_from_objective_gradient(grad[0]);
gl_w[local_row] =
Self::exact_newton_score_component_from_objective_gradient(grad[1]);
hm_w[local_row] = h[0][0];
hl_w[local_row] = h[1][1];
}
add_weighted_chunk_gradient(&marginal_chunk, &gm_w, &mut gm);
add_weighted_chunk_gradient(&logslope_chunk, &gl_w, &mut gl);
add_weighted_chunk_gram(&marginal_chunk, &hm_w, &mut hm);
add_weighted_chunk_gram(&logslope_chunk, &hl_w, &mut hl);
Ok((ll, gm, gl, hm, hl))
},
)
.try_reduce(
make_acc,
|(lll, mut lgm, mut lgl, mut lhm, mut lhl),
(rll, rgm, rgl, rhm, rhl)|
-> Result<_, String> {
lgm += &rgm;
lgl += &rgl;
lhm += &rhm;
lhl += &rhl;
Ok((lll + rll, lgm, lgl, lhm, lhl))
},
)?;
Ok(FamilyEvaluation {
log_likelihood: ll,
blockworking_sets: {
let mut sets = vec![
BlockWorkingSet::ExactNewton {
gradient: grad_marginal,
hessian: SymmetricMatrix::Dense(hess_marginal),
},
BlockWorkingSet::ExactNewton {
gradient: grad_logslope,
hessian: SymmetricMatrix::Dense(hess_logslope),
},
];
if let Some(range) = slices.h.as_ref() {
sets.push(BlockWorkingSet::ExactNewton {
gradient: Array1::zeros(range.len()),
hessian: SymmetricMatrix::Dense(Array2::zeros((range.len(), range.len()))),
});
}
if let Some(range) = slices.w.as_ref() {
sets.push(BlockWorkingSet::ExactNewton {
gradient: Array1::zeros(range.len()),
hessian: SymmetricMatrix::Dense(Array2::zeros((range.len(), range.len()))),
});
}
sets
},
})
}
}
impl CustomFamily for BernoulliMarginalSlopeFamily {
fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
true
}
fn pseudo_logdet_mode(&self) -> crate::custom_family::PseudoLogdetMode {
crate::custom_family::PseudoLogdetMode::HardPseudo
}
fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
let n = self.y.len() as u64;
let p_total: u64 = specs
.iter()
.map(|s| s.design.ncols() as u64)
.fold(0u64, |a, p| a.saturating_add(p));
if crate::custom_family::use_joint_matrix_free_path(p_total as usize, n as usize) {
n.saturating_mul(p_total)
} else {
crate::custom_family::joint_coupled_coefficient_hessian_cost(n, specs)
}
}
fn exact_newton_joint_psi_workspace_for_first_order_terms(&self) -> bool {
true
}
fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
self.validate_exact_monotonicity(block_states)?;
self.evaluate_blockwise_exact_newton(block_states)
}
fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
Self::log_likelihood_only_with_options(self, block_states, &BlockwiseFitOptions::default())
}
fn log_likelihood_only_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<f64, String> {
Self::log_likelihood_only_with_options(self, block_states, options)
}
fn supports_log_likelihood_early_exit(&self) -> bool {
true
}
fn max_feasible_step_size(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
delta: &Array1<f64>,
) -> Result<Option<f64>, String> {
self.validate_exact_block_state_shapes(block_states)?;
let beta = block_states.get(block_idx).ok_or_else(|| {
format!(
"bernoulli marginal-slope block index {block_idx} is out of bounds for {} states",
block_states.len()
)
})?;
if delta.len() != beta.beta.len() {
return Err(format!(
"bernoulli marginal-slope step length mismatch for block {block_idx}: delta={}, beta={}",
delta.len(),
beta.beta.len()
));
}
Ok(None)
}
fn has_explicit_joint_hessian(&self) -> bool {
true
}
fn exact_newton_joint_hessian(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
let slices = block_slices(self);
if slices.total >= 512 {
return Ok(None);
}
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let cache = build_row_kernel_cache(&kern)?;
return Ok(Some(row_kernel_hessian_dense(&kern, &cache)));
}
let cache = self.build_exact_eval_cache_with_order(block_states)?;
self.exact_newton_joint_hessian_dense_from_cache(block_states, &cache)
.map(Some)
}
fn exact_newton_joint_gradient_evaluation(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
self.validate_exact_monotonicity(block_states)?;
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let cache = build_row_kernel_cache(&kern)?;
return Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood: row_kernel_log_likelihood(&cache),
gradient: Self::exact_newton_score_from_objective_gradient(row_kernel_gradient(
&kern, &cache,
)),
}));
}
let cache = self.build_exact_eval_cache_with_order(block_states)?;
self.exact_newton_joint_gradient_evaluation_from_cache(block_states, &cache)
.map(Some)
}
fn requires_joint_outer_hyper_path(&self) -> bool {
true
}
fn exact_newton_joint_hessian_workspace(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
Ok(Some(Arc::new(RowKernelHessianWorkspace::new(kern)?)))
} else {
Ok(Some(Arc::new(
BernoulliMarginalSlopeExactNewtonJointHessianWorkspace::new(
self.clone(),
block_states.to_vec(),
BlockwiseFitOptions::default(),
)?,
)))
}
}
fn exact_newton_joint_hessian_workspace_with_options(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
Ok(Some(Arc::new(RowKernelHessianWorkspace::new(kern)?)))
} else {
Ok(Some(Arc::new(
BernoulliMarginalSlopeExactNewtonJointHessianWorkspace::new(
self.clone(),
block_states.to_vec(),
options.clone(),
)?,
)))
}
}
fn inner_coefficient_hessian_hvp_available(&self, _specs: &[ParameterBlockSpec]) -> bool {
true
}
fn inner_joint_workspace_gradient_available(&self, _specs: &[ParameterBlockSpec]) -> bool {
true
}
fn inner_joint_workspace_log_likelihood_available(
&self,
_specs: &[ParameterBlockSpec],
) -> bool {
true
}
fn exact_newton_joint_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let sl = d_beta_flat.as_slice().ok_or("non-contiguous d_beta")?;
crate::families::row_kernel::row_kernel_directional_derivative(&kern, sl).map(Some)
} else {
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_hessian_directional_derivative_from_cache(
block_states,
d_beta_flat,
&cache,
)
}
}
fn exact_newton_joint_hessiansecond_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if !self.effective_flex_active(block_states)? {
let kern = BernoulliRigidRowKernel::new(self.clone(), block_states.to_vec());
let su = d_beta_u_flat.as_slice().ok_or("non-contiguous d_beta_u")?;
let sv = d_beta_v_flat.as_slice().ok_or("non-contiguous d_beta_v")?;
crate::families::row_kernel::row_kernel_second_directional_derivative(&kern, su, sv)
.map(Some)
} else {
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_hessiansecond_directional_derivative_from_cache(
block_states,
d_beta_u_flat,
d_beta_v_flat,
&cache,
)
}
}
fn exact_newton_joint_psi_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_index) {
return self.sigma_exact_joint_psi_terms(block_states, specs);
}
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_psi_terms_from_cache(
block_states,
derivative_blocks,
psi_index,
&cache,
)
}
fn exact_newton_joint_psisecond_order_terms(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_i)
|| self.is_sigma_aux_index(derivative_blocks, psi_j)
{
if self.is_sigma_aux_index(derivative_blocks, psi_i)
&& self.is_sigma_aux_index(derivative_blocks, psi_j)
{
return self.sigma_exact_joint_psisecond_order_terms(block_states);
}
return Err(
"bernoulli marginal-slope mixed log-sigma/spatial psi second derivatives require cross auxiliary terms; only pure log-sigma second derivatives are supported"
.to_string(),
);
}
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_psisecond_order_terms_from_cache(
block_states,
derivative_blocks,
psi_i,
psi_j,
&cache,
)
}
fn exact_newton_joint_psihessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_index) {
return self
.sigma_exact_joint_psihessian_directional_derivative(block_states, d_beta_flat);
}
let cache = self.build_exact_eval_cache(block_states)?;
self.exact_newton_joint_psihessian_directional_derivative_from_cache(
block_states,
derivative_blocks,
psi_index,
d_beta_flat,
&cache,
)
}
fn exact_newton_joint_psi_workspace(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
Ok(Some(Arc::new(
BernoulliMarginalSlopeExactNewtonJointPsiWorkspace::new(
self.clone(),
block_states.to_vec(),
specs.to_vec(),
derivative_blocks.to_vec(),
BlockwiseFitOptions::default(),
)?,
)))
}
fn exact_newton_joint_psi_workspace_with_options(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
Ok(Some(Arc::new(
BernoulliMarginalSlopeExactNewtonJointPsiWorkspace::new(
self.clone(),
block_states.to_vec(),
specs.to_vec(),
derivative_blocks.to_vec(),
options.clone(),
)?,
)))
}
fn block_linear_constraints(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
spec: &ParameterBlockSpec,
) -> Result<Option<LinearInequalityConstraints>, String> {
if block_states.len() == usize::MAX
|| block_idx == usize::MAX
|| spec.design.ncols() == usize::MAX
{
return Err("unreachable bernoulli marginal-slope constraint state".to_string());
}
if self.score_block_index().is_some_and(|idx| block_idx == idx) {
return Ok(self
.score_warp
.as_ref()
.map(DeviationRuntime::structural_monotonicity_constraints));
}
if self.link_block_index().is_some_and(|idx| block_idx == idx) {
return Ok(self
.link_dev
.as_ref()
.map(DeviationRuntime::structural_monotonicity_constraints));
}
Ok(None)
}
fn post_update_block_beta(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
_: &ParameterBlockSpec,
beta: Array1<f64>,
) -> Result<Array1<f64>, String> {
self.validate_exact_block_state_shapes(block_states)?;
if block_idx >= block_states.len() {
return Err(format!(
"post-update block index {} out of range for {} blocks",
block_idx,
block_states.len()
));
}
if self.score_block_index().is_some_and(|idx| block_idx == idx) {
if let (Some(runtime), Some(score)) =
(&self.score_warp, self.score_block_state(block_states)?)
{
let current = &score.beta;
if current.len() != beta.len() {
return Err(format!(
"score-warp post-update beta length mismatch: current={}, proposed={}",
current.len(),
beta.len()
));
}
return project_monotone_feasible_beta(runtime, current, &beta, "score_warp_dev");
}
}
if self.link_block_index().is_some_and(|idx| block_idx == idx) {
if let (Some(runtime), Some(link)) =
(&self.link_dev, self.link_block_state(block_states)?)
{
let current = &link.beta;
if current.len() != beta.len() {
return Err(format!(
"link-deviation post-update beta length mismatch: current={}, proposed={}",
current.len(),
beta.len()
));
}
return project_monotone_feasible_beta(runtime, current, &beta, "link_dev");
}
}
Ok(beta)
}
}
impl BernoulliMarginalSlopeExactNewtonJointHessianWorkspace {
fn new(
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
options: BlockwiseFitOptions,
) -> Result<Self, String> {
let mut cache = family.build_exact_eval_cache(&block_states)?;
cache.row_primary_hessians =
family.build_row_primary_hessian_cache(&block_states, &cache)?;
Ok(Self {
family,
block_states,
cache,
matvec_calls: AtomicUsize::new(0),
options,
})
}
}
impl ExactNewtonJointHessianWorkspace for BernoulliMarginalSlopeExactNewtonJointHessianWorkspace {
fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
if self.cache.slices.total >= 512 {
return Ok(None);
}
self.family
.exact_newton_joint_hessian_dense_from_cache(&self.block_states, &self.cache)
.map(Some)
}
fn joint_log_likelihood_evaluation(&self) -> Result<Option<f64>, String> {
self.family
.log_likelihood_from_exact_cache(&self.block_states, &self.cache)
.map(Some)
}
fn joint_gradient_evaluation(
&self,
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
self.family
.exact_newton_joint_gradient_evaluation_from_cache(&self.block_states, &self.cache)
.map(Some)
}
fn hessian_matvec(&self, beta_flat: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
let call = self.matvec_calls.fetch_add(1, Ordering::Relaxed) + 1;
let started = std::time::Instant::now();
let result = self
.family
.exact_newton_joint_hessian_matvec_from_cache(
beta_flat,
&self.block_states,
&self.cache,
)
.map(Some);
if log_exact_work(self.family.y.len()) && (call <= 3 || call.is_power_of_two()) {
log::info!(
"[BMS Hessian-Hv] call={} n={} p={} elapsed={:.3}s",
call,
self.family.y.len(),
self.cache.slices.total,
started.elapsed().as_secs_f64()
);
}
result
}
fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
self.family
.exact_newton_joint_hessian_diagonal_from_cache(&self.block_states, &self.cache)
.map(Some)
}
fn directional_derivative_operator(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
self.family
.exact_newton_joint_hessian_directional_derivative_operator_from_cache_with_options(
&self.block_states,
d_beta_flat,
&self.cache,
&self.options,
)
}
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&self.block_states,
d_beta_flat,
&self.cache,
&self.options,
)
}
fn second_directional_derivative_operator(
&self,
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
self.family
.exact_newton_joint_hessiansecond_directional_derivative_operator_from_cache_with_options(
&self.block_states,
d_beta_u_flat,
d_beta_v_flat,
&self.cache,
&self.options,
)
}
fn second_directional_derivative(
&self,
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessiansecond_directional_derivative_from_cache_with_options(
&self.block_states,
d_beta_u_flat,
d_beta_v_flat,
&self.cache,
&self.options,
)
}
}
impl BernoulliMarginalSlopeExactNewtonJointPsiWorkspace {
fn new(
family: BernoulliMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
specs: Vec<ParameterBlockSpec>,
derivative_blocks: Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
options: BlockwiseFitOptions,
) -> Result<Self, String> {
let cache = family.build_exact_eval_cache(&block_states)?;
Ok(Self {
family,
block_states,
specs,
derivative_blocks,
cache,
options,
})
}
}
impl ExactNewtonJointPsiWorkspace for BernoulliMarginalSlopeExactNewtonJointPsiWorkspace {
fn first_order_terms(
&self,
psi_index: usize,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_index)
{
return self.family.sigma_exact_joint_psi_terms_with_options(
&self.block_states,
&self.specs,
&self.options,
);
}
self.family
.exact_newton_joint_psi_terms_from_cache_with_options(
&self.block_states,
&self.derivative_blocks,
psi_index,
&self.cache,
&self.options,
)
}
fn second_order_terms(
&self,
psi_i: usize,
psi_j: usize,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_i)
|| self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_j)
{
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_i)
&& self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_j)
{
return self
.family
.sigma_exact_joint_psisecond_order_terms_with_options(
&self.block_states,
&self.options,
);
}
return Err(
"bernoulli marginal-slope mixed log-sigma/spatial psi second derivatives require cross auxiliary terms; only pure log-sigma second derivatives are supported"
.to_string(),
);
}
self.family
.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&self.block_states,
&self.derivative_blocks,
psi_i,
psi_j,
&self.cache,
&self.options,
)
}
fn hessian_directional_derivative(
&self,
psi_index: usize,
d_beta_flat: &Array1<f64>,
) -> Result<Option<crate::solver::estimate::reml::unified::DriftDerivResult>, String> {
if self
.family
.is_sigma_aux_index(&self.derivative_blocks, psi_index)
{
return self
.family
.sigma_exact_joint_psihessian_directional_derivative_with_options(
&self.block_states,
d_beta_flat,
&self.options,
)
.map(|result| {
result.map(crate::solver::estimate::reml::unified::DriftDerivResult::Dense)
});
}
self.family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&self.block_states,
&self.derivative_blocks,
psi_index,
d_beta_flat,
&self.cache,
&self.options,
)
.map(|result| {
result.map(crate::solver::estimate::reml::unified::DriftDerivResult::Operator)
})
}
}
fn build_blockspec(
name: &str,
design: &TermCollectionDesign,
baseline: f64,
offset: &Array1<f64>,
rho: Array1<f64>,
beta_hint: Option<Array1<f64>>,
) -> ParameterBlockSpec {
ParameterBlockSpec {
name: name.to_string(),
design: design.design.clone(),
offset: offset + baseline,
penalties: design.penalties_as_penalty_matrix(),
nullspace_dims: design.nullspace_dims.clone(),
initial_log_lambdas: rho,
initial_beta: beta_hint,
}
}
pub(crate) fn build_deviation_aux_blockspec(
name: &str,
prepared: &DeviationPrepared,
rho: Array1<f64>,
beta_hint: Option<Array1<f64>>,
) -> Result<ParameterBlockSpec, String> {
let mut block = prepared.block.clone();
block.initial_log_lambdas = Some(rho);
let candidate_beta = beta_hint.or_else(|| Some(Array1::<f64>::zeros(block.design.ncols())));
block.initial_beta = candidate_beta
.map(|beta| {
let zero = Array1::<f64>::zeros(beta.len());
project_monotone_feasible_beta(&prepared.runtime, &zero, &beta, name)
})
.transpose()?;
block.intospec(name)
}
pub(crate) fn push_deviation_aux_blockspecs(
blocks: &mut Vec<ParameterBlockSpec>,
rho: &Array1<f64>,
cursor: &mut usize,
score_warp_prepared: Option<&DeviationPrepared>,
link_dev_prepared: Option<&DeviationPrepared>,
score_warp_beta_hint: Option<Array1<f64>>,
link_dev_beta_hint: Option<Array1<f64>>,
) -> Result<(), String> {
if let Some(prepared) = score_warp_prepared {
let rho_h = rho
.slice(s![*cursor..*cursor + prepared.block.penalties.len()])
.to_owned();
*cursor += prepared.block.penalties.len();
blocks.push(build_deviation_aux_blockspec(
"score_warp_dev",
prepared,
rho_h,
score_warp_beta_hint,
)?);
}
if let Some(prepared) = link_dev_prepared {
let rho_w = rho
.slice(s![*cursor..*cursor + prepared.block.penalties.len()])
.to_owned();
blocks.push(build_deviation_aux_blockspec(
"link_dev",
prepared,
rho_w,
link_dev_beta_hint,
)?);
}
Ok(())
}
fn inner_fit(
family: &BernoulliMarginalSlopeFamily,
blocks: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<UnifiedFitResult, String> {
fit_custom_family(family, blocks, options).map_err(|e| e.to_string())
}
pub fn fit_bernoulli_marginal_slope_terms(
data: ArrayView2<'_, f64>,
spec: BernoulliMarginalSlopeTermSpec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
policy: &crate::resource::ResourcePolicy,
) -> Result<BernoulliMarginalSlopeFitResult, String> {
let mut spec = spec;
let data_view = data;
validate_spec(data_view, &spec)?;
let (z_standardized, z_normalization) = standardize_latent_z_with_policy(
&spec.z,
&spec.weights,
"bernoulli-marginal-slope",
&spec.latent_z_policy,
)?;
spec.z = z_standardized;
let pilot_baseline = pooled_probit_baseline(&spec.y, &spec.z, &spec.weights)?;
let sigma_learnable = matches!(
&spec.frailty,
FrailtySpec::GaussianShift { sigma_fixed: None }
);
let initial_sigma = match &spec.frailty {
FrailtySpec::GaussianShift {
sigma_fixed: Some(s),
} => Some(*s),
FrailtySpec::GaussianShift { sigma_fixed: None } => Some(0.5),
FrailtySpec::None => None,
FrailtySpec::HazardMultiplier { .. } => {
unreachable!("validate_spec rejects unsupported marginal-slope frailty")
}
};
let probit_scale = probit_frailty_scale(initial_sigma);
let baseline = (
bernoulli_marginal_slope_eta_from_probability(
&spec.base_link,
normal_cdf(pilot_baseline.0),
"bernoulli marginal-slope baseline link inversion",
)?,
pilot_baseline.1 / probit_scale,
);
let (mut joint_designs, mut joint_specs) = build_term_collection_designs_and_freeze_joint(
data_view,
&[spec.marginalspec.clone(), spec.logslopespec.clone()],
)
.map_err(|e| e.to_string())?;
let marginal_design = joint_designs.remove(0);
let logslope_design = joint_designs.remove(0);
let marginalspec_boot = joint_specs.remove(0);
let logslopespec_boot = joint_specs.remove(0);
let latent_measure = build_latent_measure_with_geometry(
&spec.z,
&spec.weights,
&spec.latent_z_policy,
Some(data_view),
&[&marginalspec_boot, &logslopespec_boot],
)?;
if latent_measure.is_empirical() && sigma_learnable {
return Err(
"empirical latent-measure marginal-slope calibration requires fixed GaussianShift sigma; learnable sigma derivatives must be fit under the standard-normal latent measure"
.to_string(),
);
}
let y = Arc::new(spec.y.clone());
let weights = Arc::new(spec.weights.clone());
let z = Arc::new(spec.z.clone());
let score_warp_prepared = spec
.score_warp
.as_ref()
.map(|cfg| {
if matches!(latent_measure, LatentMeasureKind::StandardNormal) {
build_score_warp_deviation_block_from_seed(&spec.z, cfg)
} else {
build_score_warp_deviation_block_from_seed_empirical_anchor(
&spec.z,
&spec.weights,
cfg,
)
}
})
.transpose()?;
let link_dev_prepared = spec
.link_dev
.as_ref()
.map(|cfg| {
let q0_seed = Array1::from_iter((0..spec.z.len()).map(|row| {
let a0 = bernoulli_marginal_link_map(
&spec.base_link,
baseline.0 + spec.marginal_offset[row],
)
.expect("validated bernoulli marginal base link should produce finite pilot q")
.q;
let b0 = baseline.1 + spec.logslope_offset[row];
rigid_observed_eta(a0, b0, spec.z[row], probit_scale)
}));
let link_dev_seed = padded_deviation_seed(&q0_seed, 1.0, 0.5);
build_link_deviation_block_from_knots_design_seed_and_weights(
&link_dev_seed,
&q0_seed,
&spec.weights,
cfg,
)
})
.transpose()?;
let extra_rho0 = {
let mut out = Vec::new();
if let Some(ref prepared) = score_warp_prepared {
out.extend(std::iter::repeat(0.0).take(prepared.block.penalties.len()));
}
if let Some(ref prepared) = link_dev_prepared {
out.extend(std::iter::repeat(0.0).take(prepared.block.penalties.len()));
}
out
};
let setup = joint_setup(
data_view,
&marginalspec_boot,
&logslopespec_boot,
marginal_design.penalties.len(),
logslope_design.penalties.len(),
&extra_rho0,
kappa_options,
);
let setup = if sigma_learnable {
setup.with_auxiliary(
Array1::from_vec(vec![initial_sigma.expect("learnable sigma seed").ln()]),
Array1::from_vec(vec![0.01_f64.ln()]),
Array1::from_vec(vec![5.0_f64.ln()]),
)
} else {
setup
};
let final_sigma_cell = std::cell::Cell::new(initial_sigma);
let exact_warm_start = RefCell::new(None::<CustomFamilyWarmStart>);
let hints = RefCell::new(ThetaHints::default());
let score_warp_runtime = score_warp_prepared.as_ref().map(|p| p.runtime.clone());
let link_dev_runtime = link_dev_prepared.as_ref().map(|p| p.runtime.clone());
let build_blocks = |rho: &Array1<f64>,
marginal_design: &TermCollectionDesign,
logslope_design: &TermCollectionDesign|
-> Result<Vec<ParameterBlockSpec>, String> {
let hints = hints.borrow();
let mut cursor = 0usize;
let rho_marginal = rho
.slice(s![cursor..cursor + marginal_design.penalties.len()])
.to_owned();
cursor += marginal_design.penalties.len();
let rho_logslope = rho
.slice(s![cursor..cursor + logslope_design.penalties.len()])
.to_owned();
cursor += logslope_design.penalties.len();
let mut blocks = vec![
build_blockspec(
"marginal_surface",
marginal_design,
baseline.0,
&spec.marginal_offset,
rho_marginal,
hints.marginal_beta.clone(),
),
build_blockspec(
"logslope_surface",
logslope_design,
baseline.1,
&spec.logslope_offset,
rho_logslope,
hints.logslope_beta.clone(),
),
];
push_deviation_aux_blockspecs(
&mut blocks,
rho,
&mut cursor,
score_warp_prepared.as_ref(),
link_dev_prepared.as_ref(),
hints.score_warp_beta.clone(),
hints.link_dev_beta.clone(),
)?;
Ok(blocks)
};
let intercept_warm_starts = new_intercept_warm_start_cache(y.len());
let cell_moment_lru = new_cell_moment_lru_cache(policy);
let cell_moment_cache_stats = new_cell_moment_cache_stats();
let make_family = |marginal_design: &TermCollectionDesign,
logslope_design: &TermCollectionDesign,
sigma: Option<f64>|
-> BernoulliMarginalSlopeFamily {
BernoulliMarginalSlopeFamily {
y: Arc::clone(&y),
weights: Arc::clone(&weights),
z: Arc::clone(&z),
latent_measure: latent_measure.clone(),
gaussian_frailty_sd: sigma,
base_link: spec.base_link.clone(),
marginal_design: marginal_design.design.clone(),
logslope_design: logslope_design.design.clone(),
score_warp: score_warp_runtime.clone(),
link_dev: link_dev_runtime.clone(),
policy: policy.clone(),
cell_moment_lru: Arc::clone(&cell_moment_lru),
cell_moment_cache_stats: Arc::clone(&cell_moment_cache_stats),
intercept_warm_starts: Some(Arc::clone(&intercept_warm_starts)),
}
};
let marginal_terms = spatial_length_scale_term_indices(&marginalspec_boot);
let logslope_terms = spatial_length_scale_term_indices(&logslopespec_boot);
let marginal_has_spatial = !marginal_terms.is_empty();
let logslope_has_spatial = !logslope_terms.is_empty();
let analytic_joint_derivatives_available =
marginal_has_spatial || logslope_has_spatial || setup.log_kappa_dim() == 0;
if setup.log_kappa_dim() > 0 && !analytic_joint_derivatives_available {
return Err(
"exact bernoulli marginal-slope spatial optimization requires analytic joint psi derivatives"
.to_string(),
);
}
let initial_rho = setup.theta0().slice(s![..setup.rho_dim()]).to_owned();
let initial_blocks = build_blocks(&initial_rho, &marginal_design, &logslope_design)?;
let initial_family = make_family(&marginal_design, &logslope_design, initial_sigma);
let (joint_gradient, joint_hessian) =
custom_family_outer_derivatives(&initial_family, &initial_blocks, options);
let analytic_joint_gradient_available = analytic_joint_derivatives_available
&& matches!(
joint_gradient,
crate::solver::outer_strategy::Derivative::Analytic
);
let analytic_joint_hessian_available =
analytic_joint_derivatives_available && joint_hessian.is_analytic();
let kappa_options_ref: &SpatialLengthScaleOptimizationOptions = kappa_options;
let sigma_from_theta = |theta: &Array1<f64>| -> Option<f64> {
if sigma_learnable {
Some(theta[setup.rho_dim() + setup.log_kappa_dim()].exp())
} else {
initial_sigma
}
};
let derivative_block_cache = RefCell::new(
None::<(
Array1<f64>,
Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
)>,
);
let theta_matches = |left: &Array1<f64>, right: &Array1<f64>| -> bool {
left.len() == right.len()
&& left
.iter()
.zip(right.iter())
.all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12 * (1.0 + lhs.abs().max(rhs.abs())))
};
let get_derivative_blocks = |theta: &Array1<f64>,
specs: &[TermCollectionSpec],
designs: &[TermCollectionDesign]|
-> Result<
Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
String,
> {
if let Some((cached_theta, cached_blocks)) = derivative_block_cache.borrow().as_ref()
&& theta_matches(cached_theta, theta)
{
return Ok(Arc::clone(cached_blocks));
}
let built = |specs: &[TermCollectionSpec],
designs: &[TermCollectionDesign]|
-> Result<
Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
String,
> {
let marginal_psi_derivs = if marginal_has_spatial {
build_block_spatial_psi_derivatives(data_view, &specs[0], &designs[0])?.ok_or_else(
|| {
"bernoulli marginal-slope: marginal block has spatial terms \
but spatial psi derivatives are unavailable"
.to_string()
},
)?
} else {
Vec::new()
};
let logslope_psi_derivs = if logslope_has_spatial {
build_block_spatial_psi_derivatives(data_view, &specs[1], &designs[1])?.ok_or_else(
|| {
"bernoulli marginal-slope: logslope block has spatial terms \
but spatial psi derivatives are unavailable"
.to_string()
},
)?
} else {
Vec::new()
};
let mut derivative_blocks = vec![marginal_psi_derivs, logslope_psi_derivs];
if score_warp_runtime.is_some() {
derivative_blocks.push(Vec::new());
}
if link_dev_runtime.is_some() {
derivative_blocks.push(Vec::new());
}
if sigma_learnable {
derivative_blocks
.last_mut()
.expect("bernoulli derivative block list is non-empty")
.push(crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
Array2::zeros((0, 0)),
Array2::zeros((0, 0)),
None,
None,
None,
None,
));
}
Ok(derivative_blocks)
}(specs, designs)?;
let built = Arc::new(built);
derivative_block_cache.replace(Some((theta.clone(), Arc::clone(&built))));
Ok(built)
};
let solved = optimize_spatial_length_scale_exact_joint(
data_view,
&[marginalspec_boot.clone(), logslopespec_boot.clone()],
&[marginal_terms.clone(), logslope_terms.clone()],
kappa_options_ref,
&setup,
crate::seeding::SeedRiskProfile::GeneralizedLinear,
analytic_joint_gradient_available,
analytic_joint_hessian_available,
true,
|theta, _: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let fit = inner_fit(&family, &blocks, options)?;
let mut hints_mut = hints.borrow_mut();
let mut bidx = 0usize;
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.marginal_beta = Some(block.beta.clone());
}
bidx += 1;
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.logslope_beta = Some(block.beta.clone());
}
bidx += 1;
if score_warp_prepared.is_some() {
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.score_warp_beta = Some(block.beta.clone());
}
bidx += 1;
}
if link_dev_prepared.is_some() {
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.link_dev_beta = Some(block.beta.clone());
}
}
Ok(fit)
},
|theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign], eval_mode| {
use crate::solver::estimate::reml::unified::EvalMode;
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
let effective_mode = match eval_mode {
EvalMode::ValueGradientHessian if !analytic_joint_hessian_available => {
EvalMode::ValueAndGradient
}
other => other,
};
let eval = evaluate_custom_family_joint_hyper_shared(
&family,
&blocks,
options,
&rho,
derivative_blocks,
exact_warm_start.borrow().as_ref(),
effective_mode,
)?;
exact_warm_start.replace(Some(eval.warm_start.clone()));
if !eval.inner_converged {
return Err(
"exact bernoulli marginal-slope inner solve did not converge".to_string(),
);
}
if matches!(eval_mode, EvalMode::ValueGradientHessian)
&& analytic_joint_hessian_available
&& !eval.outer_hessian.is_analytic()
{
return Err(
"exact bernoulli marginal-slope joint [rho, psi] objective did not return an outer Hessian"
.to_string(),
);
}
Ok((eval.objective, eval.gradient, eval.outer_hessian))
},
|theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
let eval = evaluate_custom_family_joint_hyper_efs_shared(
&family,
&blocks,
options,
&rho,
derivative_blocks,
exact_warm_start.borrow().as_ref(),
)?;
exact_warm_start.replace(Some(eval.warm_start.clone()));
if !eval.inner_converged {
return Err(
"exact bernoulli marginal-slope EFS inner solve did not converge".to_string(),
);
}
Ok(eval.efs_eval)
},
)?;
let mut resolved_specs = solved.resolved_specs;
let mut designs = solved.designs;
Ok(BernoulliMarginalSlopeFitResult {
fit: solved.fit,
marginalspec_resolved: resolved_specs.remove(0),
logslopespec_resolved: resolved_specs.remove(0),
marginal_design: designs.remove(0),
logslope_design: designs.remove(0),
baseline_marginal: baseline.0,
baseline_logslope: baseline.1,
z_normalization,
latent_measure,
score_warp_runtime,
link_dev_runtime,
gaussian_frailty_sd: final_sigma_cell.get(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::custom_family::{CustomFamily, ExactOuterDerivativeOrder};
use crate::families::bernoulli_marginal_slope::exact_kernel::{
DenestedCubicCell as ExactDenestedCubicCell, ExactCellBranch as ExactCellBranchShared,
LocalSpanCubic, branch_cell as branch_exact_cell, build_denested_partition_cells,
denested_cell_coefficient_partials as exact_denested_cell_coefficient_partials,
global_cubic_from_local as exact_global_cubic_from_local,
transformed_link_cubic as exact_transformed_link_cubic,
};
use ndarray::array;
#[inline]
fn bernoulli_marginal_slope_probit_link() -> InverseLink {
InverseLink::Standard(LinkFunction::Probit)
}
fn empty_termspec() -> TermCollectionSpec {
TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![],
}
}
fn dummy_blockspec(p: usize, n_rows: usize) -> ParameterBlockSpec {
ParameterBlockSpec {
name: "dummy".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::<f64>::zeros((n_rows, p)),
)),
offset: Array1::zeros(n_rows),
penalties: vec![],
nullspace_dims: vec![],
initial_log_lambdas: Array1::zeros(0),
initial_beta: Some(Array1::zeros(p)),
}
}
fn dummy_block_state(beta: Array1<f64>, n_rows: usize) -> ParameterBlockState {
ParameterBlockState {
beta,
eta: Array1::zeros(n_rows),
}
}
fn flex_hessian_matvec_fixture(
n: usize,
) -> Result<
(
BernoulliMarginalSlopeFamily,
Vec<ParameterBlockState>,
BernoulliMarginalSlopeExactEvalCache,
Array1<f64>,
),
String,
> {
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64 + 0.5) / n as f64;
(10.0 * t).sin() + 0.3 * (31.0 * t).cos()
}));
let y = Array1::from_iter((0..n).map(|i| if (i * 37 + 11) % 101 < 43 { 1.0 } else { 0.0 }));
let weights = Array1::from_iter((0..n).map(|i| 0.75 + 0.5 * ((i % 7) as f64) / 6.0));
let design = Array2::from_shape_fn((n, 2), |(row, col)| match col {
0 => 1.0,
1 => z[row],
_ => unreachable!(),
});
let cfg = DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
};
let score_prepared = build_score_warp_deviation_block_from_seed(&z, &cfg)?;
let q_seed = Array1::from_iter(z.iter().map(|zi| 0.05 + 0.2 * zi));
let link_seed = padded_deviation_seed(&q_seed, 1.0, 0.5);
let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q_seed, &weights, &cfg,
)?;
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
design.clone(),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(design)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let marginal_beta = array![0.05, 0.08];
let logslope_beta = array![-0.15, 0.04];
let marginal_eta =
Array1::from_iter(z.iter().map(|zi| marginal_beta[0] + marginal_beta[1] * zi));
let logslope_eta =
Array1::from_iter(z.iter().map(|zi| logslope_beta[0] + logslope_beta[1] * zi));
let states = vec![
ParameterBlockState {
beta: marginal_beta,
eta: marginal_eta,
},
ParameterBlockState {
beta: logslope_beta,
eta: logslope_eta,
},
ParameterBlockState {
beta: Array1::zeros(score_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
ParameterBlockState {
beta: Array1::zeros(link_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
];
let mut cache = family.build_exact_eval_cache(&states)?;
cache.row_primary_hessians = family.build_row_primary_hessian_cache(&states, &cache)?;
let direction = Array1::from_iter((0..cache.slices.total).map(|j| {
let x = j as f64 + 1.0;
0.03 * x.sin() + 0.01 * (0.37 * x).cos()
}));
Ok((family, states, cache, direction))
}
fn assert_allclose_relative(actual: &Array1<f64>, expected: &Array1<f64>, tol: f64) {
assert_eq!(actual.len(), expected.len());
for (idx, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
let denom = a.abs().max(e.abs()).max(1.0);
let rel = (a - e).abs() / denom;
assert!(
rel <= tol,
"entry {idx}: actual={a:.17e}, expected={e:.17e}, rel={rel:.3e}, tol={tol:.3e}"
);
}
}
#[test]
fn flex_hessian_matvec_parallel_chunks_match_serial_reference() {
let (family, states, cache, direction) =
flex_hessian_matvec_fixture(96).expect("flex Hv fixture");
let serial = family
.exact_newton_joint_hessian_matvec_from_cache_serial_reference(
&direction, &states, &cache,
)
.expect("serial reference Hv");
let parallel = family
.exact_newton_joint_hessian_matvec_from_cache(&direction, &states, &cache)
.expect("parallel chunked Hv");
assert_allclose_relative(¶llel, &serial, 1.0e-13);
}
#[test]
fn cell_moment_per_row_dedup_matches_undeduped_row_primary_hessian() {
super::set_cell_moment_per_row_dedup_enabled(false);
let (_family_off, _states_off, cache_off, _dir_off) =
flex_hessian_matvec_fixture(64).expect("flex fixture (dedup off)");
let off_rows = cache_off
.row_primary_hessians
.clone()
.expect("row primary hessians (dedup off)");
super::set_cell_moment_per_row_dedup_enabled(true);
let (_family_on, _states_on, cache_on, _dir_on) =
flex_hessian_matvec_fixture(64).expect("flex fixture (dedup on)");
let on_rows = cache_on
.row_primary_hessians
.clone()
.expect("row primary hessians (dedup on)");
super::set_cell_moment_per_row_dedup_enabled(true);
assert_eq!(off_rows.shape(), on_rows.shape());
let mut max_abs = 0.0_f64;
for ((on, off), idx) in on_rows.iter().zip(off_rows.iter()).zip(0u64..) {
let diff = (on - off).abs();
if diff > max_abs {
max_abs = diff;
}
assert!(
diff <= 1.0e-14,
"dedup mismatch at flat idx {idx}: on={on:.17e} off={off:.17e} diff={diff:.3e}"
);
}
let any_nonzero = on_rows.iter().any(|v| v.abs() > 0.0);
assert!(any_nonzero, "row primary hessians should not be all zero");
let _ = max_abs;
}
#[test]
#[ignore = "criterion-style local timing for the biobank-shape FLEX Hv pattern"]
fn bench_flex_hessian_matvec_parallel_chunks_biobank_shape() {
use criterion::{Criterion, black_box};
use std::time::Duration;
let (family, states, cache, direction) =
flex_hessian_matvec_fixture(4096).expect("biobank-shape FLEX Hv fixture");
let serial = family
.exact_newton_joint_hessian_matvec_from_cache_serial_reference(
&direction, &states, &cache,
)
.expect("serial reference Hv");
let parallel = family
.exact_newton_joint_hessian_matvec_from_cache(&direction, &states, &cache)
.expect("parallel chunked Hv");
assert_allclose_relative(¶llel, &serial, 1.0e-13);
let mut criterion = Criterion::default()
.sample_size(10)
.warm_up_time(Duration::from_millis(200))
.measurement_time(Duration::from_millis(500));
criterion.bench_function("bernoulli_margslope_flex_hv_serial_reference_4096", |b| {
b.iter(|| {
black_box(
family
.exact_newton_joint_hessian_matvec_from_cache_serial_reference(
black_box(&direction),
black_box(&states),
black_box(&cache),
)
.expect("serial reference Hv"),
)
})
});
criterion.bench_function("bernoulli_margslope_flex_hv_parallel_chunks_4096", |b| {
b.iter(|| {
black_box(
family
.exact_newton_joint_hessian_matvec_from_cache(
black_box(&direction),
black_box(&states),
black_box(&cache),
)
.expect("parallel chunked Hv"),
)
})
});
criterion.final_summary();
}
#[test]
fn bernoulli_margslope_warm_start_cache_persists_across_eval_cache_builds() {
let n = 256usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64 + 0.5) / n as f64;
(12.0 * t).sin() + 0.25 * (37.0 * t).cos()
}));
let y = Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1.0 } else { 0.0 }));
let weights = Array1::ones(n);
let cfg = DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
};
let score_prepared = build_score_warp_deviation_block_from_seed(&z, &cfg)
.expect("score-warp deviation block");
let q_seed = Array1::from_iter(z.iter().map(|zi| 0.1 + 0.25 * zi));
let link_seed = padded_deviation_seed(&q_seed, 1.0, 0.5);
let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q_seed, &weights, &cfg,
)
.expect("link-wiggle deviation block");
let cache = new_intercept_warm_start_cache(n);
let make_family =
|cache: Option<Arc<BernoulliInterceptWarmStartCache>>| BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::ones((n, 1)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::ones((n, 1)),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: cache,
};
let marginal_eta = Array1::from_iter((0..n).map(|i| 0.15 * ((i as f64) * 0.001).sin()));
let slope_eta = Array1::from_iter((0..n).map(|i| 0.35 + 0.02 * ((i as f64) * 0.003).cos()));
let states = vec![
ParameterBlockState {
beta: array![0.0],
eta: marginal_eta,
},
ParameterBlockState {
beta: array![0.0],
eta: slope_eta,
},
ParameterBlockState {
beta: Array1::zeros(score_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
ParameterBlockState {
beta: Array1::zeros(link_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
];
let warm_family = make_family(Some(Arc::clone(&cache)));
let first = warm_family
.build_exact_eval_cache(&states)
.expect("first warm eval cache");
let nan_bits = f64::NAN.to_bits();
for slot in cache.iter() {
let bits = slot.load(Ordering::Relaxed);
let v = f64::from_bits(bits);
assert!(
v.is_finite(),
"cache slot should be populated with converged intercept after first build"
);
assert_ne!(bits, nan_bits);
}
let second = warm_family
.build_exact_eval_cache(&states)
.expect("second warm eval cache");
let cold_family = make_family(None);
let cold = cold_family
.build_exact_eval_cache(&states)
.expect("cold reference eval cache");
for ((warm_a, warm_b), cold_ctx) in first
.row_contexts
.iter()
.zip(second.row_contexts.iter())
.zip(cold.row_contexts.iter())
{
assert!((warm_a.intercept - cold_ctx.intercept).abs() < 1e-9);
assert!((warm_b.intercept - cold_ctx.intercept).abs() < 1e-9);
}
}
#[test]
fn bernoulli_margslope_flex_ll_early_exit_is_exact_or_provably_rejected() {
let n = 24_000usize;
let z = Array1::from_iter((0..n).map(|i| {
let t = (i as f64 + 0.5) / n as f64;
(10.0 * t).sin() + 0.2 * (29.0 * t).cos()
}));
let y = Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
let weights = Array1::ones(n);
let cfg = DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
};
let score_prepared = build_score_warp_deviation_block_from_seed(&z, &cfg)
.expect("score-warp deviation block");
let q_seed = Array1::from_iter(z.iter().map(|zi| 0.1 + 0.2 * zi));
let link_seed = padded_deviation_seed(&q_seed, 1.0, 0.5);
let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &q_seed, &weights, &cfg,
)
.expect("link-wiggle deviation block");
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::ones((n, 1)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::ones((n, 1)),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let states = vec![
ParameterBlockState {
beta: array![0.0],
eta: Array1::from_iter((0..n).map(|i| 0.08 * ((i as f64) * 0.002).sin())),
},
ParameterBlockState {
beta: array![0.0],
eta: Array1::from_iter((0..n).map(|i| 0.30 + 0.03 * ((i as f64) * 0.003).cos())),
},
ParameterBlockState {
beta: Array1::zeros(score_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
ParameterBlockState {
beta: Array1::zeros(link_prepared.block.design.ncols()),
eta: Array1::zeros(n),
},
];
let exact = family
.log_likelihood_only_with_options(&states, &BlockwiseFitOptions::default())
.expect("exact full FLEX ll");
let mut permissive = BlockwiseFitOptions::default();
permissive.early_exit_threshold = Some((-exact) + 1.0);
let accepted = family
.log_likelihood_only_with_options(&states, &permissive)
.expect("permissive threshold should compute full FLEX ll");
let rel = ((accepted - exact) / exact.abs().max(1.0)).abs();
assert!(
rel < 1e-10,
"accepted early-exit-enabled FLEX LL {accepted} differs from exact {exact} by rel {rel}"
);
let mut rejecting = BlockwiseFitOptions::default();
rejecting.early_exit_threshold = Some(1e-6);
let err = family
.log_likelihood_only_with_options(&states, &rejecting)
.expect_err("tight threshold should reject before the full FLEX row sweep");
assert!(
err.contains("line-search rejected early"),
"unexpected early-exit error: {err}"
);
}
#[test]
fn row_primary_fourth_contracted_rejects_bad_direction_lengths() {
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![1.0]),
weights: Arc::new(array![1.0]),
z: Arc::new(array![0.25]),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((1, 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((1, 0)),
)),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: Array1::zeros(0),
eta: array![0.15],
},
ParameterBlockState {
beta: Array1::zeros(0),
eta: array![0.2],
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let row_ctx = family
.build_row_exact_context(0, &block_states)
.expect("row context");
let bad_dir = array![1.0];
let good_dir = array![0.0, 1.0];
let err = family
.row_primary_fourth_contracted_recompute_ordered(
0,
&block_states,
&cache,
&row_ctx,
&bad_dir,
&good_dir,
)
.expect_err("bad direction length should be rejected before indexing");
assert!(
err.contains("direction lengths (1,2) != 2"),
"unexpected error: {err}"
);
}
fn base_spec(
y: Array1<f64>,
weights: Array1<f64>,
z: Array1<f64>,
) -> BernoulliMarginalSlopeTermSpec {
let n = y.len();
BernoulliMarginalSlopeTermSpec {
y,
weights,
z,
base_link: bernoulli_marginal_slope_probit_link(),
marginalspec: empty_termspec(),
logslopespec: empty_termspec(),
marginal_offset: Array1::zeros(n),
logslope_offset: Array1::zeros(n),
frailty: FrailtySpec::None,
score_warp: None,
link_dev: None,
latent_z_policy: LatentZPolicy::default(),
}
}
#[test]
fn bernoulli_marginal_link_map_zeroes_derivatives_on_clamped_tails() {
let link = bernoulli_marginal_slope_probit_link();
let lower = bernoulli_marginal_link_map(&link, -8.0).expect("lower tail map");
let upper = bernoulli_marginal_link_map(&link, 8.0).expect("upper tail map");
let lower_q = standard_normal_quantile(BERNOULLI_LINK_PROBABILITY_EPS).unwrap();
let upper_q = standard_normal_quantile(1.0 - BERNOULLI_LINK_PROBABILITY_EPS).unwrap();
assert_eq!(lower.mu, BERNOULLI_LINK_PROBABILITY_EPS);
assert_eq!(upper.mu, 1.0 - BERNOULLI_LINK_PROBABILITY_EPS);
assert!((lower.q - lower_q).abs() < 1e-12);
assert!((upper.q - upper_q).abs() < 1e-12);
assert_eq!([lower.mu1, lower.mu2, lower.mu3, lower.mu4], [0.0; 4]);
assert_eq!([upper.mu1, upper.mu2, upper.mu3, upper.mu4], [0.0; 4]);
assert_eq!([lower.q1, lower.q2, lower.q3, lower.q4], [0.0; 4]);
assert_eq!([upper.q1, upper.q2, upper.q3, upper.q4], [0.0; 4]);
}
#[test]
fn rigid_transformed_gradient_matches_negative_log_likelihood_derivative() {
let link = bernoulli_marginal_slope_probit_link();
let eta = 0.25;
let g = -0.15;
let z = 0.7;
let y = 1.0;
let weight = 1.3;
let probit_scale = 1.0;
let objective = |eta_value: f64, g_value: f64| {
let marginal = bernoulli_marginal_link_map(&link, eta_value).unwrap();
let kernel =
RigidProbitKernel::new(marginal.q, g_value, z, y, weight, probit_scale).unwrap();
-weight * kernel.logcdf
};
let marginal = bernoulli_marginal_link_map(&link, eta).expect("marginal map");
let kernel =
RigidProbitKernel::new(marginal.q, g, z, y, weight, probit_scale).expect("kernel");
let grad = rigid_transformed_gradient(marginal, &kernel);
let step = 1e-6;
let finite_eta = (objective(eta + step, g) - objective(eta - step, g)) / (2.0 * step);
let finite_g = (objective(eta, g + step) - objective(eta, g - step)) / (2.0 * step);
assert!(
(grad[0] - finite_eta).abs() < 1e-7,
"eta gradient {} != finite difference {}",
grad[0],
finite_eta
);
assert!(
(grad[1] - finite_g).abs() < 1e-7,
"g gradient {} != finite difference {}",
grad[1],
finite_g
);
assert!(
grad[0] < 0.0,
"y=1 probit nll should decrease as eta increases"
);
}
fn pair_distance(lhs: (f64, f64), rhs: (f64, f64)) -> f64 {
(lhs.0 - rhs.0).abs() + (lhs.1 - rhs.1).abs()
}
fn make_rigid_test_family(n: usize) -> BernoulliMarginalSlopeFamily {
let y: Array1<f64> =
Array1::from_iter((0..n).map(|i| if (i * 31 + 7) % 5 >= 3 { 1.0 } else { 0.0 }));
let weights: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.5 + ((i * 13 + 4) % 7) as f64 * 0.1));
let z: Array1<f64> = Array1::from_iter(
(0..n).map(|i| -1.5 + 3.0 * (((i * 17 + 5) % n) as f64 + 0.5) / (n as f64)),
);
let ones_col = Array2::from_shape_fn((n, 1), |_| 1.0);
BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
ones_col.clone(),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(ones_col)),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
}
}
fn rigid_block_states(
family: &BernoulliMarginalSlopeFamily,
q: f64,
b: f64,
) -> Vec<ParameterBlockState> {
let n = family.y.len();
vec![
ParameterBlockState {
beta: array![q],
eta: Array1::from_elem(n, q),
},
ParameterBlockState {
beta: array![b],
eta: Array1::from_elem(n, b),
},
]
}
#[test]
fn bernoulli_log_likelihood_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_rigid_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let baseline = family
.log_likelihood_only(&states)
.expect("baseline ll (no subsample)");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full_mask = family
.log_likelihood_only_with_options(&states, &opts_full)
.expect("ll with mask=full");
let rel = ((with_full_mask - baseline) / baseline.abs().max(1.0)).abs();
assert!(
rel < 1e-12,
"subsample(mask=full) {} differs from baseline {} by rel {}",
with_full_mask,
baseline,
rel
);
}
#[test]
fn bernoulli_log_likelihood_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_rigid_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.log_likelihood_only_with_options(&states, &opts_half)
.expect("ll with mask=even");
let mut opts_even_unscaled = BlockwiseFitOptions::default();
opts_even_unscaled.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m, weight_scale: 1.0,
seed: 0,
}));
let raw_even_sum = family
.log_likelihood_only_with_options(&states, &opts_even_unscaled)
.expect("raw even-row ll sum");
let expected_scaled = (n as f64 / m as f64) * raw_even_sum;
let rel = ((scaled - expected_scaled) / expected_scaled.abs().max(1.0)).abs();
assert!(
rel < 1e-12,
"scaled {} != 2*even_sum {} (rel {})",
scaled,
expected_scaled,
rel
);
let baseline = family.log_likelihood_only(&states).expect("baseline ll");
let ht_rel = ((scaled - baseline) / baseline.abs().max(1.0)).abs();
assert!(
ht_rel < 0.05,
"Horvitz-Thompson scaled {} not near baseline {} (rel {})",
scaled,
baseline,
ht_rel
);
}
fn make_sigma_aware_test_family(n: usize) -> BernoulliMarginalSlopeFamily {
let mut family = make_rigid_test_family(n);
family.gaussian_frailty_sd = Some(0.7);
family
}
fn rel_diff_array1(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let mut max = 0.0f64;
for i in 0..a.len() {
let d = (a[i] - b[i]).abs() / b[i].abs().max(1.0);
if d > max {
max = d;
}
}
max
}
fn rel_diff_array2(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let mut max = 0.0f64;
for ((i, j), &av) in a.indexed_iter() {
let bv = b[[i, j]];
let d = (av - bv).abs() / bv.abs().max(1.0);
if d > max {
max = d;
}
}
max
}
#[test]
fn bernoulli_sigma_psi_terms_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let baseline = family
.sigma_exact_joint_psi_terms(&states, &specs)
.expect("baseline psi terms")
.expect("baseline some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_full)
.expect("psi terms with full mask")
.expect("some");
let obj_rel = ((with_full.objective_psi - baseline.objective_psi)
/ baseline.objective_psi.abs().max(1.0))
.abs();
assert!(obj_rel < 1e-12, "objective_psi rel {}", obj_rel);
let score_rel = rel_diff_array1(&with_full.score_psi, &baseline.score_psi);
assert!(score_rel < 1e-12, "score_psi rel {}", score_rel);
let h_full = with_full
.hessian_psi_operator
.as_ref()
.expect("op")
.to_dense();
let h_baseline = baseline
.hessian_psi_operator
.as_ref()
.expect("op")
.to_dense();
let h_rel = rel_diff_array2(&h_full, &h_baseline);
assert!(h_rel < 1e-12, "hessian rel {}", h_rel);
}
#[test]
fn bernoulli_sigma_psi_terms_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_half)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m,
weight_scale: 1.0,
seed: 0,
}));
let raw = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_raw)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp_obj = factor * raw.objective_psi;
let obj_rel = ((scaled.objective_psi - exp_obj) / exp_obj.abs().max(1.0)).abs();
assert!(obj_rel < 1e-12, "objective_psi rel {}", obj_rel);
let exp_score = &raw.score_psi * factor;
let score_rel = rel_diff_array1(&scaled.score_psi, &exp_score);
assert!(score_rel < 1e-12, "score_psi rel {}", score_rel);
let h_scaled = scaled.hessian_psi_operator.as_ref().expect("op").to_dense();
let h_raw = raw.hessian_psi_operator.as_ref().expect("op").to_dense();
let h_exp = &h_raw * factor;
let h_rel = rel_diff_array2(&h_scaled, &h_exp);
assert!(h_rel < 1e-12, "hessian rel {}", h_rel);
}
#[test]
fn bernoulli_sigma_psi_second_order_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let baseline = family
.sigma_exact_joint_psisecond_order_terms(&states)
.expect("baseline")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.sigma_exact_joint_psisecond_order_terms_with_options(&states, &opts_full)
.expect("with full mask")
.expect("some");
let obj_rel = ((with_full.objective_psi_psi - baseline.objective_psi_psi)
/ baseline.objective_psi_psi.abs().max(1.0))
.abs();
assert!(obj_rel < 1e-12, "objective rel {}", obj_rel);
let score_rel = rel_diff_array1(&with_full.score_psi_psi, &baseline.score_psi_psi);
assert!(score_rel < 1e-12, "score rel {}", score_rel);
}
#[test]
fn bernoulli_sigma_psi_second_order_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.sigma_exact_joint_psisecond_order_terms_with_options(&states, &opts_half)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m,
weight_scale: 1.0,
seed: 0,
}));
let raw = family
.sigma_exact_joint_psisecond_order_terms_with_options(&states, &opts_raw)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp_obj = factor * raw.objective_psi_psi;
let obj_rel = ((scaled.objective_psi_psi - exp_obj) / exp_obj.abs().max(1.0)).abs();
assert!(obj_rel < 1e-12, "objective rel {}", obj_rel);
let exp_score = &raw.score_psi_psi * factor;
let score_rel = rel_diff_array1(&scaled.score_psi_psi, &exp_score);
assert!(score_rel < 1e-12, "score rel {}", score_rel);
}
#[test]
fn bernoulli_sigma_psihessian_directional_derivative_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let dir = array![0.1, -0.2];
let baseline = family
.sigma_exact_joint_psihessian_directional_derivative(&states, &dir)
.expect("baseline")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.sigma_exact_joint_psihessian_directional_derivative_with_options(
&states, &dir, &opts_full,
)
.expect("with full")
.expect("some");
let rel = rel_diff_array2(&with_full, &baseline);
assert!(rel < 1e-12, "drift rel {}", rel);
}
#[test]
fn bernoulli_sigma_psihessian_directional_derivative_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let dir = array![0.1, -0.2];
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.sigma_exact_joint_psihessian_directional_derivative_with_options(
&states, &dir, &opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m,
weight_scale: 1.0,
seed: 0,
}));
let raw = family
.sigma_exact_joint_psihessian_directional_derivative_with_options(
&states, &dir, &opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp = &raw * factor;
let rel = rel_diff_array2(&scaled, &exp);
assert!(rel < 1e-12, "drift rel {}", rel);
}
#[test]
fn bernoulli_psi_workspace_with_options_threads_subsample_to_first_order() {
use crate::custom_family::CustomFamilyBlockPsiDerivative;
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let derivative_blocks: Vec<Vec<CustomFamilyBlockPsiDerivative>> = vec![
Vec::new(),
vec![CustomFamilyBlockPsiDerivative::new(
None,
Array2::zeros((0, 0)),
Array2::zeros((0, 0)),
None,
None,
None,
None,
)],
];
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xBEEF_CAFE,
)));
let direct = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_half)
.expect("direct sigma terms with options")
.expect("direct some");
let ws = family
.exact_newton_joint_psi_workspace_with_options(
&states,
&specs,
&derivative_blocks,
&opts_half,
)
.expect("workspace with options")
.expect("workspace some");
let psi_total: usize = derivative_blocks.iter().map(Vec::len).sum();
let sigma_psi = psi_total - 1;
let via_ws = ws
.first_order_terms(sigma_psi)
.expect("ws first_order_terms")
.expect("some");
assert_eq!(via_ws.objective_psi, direct.objective_psi);
let score_rel = rel_diff_array1(&via_ws.score_psi, &direct.score_psi);
assert!(score_rel == 0.0, "score_psi diverged: rel {}", score_rel);
let h_ws = via_ws
.hessian_psi_operator
.as_ref()
.expect("ws hessian op")
.to_dense();
let h_direct = direct
.hessian_psi_operator
.as_ref()
.expect("direct hessian op")
.to_dense();
let h_rel = rel_diff_array2(&h_ws, &h_direct);
assert!(h_rel == 0.0, "hessian diverged: rel {}", h_rel);
let ws_full = family
.exact_newton_joint_psi_workspace_with_options(
&states,
&specs,
&derivative_blocks,
&BlockwiseFitOptions::default(),
)
.expect("workspace full")
.expect("some");
let via_ws_full = ws_full
.first_order_terms(sigma_psi)
.expect("ws_full first_order_terms")
.expect("some");
assert!(
(via_ws.objective_psi - via_ws_full.objective_psi).abs() > 1e-9,
"subsample objective {} too close to full-data {} (subsample not threaded?)",
via_ws.objective_psi,
via_ws_full.objective_psi
);
let m_f = m as f64;
let n_f = n as f64;
let ratio_bound = (n_f / m_f) * 2.0 + 1.0;
let ratio = (via_ws.objective_psi.abs() + 1.0) / (via_ws_full.objective_psi.abs() + 1.0);
assert!(
ratio < ratio_bound && (1.0 / ratio) < ratio_bound,
"subsample/full ratio {} outside coarse bound {}",
ratio,
ratio_bound
);
}
fn build_test_link_deviation_block_from_seed(
seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_link_deviation_block_from_knots_design_seed_and_weights(
seed,
seed,
&Array1::ones(seed.len()),
cfg,
)
}
#[test]
fn score_warp_basis_is_orthogonal_to_standard_normal_location_and_scale() {
let seed = array![-2.0, -1.0, 0.0, 1.0, 2.0];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 5,
..DeviationBlockConfig::default()
},
)
.expect("build standard-normal anchored score-warp");
let rule = crate::quadrature::compute_gauss_hermite_n(51);
let z = Array1::from_iter(
rule.nodes
.iter()
.map(|&node| std::f64::consts::SQRT_2 * node),
);
let design = prepared
.runtime
.design(&z)
.expect("score-warp quadrature design");
let inv_sqrt_pi = std::f64::consts::PI.sqrt().recip();
for basis_idx in 0..design.ncols() {
let mut mean_moment = 0.0;
let mut scale_moment = 0.0;
for row in 0..design.nrows() {
let weight = rule.weights[row] * inv_sqrt_pi;
mean_moment += weight * design[[row, basis_idx]];
scale_moment += weight * z[row] * design[[row, basis_idx]];
}
assert!(
mean_moment.abs() <= 1e-10,
"score-warp basis column {basis_idx} has nonzero standard-normal mean moment {mean_moment}"
);
assert!(
scale_moment.abs() <= 1e-10,
"score-warp basis column {basis_idx} has nonzero standard-normal scale moment {scale_moment}"
);
}
}
#[test]
fn link_deviation_basis_is_orthogonal_to_weighted_training_index_moments() {
let q = array![-2.0, -0.8, -0.1, 0.4, 1.3, 2.1];
let weights = array![0.2, 1.7, 0.5, 2.3, 0.8, 1.1];
let prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&q,
&q,
&weights,
&DeviationBlockConfig {
num_internal_knots: 5,
..DeviationBlockConfig::default()
},
)
.expect("build weighted empirical anchored link deviation");
let design = prepared
.runtime
.design(&q)
.expect("link-deviation training design");
let total_weight: f64 = weights.iter().copied().sum();
for basis_idx in 0..design.ncols() {
let mut mean_moment = 0.0;
let mut scale_moment = 0.0;
for row in 0..design.nrows() {
let weight = weights[row] / total_weight;
mean_moment += weight * design[[row, basis_idx]];
scale_moment += weight * q[row] * design[[row, basis_idx]];
}
assert!(
mean_moment.abs() <= 1e-10,
"link-deviation basis column {basis_idx} has nonzero weighted mean moment {mean_moment}"
);
assert!(
scale_moment.abs() <= 1e-10,
"link-deviation basis column {basis_idx} has nonzero weighted scale moment {scale_moment}"
);
}
}
#[test]
fn bernoulli_marginal_slope_rejects_nonprobit_base_link() {
let y = array![0.0, 1.0];
let weights = array![1.0, 1.0];
let z = array![-0.4, 0.9];
let design =
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![[1.0], [1.0]]));
let spec = BernoulliMarginalSlopeTermSpec {
y,
weights,
z,
base_link: InverseLink::Standard(LinkFunction::Logit),
marginalspec: empty_termspec(),
logslopespec: empty_termspec(),
marginal_offset: Array1::zeros(2),
logslope_offset: Array1::zeros(2),
frailty: FrailtySpec::None,
score_warp: None,
link_dev: None,
latent_z_policy: LatentZPolicy::default(),
};
let err = validate_spec(design.to_dense().view(), &spec)
.expect_err("non-probit marginal-slope link should be rejected");
assert!(err.contains("requires link(type=probit)"));
let err = bernoulli_marginal_slope_eta_from_probability(
&InverseLink::Standard(LinkFunction::Logit),
0.5,
"test logit inverse",
)
.expect_err("non-probit marginal-slope inverse should be rejected");
assert!(err.contains("requires link(type=probit)"));
}
fn expand_integer_weight_rows(
y: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
) -> (Array1<f64>, Array1<f64>) {
let mut y_expanded = Vec::new();
let mut z_expanded = Vec::new();
for i in 0..y.len() {
let reps = weights[i] as usize;
assert!(
(weights[i] - reps as f64).abs() < 1e-12,
"test helper expects integer weights, got {}",
weights[i]
);
for _ in 0..reps {
y_expanded.push(y[i]);
z_expanded.push(z[i]);
}
}
(Array1::from_vec(y_expanded), Array1::from_vec(z_expanded))
}
#[test]
fn link_dev_without_score_warp_exposes_structural_derivative_lower_bounds() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link block initial beta")
.len();
let beta_link = Array1::from_iter((0..link_dim).map(|idx| 0.1 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(seed.len())),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_link.clone(), seed.len()),
];
let slices = block_slices(&family);
assert!(slices.h.is_none(), "score-warp slice should be absent");
let link_slice = slices.w.as_ref().expect("link slice");
assert_eq!(
slices.marginal.len(),
0,
"zero-column marginal design should not contribute coefficient coordinates"
);
assert_eq!(
slices.logslope.len(),
0,
"zero-column logslope design should not contribute coefficient coordinates"
);
assert_eq!(
link_slice.start, 0,
"link-only coefficients should start at 0"
);
assert_eq!(link_slice.len(), link_dim);
let primary = primary_slices(&slices);
assert!(primary.h.is_none(), "primary h slice should be absent");
let primary_w = primary.w.as_ref().expect("primary link slice");
assert_eq!(primary_w.start, 2, "primary link slice should start at 2");
assert_eq!(primary.total, 2 + link_dim);
family
.build_exact_eval_cache(&block_states)
.expect("eval cache");
let row_ctx = family
.build_row_exact_context(0, &block_states)
.expect("row context");
let (nll, grad, hess) = family
.compute_row_primary_gradient_hessian(0, &block_states, &primary, &row_ctx)
.expect("analytic flex eval");
assert!(nll.is_finite(), "neglog should be finite for link-dev-only");
assert!(
grad.iter().all(|v| v.is_finite()),
"gradient should be finite"
);
assert!(
hess.iter().all(|v| v.is_finite()),
"Hessian should be finite"
);
let dummy_spec = dummy_blockspec(link_dim, seed.len());
assert!(
family
.block_linear_constraints(&block_states, 1, &dummy_spec)
.expect("non-link constraint lookup")
.is_none(),
"non-link block should not expose auxiliary monotonicity constraints"
);
let constraints = family
.block_linear_constraints(&block_states, 2, &dummy_spec)
.expect("link constraint lookup")
.expect("link constraints");
assert_eq!(constraints.a.ncols(), link_dim);
assert_eq!(constraints.b.len(), constraints.a.nrows());
assert!(
constraints.a.nrows() >= link_dim,
"anchored link constraints should be expressed in raw derivative-control rows"
);
assert_eq!(
constraints.b,
Array1::<f64>::from_elem(
constraints.a.nrows(),
prepared.runtime.monotonicity_eps() - 1.0
)
);
}
#[test]
fn zero_deviation_intercept_fast_path_matches_denested_calibration() {
let seed = Array1::from_iter((0..25).map(|i| -2.4 + 4.8 * i as f64 / 24.0));
let score_prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 8,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 8,
..DeviationBlockConfig::default()
},
)
.expect("build link-deviation block");
let n = seed.len();
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(n)),
weights: Arc::new(Array1::ones(n)),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: Some(0.65),
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((n, 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((n, 0)),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: Some(new_intercept_warm_start_cache(n)),
};
let marginal_eta = 0.35;
let slope = -0.8;
let beta_h = Array1::zeros(score_prepared.runtime.basis_dim());
let beta_w = Array1::zeros(link_prepared.runtime.basis_dim());
let marginal = family
.marginal_link_map(marginal_eta)
.expect("marginal map");
let scale = family.probit_frailty_scale();
let rigid_a = rigid_prescale_intercept_from_marginal(marginal.q, slope, scale);
let (f_rigid, f_a_rigid, _) = family
.evaluate_denested_calibration_newton(
rigid_a,
marginal_eta,
slope,
Some(&beta_h),
Some(&beta_w),
)
.expect("denested zero-deviation calibration");
assert!(
f_rigid.abs() <= 5e-13,
"closed-form rigid intercept residual should be at machine epsilon, got {f_rigid}"
);
let analytic_deriv = rigid_prescale_intercept_derivative_abs(marginal.q, slope, scale);
assert!(
(f_a_rigid - analytic_deriv).abs() <= 5e-13,
"denested derivative {f_a_rigid} != analytic {analytic_deriv}"
);
let (a_fast, deriv_fast, fast_path) = family
.solve_row_intercept_base(0, marginal_eta, slope, Some(&beta_h), Some(&beta_w), None)
.expect("zero-deviation solve");
assert!(
fast_path,
"zero coefficients should take analytic fast path"
);
assert_eq!(a_fast, rigid_a);
assert_eq!(deriv_fast, analytic_deriv);
}
#[test]
fn exact_layout_ignores_dummy_beta_widths_for_empty_design_blocks() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let score_prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(seed.len())),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::zeros(score_prepared.runtime.basis_dim()),
seed.len(),
),
dummy_block_state(Array1::zeros(link_prepared.runtime.basis_dim()), seed.len()),
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
assert_eq!(cache.slices.marginal.len(), 0);
assert_eq!(cache.slices.logslope.len(), 0);
assert_eq!(cache.slices.h.as_ref().expect("h slice").start, 0);
assert_eq!(
cache.slices.w.as_ref().expect("w slice").start,
score_prepared.runtime.basis_dim()
);
assert_eq!(
cache.slices.total,
score_prepared.runtime.basis_dim() + link_prepared.runtime.basis_dim()
);
assert_eq!(cache.primary.q, 0);
assert_eq!(cache.primary.logslope, 1);
assert_eq!(cache.primary.h.as_ref().expect("primary h").start, 2);
assert_eq!(
cache.primary.w.as_ref().expect("primary w").start,
2 + score_prepared.runtime.basis_dim()
);
}
#[test]
fn score_warp_block_exposes_structural_derivative_lower_bounds() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(seed.len())),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(Array1::zeros(score_dim), seed.len()),
];
let dummy_spec = dummy_blockspec(score_dim, seed.len());
let constraints = family
.block_linear_constraints(&block_states, 2, &dummy_spec)
.expect("constraint lookup")
.expect("score-warp constraints");
assert_eq!(constraints.a.ncols(), score_dim);
assert_eq!(constraints.b.len(), constraints.a.nrows());
assert!(
constraints.a.nrows() >= score_dim,
"anchored score-warp constraints should be expressed in raw derivative-control rows"
);
assert_eq!(
constraints.b,
Array1::<f64>::from_elem(
constraints.a.nrows(),
prepared.runtime.monotonicity_eps() - 1.0
)
);
}
#[test]
fn post_update_block_beta_projects_score_warp_to_feasible_step() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared.block.design.ncols();
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(seed.len())),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let current = Array1::<f64>::zeros(score_dim);
let mut proposed = current.clone();
proposed[0] = -128.0;
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(current.clone(), seed.len()),
];
let spec = dummy_blockspec(score_dim, seed.len());
let updated = family
.post_update_block_beta(&block_states, 2, &spec, proposed.clone())
.expect("projected beta");
prepared
.runtime
.monotonicity_feasible(&updated, "projected score-warp")
.expect("post-update beta should remain feasible");
assert_ne!(updated, proposed);
}
#[test]
fn structural_deviation_runtime_is_piecewise_cubic() {
let seed = array![-1.0, 0.0, 1.0];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
..DeviationBlockConfig::default()
},
)
.expect("structural deviation basis");
assert_eq!(prepared.runtime.degree(), 3);
assert_eq!(prepared.runtime.value_span_degree(), 3);
let has_cubic_curvature = prepared
.runtime
.span_c3()
.iter()
.any(|value| value.abs() > 1e-12);
assert!(
has_cubic_curvature,
"structural deviation basis must expose true cubic span coefficients"
);
}
#[test]
fn structural_deviation_runtime_is_c2_at_internal_breakpoints() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("structural deviation basis");
let dim = prepared.block.design.ncols();
let beta = Array1::from_iter((0..dim).map(|idx| 0.015 * (idx as f64 + 1.0)));
let n_spans = prepared.runtime.breakpoints().len().saturating_sub(1);
for span_idx in 1..n_spans {
let left_cubic = prepared
.runtime
.local_cubic_on_span(&beta, span_idx - 1)
.expect("left span cubic");
let right_cubic = prepared
.runtime
.local_cubic_on_span(&beta, span_idx)
.expect("right span cubic");
let knot = prepared.runtime.breakpoints()[span_idx];
assert!(
(left_cubic.evaluate(knot) - right_cubic.evaluate(knot)).abs() <= 1e-10,
"deviation value should be continuous at breakpoint {span_idx}"
);
assert!(
(left_cubic.first_derivative(knot) - right_cubic.first_derivative(knot)).abs()
<= 1e-10,
"deviation first derivative should be continuous at breakpoint {span_idx}"
);
assert!(
(left_cubic.second_derivative(knot) - right_cubic.second_derivative(knot)).abs()
<= 1e-10,
"deviation second derivative should be continuous at breakpoint {span_idx}"
);
}
}
#[test]
fn structural_deviation_rejects_noncubic_degree() {
let seed = array![-1.0, 0.0, 1.0];
let err = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
degree: 4,
..DeviationBlockConfig::default()
},
)
.expect_err("structural deviation block should reject non-cubic degree");
assert!(err.contains("degree must be 3"));
}
#[test]
fn deviation_runtime_replays_exact_training_design() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let replayed = prepared.runtime.design(&seed).expect("replayed design");
let trained = prepared.block.design.to_dense();
assert_eq!(replayed.dim(), trained.dim());
for i in 0..replayed.nrows() {
for j in 0..replayed.ncols() {
assert!(
(replayed[[i, j]] - trained[[i, j]]).abs() <= 1e-10,
"training-basis replay mismatch at ({i},{j})"
);
}
}
}
#[test]
fn structural_constraints_match_exact_monotonicity_guard() {
let seed = array![-1.0, 0.0, 1.0, 2.0];
let prepared =
build_score_warp_deviation_block_from_seed(&seed, &DeviationBlockConfig::default())
.expect("build deviation block");
let constraints = prepared.runtime.structural_monotonicity_constraints();
let dim = constraints.a.ncols();
assert_eq!(dim, prepared.runtime.basis_dim());
assert_eq!(
constraints.a.nrows(),
3 * prepared.runtime.breakpoints().len().saturating_sub(1)
);
assert_eq!(
constraints.b,
Array1::<f64>::from_elem(
constraints.a.nrows(),
prepared.runtime.monotonicity_eps() - 1.0
)
);
let feasible = Array1::<f64>::zeros(dim);
prepared
.runtime
.monotonicity_feasible(&feasible, "feasible structural beta")
.expect("zero deviation should be feasible");
let d1_design = prepared
.runtime
.first_derivative_design(&seed)
.expect("derivative design");
let row_idx = (0..d1_design.nrows())
.find(|&idx| d1_design.row(idx).dot(&d1_design.row(idx)) > 0.0)
.expect("derivative design should include a nonzero row");
let derivative_row = d1_design.row(row_idx).to_owned();
let row_norm_sq = derivative_row.dot(&derivative_row);
let infeasible = derivative_row.mapv(|value| -2.0 * value / row_norm_sq);
assert!(
prepared
.runtime
.monotonicity_feasible(&infeasible, "infeasible structural beta")
.is_err()
);
}
#[test]
fn structural_constraints_are_quadratic_derivative_bernstein_controls() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let constraints = prepared.runtime.structural_monotonicity_constraints();
let beta = Array1::from_iter((0..prepared.runtime.basis_dim()).map(|idx| {
let centered = idx as f64 - 0.5 * (prepared.runtime.basis_dim() as f64 - 1.0);
0.025 * centered
}));
let controls = constraints.a.dot(&beta);
let n_spans = prepared.runtime.breakpoints().len().saturating_sub(1);
for span_idx in 0..n_spans {
let cubic = prepared
.runtime
.local_cubic_on_span(&beta, span_idx)
.expect("local cubic");
let left = cubic.left;
let right = cubic.right;
let mid = 0.5 * (left + right);
let b0 = controls[3 * span_idx];
let b1 = controls[3 * span_idx + 1];
let b2 = controls[3 * span_idx + 2];
assert!(
(b0 - cubic.first_derivative(left)).abs() <= 1e-10,
"left Bernstein control should equal derivative at span start"
);
assert!(
(b2 - cubic.first_derivative(right)).abs() <= 1e-10,
"right Bernstein control should equal derivative at span end"
);
let midpoint_from_bernstein = 0.25 * b0 + 0.5 * b1 + 0.25 * b2;
assert!(
(midpoint_from_bernstein - cubic.first_derivative(mid)).abs() <= 1e-10,
"quadratic Bernstein controls should reconstruct derivative at span midpoint"
);
}
}
#[test]
fn deviation_penalties_are_integrated_function_penalties() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
penalty_order: 2,
penalty_orders: vec![1, 2, 3],
double_penalty: true,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let expected_orders = [1, 2, 3, 0];
assert_eq!(prepared.block.penalties.len(), expected_orders.len());
for ((penalty, &nullity), &order) in prepared
.block
.penalties
.iter()
.zip(prepared.block.nullspace_dims.iter())
.zip(expected_orders.iter())
{
let crate::solver::estimate::PenaltySpec::Dense(actual) = penalty else {
panic!("deviation penalties should be dense local Gram matrices");
};
let (expected, expected_nullity) = prepared
.runtime
.integrated_derivative_penalty_with_nullity(order)
.expect("integrated function penalty");
assert_eq!(nullity, expected_nullity);
assert_eq!(actual.dim(), expected.dim());
for i in 0..actual.nrows() {
for j in 0..actual.ncols() {
assert!(
(actual[[i, j]] - expected[[i, j]]).abs() <= 1e-10,
"penalty order {order} mismatch at ({i},{j}): got {}, expected {}",
actual[[i, j]],
expected[[i, j]]
);
}
}
}
let crate::solver::estimate::PenaltySpec::Dense(l2_penalty) = &prepared.block.penalties[1]
else {
panic!("deviation double penalty should be dense");
};
let mut max_identity_diff = 0.0_f64;
for i in 0..l2_penalty.nrows() {
for j in 0..l2_penalty.ncols() {
let identity = if i == j { 1.0 } else { 0.0 };
max_identity_diff = max_identity_diff.max((l2_penalty[[i, j]] - identity).abs());
}
}
assert!(
max_identity_diff > 1e-6,
"deviation double penalty must be integrated L2, not coefficient identity"
);
}
#[test]
fn local_cubic_span_reconstructs_deviation_exactly() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let dim = prepared.block.design.ncols();
let beta = Array1::from_iter((0..dim).map(|idx| 0.025 * (idx as f64 + 1.0)));
let n_spans = prepared.runtime.breakpoints().len().saturating_sub(1);
let support_left = prepared.runtime.breakpoints()[0];
let support_right =
prepared.runtime.breakpoints()[prepared.runtime.breakpoints().len() - 1];
for span_idx in 0..n_spans {
let cubic = prepared
.runtime
.local_cubic_on_span(&beta, span_idx)
.expect("local cubic coefficients");
let left = cubic.left;
let right = cubic.right;
let x_eval = array![left, 0.5 * (left + right), right];
let value_design = prepared.runtime.design(&x_eval).expect("value design");
let d1_design = prepared
.runtime
.first_derivative_design(&x_eval)
.expect("first derivative design");
let d2_design = prepared
.runtime
.second_derivative_design(&x_eval)
.expect("second derivative design");
let expected = value_design.dot(&beta);
let expected_d1 = d1_design.dot(&beta);
let expected_d2 = d2_design.dot(&beta);
for i in 0..x_eval.len() {
let x = x_eval[i];
assert!(
(cubic.evaluate(x) - expected[i]).abs() < 1e-10,
"span {span_idx}, x={x:.6}: cubic value mismatch"
);
assert!(
(cubic.first_derivative(x) - expected_d1[i]).abs() < 1e-10,
"span {span_idx}, x={x:.6}: cubic first-derivative mismatch"
);
assert!(
(cubic.second_derivative(x) - expected_d2[i]).abs() < 1e-10,
"span {span_idx}, x={x:.6}: cubic second-derivative mismatch"
);
let selected = prepared
.runtime
.local_cubic_at(&beta, x)
.expect("lookup cubic at x");
if x < support_left || x > support_right {
assert!(selected.c1.abs() < 1e-12);
assert!(selected.c2.abs() < 1e-12);
assert!(selected.c3.abs() < 1e-12);
assert!((selected.evaluate(x) - expected[i]).abs() < 1e-10);
} else {
let expected_span_idx = if i == 0 && span_idx > 0 {
span_idx - 1
} else {
span_idx
};
let expected_cubic = prepared
.runtime
.local_cubic_on_span(&beta, expected_span_idx)
.expect("expected lookup cubic on span");
assert_eq!(selected, expected_cubic);
}
}
}
}
#[test]
fn basis_span_cubic_reconstructs_basis_column_exactly() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let basis_idx = 0usize;
let support_left = prepared.runtime.breakpoints()[0];
let support_right =
prepared.runtime.breakpoints()[prepared.runtime.breakpoints().len() - 1];
let cubic = prepared
.runtime
.basis_span_cubic(0, basis_idx)
.expect("basis span cubic");
let x_eval = array![cubic.left, 0.5 * (cubic.left + cubic.right), cubic.right];
let design = prepared.runtime.design(&x_eval).expect("basis design");
let d1 = prepared
.runtime
.first_derivative_design(&x_eval)
.expect("basis d1 design");
for i in 0..x_eval.len() {
let x = x_eval[i];
assert!((cubic.evaluate(x) - design[[i, basis_idx]]).abs() < 1e-10);
assert!((cubic.first_derivative(x) - d1[[i, basis_idx]]).abs() < 1e-10);
let selected = prepared
.runtime
.basis_cubic_at(basis_idx, x)
.expect("basis cubic at x");
if x < support_left || x > support_right {
assert!(selected.c1.abs() < 1e-12);
assert!(selected.c2.abs() < 1e-12);
assert!(selected.c3.abs() < 1e-12);
assert!((selected.evaluate(x) - design[[i, basis_idx]]).abs() < 1e-10);
} else {
let expected_span_idx = 0;
let expected_cubic = prepared
.runtime
.basis_span_cubic(expected_span_idx, basis_idx)
.expect("expected basis span cubic");
assert_eq!(selected, expected_cubic);
}
}
}
#[test]
fn deviation_runtime_saturates_outside_support() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let dim = prepared.block.design.ncols();
let beta = Array1::from_iter((0..dim).map(|idx| 0.02 * (idx as f64 + 1.0)));
let left = prepared.runtime.breakpoints()[0];
let right = prepared.runtime.breakpoints()[prepared.runtime.breakpoints().len() - 1];
let left_tail_near = prepared
.runtime
.local_cubic_at(&beta, left - 0.25)
.expect("left tail");
let left_tail_far = prepared
.runtime
.local_cubic_at(&beta, left - 3.0)
.expect("left far tail");
let right_tail_near = prepared
.runtime
.local_cubic_at(&beta, right + 0.25)
.expect("right tail");
let right_tail_far = prepared
.runtime
.local_cubic_at(&beta, right + 3.0)
.expect("right far tail");
for cubic in [
left_tail_near,
left_tail_far,
right_tail_near,
right_tail_far,
] {
assert!(cubic.c1.abs() < 1e-12);
assert!(cubic.c2.abs() < 1e-12);
assert!(cubic.c3.abs() < 1e-12);
}
assert!((left_tail_near.c0 - left_tail_far.c0).abs() < 1e-12);
assert!((right_tail_near.c0 - right_tail_far.c0).abs() < 1e-12);
}
#[test]
fn deviation_runtime_replays_the_exact_training_basis() {
let seed = array![-2.0, -1.0, -0.25, 0.25, 1.0, 2.0];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build deviation block");
let replayed = prepared
.runtime
.design(&seed)
.expect("replay anchored deviation design");
let trained = prepared.block.design.to_dense();
assert_eq!(replayed.dim(), trained.dim());
for i in 0..replayed.nrows() {
for j in 0..replayed.ncols() {
assert!(
(replayed[[i, j]] - trained[[i, j]]).abs() <= 1e-10,
"replayed anchored deviation design mismatch at ({i},{j})"
);
}
}
}
#[test]
fn denested_microcells_follow_score_and_link_breaks() {
let score_seed = array![-2.0, -1.0, 0.0, 1.0, 2.0];
let link_seed = array![-1.5, -0.5, 0.5, 1.5];
let score_prepared = build_score_warp_deviation_block_from_seed(
&score_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build score warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.02 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
);
let exact_cells_a0 = build_denested_partition_cells(
0.25,
0.9,
score_prepared
.runtime
.breakpoints()
.as_slice()
.expect("score breaks"),
link_prepared
.runtime
.breakpoints()
.as_slice()
.expect("link breaks"),
|z| score_prepared.runtime.local_cubic_at(&beta_h, z),
|u| link_prepared.runtime.local_cubic_at(&beta_w, u),
)
.expect("exact module microcells for a=0.25");
let exact_cells_a1 = build_denested_partition_cells(
0.55,
0.9,
score_prepared
.runtime
.breakpoints()
.as_slice()
.expect("score breaks"),
link_prepared
.runtime
.breakpoints()
.as_slice()
.expect("link breaks"),
|z| score_prepared.runtime.local_cubic_at(&beta_h, z),
|u| link_prepared.runtime.local_cubic_at(&beta_w, u),
)
.expect("exact module microcells for a=0.55");
assert!(
exact_cells_a0.len() >= score_prepared.runtime.breakpoints().len().saturating_sub(1),
"microcell partition should refine the score spans"
);
assert!(
exact_cells_a0
.windows(2)
.all(|w| (w[0].cell.right - w[1].cell.left).abs() <= 1e-12),
"microcells should tile the partition contiguously"
);
assert!(exact_cells_a0.first().unwrap().cell.left.is_infinite());
assert!(exact_cells_a0.last().unwrap().cell.right.is_infinite());
assert!(
exact_cells_a0
.iter()
.zip(exact_cells_a1.iter())
.any(|(lhs, rhs)| (lhs.cell.left - rhs.cell.left).abs() > 1e-10),
"changing the intercept should move at least one link-induced breakpoint"
);
}
#[test]
fn denested_microcell_eta_matches_direct_denested_formula() {
let score_seed = array![-2.0, -1.0, 0.0, 1.0, 2.0];
let link_seed = array![-1.5, -0.5, 0.5, 1.5];
let score_prepared = build_score_warp_deviation_block_from_seed(
&score_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build score warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.02 * (idx as f64 + 1.0)),
);
let a = 0.35;
let b = -0.7;
let cells = build_denested_partition_cells(
a,
b,
score_prepared
.runtime
.breakpoints()
.as_slice()
.expect("score breaks"),
link_prepared
.runtime
.breakpoints()
.as_slice()
.expect("link breaks"),
|z| score_prepared.runtime.local_cubic_at(&beta_h, z),
|u| link_prepared.runtime.local_cubic_at(&beta_w, u),
)
.expect("microcells");
for cell in &cells {
let z = exact_kernel::interval_probe_point(cell.cell.left, cell.cell.right)
.expect("finite microcell probe");
let h = score_prepared
.runtime
.design(&array![z])
.expect("score design")
.row(0)
.dot(&beta_h);
let link = link_prepared
.runtime
.design(&array![a + b * z])
.expect("link design")
.row(0)
.dot(&beta_w);
let expected = a + b * z + b * h + link;
assert!(
(cell.cell.eta(z) - expected).abs() < 1e-10,
"microcell eta should equal the direct de-nested predictor at z={z:.6}"
);
}
}
#[test]
fn local_cubic_global_transform_reconstructs_same_function() {
let cubic = exact_kernel::LocalSpanCubic {
left: -1.3,
right: 0.7,
c0: 0.4,
c1: -0.2,
c2: 0.15,
c3: -0.05,
};
let (g0, g1, g2, g3) = exact_global_cubic_from_local(LocalSpanCubic {
left: cubic.left,
right: cubic.right,
c0: cubic.c0,
c1: cubic.c1,
c2: cubic.c2,
c3: cubic.c3,
});
for &x in &[-1.3, -0.8, -0.1, 0.5, 0.7] {
let direct = cubic.evaluate(x);
let global = g0 + g1 * x + g2 * x * x + g3 * x * x * x;
assert!(
(direct - global).abs() < 1e-12,
"global cubic transform should preserve the span polynomial at x={x}"
);
}
}
#[test]
fn denested_branch_selection_uses_normalized_cell_coefficients() {
let affine = ExactDenestedCubicCell {
left: -1.0,
right: 1.0,
c0: 0.2,
c1: -0.4,
c2: 1e-13,
c3: -1e-13,
};
let quartic = ExactDenestedCubicCell {
c2: 2e-4,
c3: 1e-13,
..affine
};
let sextic = ExactDenestedCubicCell {
c2: 2e-4,
c3: 5e-3,
..affine
};
assert_eq!(
branch_exact_cell(affine).expect("affine branch"),
ExactCellBranchShared::Affine
);
assert_eq!(
branch_exact_cell(quartic).expect("quartic branch"),
ExactCellBranchShared::Quartic
);
assert_eq!(
branch_exact_cell(sextic).expect("sextic branch"),
ExactCellBranchShared::Sextic
);
}
#[test]
fn denested_cell_coefficient_partials_match_finite_differences() {
let score_span = exact_kernel::LocalSpanCubic {
left: -0.75,
right: 0.25,
c0: 0.08,
c1: -0.03,
c2: 0.02,
c3: -0.01,
};
let link_span = exact_kernel::LocalSpanCubic {
left: -0.6,
right: 0.9,
c0: -0.05,
c1: 0.04,
c2: -0.02,
c3: 0.015,
};
let a = 0.3;
let b = -0.7;
let eps = 1e-6;
let coeffs = |aa: f64, bb: f64| {
let (h0, h1, h2, h3) = exact_global_cubic_from_local(LocalSpanCubic {
left: score_span.left,
right: score_span.right,
c0: score_span.c0,
c1: score_span.c1,
c2: score_span.c2,
c3: score_span.c3,
});
let (d0, d1, d2, d3) = exact_transformed_link_cubic(
LocalSpanCubic {
left: link_span.left,
right: link_span.right,
c0: link_span.c0,
c1: link_span.c1,
c2: link_span.c2,
c3: link_span.c3,
},
aa,
bb,
);
[
aa + bb * h0 + d0,
bb + bb * h1 + d1,
bb * h2 + d2,
bb * h3 + d3,
]
};
let (dc_da, dc_db) = exact_denested_cell_coefficient_partials(
LocalSpanCubic {
left: score_span.left,
right: score_span.right,
c0: score_span.c0,
c1: score_span.c1,
c2: score_span.c2,
c3: score_span.c3,
},
LocalSpanCubic {
left: link_span.left,
right: link_span.right,
c0: link_span.c0,
c1: link_span.c1,
c2: link_span.c2,
c3: link_span.c3,
},
a,
b,
);
let plus_a = coeffs(a + eps, b);
let minus_a = coeffs(a - eps, b);
let plus_b = coeffs(a, b + eps);
let minus_b = coeffs(a, b - eps);
for j in 0..4 {
let fd_a = (plus_a[j] - minus_a[j]) / (2.0 * eps);
let fd_b = (plus_b[j] - minus_b[j]) / (2.0 * eps);
assert!(
(dc_da[j] - fd_a).abs() < 1e-6,
"dc/da mismatch at coefficient {j}: analytic={}, fd={fd_a}",
dc_da[j]
);
assert!(
(dc_db[j] - fd_b).abs() < 1e-6,
"dc/db mismatch at coefficient {j}: analytic={}, fd={fd_b}",
dc_db[j]
);
}
}
#[test]
fn observed_denested_partials_include_third_a_derivative_for_piecewise_cubic_link() {
let z = array![-0.8, 0.2, 1.1];
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 1.0]),
weights: Arc::new(array![1.0, 0.7, 1.3]),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: None,
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let a = 0.35;
let b = 0.6;
let row = 1usize;
let obs = family
.observed_denested_cell_partials(row, a, b, None, Some(&beta_w))
.expect("observed denested partials");
let u_obs = a + b * z[row];
let link_span = link_prepared
.runtime
.local_cubic_at(&beta_w, u_obs)
.expect("local cubic at observed point");
let expected_daaa = exact_kernel::denested_cell_third_partials(link_span).0;
assert_eq!(obs.dc_daaa, expected_daaa);
assert!(
eval_coeff4_at(&obs.dc_daaa, z[row]).abs() > 1e-12,
"piecewise-cubic link spans should contribute a third a-derivative"
);
}
#[test]
fn pooled_probit_baseline_matches_expanded_integer_weight_fit() {
let y = array![0.0, 1.0, 0.0, 1.0];
let z = array![-1.5, -0.2, 0.4, 1.4];
let weights = array![25.0, 2.0, 1.0, 20.0];
let weighted = pooled_probit_baseline(&y, &z, &weights).expect("weighted baseline");
let unweighted =
pooled_probit_baseline(&y, &z, &Array1::ones(y.len())).expect("unweighted baseline");
let (y_expanded, z_expanded) = expand_integer_weight_rows(&y, &z, &weights);
let expanded =
pooled_probit_baseline(&y_expanded, &z_expanded, &Array1::ones(y_expanded.len()))
.expect("expanded baseline");
assert!(
pair_distance(expanded, unweighted) > 1e-2,
"test data should distinguish weighted from unweighted seeding"
);
assert!(
pair_distance(weighted, expanded) < 1e-8,
"weighted pilot baseline should match the expanded integer-weight fit"
);
}
#[test]
fn validate_spec_rejects_nonfinite_or_negative_weights() {
let data = Array2::<f64>::zeros((3, 0));
let y = array![0.0, 1.0, 0.0];
let z = array![-1.0, 0.0, 1.0];
let err = validate_spec(
data.view(),
&base_spec(y.clone(), array![1.0, f64::NAN, 1.0], z.clone()),
)
.expect_err("non-finite weights should be rejected");
assert!(err.contains("finite non-negative weights"));
let err = validate_spec(data.view(), &base_spec(y, array![1.0, -0.5, 1.0], z))
.expect_err("negative weights should be rejected");
assert!(err.contains("finite non-negative weights"));
}
#[test]
fn validate_spec_rejects_nonfinite_z_values() {
let data = Array2::<f64>::zeros((3, 0));
let err = validate_spec(
data.view(),
&base_spec(
array![0.0, 1.0, 0.0],
array![1.0, 1.0, 1.0],
array![-1.0, f64::INFINITY, 1.0],
),
)
.expect_err("non-finite z should be rejected");
assert!(err.contains("finite z values"));
}
#[test]
fn validate_spec_accepts_learnable_gaussian_shift_sigma() {
let data = Array2::<f64>::zeros((3, 0));
let mut spec = base_spec(
array![0.0, 1.0, 0.0],
array![1.0, 1.0, 1.0],
array![-1.0, 0.0, 1.0],
);
spec.frailty = FrailtySpec::GaussianShift { sigma_fixed: None };
validate_spec(data.view(), &spec).expect("learnable GaussianShift sigma should validate");
}
#[test]
fn signed_probit_helpers_handle_nonfinite_boundaries_explicitly() {
let (logcdf_pos, lambda_pos) = signed_probit_logcdf_and_mills_ratio(f64::INFINITY);
assert_eq!(logcdf_pos, 0.0);
assert_eq!(lambda_pos, 0.0);
let (logcdf_neg, lambda_neg) = signed_probit_logcdf_and_mills_ratio(f64::NEG_INFINITY);
assert_eq!(logcdf_neg, f64::NEG_INFINITY);
assert_eq!(lambda_neg, f64::INFINITY);
let (logcdf_nan, lambda_nan) = signed_probit_logcdf_and_mills_ratio(f64::NAN);
assert!(logcdf_nan.is_nan());
assert!(lambda_nan.is_nan());
}
#[test]
fn signed_probit_exact_derivative_helper_rejects_invalid_nonfinite_margins() {
assert_eq!(
signed_probit_neglog_derivatives_up_to_fourth(f64::INFINITY, 2.5)
.expect("+inf should use the zero tail"),
(0.0, 0.0, 0.0, 0.0)
);
let neg_inf_err = signed_probit_neglog_derivatives_up_to_fourth(f64::NEG_INFINITY, 2.5)
.expect_err("-inf should be rejected in the exact derivative path");
assert!(neg_inf_err.contains("non-finite signed margin"));
let nan_err = signed_probit_neglog_derivatives_up_to_fourth(f64::NAN, 2.5)
.expect_err("NaN should be rejected in the exact derivative path");
assert!(nan_err.contains("non-finite signed margin"));
}
#[test]
fn unary_neglog_phi_preserves_negative_infinity_and_nan_boundaries() {
assert_eq!(
unary_derivatives_neglog_phi(f64::INFINITY, 1.75),
[0.0, 0.0, 0.0, 0.0, 0.0]
);
assert_eq!(
unary_derivatives_neglog_phi(f64::NEG_INFINITY, 1.75),
[f64::INFINITY, f64::NEG_INFINITY, 1.75, 0.0, 0.0]
);
let nan_terms = unary_derivatives_neglog_phi(f64::NAN, 1.75);
assert!(nan_terms.iter().all(|value| value.is_nan()));
}
#[test]
fn flexible_family_exposes_exact_outer_derivative_path() {
let seed = array![-1.0, 0.0, 1.0];
let score_prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(3)),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let specs = vec![
dummy_blockspec(1, 3),
dummy_blockspec(1, 3),
dummy_blockspec(2, 3),
];
assert_eq!(
family.exact_outer_derivative_order(&specs, &BlockwiseFitOptions::default()),
ExactOuterDerivativeOrder::Second
);
assert!(family.exact_newton_joint_psi_workspace_for_first_order_terms());
}
#[test]
fn exact_outer_order_stays_second_order_at_biobank_work_scale() {
use crate::custom_family::{
default_coefficient_hessian_cost, exact_outer_order_from_capability,
};
use crate::matrix::DesignMatrix;
use ndarray::Array2;
let small_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((10, 8)),
));
let small_specs: Vec<ParameterBlockSpec> = (0..2)
.map(|i| ParameterBlockSpec {
name: format!("block_{i}"),
design: small_design.clone(),
offset: ndarray::Array1::zeros(10),
penalties: (0..2)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((8, 8))))
.collect(),
nullspace_dims: vec![0; 2],
initial_log_lambdas: ndarray::Array1::zeros(2),
initial_beta: None,
})
.collect();
let small_cost = default_coefficient_hessian_cost(&small_specs);
assert_eq!(
exact_outer_order_from_capability(&small_specs, small_cost),
ExactOuterDerivativeOrder::Second
);
let big_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((5_000, 500)),
));
let big_specs: Vec<ParameterBlockSpec> = (0..2)
.map(|i| ParameterBlockSpec {
name: format!("block_{i}"),
design: big_design.clone(),
offset: ndarray::Array1::zeros(5_000),
penalties: (0..10)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((500, 500))))
.collect(),
nullspace_dims: vec![0; 10],
initial_log_lambdas: ndarray::Array1::zeros(10),
initial_beta: None,
})
.collect();
let big_cost = default_coefficient_hessian_cost(&big_specs);
assert_eq!(
exact_outer_order_from_capability(&big_specs, big_cost),
ExactOuterDerivativeOrder::Second
);
let ctn_specs = vec![ParameterBlockSpec {
name: "ctn".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Array2::zeros((
1, 1,
)))),
offset: ndarray::Array1::zeros(1),
penalties: (0..22)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((1, 1))))
.collect(),
nullspace_dims: vec![0; 22],
initial_log_lambdas: ndarray::Array1::zeros(22),
initial_beta: None,
}];
let ctn_cost: u64 = 400_000u64 * 300 * 300;
assert_eq!(
exact_outer_order_from_capability(&ctn_specs, ctn_cost),
ExactOuterDerivativeOrder::Second
);
use crate::custom_family::exact_outer_order_with_outer_hvp;
assert_eq!(
exact_outer_order_with_outer_hvp(&ctn_specs, ctn_cost, false),
ExactOuterDerivativeOrder::Second,
"exact outer order must not be cost-demoted"
);
let huge_k_specs = vec![ParameterBlockSpec {
name: "k".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Array2::zeros((
1, 1,
)))),
offset: ndarray::Array1::zeros(1),
penalties: (0..5_000)
.map(|_| crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((1, 1))))
.collect(),
nullspace_dims: vec![0; 5_000],
initial_log_lambdas: ndarray::Array1::zeros(5_000),
initial_beta: None,
}];
assert_eq!(
exact_outer_order_with_outer_hvp(&huge_k_specs, 0, true),
ExactOuterDerivativeOrder::Second,
"outer HVP support keeps the exact second-order declaration regardless of K"
);
}
#[test]
fn bernoulli_marginal_slope_coefficient_cost_uses_joint_coupled_formula() {
use crate::custom_family::default_coefficient_hessian_cost;
use crate::matrix::DesignMatrix;
let n = 1000usize;
let p_marg = 20usize;
let p_log = 8usize;
let marg_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((n, p_marg)),
));
let log_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((n, p_log)),
));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::zeros(n)),
weights: Arc::new(Array1::from_elem(n, 1.0)),
z: Arc::new(Array1::zeros(n)),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: marg_design.clone(),
logslope_design: log_design.clone(),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let specs = vec![
ParameterBlockSpec {
name: "marginal".to_string(),
design: marg_design,
offset: Array1::zeros(n),
penalties: vec![crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((
p_marg, p_marg,
)))],
nullspace_dims: vec![0],
initial_log_lambdas: Array1::zeros(1),
initial_beta: None,
},
ParameterBlockSpec {
name: "logslope".to_string(),
design: log_design,
offset: Array1::zeros(n),
penalties: vec![crate::custom_family::PenaltyMatrix::Dense(Array2::zeros((
p_log, p_log,
)))],
nullspace_dims: vec![0],
initial_log_lambdas: Array1::zeros(1),
initial_beta: None,
},
];
let p_total = (p_marg + p_log) as u64;
let expected_joint = (n as u64) * p_total * p_total;
let expected_block_diag = (n as u64) * ((p_marg * p_marg + p_log * p_log) as u64);
assert_eq!(family.coefficient_hessian_cost(&specs), expected_joint);
assert_eq!(
default_coefficient_hessian_cost(&specs),
expected_block_diag
);
assert!(family.coefficient_hessian_cost(&specs) > default_coefficient_hessian_cost(&specs));
}
#[test]
fn rigid_fast_path_matches_loglik_finite_differences() {
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![1.0]),
weights: Arc::new(array![1.2]),
z: Arc::new(array![0.3]),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![[
1.0
]])),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![[
1.0
]])),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let states_at = |q: f64, g: f64| {
vec![
ParameterBlockState {
beta: array![q],
eta: array![q],
},
ParameterBlockState {
beta: array![g],
eta: array![g],
},
]
};
let q = 0.4;
let g = 0.7;
let block_states = states_at(q, g);
let eval = family
.evaluate(&block_states)
.expect("rigid family evaluation");
let grad_q = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
BlockWorkingSet::Diagonal { .. } => {
panic!("expected exact-newton marginal block")
}
};
let grad_g = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
BlockWorkingSet::Diagonal { .. } => {
panic!("expected exact-newton log-slope block")
}
};
let hess_qq = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { hessian, .. } => match hessian {
SymmetricMatrix::Dense(h) => h[[0, 0]],
_ => panic!("expected dense marginal Hessian"),
},
BlockWorkingSet::Diagonal { .. } => {
panic!("expected exact-newton marginal block")
}
};
let hess_gg = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { hessian, .. } => match hessian {
SymmetricMatrix::Dense(h) => h[[0, 0]],
_ => panic!("expected dense log-slope Hessian"),
},
BlockWorkingSet::Diagonal { .. } => {
panic!("expected exact-newton log-slope block")
}
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("rigid exact eval cache");
let row_ctx = family
.build_row_exact_context(0, &block_states)
.expect("rigid row context");
let (_, primary_grad, primary_hess) = family
.compute_row_primary_gradient_hessian(0, &block_states, &cache.primary, &row_ctx)
.expect("rigid exact row derivatives");
let expected_score_q =
BernoulliMarginalSlopeFamily::exact_newton_score_component_from_objective_gradient(
primary_grad[0],
);
let expected_score_g =
BernoulliMarginalSlopeFamily::exact_newton_score_component_from_objective_gradient(
primary_grad[1],
);
assert!(
(grad_q - expected_score_q).abs() < 1e-10,
"marginal gradient mismatch: fast={grad_q:.12e}, exact={expected_score_q:.12e}"
);
assert!(
(grad_g - expected_score_g).abs() < 1e-10,
"logslope gradient mismatch: fast={grad_g:.12e}, exact={expected_score_g:.12e}"
);
assert!(
(hess_qq - primary_hess[[0, 0]]).abs() < 1e-10,
"marginal Hessian mismatch: fast={hess_qq:.12e}, exact={:.12e}",
primary_hess[[0, 0]]
);
assert!(
(hess_gg - primary_hess[[1, 1]]).abs() < 1e-10,
"logslope Hessian mismatch: fast={hess_gg:.12e}, exact={:.12e}",
primary_hess[[1, 1]]
);
}
#[test]
fn w_only_gradient_hessian_finite_and_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let beta_link = Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_link.clone(), seed.len()),
];
let slices = block_slices(&family);
assert!(slices.h.is_none(), "score-warp absent → no h slice");
let primary = primary_slices(&slices);
assert!(primary.h.is_none(), "primary h absent");
assert_eq!(primary.total, 2 + link_dim);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let (_, grad, hess) = family
.compute_row_primary_gradient_hessian(row, &block_states, &primary, &row_ctx)
.unwrap_or_else(|e| {
panic!("row {row}: compute_row_primary_gradient_hessian failed: {e}")
});
assert_eq!(
grad.len(),
primary.total,
"row {row}: gradient length mismatch"
);
assert_eq!(
hess.dim(),
(primary.total, primary.total),
"row {row}: hessian shape mismatch"
);
assert!(
grad.iter().all(|v| v.is_finite()),
"row {row}: non-finite gradient entry: {grad:?}"
);
assert!(
hess.iter().all(|v| v.is_finite()),
"row {row}: non-finite hessian entry"
);
for a in 0..primary.total {
for b in 0..a {
let diff = (hess[[a, b]] - hess[[b, a]]).abs();
assert!(
diff < 1e-10,
"row {row}: hessian asymmetry at ({a},{b}): \
H[{a},{b}]={:.6e} vs H[{b},{a}]={:.6e}, diff={diff:.3e}",
hess[[a, b]],
hess[[b, a]]
);
}
}
}
}
#[test]
fn h_only_gradient_hessian_finite_and_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let beta_score = Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_score, seed.len()),
];
let slices = block_slices(&family);
assert!(slices.w.is_none(), "link-dev absent → no w slice");
let primary = primary_slices(&slices);
assert!(primary.w.is_none(), "primary w absent");
assert_eq!(primary.total, 2 + score_dim);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let (_, grad, hess) = family
.compute_row_primary_gradient_hessian(row, &block_states, &primary, &row_ctx)
.unwrap_or_else(|e| {
panic!("row {row}: compute_row_primary_gradient_hessian failed: {e}")
});
assert_eq!(
grad.len(),
primary.total,
"row {row}: gradient length mismatch"
);
assert_eq!(
hess.dim(),
(primary.total, primary.total),
"row {row}: hessian shape mismatch"
);
assert!(
grad.iter().all(|v| v.is_finite()),
"row {row}: non-finite gradient entry: {grad:?}"
);
assert!(
hess.iter().all(|v| v.is_finite()),
"row {row}: non-finite hessian entry"
);
for a in 0..primary.total {
for b in 0..a {
let diff = (hess[[a, b]] - hess[[b, a]]).abs();
assert!(
diff < 1e-10,
"row {row}: hessian asymmetry at ({a},{b}): \
H[{a},{b}]={:.6e} vs H[{b},{a}]={:.6e}, diff={diff:.3e}",
hess[[a, b]],
hess[[b, a]]
);
}
}
}
}
#[test]
fn w_only_exact_outer_directional_derivatives_are_present_and_finite() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let beta_link = Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_link, seed.len()),
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("w-only third directional derivative")
.expect("w-only third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("w-only fourth directional derivative")
.expect("w-only fourth directional derivative matrix");
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
let max_abs_third = third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
let max_abs_fourth = fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
assert!(
max_abs_third > 1e-10,
"expected nonzero w-only third directional derivative"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero w-only fourth directional derivative"
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
#[test]
fn h_only_exact_outer_directional_derivatives_are_present_and_finite() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let beta_score = Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0)));
let scalar_design = || {
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Array2::from_elem(
(seed.len(), 1),
1.0,
)))
};
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: scalar_design(),
logslope_design: scalar_design(),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(seed.len(), 0.25),
},
ParameterBlockState {
beta: array![0.15],
eta: Array1::from_elem(seed.len(), 0.15),
},
ParameterBlockState {
beta: beta_score,
eta: Array1::zeros(seed.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = -0.35;
dir_u[slices.logslope.start] = 0.28;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[slices.marginal.start] = 0.18;
dir_v[slices.logslope.start] = -0.22;
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("h-only third directional derivative")
.expect("h-only third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("h-only fourth directional derivative")
.expect("h-only fourth directional derivative matrix");
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
let max_abs_third = third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
let max_abs_fourth = fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
assert!(
max_abs_third > 1e-10,
"expected nonzero h-only third directional derivative"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero h-only fourth directional derivative"
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
#[test]
fn h_only_row_primary_higher_order_contractions_are_finite_and_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let beta_score = Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_score, seed.len()),
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = -0.35;
dir_u[cache.primary.logslope] = 0.28;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[cache.primary.q] = 0.18;
dir_v[cache.primary.logslope] = -0.22;
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let mut max_abs_third = 0.0_f64;
let mut max_abs_fourth = 0.0_f64;
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
max_abs_third = max_abs_third.max(
third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
max_abs_fourth = max_abs_fourth.max(
fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
assert!(
max_abs_third > 1e-10,
"expected nonzero h-only third contraction"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero h-only fourth contraction"
);
}
#[test]
fn w_only_row_primary_higher_order_contractions_are_finite_and_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let beta_link = Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0)));
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(beta_link, seed.len()),
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.4;
dir_u[cache.primary.logslope] = -0.3;
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[cache.primary.q] = -0.2;
dir_v[cache.primary.logslope] = 0.25;
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let mut max_abs_third = 0.0_f64;
let mut max_abs_fourth = 0.0_f64;
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
max_abs_third = max_abs_third.max(
third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
max_abs_fourth = max_abs_fourth.max(
fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
assert!(
max_abs_third > 1e-10,
"expected nonzero w-only third contraction"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero w-only fourth contraction"
);
}
#[test]
fn dual_flex_row_primary_higher_order_contractions_are_finite_and_symmetric() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.7;
dir_u[cache.primary.logslope] = -0.2;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[cache.primary.q] = -0.4;
dir_v[cache.primary.logslope] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let mut max_abs_third = 0.0_f64;
let mut max_abs_fourth = 0.0_f64;
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
max_abs_third = max_abs_third.max(
third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
max_abs_fourth = max_abs_fourth.max(
fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs())),
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
assert!(
max_abs_third > 1e-10,
"expected nonzero dual-flex third contraction"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero dual-flex fourth contraction"
);
}
#[test]
fn dual_flex_row_primary_higher_order_zero_direction_returns_zero() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let zero = Array1::<f64>::zeros(cache.primary.total);
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &zero)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&zero,
&zero,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero third contraction for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero fourth contraction for zero directions"
);
}
}
#[test]
fn h_only_row_primary_higher_order_zero_direction_returns_zero() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let zero = Array1::<f64>::zeros(cache.primary.total);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &zero)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&zero,
&zero,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero h-only third contraction for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero h-only fourth contraction for zero directions"
);
}
}
#[test]
fn w_only_row_primary_higher_order_zero_direction_returns_zero() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let zero = Array1::<f64>::zeros(cache.primary.total);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &zero)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_third_contracted_recompute failed: {e}")
});
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&zero,
&zero,
)
.unwrap_or_else(|e| {
panic!("row {row}: row_primary_fourth_contracted_recompute failed: {e}")
});
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero w-only third contraction for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"row {row}: expected zero w-only fourth contraction for zero directions"
);
}
}
#[test]
fn dual_flex_exact_outer_zero_direction_returns_zero() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let zero = Array1::<f64>::zeros(slices.total);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &zero)
.expect("dual-flex third directional derivative")
.expect("dual-flex third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &zero, &zero)
.expect("dual-flex fourth directional derivative")
.expect("dual-flex fourth directional derivative matrix");
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"expected zero dual-flex third directional derivative for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"expected zero dual-flex fourth directional derivative for zero directions"
);
}
#[test]
fn dual_flex_exact_outer_fourth_direction_swap_is_symmetric() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let forward = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("dual-flex fourth directional derivative")
.expect("dual-flex fourth directional derivative matrix");
let swapped = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_v, &dir_u)
.expect("dual-flex swapped fourth directional derivative")
.expect("dual-flex swapped fourth directional derivative matrix");
assert_eq!(forward.dim(), (total, total));
assert_eq!(swapped.dim(), (total, total));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"fourth directional derivative should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
#[test]
fn dual_flex_row_primary_fourth_direction_swap_is_symmetric() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.7;
dir_u[cache.primary.logslope] = -0.2;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[cache.primary.q] = -0.4;
dir_v[cache.primary.logslope] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let forward = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: forward fourth contraction failed: {e}"));
let swapped = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: swapped fourth contraction failed: {e}"));
assert_eq!(forward.dim(), (total, total));
assert_eq!(swapped.dim(), (total, total));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"row {row}: fourth contraction should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
}
#[test]
fn dual_flex_row_primary_higher_order_direction_sign_rules_hold() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.7;
dir_u[cache.primary.logslope] = -0.2;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[cache.primary.q] = -0.4;
dir_v[cache.primary.logslope] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let neg_dir_u = dir_u.mapv(|value| -value);
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction failed: {e}"));
let third_neg = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: negated third contraction failed: {e}"));
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: fourth contraction failed: {e}"));
let fourth_neg_u = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: negated-u fourth contraction failed: {e}"));
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"row {row}: third contraction should be odd in its direction at ({i},{j})"
);
assert!(
(fourth_neg_u[[i, j]] + fourth[[i, j]]).abs() < 1e-8,
"row {row}: fourth contraction should be linear in dir_u sign at ({i},{j})"
);
}
}
}
}
#[test]
fn h_only_row_primary_fourth_direction_swap_is_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = -0.35;
dir_u[cache.primary.logslope] = 0.28;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[cache.primary.q] = 0.18;
dir_v[cache.primary.logslope] = -0.22;
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let forward = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: forward fourth contraction failed: {e}"));
let swapped = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: swapped fourth contraction failed: {e}"));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"row {row}: h-only fourth contraction should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
}
#[test]
fn w_only_row_primary_fourth_direction_swap_is_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.4;
dir_u[cache.primary.logslope] = -0.3;
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[cache.primary.q] = -0.2;
dir_v[cache.primary.logslope] = 0.25;
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let forward = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: forward fourth contraction failed: {e}"));
let swapped = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: swapped fourth contraction failed: {e}"));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"row {row}: w-only fourth contraction should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
}
#[test]
fn h_only_row_primary_higher_order_direction_sign_rules_hold() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir = Array1::<f64>::zeros(total);
dir[cache.primary.q] = -0.35;
dir[cache.primary.logslope] = 0.28;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir[h_range.start] = 0.12;
if h_range.len() > 1 {
dir[h_range.start + 1] = -0.06;
}
let neg_dir = dir.mapv(|value| -value);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &dir)
.unwrap_or_else(|e| panic!("row {row}: third contraction failed: {e}"));
let third_neg = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir,
)
.unwrap_or_else(|e| panic!("row {row}: negated third contraction failed: {e}"));
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir,
&dir,
)
.unwrap_or_else(|e| panic!("row {row}: fourth contraction failed: {e}"));
let fourth_neg = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir,
&neg_dir,
)
.unwrap_or_else(|e| {
panic!("row {row}: doubly-negated fourth contraction failed: {e}")
});
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"row {row}: h-only third contraction should be odd at ({i},{j})"
);
assert!(
(fourth_neg[[i, j]] - fourth[[i, j]]).abs() < 1e-8,
"row {row}: h-only fourth contraction should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
}
#[test]
fn w_only_row_primary_higher_order_direction_sign_rules_hold() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir = Array1::<f64>::zeros(total);
dir[cache.primary.q] = 0.4;
dir[cache.primary.logslope] = -0.3;
let w_range = cache.primary.w.as_ref().expect("w slice");
dir[w_range.start] = 0.15;
if w_range.len() > 1 {
dir[w_range.start + 1] = -0.07;
}
let neg_dir = dir.mapv(|value| -value);
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third = family
.row_primary_third_contracted_recompute(row, &block_states, &cache, &row_ctx, &dir)
.unwrap_or_else(|e| panic!("row {row}: third contraction failed: {e}"));
let third_neg = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir,
)
.unwrap_or_else(|e| panic!("row {row}: negated third contraction failed: {e}"));
let fourth = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir,
&dir,
)
.unwrap_or_else(|e| panic!("row {row}: fourth contraction failed: {e}"));
let fourth_neg = family
.row_primary_fourth_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&neg_dir,
&neg_dir,
)
.unwrap_or_else(|e| {
panic!("row {row}: doubly-negated fourth contraction failed: {e}")
});
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"row {row}: w-only third contraction should be odd at ({i},{j})"
);
assert!(
(fourth_neg[[i, j]] - fourth[[i, j]]).abs() < 1e-8,
"row {row}: w-only fourth contraction should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
}
#[test]
fn dual_flex_exact_outer_direction_sign_rules_hold() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let neg_dir_u = dir_u.mapv(|value| -value);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("dual-flex third directional derivative")
.expect("dual-flex third directional derivative matrix");
let third_neg = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &neg_dir_u)
.expect("dual-flex negated third directional derivative")
.expect("dual-flex negated third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("dual-flex fourth directional derivative")
.expect("dual-flex fourth directional derivative matrix");
let fourth_neg_u = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&neg_dir_u,
&dir_v,
)
.expect("dual-flex negated-u fourth directional derivative")
.expect("dual-flex negated-u fourth directional derivative matrix");
assert_eq!(third.dim(), (total, total));
assert_eq!(third_neg.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert_eq!(fourth_neg_u.dim(), (total, total));
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"third directional derivative should be odd in its direction at ({i},{j})"
);
assert!(
(fourth_neg_u[[i, j]] + fourth[[i, j]]).abs() < 1e-8,
"fourth directional derivative should be linear in dir_u sign at ({i},{j})"
);
}
}
}
#[test]
fn dual_flex_exact_outer_fourth_double_sign_flip_is_invariant() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let neg_dir_u = dir_u.mapv(|value| -value);
let neg_dir_v = dir_v.mapv(|value| -value);
let forward = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("dual-flex fourth directional derivative")
.expect("dual-flex fourth directional derivative matrix");
let flipped = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&neg_dir_u,
&neg_dir_v,
)
.expect("dual-flex doubly-negated fourth directional derivative")
.expect("dual-flex doubly-negated fourth directional derivative matrix");
assert_eq!(forward.dim(), (total, total));
assert_eq!(flipped.dim(), (total, total));
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - flipped[[i, j]]).abs() < 1e-8,
"fourth directional derivative should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
#[test]
fn dual_flex_exact_outer_third_direction_is_linear() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let dir_sum = &dir_u + &dir_v;
let third_u = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("dual-flex third directional derivative u")
.expect("dual-flex third directional derivative u matrix");
let third_v = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_v)
.expect("dual-flex third directional derivative v")
.expect("dual-flex third directional derivative v matrix");
let third_sum = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_sum)
.expect("dual-flex third directional derivative sum")
.expect("dual-flex third directional derivative sum matrix");
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"third directional derivative should be linear in its direction at ({i},{j})"
);
}
}
}
#[test]
fn dual_flex_row_primary_third_direction_is_linear() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.7;
dir_u[cache.primary.logslope] = -0.2;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[cache.primary.q] = -0.4;
dir_v[cache.primary.logslope] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
let dir_sum = &dir_u + &dir_v;
for row in 0..z.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third_u = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction u failed: {e}"));
let third_v = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction v failed: {e}"));
let third_sum = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_sum,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction sum failed: {e}"));
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"row {row}: third contraction should be linear in its direction at ({i},{j})"
);
}
}
}
}
#[test]
fn h_only_row_primary_third_direction_is_linear() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = -0.35;
dir_u[cache.primary.logslope] = 0.28;
let h_range = cache.primary.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[cache.primary.q] = 0.18;
dir_v[cache.primary.logslope] = -0.22;
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let dir_sum = &dir_u + &dir_v;
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third_u = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction u failed: {e}"));
let third_v = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction v failed: {e}"));
let third_sum = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_sum,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction sum failed: {e}"));
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"row {row}: h-only third contraction should be linear at ({i},{j})"
);
}
}
}
}
#[test]
fn w_only_row_primary_third_direction_is_linear() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let cache = family
.build_exact_eval_cache(&block_states)
.expect("exact eval cache");
let total = cache.primary.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[cache.primary.q] = 0.4;
dir_u[cache.primary.logslope] = -0.3;
let w_range = cache.primary.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[cache.primary.q] = -0.2;
dir_v[cache.primary.logslope] = 0.25;
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let dir_sum = &dir_u + &dir_v;
for row in 0..seed.len() {
let row_ctx = family
.build_row_exact_context(row, &block_states)
.unwrap_or_else(|e| panic!("row {row}: build_row_exact_context failed: {e}"));
let third_u = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_u,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction u failed: {e}"));
let third_v = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_v,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction v failed: {e}"));
let third_sum = family
.row_primary_third_contracted_recompute(
row,
&block_states,
&cache,
&row_ctx,
&dir_sum,
)
.unwrap_or_else(|e| panic!("row {row}: third contraction sum failed: {e}"));
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"row {row}: w-only third contraction should be linear at ({i},{j})"
);
}
}
}
}
#[test]
fn dual_flex_exact_outer_fourth_first_direction_is_linear() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let mut dir_w = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.08;
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
dir_v[h_range.start] = -0.03;
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
dir_w[slices.marginal.start] = 0.11;
dir_w[slices.logslope.start] = -0.09;
dir_w[h_range.start] = 0.04;
dir_w[w_range.start] = -0.05;
let dir_sum = &dir_u + &dir_v;
let fourth_u = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_w)
.expect("dual-flex fourth directional derivative u,w")
.expect("dual-flex fourth directional derivative u,w matrix");
let fourth_v = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_v, &dir_w)
.expect("dual-flex fourth directional derivative v,w")
.expect("dual-flex fourth directional derivative v,w matrix");
let fourth_sum = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&dir_sum,
&dir_w,
)
.expect("dual-flex fourth directional derivative (u+v),w")
.expect("dual-flex fourth directional derivative (u+v),w matrix");
for i in 0..total {
for j in 0..total {
let expected = fourth_u[[i, j]] + fourth_v[[i, j]];
assert!(
(fourth_sum[[i, j]] - expected).abs() < 1e-8,
"fourth directional derivative should be linear in its first direction at ({i},{j})"
);
}
}
}
#[test]
fn h_only_exact_outer_third_direction_is_linear() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let dir_sum = &dir_u + &dir_v;
let third_u = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("h-only third directional derivative u")
.expect("h-only third directional derivative u matrix");
let third_v = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_v)
.expect("h-only third directional derivative v")
.expect("h-only third directional derivative v matrix");
let third_sum = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_sum)
.expect("h-only third directional derivative sum")
.expect("h-only third directional derivative sum matrix");
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"h-only third directional derivative should be linear at ({i},{j})"
);
}
}
}
#[test]
fn w_only_exact_outer_third_direction_is_linear() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let dir_sum = &dir_u + &dir_v;
let third_u = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("w-only third directional derivative u")
.expect("w-only third directional derivative u matrix");
let third_v = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_v)
.expect("w-only third directional derivative v")
.expect("w-only third directional derivative v matrix");
let third_sum = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_sum)
.expect("w-only third directional derivative sum")
.expect("w-only third directional derivative sum matrix");
for i in 0..total {
for j in 0..total {
let expected = third_u[[i, j]] + third_v[[i, j]];
assert!(
(third_sum[[i, j]] - expected).abs() < 1e-8,
"w-only third directional derivative should be linear at ({i},{j})"
);
}
}
}
#[test]
fn h_only_exact_outer_direction_sign_rules_hold() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir = Array1::<f64>::zeros(total);
let h_range = slices.h.as_ref().expect("h slice");
dir[h_range.start] = 0.12;
if h_range.len() > 1 {
dir[h_range.start + 1] = -0.06;
}
let neg_dir = dir.mapv(|value| -value);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir)
.expect("h-only third directional derivative")
.expect("h-only third directional derivative matrix");
let third_neg = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &neg_dir)
.expect("h-only negated third directional derivative")
.expect("h-only negated third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir, &dir)
.expect("h-only fourth directional derivative")
.expect("h-only fourth directional derivative matrix");
let fourth_neg = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&neg_dir,
&neg_dir,
)
.expect("h-only doubly-negated fourth directional derivative")
.expect("h-only doubly-negated fourth directional derivative matrix");
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"h-only third directional derivative should be odd at ({i},{j})"
);
assert!(
(fourth_neg[[i, j]] - fourth[[i, j]]).abs() < 1e-8,
"h-only fourth directional derivative should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
#[test]
fn w_only_exact_outer_direction_sign_rules_hold() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir = Array1::<f64>::zeros(total);
let w_range = slices.w.as_ref().expect("w slice");
dir[w_range.start] = 0.15;
if w_range.len() > 1 {
dir[w_range.start + 1] = -0.07;
}
let neg_dir = dir.mapv(|value| -value);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir)
.expect("w-only third directional derivative")
.expect("w-only third directional derivative matrix");
let third_neg = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &neg_dir)
.expect("w-only negated third directional derivative")
.expect("w-only negated third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir, &dir)
.expect("w-only fourth directional derivative")
.expect("w-only fourth directional derivative matrix");
let fourth_neg = family
.exact_newton_joint_hessiansecond_directional_derivative(
&block_states,
&neg_dir,
&neg_dir,
)
.expect("w-only doubly-negated fourth directional derivative")
.expect("w-only doubly-negated fourth directional derivative matrix");
for i in 0..total {
for j in 0..total {
assert!(
(third_neg[[i, j]] + third[[i, j]]).abs() < 1e-8,
"w-only third directional derivative should be odd at ({i},{j})"
);
assert!(
(fourth_neg[[i, j]] - fourth[[i, j]]).abs() < 1e-8,
"w-only fourth directional derivative should be invariant under flipping both directions at ({i},{j})"
);
}
}
}
#[test]
fn h_only_exact_outer_fourth_direction_swap_is_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let h_range = slices.h.as_ref().expect("h slice");
dir_u[h_range.start] = 0.12;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.06;
}
dir_v[h_range.start] = 0.07;
if h_range.len() > 1 {
dir_v[h_range.start + 1] = 0.05;
}
let forward = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("h-only fourth directional derivative")
.expect("h-only fourth directional derivative matrix");
let swapped = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_v, &dir_u)
.expect("h-only swapped fourth directional derivative")
.expect("h-only swapped fourth directional derivative matrix");
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"h-only fourth directional derivative should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
#[test]
fn w_only_exact_outer_fourth_direction_swap_is_symmetric() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
let w_range = slices.w.as_ref().expect("w slice");
dir_u[w_range.start] = 0.15;
if w_range.len() > 1 {
dir_u[w_range.start + 1] = -0.07;
}
dir_v[w_range.start] = 0.09;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = 0.03;
}
let forward = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("w-only fourth directional derivative")
.expect("w-only fourth directional derivative matrix");
let swapped = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_v, &dir_u)
.expect("w-only swapped fourth directional derivative")
.expect("w-only swapped fourth directional derivative matrix");
for i in 0..total {
for j in 0..total {
assert!(
(forward[[i, j]] - swapped[[i, j]]).abs() < 1e-8,
"w-only fourth directional derivative should be symmetric in direction arguments at ({i},{j})"
);
}
}
}
#[test]
fn h_only_exact_outer_zero_direction_returns_zero() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_score_warp_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build score-warp block");
let score_dim = prepared
.block
.initial_beta
.as_ref()
.expect("score-warp initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..score_dim).map(|idx| 0.04 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: Some(prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let slices = block_slices(&family);
let zero = Array1::<f64>::zeros(slices.total);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &zero)
.expect("h-only third directional derivative")
.expect("h-only third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &zero, &zero)
.expect("h-only fourth directional derivative")
.expect("h-only fourth directional derivative matrix");
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"expected zero h-only third directional derivative for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"expected zero h-only fourth directional derivative for zero directions"
);
}
#[test]
fn w_only_exact_outer_zero_direction_returns_zero() {
let seed = array![-1.5, -0.5, 0.0, 0.5, 1.5];
let prepared = build_test_link_deviation_block_from_seed(
&seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let link_dim = prepared
.block
.initial_beta
.as_ref()
.expect("link initial beta")
.len();
let block_states = vec![
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(array![0.0], seed.len()),
dummy_block_state(
Array1::from_iter((0..link_dim).map(|idx| 0.05 * (idx as f64 + 1.0))),
seed.len(),
),
];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(array![0.0, 1.0, 0.0, 1.0, 0.0]),
weights: Arc::new(Array1::ones(seed.len())),
z: Arc::new(seed.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((seed.len(), 0)),
)),
score_warp: None,
link_dev: Some(prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let slices = block_slices(&family);
let zero = Array1::<f64>::zeros(slices.total);
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &zero)
.expect("w-only third directional derivative")
.expect("w-only third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &zero, &zero)
.expect("w-only fourth directional derivative")
.expect("w-only fourth directional derivative matrix");
assert!(
third.iter().all(|value| value.abs() <= 0.0),
"expected zero w-only third directional derivative for zero direction"
);
assert!(
fourth.iter().all(|value| value.abs() <= 0.0),
"expected zero w-only fourth directional derivative for zero directions"
);
}
#[test]
fn w_only_gradient_matches_loglik_finite_differences() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: None,
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let states_at = |q: f64, b: f64, bw: Array1<f64>| {
vec![
ParameterBlockState {
beta: array![q],
eta: Array1::from_elem(z.len(), q),
},
ParameterBlockState {
beta: array![b],
eta: Array1::from_elem(z.len(), b),
},
ParameterBlockState {
beta: bw,
eta: Array1::zeros(z.len()),
},
]
};
let q0 = 0.25;
let b0 = 0.6;
let block_states = states_at(q0, b0, beta_w.clone());
let eval = family.evaluate(&block_states).expect("family evaluation");
let grad_q = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton q block"),
};
let grad_b = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton b block"),
};
let grad_w0 = match &eval.blockworking_sets[2] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton w block"),
};
let fd = |which: &str, eps: f64| match which {
"q" => {
let plus = family
.log_likelihood_only(&states_at(q0 + eps, b0, beta_w.clone()))
.expect("ll plus q");
let minus = family
.log_likelihood_only(&states_at(q0 - eps, b0, beta_w.clone()))
.expect("ll minus q");
(plus - minus) / (2.0 * eps)
}
"b" => {
let plus = family
.log_likelihood_only(&states_at(q0, b0 + eps, beta_w.clone()))
.expect("ll plus b");
let minus = family
.log_likelihood_only(&states_at(q0, b0 - eps, beta_w.clone()))
.expect("ll minus b");
(plus - minus) / (2.0 * eps)
}
"w0" => {
let mut plus_w = beta_w.clone();
plus_w[0] += eps;
let mut minus_w = beta_w.clone();
minus_w[0] -= eps;
let plus = family
.log_likelihood_only(&states_at(q0, b0, plus_w))
.expect("ll plus w");
let minus = family
.log_likelihood_only(&states_at(q0, b0, minus_w))
.expect("ll minus w");
(plus - minus) / (2.0 * eps)
}
_ => panic!("unknown derivative target"),
};
let eps = 1e-5;
assert!((grad_q - fd("q", eps)).abs() < 2e-4);
assert!((grad_b - fd("b", eps)).abs() < 2e-4);
assert!((grad_w0 - fd("w0", eps)).abs() < 2e-4);
}
#[test]
fn h_only_gradient_matches_loglik_finite_differences() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
);
let states_at = |q: f64, b: f64, bh: Array1<f64>| {
vec![
ParameterBlockState {
beta: array![q],
eta: Array1::from_elem(z.len(), q),
},
ParameterBlockState {
beta: array![b],
eta: Array1::from_elem(z.len(), b),
},
ParameterBlockState {
beta: bh,
eta: Array1::zeros(z.len()),
},
]
};
let q0 = 0.25;
let b0 = 0.6;
let block_states = states_at(q0, b0, beta_h.clone());
let eval = family.evaluate(&block_states).expect("family evaluation");
let grad_q = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton q block"),
};
let grad_b = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton b block"),
};
let grad_h0 = match &eval.blockworking_sets[2] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton h block"),
};
let fd = |which: &str, eps: f64| match which {
"q" => {
let plus = family
.log_likelihood_only(&states_at(q0 + eps, b0, beta_h.clone()))
.expect("ll plus q");
let minus = family
.log_likelihood_only(&states_at(q0 - eps, b0, beta_h.clone()))
.expect("ll minus q");
(plus - minus) / (2.0 * eps)
}
"b" => {
let plus = family
.log_likelihood_only(&states_at(q0, b0 + eps, beta_h.clone()))
.expect("ll plus b");
let minus = family
.log_likelihood_only(&states_at(q0, b0 - eps, beta_h.clone()))
.expect("ll minus b");
(plus - minus) / (2.0 * eps)
}
"h0" => {
let mut plus_h = beta_h.clone();
plus_h[0] += eps;
let mut minus_h = beta_h.clone();
minus_h[0] -= eps;
let plus = family
.log_likelihood_only(&states_at(q0, b0, plus_h))
.expect("ll plus h");
let minus = family
.log_likelihood_only(&states_at(q0, b0, minus_h))
.expect("ll minus h");
(plus - minus) / (2.0 * eps)
}
_ => panic!("unknown derivative target"),
};
let eps = 1e-5;
assert!((grad_q - fd("q", eps)).abs() < 2e-4);
assert!((grad_b - fd("b", eps)).abs() < 2e-4);
assert!((grad_h0 - fd("h0", eps)).abs() < 2e-4);
}
#[test]
fn flexible_denested_gradient_matches_loglik_finite_differences() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let states_at = |q: f64, b: f64, bh: Array1<f64>, bw: Array1<f64>| {
vec![
ParameterBlockState {
beta: array![q],
eta: Array1::from_elem(z.len(), q),
},
ParameterBlockState {
beta: array![b],
eta: Array1::from_elem(z.len(), b),
},
ParameterBlockState {
beta: bh,
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: bw,
eta: Array1::zeros(z.len()),
},
]
};
let q0 = 0.25;
let b0 = 0.6;
let block_states = states_at(q0, b0, beta_h.clone(), beta_w.clone());
let eval = family.evaluate(&block_states).expect("family evaluation");
let grad_q = match &eval.blockworking_sets[0] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton q block"),
};
let grad_b = match &eval.blockworking_sets[1] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton b block"),
};
let grad_h0 = match &eval.blockworking_sets[2] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton h block"),
};
let grad_w0 = match &eval.blockworking_sets[3] {
BlockWorkingSet::ExactNewton { gradient, .. } => gradient[0],
_ => panic!("expected exact-newton w block"),
};
let fd = |which: &str, eps: f64| match which {
"q" => {
let plus = family
.log_likelihood_only(&states_at(q0 + eps, b0, beta_h.clone(), beta_w.clone()))
.expect("ll plus q");
let minus = family
.log_likelihood_only(&states_at(q0 - eps, b0, beta_h.clone(), beta_w.clone()))
.expect("ll minus q");
(plus - minus) / (2.0 * eps)
}
"b" => {
let plus = family
.log_likelihood_only(&states_at(q0, b0 + eps, beta_h.clone(), beta_w.clone()))
.expect("ll plus b");
let minus = family
.log_likelihood_only(&states_at(q0, b0 - eps, beta_h.clone(), beta_w.clone()))
.expect("ll minus b");
(plus - minus) / (2.0 * eps)
}
"h0" => {
let mut plus_h = beta_h.clone();
plus_h[0] += eps;
let mut minus_h = beta_h.clone();
minus_h[0] -= eps;
let plus = family
.log_likelihood_only(&states_at(q0, b0, plus_h, beta_w.clone()))
.expect("ll plus h");
let minus = family
.log_likelihood_only(&states_at(q0, b0, minus_h, beta_w.clone()))
.expect("ll minus h");
(plus - minus) / (2.0 * eps)
}
"w0" => {
let mut plus_w = beta_w.clone();
plus_w[0] += eps;
let mut minus_w = beta_w.clone();
minus_w[0] -= eps;
let plus = family
.log_likelihood_only(&states_at(q0, b0, beta_h.clone(), plus_w))
.expect("ll plus w");
let minus = family
.log_likelihood_only(&states_at(q0, b0, beta_h.clone(), minus_w))
.expect("ll minus w");
(plus - minus) / (2.0 * eps)
}
_ => panic!("unknown derivative target"),
};
let eps = 1e-5;
assert!((grad_q - fd("q", eps)).abs() < 2e-4);
assert!((grad_b - fd("b", eps)).abs() < 2e-4);
assert!((grad_h0 - fd("h0", eps)).abs() < 2e-4);
assert!((grad_w0 - fd("w0", eps)).abs() < 2e-4);
}
#[test]
fn flexible_exact_outer_directional_derivatives_are_present_and_finite() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
);
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: beta_h.clone(),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: beta_w.clone(),
eta: Array1::zeros(z.len()),
},
];
let slices = block_slices(&family);
let total = slices.total;
let mut dir_u = Array1::<f64>::zeros(total);
let mut dir_v = Array1::<f64>::zeros(total);
dir_u[slices.marginal.start] = 0.7;
dir_u[slices.logslope.start] = -0.2;
if let Some(h_range) = slices.h.as_ref() {
dir_u[h_range.start] = 0.1;
if h_range.len() > 1 {
dir_u[h_range.start + 1] = -0.05;
}
}
if let Some(w_range) = slices.w.as_ref() {
dir_u[w_range.start] = 0.08;
}
dir_v[slices.marginal.start] = -0.4;
dir_v[slices.logslope.start] = 0.3;
if let Some(h_range) = slices.h.as_ref() {
dir_v[h_range.start] = -0.03;
}
if let Some(w_range) = slices.w.as_ref() {
dir_v[w_range.start] = 0.06;
if w_range.len() > 1 {
dir_v[w_range.start + 1] = -0.02;
}
}
let third = family
.exact_newton_joint_hessian_directional_derivative(&block_states, &dir_u)
.expect("flex third directional derivative")
.expect("flex third directional derivative matrix");
let fourth = family
.exact_newton_joint_hessiansecond_directional_derivative(&block_states, &dir_u, &dir_v)
.expect("flex fourth directional derivative")
.expect("flex fourth directional derivative matrix");
assert_eq!(third.dim(), (total, total));
assert_eq!(fourth.dim(), (total, total));
assert!(third.iter().all(|value| value.is_finite()));
assert!(fourth.iter().all(|value| value.is_finite()));
let max_abs_third = third
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
let max_abs_fourth = fourth
.iter()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
assert!(
max_abs_third > 1e-10,
"expected nonzero dual-flex third directional derivative"
);
assert!(
max_abs_fourth > 1e-10,
"expected nonzero dual-flex fourth directional derivative"
);
for i in 0..total {
for j in 0..i {
assert!((third[[i, j]] - third[[j, i]]).abs() < 1e-8);
assert!((fourth[[i, j]] - fourth[[j, i]]).abs() < 1e-8);
}
}
}
#[test]
fn flexible_evaluate_block_diagonals_match_joint_exact_oracle() {
let z = array![-1.1, -0.25, 0.35, 1.2];
let y = array![0.0, 1.0, 0.0, 1.0];
let weights = array![1.0, 0.8, 1.3, 0.7];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-1.8, -0.7, 0.1, 0.9, 1.7];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let marginal_x = array![[1.0, -0.4], [1.0, 0.2], [1.0, 0.7], [1.0, 1.1]];
let logslope_x = array![[1.0, 0.3], [1.0, -0.6], [1.0, 0.5], [1.0, -1.0]];
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
marginal_x.clone(),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
logslope_x.clone(),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let beta_m = array![0.18, -0.07];
let beta_g = array![0.42, 0.05];
let beta_h = Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.006 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| -0.004 * (idx as f64 + 1.0)),
);
let block_states = vec![
ParameterBlockState {
beta: beta_m.clone(),
eta: marginal_x.dot(&beta_m),
},
ParameterBlockState {
beta: beta_g.clone(),
eta: logslope_x.dot(&beta_g),
},
ParameterBlockState {
beta: beta_h,
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: beta_w,
eta: Array1::zeros(z.len()),
},
];
let eval = family.evaluate(&block_states).expect("block evaluation");
let joint_hessian = family
.exact_newton_joint_hessian(&block_states)
.expect("joint hessian result")
.expect("dense joint hessian");
let joint_gradient = family
.exact_newton_joint_gradient_evaluation(&block_states, &[])
.expect("joint gradient result")
.expect("joint gradient");
let slices = block_slices(&family);
let ranges = vec![
slices.marginal.clone(),
slices.logslope.clone(),
slices.h.clone().expect("score-warp block"),
slices.w.clone().expect("link-deviation block"),
];
assert!((eval.log_likelihood - joint_gradient.log_likelihood).abs() < 2e-12);
assert!(
(eval.log_likelihood
- family
.log_likelihood_only(&block_states)
.expect("log likelihood only"))
.abs()
< 2e-12
);
for (block_idx, range) in ranges.iter().enumerate() {
let (gradient, hessian) = match &eval.blockworking_sets[block_idx] {
BlockWorkingSet::ExactNewton { gradient, hessian } => (gradient, hessian),
_ => panic!("expected exact-newton block {block_idx}"),
};
let expected_gradient = joint_gradient.gradient.slice(s![range.clone()]);
for idx in 0..gradient.len() {
assert!(
(gradient[idx] - expected_gradient[idx]).abs() < 2e-10,
"gradient mismatch block {block_idx} idx {idx}: got {} expected {}",
gradient[idx],
expected_gradient[idx]
);
}
let dense_hessian = match hessian {
SymmetricMatrix::Dense(h) => h,
_ => panic!("expected dense hessian block {block_idx}"),
};
let expected_hessian = joint_hessian.slice(s![range.clone(), range.clone()]);
for i in 0..dense_hessian.nrows() {
for j in 0..dense_hessian.ncols() {
assert!(
(dense_hessian[[i, j]] - expected_hessian[[i, j]]).abs() < 2e-9,
"hessian mismatch block {block_idx} ({i},{j}): got {} expected {}",
dense_hessian[[i, j]],
expected_hessian[[i, j]]
);
}
}
}
}
#[test]
fn latent_z_normalization_accepts_finite_sample_gaussian_scores() {
let z = array![
-0.85, -0.12, 0.31, 1.04, -1.21, 0.56, 0.77, -0.44, 1.33, -0.09, 0.28, -0.67
];
let weights = Array1::from_elem(12, 1.0);
let (standardized, normalization) = standardize_latent_z_with_policy(
&z,
&weights,
"bernoulli-marginal-slope",
&LatentZPolicy::exploratory_fit_weighted(),
)
.expect("normalize z");
let replayed = normalization
.apply(&z, "bernoulli-marginal-slope replay")
.expect("replay normalized z");
let mean = standardized.sum() / standardized.len() as f64;
let var = standardized.iter().map(|v| v * v).sum::<f64>() / standardized.len() as f64;
assert_eq!(replayed, standardized);
assert!(mean.abs() < 1e-12);
assert!((var.sqrt() - 1.0).abs() < 1e-12);
}
#[test]
fn latent_z_normalization_rejects_extreme_non_gaussian_scores() {
let z = array![0.0, 0.0, 0.0, 0.0, 10.0, -10.0];
let weights = Array1::from_elem(6, 1.0);
let strict_policy = LatentZPolicy {
check_mode: LatentZCheckMode::Strict,
..LatentZPolicy::default()
};
let err = standardize_latent_z_with_policy(
&z,
&weights,
"bernoulli-marginal-slope",
&strict_policy,
)
.expect_err("expected non-gaussian rejection");
assert!(err.contains("approximately latent N(0,1)"));
}
#[test]
fn auto_latent_measure_uses_empirical_grid_for_bad_normal_diagnostics() {
let z = array![0.0, 0.0, 0.0, 0.0, 10.0, -10.0];
let weights = Array1::from_elem(6, 1.0);
let policy = LatentZPolicy {
check_mode: LatentZCheckMode::Off,
normalization: LatentZNormalizationMode::None,
latent_measure: LatentMeasureSpec::Auto { grid_size: 5 },
..LatentZPolicy::default()
};
let measure = build_latent_measure_with_geometry(&z, &weights, &policy, None, &[])
.expect("auto latent measure");
match measure {
LatentMeasureKind::GlobalEmpirical { nodes, weights } => {
assert_eq!(nodes.len(), 5);
assert_eq!(weights.len(), 5);
assert!((weights.iter().sum::<f64>() - 1.0).abs() <= 1e-12);
}
LatentMeasureKind::StandardNormal => {
panic!("bad normal diagnostics must select empirical latent measure")
}
LatentMeasureKind::LocalEmpirical { .. } => {
panic!("auto latent measure without geometry must select global empirical")
}
}
}
#[test]
fn empirical_intercept_calibrates_marginal_probability() {
let nodes = vec![-2.0, -0.25, 0.5, 3.0];
let weights = vec![0.2, 0.3, 0.1, 0.4];
let target_q = -0.35;
let target_mu = normal_cdf(target_q);
let slope = 0.8;
let scale = 0.9;
let intercept = empirical_intercept_from_marginal(
target_mu, target_q, slope, scale, &nodes, &weights, None,
)
.expect("empirical intercept");
let calibrated = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(intercept + scale * slope * node))
.sum::<f64>();
assert!((calibrated - target_mu).abs() <= 1e-10);
}
#[test]
fn skewed_rigid_empirical_grid_calibrates_marginal_probability() {
let z = array![-1.15, -1.05, -0.95, -0.8, -0.65, -0.45, 0.1, 0.9, 2.4, 4.7];
let weights = array![1.0, 1.0, 1.0, 0.9, 0.9, 0.8, 0.5, 0.35, 0.2, 0.1];
let grid = build_empirical_z_grid(&z, &weights, 7, "test skewed grid").expect("grid");
let target_q = 0.25;
let target_mu = normal_cdf(target_q);
let slope = 1.35;
let scale = 0.82;
let intercept = empirical_intercept_from_marginal(
target_mu,
target_q,
slope,
scale,
&grid.nodes,
&grid.weights,
None,
)
.expect("empirical intercept");
let calibrated = grid
.nodes
.iter()
.zip(grid.weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(intercept + scale * slope * node))
.sum::<f64>();
assert!((calibrated - target_mu).abs() <= 1e-10);
}
#[test]
fn gaussian_rigid_intercept_miscalibrates_skewed_empirical_law() {
let nodes = vec![-0.95, -0.7, -0.45, -0.2, 0.4, 1.3, 3.1];
let weights = vec![0.28, 0.22, 0.17, 0.13, 0.1, 0.07, 0.03];
let target_q = -0.15;
let target_mu = normal_cdf(target_q);
let slope = 1.4;
let scale = 0.9;
let gaussian_intercept = rigid_intercept_from_marginal(target_q, slope, scale);
let gaussian_mu = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(gaussian_intercept + scale * slope * node))
.sum::<f64>();
assert!(
(gaussian_mu - target_mu).abs() > 1e-3,
"skewed empirical law should not be calibrated by Gaussian identity"
);
let empirical_intercept = empirical_intercept_from_marginal(
target_mu, target_q, slope, scale, &nodes, &weights, None,
)
.expect("empirical intercept");
let empirical_mu = nodes
.iter()
.zip(weights.iter())
.map(|(&node, &weight)| weight * normal_cdf(empirical_intercept + scale * slope * node))
.sum::<f64>();
assert!((empirical_mu - target_mu).abs() <= 1e-10);
}
#[test]
fn auto_latent_measure_preserves_standard_normal_fast_path() {
let n = 2001usize;
let z = Array1::from_iter((0..n).map(|idx| {
let p = (idx as f64 + 0.5) / n as f64;
standard_normal_quantile(p).expect("normal quantile")
}));
let weights = Array1::ones(n);
let policy = LatentZPolicy {
latent_measure: LatentMeasureSpec::Auto { grid_size: 17 },
..LatentZPolicy::default()
};
let measure =
build_latent_measure_with_geometry(&z, &weights, &policy, None, &[]).expect("measure");
assert!(matches!(measure, LatentMeasureKind::StandardNormal));
let slope = 0.7;
let scale = 0.85;
let target_q = 0.4;
assert_eq!(
rigid_intercept_from_marginal(target_q, slope, scale),
target_q * (1.0 + (scale * slope).powi(2)).sqrt()
);
}
#[test]
fn score_warp_empirical_anchor_zeroes_empirical_intercept_and_slope_components() {
let z = array![-1.4, -0.9, -0.35, 0.1, 0.55, 1.2, 2.8];
let weights = array![1.0, 0.8, 1.2, 0.7, 0.9, 0.6, 0.25];
let prepared = build_score_warp_deviation_block_from_seed_empirical_anchor(
&z,
&weights,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("empirical score warp");
let design = prepared.runtime.design(&z).expect("score warp design");
let beta = Array1::from_iter((0..design.ncols()).map(|idx| 0.1 * (idx as f64 + 1.0)));
let h = design.dot(&beta);
let total_weight = weights.sum();
let empirical_mean = h
.iter()
.zip(weights.iter())
.map(|(&value, &weight)| weight * value)
.sum::<f64>()
/ total_weight;
let empirical_slope = h
.iter()
.zip(z.iter())
.zip(weights.iter())
.map(|((&value, &z_value), &weight)| weight * z_value * value)
.sum::<f64>()
/ total_weight;
assert!(empirical_mean.abs() <= 1e-10);
assert!(empirical_slope.abs() <= 1e-10);
}
#[test]
fn flexible_family_exposes_exact_newton_workspaces() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let score_prepared = build_score_warp_deviation_block_from_seed(
&z,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("score warp block");
let link_seed = array![-2.0, -0.5, 0.0, 0.5, 2.0];
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 4,
..DeviationBlockConfig::default()
},
)
.expect("link block");
let family =
BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..score_prepared.block.design.ncols()).map(|idx| 0.015 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: Array1::from_iter(
(0..link_prepared.block.design.ncols()).map(|idx| 0.01 * (idx as f64 + 1.0)),
),
eta: Array1::zeros(z.len()),
},
];
let specs = vec![
dummy_blockspec(1, z.len()),
dummy_blockspec(1, z.len()),
dummy_blockspec(score_prepared.block.design.ncols(), z.len()),
dummy_blockspec(link_prepared.block.design.ncols(), z.len()),
];
let derivative_blocks = vec![Vec::new(), Vec::new(), Vec::new(), Vec::new()];
assert!(
family
.exact_newton_joint_hessian_workspace(&block_states, &specs)
.expect("flex hessian workspace")
.is_some()
);
assert!(
family
.exact_newton_joint_psi_workspace(&block_states, &specs, &derivative_blocks)
.expect("flex psi workspace")
.is_some()
);
}
#[test]
fn sigma_exact_joint_psi_terms_returns_analytic_terms() {
let z = array![-0.8, 0.2, 1.1];
let y = array![0.0, 1.0, 1.0];
let weights = array![1.0, 0.7, 1.3];
let sigma = 0.7;
let make_family =
|sigma: f64| BernoulliMarginalSlopeFamily {
y: Arc::new(y.clone()),
weights: Arc::new(weights.clone()),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: Some(sigma),
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
array![[1.0], [1.0], [1.0]],
)),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let family = make_family(sigma);
let block_states = vec![
ParameterBlockState {
beta: array![0.25],
eta: Array1::from_elem(z.len(), 0.25),
},
ParameterBlockState {
beta: array![0.6],
eta: Array1::from_elem(z.len(), 0.6),
},
];
let specs = vec![dummy_blockspec(1, z.len()), dummy_blockspec(1, z.len())];
let terms = family
.sigma_exact_joint_psi_terms(&block_states, &specs)
.expect("analytic sigma psi terms")
.expect("sigma terms present");
assert!(terms.objective_psi.is_finite());
assert_eq!(terms.score_psi.len(), 2);
assert!(terms.score_psi.iter().all(|value| value.is_finite()));
assert_eq!(
terms
.hessian_psi_operator
.as_ref()
.expect("sigma Hessian operator")
.to_dense()
.dim(),
(2, 2)
);
let second = family
.sigma_exact_joint_psisecond_order_terms(&block_states)
.expect("analytic second sigma terms")
.expect("second sigma terms present");
assert!(second.objective_psi_psi.is_finite());
assert_eq!(second.score_psi_psi.len(), 2);
let drift = family
.sigma_exact_joint_psihessian_directional_derivative(&block_states, &array![0.1, -0.2])
.expect("analytic sigma Hessian directional derivative")
.expect("sigma drift present");
assert_eq!(drift.dim(), (2, 2));
assert!(drift.iter().all(|value| value.is_finite()));
let tau = sigma.ln();
let eps = 1e-5;
let ll_plus = make_family((tau + eps).exp())
.log_likelihood_only(&block_states)
.expect("ll plus sigma");
let ll_minus = make_family((tau - eps).exp())
.log_likelihood_only(&block_states)
.expect("ll minus sigma");
let objective_fd = -(ll_plus - ll_minus) / (2.0 * eps);
assert!((terms.objective_psi - objective_fd).abs() < 1e-5);
}
fn make_block_psi_test_family(n: usize) -> BernoulliMarginalSlopeFamily {
let y: Array1<f64> =
Array1::from_iter((0..n).map(|i| if (i * 31 + 7) % 5 >= 3 { 1.0 } else { 0.0 }));
let weights: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.5 + ((i * 13 + 4) % 7) as f64 * 0.1));
let z: Array1<f64> = Array1::from_iter(
(0..n).map(|i| -1.5 + 3.0 * (((i * 17 + 5) % n) as f64 + 0.5) / (n as f64)),
);
let marginal_design = Array2::from_shape_fn((n, 1), |(i, _)| {
0.3 + 0.4 * (((i * 29 + 11) % n) as f64) / (n as f64)
});
let logslope_design = Array2::from_shape_fn((n, 1), |(i, _)| {
0.2 + 0.5 * (((i * 37 + 9) % n) as f64) / (n as f64)
});
BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: None,
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
marginal_design,
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
logslope_design,
)),
score_warp: None,
link_dev: None,
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
}
}
fn block_psi_test_block_states(
family: &BernoulliMarginalSlopeFamily,
m_beta: f64,
g_beta: f64,
) -> Vec<ParameterBlockState> {
let m_design = family.marginal_design.to_dense().to_owned();
let g_design = family.logslope_design.to_dense().to_owned();
let m_eta = m_design.dot(&array![m_beta]);
let g_eta = g_design.dot(&array![g_beta]);
vec![
ParameterBlockState {
beta: array![m_beta],
eta: m_eta,
},
ParameterBlockState {
beta: array![g_beta],
eta: g_eta,
},
]
}
fn block_psi_test_marginal_derivative_blocks(
n: usize,
) -> Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>> {
let x_psi = Array2::from_shape_fn((n, 1), |(i, _)| {
0.4 + 0.3 * (((i * 41 + 13) % n) as f64) / (n as f64)
});
vec![
vec![crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
x_psi,
Array2::zeros((1, 1)),
None,
None,
None,
None,
)],
Vec::new(),
]
}
fn block_psi_test_dual_derivative_blocks(
n: usize,
) -> Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>> {
let x_psi_m = Array2::from_shape_fn((n, 1), |(i, _)| {
0.4 + 0.3 * (((i * 41 + 13) % n) as f64) / (n as f64)
});
let x_psi_g = Array2::from_shape_fn((n, 1), |(i, _)| {
0.2 + 0.5 * (((i * 43 + 17) % n) as f64) / (n as f64)
});
vec![
vec![crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
x_psi_m,
Array2::zeros((1, 1)),
None,
None,
None,
None,
)],
vec![crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
x_psi_g,
Array2::zeros((1, 1)),
None,
None,
None,
None,
)],
]
}
#[test]
fn bernoulli_psi_terms_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let baseline = family
.exact_newton_joint_psi_terms_from_cache(&states, &derivative_blocks, 0, &cache)
.expect("baseline psi terms")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_psi_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let obj_rel = ((with_full.objective_psi - baseline.objective_psi)
/ baseline.objective_psi.abs().max(1.0))
.abs();
assert!(obj_rel < 1e-12, "objective_psi rel {}", obj_rel);
let score_rel = rel_diff_array1(&with_full.score_psi, &baseline.score_psi);
assert!(score_rel < 1e-12, "score_psi rel {}", score_rel);
}
#[test]
fn bernoulli_psi_terms_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_psi_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m,
weight_scale: 1.0,
seed: 0,
}));
let raw = family
.exact_newton_joint_psi_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp_obj = factor * raw.objective_psi;
let obj_rel = ((scaled.objective_psi - exp_obj) / exp_obj.abs().max(1.0)).abs();
assert!(obj_rel < 1e-12, "objective_psi rel {}", obj_rel);
let exp_score = &raw.score_psi * factor;
let score_rel = rel_diff_array1(&scaled.score_psi, &exp_score);
assert!(score_rel < 1e-12, "score_psi rel {}", score_rel);
}
#[test]
fn bernoulli_psi_second_order_terms_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_dual_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let baseline = family
.exact_newton_joint_psisecond_order_terms_from_cache(
&states,
&derivative_blocks,
0,
1,
&cache,
)
.expect("baseline psi second-order")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
1,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let obj_rel = ((with_full.objective_psi_psi - baseline.objective_psi_psi)
/ baseline.objective_psi_psi.abs().max(1.0))
.abs();
assert!(obj_rel < 1e-12, "objective rel {}", obj_rel);
let score_rel = rel_diff_array1(&with_full.score_psi_psi, &baseline.score_psi_psi);
assert!(score_rel < 1e-12, "score rel {}", score_rel);
}
#[test]
fn bernoulli_psi_second_order_terms_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_dual_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
1,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m,
weight_scale: 1.0,
seed: 0,
}));
let raw = family
.exact_newton_joint_psisecond_order_terms_from_cache_with_options(
&states,
&derivative_blocks,
0,
1,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp_obj = factor * raw.objective_psi_psi;
let obj_rel = ((scaled.objective_psi_psi - exp_obj) / exp_obj.abs().max(1.0)).abs();
assert!(obj_rel < 1e-12, "objective rel {}", obj_rel);
let exp_score = &raw.score_psi_psi * factor;
let score_rel = rel_diff_array1(&scaled.score_psi_psi, &exp_score);
assert!(score_rel < 1e-12, "score rel {}", score_rel);
}
#[test]
fn bernoulli_psihessian_directional_derivative_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let baseline = family
.exact_newton_joint_psihessian_directional_derivative_from_cache(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
)
.expect("baseline")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let rel = rel_diff_array2(&with_full, &baseline);
assert!(rel < 1e-12, "drift rel {}", rel);
}
#[test]
fn bernoulli_psihessian_directional_derivative_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m,
weight_scale: 1.0,
seed: 0,
}));
let raw = family
.exact_newton_joint_psihessian_directional_derivative_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp = &raw * factor;
let rel = rel_diff_array2(&scaled, &exp);
assert!(rel < 1e-12, "drift rel {}", rel);
}
#[test]
fn bernoulli_psihessian_operator_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let baseline = family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("baseline operator")
.expect("some");
let baseline_dense = baseline.to_dense();
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let with_full_dense = with_full.to_dense();
let rel = rel_diff_array2(&with_full_dense, &baseline_dense);
assert!(rel < 1e-12, "operator drift rel {}", rel);
}
#[test]
fn bernoulli_psihessian_operator_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let derivative_blocks = block_psi_test_marginal_derivative_blocks(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let scaled_dense = scaled.to_dense();
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m,
weight_scale: 1.0,
seed: 0,
}));
let raw = family
.exact_newton_joint_psihessian_directional_derivative_operator_from_cache_with_options(
&states,
&derivative_blocks,
0,
&d_beta_flat,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let raw_dense = raw.to_dense();
let factor = n as f64 / m as f64;
let exp = &raw_dense * factor;
let rel = rel_diff_array2(&scaled_dense, &exp);
assert!(rel < 1e-12, "operator drift rel {}", rel);
}
#[test]
fn bernoulli_jointhessian_directional_derivative_from_cache_subsample_full_equals_unsampled() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let baseline = family
.exact_newton_joint_hessian_directional_derivative_from_cache(
&states,
&d_beta_flat,
&cache,
)
.expect("baseline")
.expect("some");
let mut opts_full = BlockwiseFitOptions::default();
opts_full.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..n).collect(),
n,
0xDEADBEEF,
)));
let with_full = family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&opts_full,
)
.expect("with full")
.expect("some");
let rel = rel_diff_array2(&with_full, &baseline);
assert!(rel < 1e-12, "joint Hessian dH drift rel {}", rel);
}
fn make_flex_hvp_cache_test_family(
n: usize,
) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
let score_prepared = build_score_warp_deviation_block_from_seed(
&score_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build score warp block");
let link_prepared = build_test_link_deviation_block_from_seed(
&link_seed,
&DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
},
)
.expect("build link deviation block");
let y: Array1<f64> =
Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
let weights: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
let z: Array1<f64> =
Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
1.0
} else {
-0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
}
});
let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
1.0
} else {
0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
}
});
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: Some(0.15),
base_link: bernoulli_marginal_slope_probit_link(),
marginal_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
marginal_x.clone(),
)),
logslope_design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
logslope_x.clone(),
)),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: crate::resource::ResourcePolicy::default_library(),
cell_moment_lru: new_cell_moment_lru_cache(
&crate::resource::ResourcePolicy::default_library(),
),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
};
let beta_m = array![0.12, -0.04];
let beta_g = array![0.35, 0.03];
let beta_h = Array1::from_iter(
(0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
);
let states = vec![
ParameterBlockState {
beta: beta_m.clone(),
eta: marginal_x.dot(&beta_m),
},
ParameterBlockState {
beta: beta_g.clone(),
eta: logslope_x.dot(&beta_g),
},
ParameterBlockState {
beta: beta_h,
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: beta_w,
eta: Array1::zeros(z.len()),
},
];
(family, states)
}
#[test]
fn bernoulli_flex_paired_subsample_ll_delta_sign_matches_full_ll() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let (family, old_states) = make_flex_hvp_cache_test_family(96);
let mut trial_states = old_states.clone();
trial_states[0].beta[0] += 0.015;
trial_states[0].eta += 0.015;
trial_states[1].beta[1] -= 0.01;
let logslope_col =
Array1::from_iter((0..96).map(|i| 0.3 - 0.6 * ((i * 23 + 11) % 96) as f64 / 96.0));
trial_states[1].eta.scaled_add(-0.01, &logslope_col);
let full_old = family
.log_likelihood_only(&old_states)
.expect("full old ll");
let full_trial = family
.log_likelihood_only(&trial_states)
.expect("full trial ll");
let mut opts = BlockwiseFitOptions::default();
opts.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
(0..96).step_by(2).collect(),
96,
0x5EED5EED,
)));
let sub_old = family
.log_likelihood_only_with_options(&old_states, &opts)
.expect("subsample old ll");
let sub_trial = family
.log_likelihood_only_with_options(&trial_states, &opts)
.expect("subsample trial ll");
let full_delta = full_trial - full_old;
let sub_delta = sub_trial - sub_old;
assert!(
full_delta.abs() > 1e-8,
"synthetic beta-pair should produce a non-degenerate full-LL delta: {full_delta}"
);
assert_eq!(
full_delta.is_sign_positive(),
sub_delta.is_sign_positive(),
"paired subsample LL delta sign ({sub_delta}) should match full LL delta sign ({full_delta})"
);
}
#[test]
fn bernoulli_flex_hvp_cache_matches_uncached_path_small_case() {
let (family, states) = make_flex_hvp_cache_test_family(12);
let mut cached = family
.build_exact_eval_cache(&states)
.expect("cached exact eval cache");
cached.row_primary_hessians = family
.build_row_primary_hessian_cache(&states, &cached)
.expect("row Hessian cache");
let uncached = BernoulliMarginalSlopeExactEvalCache {
slices: cached.slices.clone(),
primary: cached.primary.clone(),
row_contexts: cached.row_contexts.clone(),
row_cell_moments: None,
row_primary_hessians: None,
};
let direction =
Array1::from_iter((0..cached.slices.total).map(|idx| 0.02 * ((idx % 5) as f64 - 2.0)));
let hv_cached = family
.exact_newton_joint_hessian_matvec_from_cache(&direction, &states, &cached)
.expect("cached Hv");
let hv_uncached = family
.exact_newton_joint_hessian_matvec_from_cache(&direction, &states, &uncached)
.expect("uncached Hv");
let rel = rel_diff_array1(&hv_cached, &hv_uncached);
assert!(rel < 5e-11, "cached Hv drift rel {rel}");
let diag_cached = family
.exact_newton_joint_hessian_diagonal_from_cache(&states, &cached)
.expect("cached diag");
let diag_uncached = family
.exact_newton_joint_hessian_diagonal_from_cache(&states, &uncached)
.expect("uncached diag");
let rel_diag = rel_diff_array1(&diag_cached, &diag_uncached);
assert!(rel_diag < 5e-11, "cached diag drift rel {rel_diag}");
}
#[test]
#[ignore]
fn bernoulli_flex_hvp_cache_timing_biobank_shape_pattern() {
let (family, states) = make_flex_hvp_cache_test_family(96);
let mut cached = family
.build_exact_eval_cache(&states)
.expect("cached exact eval cache");
cached.row_primary_hessians = family
.build_row_primary_hessian_cache(&states, &cached)
.expect("row Hessian cache");
let uncached = BernoulliMarginalSlopeExactEvalCache {
slices: cached.slices.clone(),
primary: cached.primary.clone(),
row_contexts: cached.row_contexts.clone(),
row_cell_moments: None,
row_primary_hessians: None,
};
let directions: Vec<_> = (0..4)
.map(|rep| {
Array1::from_iter(
(0..cached.slices.total)
.map(|idx| 0.01 * (((idx * 13 + rep * 7) % 11) as f64 - 5.0)),
)
})
.collect();
let start_uncached = std::time::Instant::now();
for direction in &directions {
let _ = family
.exact_newton_joint_hessian_matvec_from_cache(direction, &states, &uncached)
.expect("uncached Hv");
}
let uncached_elapsed = start_uncached.elapsed();
let start_cached = std::time::Instant::now();
for direction in &directions {
let _ = family
.exact_newton_joint_hessian_matvec_from_cache(direction, &states, &cached)
.expect("cached Hv");
}
let cached_elapsed = start_cached.elapsed();
eprintln!("flex Hv cache timing: uncached={uncached_elapsed:?} cached={cached_elapsed:?}");
assert!(
cached_elapsed < uncached_elapsed,
"expected cached Hv loop to beat uncached: cached={cached_elapsed:?} uncached={uncached_elapsed:?}"
);
}
#[test]
fn bernoulli_jointhessian_directional_operator_matches_dense_small_case() {
let n = 17usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let dense = family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("dense dH")
.expect("dH present");
let operator = family
.exact_newton_joint_hessian_directional_derivative_operator_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("operator dH")
.expect("dH operator present");
let rel = rel_diff_array2(&operator.to_dense(), &dense);
assert!(rel < 1e-12, "operator dH rel {}", rel);
}
#[test]
fn bernoulli_jointhessian_second_directional_operator_matches_dense_small_case() {
let n = 17usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_u = Array1::<f64>::zeros(slices.total);
d_beta_u[slices.marginal.start] = 0.05;
d_beta_u[slices.logslope.start] = -0.04;
let mut d_beta_v = Array1::<f64>::zeros(slices.total);
d_beta_v[slices.marginal.start] = -0.03;
d_beta_v[slices.logslope.start] = 0.02;
let dense = family
.exact_newton_joint_hessiansecond_directional_derivative_from_cache_with_options(
&states,
&d_beta_u,
&d_beta_v,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("dense d2H")
.expect("d2H present");
let operator = family
.exact_newton_joint_hessiansecond_directional_derivative_operator_from_cache_with_options(
&states,
&d_beta_u,
&d_beta_v,
&cache,
&BlockwiseFitOptions::default(),
)
.expect("operator d2H")
.expect("d2H operator present");
let rel = rel_diff_array2(&operator.to_dense(), &dense);
assert!(rel < 1e-12, "operator d2H rel {}", rel);
}
#[test]
fn bernoulli_large_scale_outer_derivatives_keep_analytic_hessian_route() {
let n = 50_001usize;
let family = make_block_psi_test_family(n);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let options = BlockwiseFitOptions::default();
let (gradient, hessian) =
crate::custom_family::custom_family_outer_derivatives(&family, &specs, &options);
assert_eq!(
gradient,
crate::solver::outer_strategy::Derivative::Analytic
);
assert_eq!(hessian, crate::solver::outer_strategy::DeclaredHessianForm::Either);
}
#[test]
fn bernoulli_jointhessian_directional_derivative_from_cache_subsample_half_scales_correctly() {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let n = 200usize;
let family = make_block_psi_test_family(n);
let states = block_psi_test_block_states(&family, 0.15, 0.25);
let cache = family
.build_exact_eval_cache(&states)
.expect("exact eval cache");
let slices = &cache.slices;
let mut d_beta_flat = Array1::<f64>::zeros(slices.total);
d_beta_flat[slices.marginal.start] = 0.05;
d_beta_flat[slices.logslope.start] = -0.04;
let even_mask: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
let m = even_mask.len();
let mut opts_half = BlockwiseFitOptions::default();
opts_half.outer_score_subsample = Some(Arc::new(OuterScoreSubsample::new(
even_mask.clone(),
n,
0xCAFE,
)));
let scaled = family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&opts_half,
)
.expect("scaled")
.expect("some");
let mut opts_raw = BlockwiseFitOptions::default();
opts_raw.outer_score_subsample = Some(Arc::new(OuterScoreSubsample {
mask: Arc::new(even_mask),
n_full: m,
weight_scale: 1.0,
seed: 0,
}));
let raw = family
.exact_newton_joint_hessian_directional_derivative_from_cache_with_options(
&states,
&d_beta_flat,
&cache,
&opts_raw,
)
.expect("raw")
.expect("some");
let factor = n as f64 / m as f64;
let exp = &raw * factor;
let rel = rel_diff_array2(&scaled, &exp);
assert!(rel < 1e-12, "joint Hessian dH HT rel {}", rel);
}
#[test]
#[ignore]
fn margslope_sigma_psi_scaling_law() {
use crate::families::marginal_slope_shared::{OuterScoreSubsample, auto_outer_subsample_k};
use std::time::Instant;
let ns: Vec<usize> = vec![5_000, 10_000, 25_000, 50_000, 100_000, 200_000, 320_000];
let per_call_budget = 80.0_f64;
const REPS: usize = 3;
eprintln!("\n[MS-SCALING] header n full_s subsample_s K speedup");
let mut full_pts: Vec<(f64, f64)> = Vec::new();
let mut sub_pts: Vec<(f64, f64)> = Vec::new();
for &n in &ns {
let family = make_sigma_aware_test_family(n);
let states = rigid_block_states(&family, 0.3, 0.4);
let specs = vec![dummy_blockspec(1, n), dummy_blockspec(1, n)];
let opts_full = BlockwiseFitOptions::default();
let mut full_samples: Vec<f64> = Vec::with_capacity(REPS);
for _ in 0..REPS {
let t0 = Instant::now();
let _ = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_full)
.expect("sigma terms full")
.expect("some");
full_samples.push(t0.elapsed().as_secs_f64());
}
full_samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let t_full = full_samples[full_samples.len() / 2];
let k = auto_outer_subsample_k(n);
let mask: Vec<usize> = (0..n).step_by((n / k).max(1)).take(k).collect();
let mut opts_sub = BlockwiseFitOptions::default();
opts_sub.outer_score_subsample =
Some(Arc::new(OuterScoreSubsample::new(mask, n, 0xC0FFEE_5EED)));
let mut sub_samples: Vec<f64> = Vec::with_capacity(REPS);
for _ in 0..REPS {
let t0 = Instant::now();
let _ = family
.sigma_exact_joint_psi_terms_with_options(&states, &specs, &opts_sub)
.expect("sigma terms sub")
.expect("some");
sub_samples.push(t0.elapsed().as_secs_f64());
}
sub_samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let t_sub = sub_samples[sub_samples.len() / 2];
let speedup = if t_sub > 0.0 {
t_full / t_sub
} else {
f64::NAN
};
eprintln!(
"[MS-SCALING] row n={n} full_s={t_full:.5} subsample_s={t_sub:.5} K={k} speedup={speedup:.2}x"
);
full_pts.push((n as f64, t_full));
sub_pts.push((n as f64, t_sub));
}
eprintln!();
report_power_law(
"[MS-SCALING-FULL]",
&full_pts,
&[("n=320k", 320_000.0), ("n=1M", 1_000_000.0)],
per_call_budget,
);
eprintln!();
report_power_law(
"[MS-SCALING-SUBSAMPLE]",
&sub_pts,
&[("n=320k", 320_000.0), ("n=1M", 1_000_000.0)],
per_call_budget,
);
}
fn report_power_law(
tag: &str,
points: &[(f64, f64)],
extrapolate: &[(&str, f64)],
budget_y: f64,
) {
if points.len() < 3 {
eprintln!("{tag} INSUFFICIENT DATA: {} points (need ≥3)", points.len());
return;
}
let logs: Vec<(f64, f64)> = points.iter().map(|(x, y)| (x.ln(), y.ln())).collect();
let n = logs.len() as f64;
let sx: f64 = logs.iter().map(|(x, _)| x).sum();
let sy: f64 = logs.iter().map(|(_, y)| y).sum();
let sxx: f64 = logs.iter().map(|(x, _)| x * x).sum();
let sxy: f64 = logs.iter().map(|(x, y)| x * y).sum();
let alpha = (n * sxy - sx * sy) / (n * sxx - sx * sx);
let log_a = (sy - alpha * sx) / n;
let a = log_a.exp();
let mean_y = sy / n;
let ss_tot: f64 = logs.iter().map(|(_, y)| (y - mean_y).powi(2)).sum();
let ss_res: f64 = logs
.iter()
.map(|(x, y)| {
let pred = log_a + alpha * x;
(y - pred).powi(2)
})
.sum();
let r2 = if ss_tot > 0.0 {
1.0 - ss_res / ss_tot
} else {
0.0
};
let max_abs_log_resid: f64 = logs
.iter()
.map(|(x, y)| (y - (log_a + alpha * x)).abs())
.fold(0.0_f64, f64::max);
eprintln!(
"{tag} fit: y ≈ {:.3e} · x^{:.3} | R²={:.4} max|log-resid|={:.3} (×{:.2})",
a,
alpha,
r2,
max_abs_log_resid,
max_abs_log_resid.exp()
);
if r2 < 0.85 || max_abs_log_resid > 0.5 {
eprintln!("{tag} REFUSING EXTRAPOLATION (R²<0.85 or max log-resid >0.5)");
return;
}
eprintln!("{tag} budget per call: {:.1}s", budget_y);
for (label, x_target) in extrapolate {
let pred = a * x_target.powf(alpha);
let verdict = if pred <= budget_y {
format!("FITS ({:.0}× headroom)", budget_y / pred)
} else {
format!("OVER by {:.1}× ({:.0}s)", pred / budget_y, pred)
};
eprintln!(
"{tag} extrap @ {label} (x={:.1e}): pred={:.4}s → {}",
x_target, pred, verdict
);
}
}
}