use super::family::*;
use super::gradient_paths::*;
use super::hessian_paths::{new_cell_moment_cache_stats, new_cell_moment_lru_cache};
use super::install_flex::validate_spec;
use super::*;
use crate::faer_ndarray::{FaerEigh, fast_ab, fast_atb, fast_xt_diag_x};
use crate::families::marginal_slope_orthogonal::INFLUENCE_ABSORBER_FIXED_LOG_LAMBDA;
use faer::Side;
const BMS_PROBIT_SEPARATION_BETA_INF: f64 = 40.0;
pub struct BmsMarginalJacobian {
pub marginal_dense: Arc<Array2<f64>>,
pub logslope_dense: Arc<Array2<f64>>,
pub offset_m: Array1<f64>,
pub offset_s: Array1<f64>,
pub p_marginal: usize,
}
impl BmsMarginalJacobian {
pub fn new(
marginal_dense: Arc<Array2<f64>>,
logslope_dense: Arc<Array2<f64>>,
offset_m: Array1<f64>,
offset_s: Array1<f64>,
p_marginal: usize,
) -> Self {
Self {
marginal_dense,
logslope_dense,
offset_m,
offset_s,
p_marginal,
}
}
}
impl BlockEffectiveJacobian for BmsMarginalJacobian {
fn effective_jacobian_at(
&self,
state: &FamilyLinearizationState<'_>,
) -> Result<Array2<f64>, String> {
let beta = state.beta;
let s = state.probit_frailty_scale;
let p_m = self.p_marginal;
let p_s_block = self.logslope_dense.ncols();
let beta_s_raw = if beta.len() > p_m {
&beta[p_m..]
} else {
&[][..]
};
let p_s_use = p_s_block.min(beta_s_raw.len());
let beta_s = &beta_s_raw[..p_s_use];
let n = self.marginal_dense.nrows();
let p_block = self.marginal_dense.ncols();
let mut out = Array2::<f64>::zeros((n, p_block));
for i in 0..n {
let g_i = self.offset_s[i]
+ self
.logslope_dense
.row(i)
.slice(ndarray::s![..p_s_use])
.dot(&ArrayView1::from(beta_s));
let sg = s * g_i;
let c_i = (1.0 + sg * sg).sqrt();
let m_row = self.marginal_dense.row(i);
out.row_mut(i).assign(&m_row.mapv(|x| c_i * x));
}
Ok(out)
}
fn n_outputs(&self) -> usize {
1
}
}
pub struct BmsLogslopeJacobian {
pub marginal_dense: Arc<Array2<f64>>,
pub logslope_dense: Arc<Array2<f64>>,
pub offset_m: Array1<f64>,
pub offset_s: Array1<f64>,
pub z: Arc<Array1<f64>>,
pub p_marginal: usize,
}
impl BmsLogslopeJacobian {
pub fn new(
marginal_dense: Arc<Array2<f64>>,
logslope_dense: Arc<Array2<f64>>,
offset_m: Array1<f64>,
offset_s: Array1<f64>,
z: Arc<Array1<f64>>,
p_marginal: usize,
) -> Self {
Self {
marginal_dense,
logslope_dense,
offset_m,
offset_s,
z,
p_marginal,
}
}
}
impl BlockEffectiveJacobian for BmsLogslopeJacobian {
fn effective_jacobian_at(
&self,
state: &FamilyLinearizationState<'_>,
) -> Result<Array2<f64>, String> {
let beta = state.beta;
let s = state.probit_frailty_scale;
let p_m = self.p_marginal;
let p_m_use = p_m.min(beta.len());
let beta_m = &beta[..p_m_use];
let beta_s_raw = if beta.len() > p_m {
&beta[p_m..]
} else {
&[][..]
};
let p_s_block = self.logslope_dense.ncols();
let p_s_use = p_s_block.min(beta_s_raw.len());
let beta_s = &beta_s_raw[..p_s_use];
let n = self.logslope_dense.nrows();
let mut out = Array2::<f64>::zeros((n, p_s_block));
for i in 0..n {
let q_i = self.offset_m[i]
+ self
.marginal_dense
.row(i)
.slice(ndarray::s![..p_m_use])
.dot(&ArrayView1::from(beta_m));
let g_i = self.offset_s[i]
+ self
.logslope_dense
.row(i)
.slice(ndarray::s![..p_s_use])
.dot(&ArrayView1::from(beta_s));
let sg = s * g_i;
let c_i = (1.0 + sg * sg).sqrt();
let z_i = self.z[i];
let factor = q_i * s * s * g_i / c_i + s * z_i;
let g_row = self.logslope_dense.row(i);
out.row_mut(i).assign(&g_row.mapv(|x| factor * x));
}
Ok(out)
}
fn n_outputs(&self) -> usize {
1
}
}
fn widen_marginal_dense_with_influence(
marginal_dense: &Arc<Array2<f64>>,
influence_columns: Option<&Array2<f64>>,
) -> Result<Arc<Array2<f64>>, String> {
let Some(z_infl) = influence_columns else {
return Ok(Arc::clone(marginal_dense));
};
let n = marginal_dense.nrows();
if z_infl.nrows() != n {
return Err(format!(
"influence block: residualised columns have {} rows, marginal design has {n}",
z_infl.nrows()
));
}
let p_m = marginal_dense.ncols();
let p1 = z_infl.ncols();
let mut widened = Array2::<f64>::zeros((n, p_m + p1));
widened
.slice_mut(s![.., ..p_m])
.assign(marginal_dense.as_ref());
widened.slice_mut(s![.., p_m..]).assign(z_infl);
Ok(Arc::new(widened))
}
const LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL: f64 = 1.0e-6;
#[derive(Debug, Clone)]
pub(super) struct ReducedLogslopeReparam {
transform: Array2<f64>,
}
impl ReducedLogslopeReparam {
#[inline]
pub(super) fn original_cols(&self) -> usize {
self.transform.nrows()
}
#[inline]
pub(super) fn reduced_cols(&self) -> usize {
self.transform.ncols()
}
pub(super) fn recover_original_logslope_beta(
&self,
beta_reduced: &Array1<f64>,
) -> Result<Array1<f64>, String> {
if beta_reduced.len() != self.reduced_cols() {
return Err(format!(
"reduced logslope reparam: β' length ({}) != reduced width ({})",
beta_reduced.len(),
self.reduced_cols()
));
}
Ok(self.transform.dot(beta_reduced))
}
}
fn build_reduced_logslope_reparam(
marginal_design: &TermCollectionDesign,
logslope_design: &TermCollectionDesign,
z: &Array1<f64>,
row_metric: &Array1<f64>,
marginal_offset: &Array1<f64>,
logslope_offset: &Array1<f64>,
marginal_baseline: f64,
logslope_baseline: f64,
probit_scale: f64,
) -> Result<Option<ReducedLogslopeReparam>, String> {
let marginal = marginal_design
.design
.try_to_dense_arc("build_reduced_logslope_reparam::marginal")?;
let logslope = logslope_design
.design
.try_to_dense_arc("build_reduced_logslope_reparam::logslope")?;
let n = marginal.nrows();
if logslope.nrows() != n
|| z.len() != n
|| row_metric.len() != n
|| marginal_offset.len() != n
|| logslope_offset.len() != n
{
return Err(format!(
"reduced logslope reparam row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
marginal.nrows(),
logslope.nrows(),
z.len(),
row_metric.len(),
marginal_offset.len(),
logslope_offset.len(),
));
}
let p_m = marginal.ncols();
let p_g = logslope.ncols();
if p_m == 0 || p_g == 0 {
return Ok(None);
}
if !marginal_baseline.is_finite()
|| !logslope_baseline.is_finite()
|| !probit_scale.is_finite()
|| probit_scale <= 0.0
|| z.iter().any(|v| !v.is_finite())
|| row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
|| marginal_offset.iter().any(|v| !v.is_finite())
|| logslope_offset.iter().any(|v| !v.is_finite())
{
return Err(
"reduced logslope reparam requires finite pilot geometry and finite non-negative row metric"
.to_string(),
);
}
let reparam = crate::solver::orthogonal_reparam::OrthogonalReparam::build_unconditional(
marginal.view(),
logslope.view(),
row_metric,
)?;
let c_tilde = reparam.reparameterized_confound().to_owned();
let stt = fast_xt_diag_x(&c_tilde, row_metric);
let stt = (&stt + &stt.t()) * 0.5;
if stt.iter().any(|v| !v.is_finite()) {
return Err("reduced logslope reparam: C̃ W-Gram produced non-finite entries".to_string());
}
let raw_gram = fast_xt_diag_x(&logslope, row_metric);
let raw_scale = (0..p_g).map(|i| raw_gram[[i, i]]).fold(0.0_f64, f64::max);
let (evals, evecs) = stt
.eigh(Side::Lower)
.map_err(|e| format!("reduced logslope reparam: eigendecomposition failed: {e:?}"))?;
if !raw_scale.is_finite() || raw_scale <= 0.0 {
return Ok(None);
}
let tol = raw_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
kept.sort_by(|&a, &b| {
evals[b]
.partial_cmp(&evals[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let r = kept.len();
if r == p_g || r == 0 {
return Ok(None);
}
let mut transform = Array2::<f64>::zeros((p_g, r));
for (out_col, &src) in kept.iter().enumerate() {
transform.column_mut(out_col).assign(&evecs.column(src));
}
if transform.iter().any(|v| !v.is_finite()) {
return Err(
"reduced logslope reparam: reduced transform produced non-finite entries".to_string(),
);
}
Ok(Some(ReducedLogslopeReparam { transform }))
}
fn reparameterize_logslope_design_reduced(
logslope_design: &TermCollectionDesign,
reparam: &ReducedLogslopeReparam,
) -> Result<TermCollectionDesign, String> {
let g = logslope_design
.design
.try_to_dense_arc("reparameterize_logslope_design_reduced::logslope")?;
let p_g = g.ncols();
if p_g != reparam.original_cols() {
return Err(format!(
"reduced logslope reparam width mismatch: design has {p_g} cols, transform expects {}",
reparam.original_cols()
));
}
let t = &reparam.transform;
let r = reparam.reduced_cols();
let g_reduced = fast_ab(&g, t);
let mut new_penalties: Vec<crate::terms::smooth::BlockwisePenalty> =
Vec::with_capacity(logslope_design.penalties.len());
let mut new_nullspace_dims: Vec<usize> = Vec::with_capacity(logslope_design.penalties.len());
for bp in &logslope_design.penalties {
let mut full = Array2::<f64>::zeros((p_g, p_g));
full.slice_mut(s![bp.col_range.clone(), bp.col_range.clone()])
.assign(&bp.local);
let st = fast_ab(&full, t); let mut s_reduced = fast_atb(t, &st); s_reduced = (&s_reduced + &s_reduced.t()) * 0.5;
let (evals, _) = s_reduced
.eigh(Side::Lower)
.map_err(|e| format!("reduced logslope penalty eigendecomposition failed: {e:?}"))?;
let max_eval = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
let pen_tol = (max_eval * 1.0e-12).max(f64::EPSILON);
let rank = evals.iter().filter(|&&v| v.abs() > pen_tol).count();
let nullspace_dim = r.saturating_sub(rank);
new_penalties.push(crate::terms::smooth::BlockwisePenalty::new(0..r, s_reduced));
new_nullspace_dims.push(nullspace_dim);
}
let new_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(g_reduced));
Ok(TermCollectionDesign {
design: new_design,
penalties: new_penalties,
nullspace_dims: new_nullspace_dims,
penaltyinfo: Vec::new(),
dropped_penaltyinfo: Vec::new(),
coefficient_lower_bounds: None,
linear_constraints: None,
intercept_range: 0..0,
linear_ranges: Vec::new(),
random_effect_ranges: Vec::new(),
random_effect_levels: Vec::new(),
smooth: crate::terms::smooth::SmoothDesign {
term_designs: Vec::new(),
penalties: Vec::new(),
nullspace_dims: Vec::new(),
penaltyinfo: Vec::new(),
dropped_penaltyinfo: Vec::new(),
terms: Vec::new(),
coefficient_lower_bounds: None,
linear_constraints: None,
},
})
}
fn marginal_penalties_with_influence_ridge(
design: &TermCollectionDesign,
rho_marginal: &Array1<f64>,
influence_columns: Option<&Array2<f64>>,
influence_ridge_log_lambda: f64,
) -> Result<(Vec<PenaltyMatrix>, Vec<usize>, Array1<f64>), String> {
let p_m = design.design.ncols();
let p1 = influence_columns.map(|z| z.ncols()).unwrap_or(0);
let total_dim = p_m + p1;
let mut penalties: Vec<PenaltyMatrix> = design
.penalties
.iter()
.map(|bp| PenaltyMatrix::from_blockwise(bp.clone(), total_dim))
.collect();
let mut nullspace_dims = design.nullspace_dims.clone();
let mut log_lambdas = rho_marginal.to_vec();
if p1 > 0 {
penalties.push(
PenaltyMatrix::Blockwise {
local: Array2::<f64>::eye(p1),
col_range: p_m..total_dim,
total_dim,
}
.with_fixed_log_lambda(influence_ridge_log_lambda),
);
nullspace_dims.push(0);
log_lambdas.push(influence_ridge_log_lambda);
}
Ok((penalties, nullspace_dims, Array1::from_vec(log_lambdas)))
}
fn widen_marginal_beta_hint(
beta_hint: Option<Array1<f64>>,
p_marginal_widened: usize,
) -> Option<Array1<f64>> {
beta_hint.map(|hint| {
if hint.len() == p_marginal_widened {
hint
} else {
let mut widened = Array1::<f64>::zeros(p_marginal_widened);
let copy = hint.len().min(p_marginal_widened);
widened
.slice_mut(s![..copy])
.assign(&hint.slice(s![..copy]));
widened
}
})
}
fn argmax_by_abs<I>(values: I) -> Option<(String, usize, f64)>
where
I: IntoIterator<Item = (String, usize, f64)>,
{
values
.into_iter()
.map(|(label, idx, value)| (label, idx, value.abs()))
.filter(|(_, _, abs)| abs.is_finite())
.max_by(|left, right| {
left.2
.partial_cmp(&right.2)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
fn marginal_parametric_argmax_from_beta(
beta: &Array1<f64>,
design: &TermCollectionDesign,
spec: &TermCollectionSpec,
) -> Option<(String, usize, f64)> {
let mut entries = Vec::<(String, usize, f64)>::new();
if design.intercept_range.len() == 1 {
let idx = design.intercept_range.start;
if idx < beta.len() {
entries.push(("intercept".to_string(), idx, beta[idx]));
}
}
for (linear, (name, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
if linear.double_penalty {
continue;
}
for local_col in range.clone() {
if local_col < beta.len() {
entries.push((name.clone(), local_col, beta[local_col]));
}
}
}
argmax_by_abs(entries)
}
fn marginal_parametric_argmax_from_warm_start(
warm_start: &CustomFamilyWarmStart,
design: &TermCollectionDesign,
spec: &TermCollectionSpec,
) -> Option<(String, usize, f64)> {
let mut entries = Vec::<(String, usize, f64)>::new();
if design.intercept_range.len() == 1
&& let Some((idx, abs)) =
warm_start.block_beta_abs_argmax_in_range(0, design.intercept_range.clone())
{
entries.push(("intercept".to_string(), idx, abs));
}
for (linear, (name, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
if linear.double_penalty {
continue;
}
if let Some((idx, abs)) = warm_start.block_beta_abs_argmax_in_range(0, range.clone()) {
entries.push((name.clone(), idx, abs));
}
}
argmax_by_abs(entries)
}
fn marginal_full_argmax_from_beta(
beta: &Array1<f64>,
design: &TermCollectionDesign,
) -> Option<(String, usize, f64)> {
let mut entries = Vec::<(String, usize, f64)>::new();
if design.intercept_range.len() == 1 {
let idx = design.intercept_range.start;
if idx < beta.len() {
entries.push(("intercept".to_string(), idx, beta[idx]));
}
}
for (name, range) in &design.linear_ranges {
for local_col in range.clone() {
if local_col < beta.len() {
entries.push((name.clone(), local_col, beta[local_col]));
}
}
}
for (name, range) in &design.random_effect_ranges {
for local_col in range.clone() {
if local_col < beta.len() {
entries.push((name.clone(), local_col, beta[local_col]));
}
}
}
let smooth_start = design
.design
.ncols()
.saturating_sub(design.smooth.total_smooth_cols());
for term in &design.smooth.terms {
let label = format!("smooth '{}'", term.name);
let start = smooth_start + term.coeff_range.start;
let end = smooth_start + term.coeff_range.end;
for local_col in start..end {
if local_col < beta.len() {
entries.push((label.clone(), local_col, beta[local_col]));
}
}
}
for local_col in design.design.ncols()..beta.len() {
entries.push((
"fixed-ridge influence absorber".to_string(),
local_col,
beta[local_col],
));
}
argmax_by_abs(entries)
}
fn marginal_full_argmax_from_warm_start(
warm_start: &CustomFamilyWarmStart,
design: &TermCollectionDesign,
) -> Option<(String, usize, f64)> {
let block_width = warm_start.block_beta_len(0)?;
let mut entries = Vec::<(String, usize, f64)>::new();
if design.intercept_range.len() == 1
&& let Some((idx, abs)) =
warm_start.block_beta_abs_argmax_in_range(0, design.intercept_range.clone())
{
entries.push(("intercept".to_string(), idx, abs));
}
for (name, range) in &design.linear_ranges {
if let Some((idx, abs)) = warm_start.block_beta_abs_argmax_in_range(0, range.clone()) {
entries.push((name.clone(), idx, abs));
}
}
for (name, range) in &design.random_effect_ranges {
if let Some((idx, abs)) = warm_start.block_beta_abs_argmax_in_range(0, range.clone()) {
entries.push((name.clone(), idx, abs));
}
}
let smooth_start = design
.design
.ncols()
.saturating_sub(design.smooth.total_smooth_cols());
for term in &design.smooth.terms {
let range = (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
if let Some((idx, abs)) = warm_start.block_beta_abs_argmax_in_range(0, range) {
entries.push((format!("smooth '{}'", term.name), idx, abs));
}
}
if block_width > design.design.ncols()
&& let Some((idx, abs)) =
warm_start.block_beta_abs_argmax_in_range(0, design.design.ncols()..block_width)
{
entries.push(("fixed-ridge influence absorber".to_string(), idx, abs));
}
argmax_by_abs(entries)
}
fn bernoulli_marginal_slope_runaway_error_from_argmax(
parametric_argmax: Option<(String, usize, f64)>,
block_argmax: Option<(String, usize, f64)>,
inner_status: &str,
eval_label: &str,
) -> Option<String> {
let (label, local_col, beta_abs, explanation) = if let Some((label, local_col, beta_abs)) =
parametric_argmax
&& beta_abs >= BMS_PROBIT_SEPARATION_BETA_INF
{
(
label,
local_col,
beta_abs,
"an unpenalized parametric marginal direction has no stable finite probit optimum",
)
} else if let Some((label, local_col, beta_abs)) = block_argmax
&& beta_abs >= BMS_PROBIT_SEPARATION_BETA_INF
{
(
label,
local_col,
beta_abs,
"a marginal smooth direction is trading off against the logslope surface; this is the under-constrained marginal/logslope coupling that appears when the score is correlated with the shared surface covariates",
)
} else {
return None;
};
if beta_abs < BMS_PROBIT_SEPARATION_BETA_INF {
return None;
}
Some(format!(
"bernoulli marginal-slope probit marginal/logslope runaway detected in block \
'marginal_surface' during {eval_label}: term '{label}' \
(local column {local_col}) has \
|β|∞={beta_abs:.3e} (diagnostic threshold \
{BMS_PROBIT_SEPARATION_BETA_INF:.1}). The joint design is identifiable; \
{explanation}. {inner_status}. The robust Jeffreys curvature path is \
already installed for this fit, so this diagnostic means the current \
coupled surface still exposes a separation-scale direction rather than \
a request for an external bias-reduction prior. Reduce or \
reparameterize the coupled marginal/logslope surface, or use a \
lower-dimensional logslope interaction. This is not a \
Matérn/Duchon polynomial-nullspace or cross-block gauge-priority \
failure."
))
}
fn bernoulli_marginal_slope_runaway_error(
warm_start: &CustomFamilyWarmStart,
design: &TermCollectionDesign,
spec: &TermCollectionSpec,
inner_converged: bool,
eval_label: &str,
) -> Option<String> {
let inner_status = if inner_converged {
"the inner solve reached a KKT certificate at a separation-scale coefficient"
} else {
"the inner solve failed while already carrying a separation-scale coefficient"
};
bernoulli_marginal_slope_runaway_error_from_argmax(
marginal_parametric_argmax_from_warm_start(warm_start, design, spec),
marginal_full_argmax_from_warm_start(warm_start, design),
inner_status,
eval_label,
)
}
#[cfg(test)]
mod runaway_tests {
use super::*;
use crate::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback, fast_xt_diag_y};
fn marginal_logslope_overlap_penalty(
marginal_design: &DesignMatrix,
logslope_design: &DesignMatrix,
z: &Array1<f64>,
row_metric: &Array1<f64>,
marginal_offset: &Array1<f64>,
logslope_offset: &Array1<f64>,
marginal_baseline: f64,
logslope_baseline: f64,
probit_scale: f64,
) -> Result<Option<Array2<f64>>, String> {
let marginal =
marginal_design.try_to_dense_arc("marginal_logslope_overlap_penalty::marginal")?;
let logslope =
logslope_design.try_to_dense_arc("marginal_logslope_overlap_penalty::logslope")?;
let n = marginal.nrows();
if logslope.nrows() != n
|| z.len() != n
|| row_metric.len() != n
|| marginal_offset.len() != n
|| logslope_offset.len() != n
{
return Err(format!(
"marginal/logslope overlap penalty row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
marginal.nrows(),
logslope.nrows(),
z.len(),
row_metric.len(),
marginal_offset.len(),
logslope_offset.len(),
));
}
let p_m = marginal.ncols();
let p_g = logslope.ncols();
if p_m == 0 || p_g == 0 {
return Ok(None);
}
if !marginal_baseline.is_finite()
|| !logslope_baseline.is_finite()
|| !probit_scale.is_finite()
|| probit_scale <= 0.0
|| z.iter().any(|v| !v.is_finite())
|| row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
|| marginal_offset.iter().any(|v| !v.is_finite())
|| logslope_offset.iter().any(|v| !v.is_finite())
{
return Err(
"marginal/logslope overlap penalty requires finite pilot geometry and finite non-negative row metric"
.to_string(),
);
}
let mut marginal_effective = Array2::<f64>::zeros((n, p_m));
let mut effective_logslope = Array2::<f64>::zeros((n, p_g));
for i in 0..n {
let q_i = marginal_offset[i] + marginal_baseline;
let g_i = logslope_offset[i] + logslope_baseline;
let sg = probit_scale * g_i;
let c_i = (1.0 + sg * sg).sqrt();
let logslope_factor =
q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
for j in 0..p_m {
marginal_effective[[i, j]] = c_i * marginal[[i, j]];
}
for j in 0..p_g {
effective_logslope[[i, j]] = logslope_factor * logslope[[i, j]];
}
}
if effective_logslope.iter().all(|v| v.abs() <= f64::EPSILON) {
return Ok(None);
}
let mut gram = fast_xt_diag_x(&effective_logslope, row_metric);
let gram_scale = gram.diag().iter().copied().fold(0.0_f64, f64::max);
if !gram_scale.is_finite() || gram_scale <= 0.0 {
return Ok(None);
}
let projection_ridge = (gram_scale * 1.0e-10).max(f64::EPSILON);
for i in 0..p_g {
gram[[i, i]] += projection_ridge;
}
let cross = fast_xt_diag_y(&effective_logslope, row_metric, &marginal_effective);
let gram_view = FaerArrayView::new(&gram);
let factor = factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower)
.map_err(|e| format!("marginal/logslope overlap Gram factorization failed: {e}"))?;
let rhsview = FaerArrayView::new(&cross);
let coeffs_mat = factor.solve(rhsview.as_ref());
let coeffs = Array2::from_shape_fn((p_g, p_m), |(i, j)| coeffs_mat[(i, j)]);
let projected_marginal = fast_ab(&effective_logslope, &coeffs);
let mut penalty = fast_xt_diag_y(&marginal_effective, row_metric, &projected_marginal);
penalty = (&penalty + &penalty.t()) * 0.5;
let max_abs = penalty.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
if !max_abs.is_finite() || max_abs <= 1.0e-12 {
return Ok(None);
}
Ok(Some(penalty))
}
#[test]
fn spatial_joint_setup_counts_only_learned_penalties_in_rho() {
let data = Array2::<f64>::zeros((3, 1));
let empty_terms = TermCollectionSpec {
linear_terms: Vec::new(),
random_effect_terms: Vec::new(),
smooth_terms: Vec::new(),
};
let setup = joint_setup(
data.view(),
&empty_terms,
&empty_terms,
2,
3,
&[0.4],
&SpatialLengthScaleOptimizationOptions::default(),
);
assert_eq!(
setup.rho_dim(),
6,
"BMS spatial setup rho must contain only learned marginal/logslope/auxiliary penalties; fixed physical ridges are carried by PenaltyMatrix::Fixed"
);
}
#[test]
fn overlap_penalty_targets_score_weighted_logslope_span() {
let marginal = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).unwrap(),
));
let logslope = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
));
let z = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let row_metric = Array1::ones(4);
let offsets = Array1::zeros(4);
let penalty = marginal_logslope_overlap_penalty(
&marginal,
&logslope,
&z,
&row_metric,
&offsets,
&offsets,
0.0,
0.0,
1.0,
)
.expect("overlap penalty should build")
.expect("marginal signal lies in the pilot logslope Jacobian span");
assert_eq!(penalty.dim(), (1, 1));
assert!((penalty[[0, 0]] - 14.0).abs() < 1.0e-6);
}
#[test]
fn overlap_penalty_skips_weight_orthogonal_channels() {
let marginal = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::from_shape_vec((4, 1), vec![-1.0, 1.0, -1.0, 1.0]).unwrap(),
));
let logslope = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
));
let z = Array1::ones(4);
let row_metric = Array1::ones(4);
let offsets = Array1::zeros(4);
let penalty = marginal_logslope_overlap_penalty(
&marginal,
&logslope,
&z,
&row_metric,
&offsets,
&offsets,
0.0,
0.0,
1.0,
)
.expect("overlap penalty should build");
assert!(penalty.is_none());
}
#[test]
fn runaway_diagnostic_names_unpenalized_parametric_direction_first() {
let msg = bernoulli_marginal_slope_runaway_error_from_argmax(
Some(("sex".to_string(), 1, 52.0)),
Some(("smooth 'matern(PC1,PC2,PC3)'".to_string(), 7, 49.0)),
"inner status",
"unit-test eval",
)
.expect("parametric runaway should be diagnosed");
assert!(msg.contains("term 'sex'"));
assert!(msg.contains("unpenalized parametric marginal direction"));
assert!(msg.contains("robust Jeffreys curvature path is already installed"));
assert!(!msg.contains("explicit declared separation/bias-reduction prior"));
assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
}
#[test]
fn runaway_diagnostic_names_marginal_logslope_coupling_when_smooth_runs_away() {
let msg = bernoulli_marginal_slope_runaway_error_from_argmax(
Some(("sex".to_string(), 1, 2.0)),
Some(("smooth 'marginal_surface[0]'".to_string(), 6, 51.4)),
"inner status",
"unit-test eval",
)
.expect("smooth runaway should be diagnosed");
assert!(msg.contains("marginal/logslope runaway"));
assert!(msg.contains("smooth 'marginal_surface[0]'"));
assert!(msg.contains("score is correlated with the shared surface covariates"));
assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
}
}
fn build_marginal_blockspec_bms(
design: &TermCollectionDesign,
baseline: f64,
offset: &Array1<f64>,
rho: Array1<f64>,
beta_hint: Option<Array1<f64>>,
logslope_design: &TermCollectionDesign,
logslope_offset: &Array1<f64>,
logslope_baseline: f64,
p_marginal: usize,
influence_columns: Option<&Array2<f64>>,
influence_ridge_log_lambda: f64,
) -> Result<ParameterBlockSpec, String> {
let offset_m = offset + baseline;
let offset_s = logslope_offset + logslope_baseline;
let raw_marginal_dense = design
.design
.try_to_dense_arc("build_marginal_blockspec_bms::marginal")?;
let marginal_dense =
widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
let logslope_dense = logslope_design
.design
.try_to_dense_arc("build_marginal_blockspec_bms::logslope")?;
let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsMarginalJacobian {
marginal_dense: Arc::clone(&marginal_dense),
logslope_dense,
offset_m: offset_m.clone(),
offset_s,
p_marginal,
});
let (penalties, nullspace_dims, initial_log_lambdas) = marginal_penalties_with_influence_ridge(
design,
&rho,
influence_columns,
influence_ridge_log_lambda,
)?;
Ok(ParameterBlockSpec {
name: "marginal_surface".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
(*marginal_dense).clone(),
)),
offset: offset_m,
penalties,
nullspace_dims,
initial_log_lambdas,
initial_beta: widen_marginal_beta_hint(beta_hint, p_marginal),
gauge_priority: 150,
jacobian_callback: Some(callback),
stacked_design: None,
stacked_offset: None,
})
}
fn build_logslope_blockspec_bms(
design: &TermCollectionDesign,
baseline: f64,
offset: &Array1<f64>,
rho: Array1<f64>,
beta_hint: Option<Array1<f64>>,
marginal_design: &TermCollectionDesign,
marginal_offset: &Array1<f64>,
marginal_baseline: f64,
z: Arc<Array1<f64>>,
p_marginal: usize,
influence_columns: Option<&Array2<f64>>,
) -> Result<ParameterBlockSpec, String> {
let offset_s = offset + baseline;
let offset_m = marginal_offset + marginal_baseline;
let raw_marginal_dense = marginal_design
.design
.try_to_dense_arc("build_logslope_blockspec_bms::marginal")?;
let marginal_dense =
widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
let logslope_dense = design
.design
.try_to_dense_arc("build_logslope_blockspec_bms::logslope")?;
let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsLogslopeJacobian {
marginal_dense,
logslope_dense: Arc::clone(&logslope_dense),
offset_m,
offset_s: offset_s.clone(),
z,
p_marginal,
});
Ok(ParameterBlockSpec {
name: "logslope_surface".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
(*logslope_dense).clone(),
)),
offset: offset_s,
penalties: design.penalties_as_penalty_matrix(),
nullspace_dims: design.nullspace_dims.clone(),
initial_log_lambdas: rho,
initial_beta: beta_hint,
gauge_priority: 120,
jacobian_callback: Some(callback),
stacked_design: None,
stacked_offset: None,
})
}
pub(crate) fn build_deviation_aux_blockspec(
name: &str,
prepared: &DeviationPrepared,
rho: Array1<f64>,
beta_hint: Option<Array1<f64>>,
) -> Result<ParameterBlockSpec, String> {
let mut block = prepared.block.clone();
block.initial_log_lambdas = Some(rho);
let candidate_beta = beta_hint.or_else(|| Some(Array1::<f64>::zeros(block.design.ncols())));
block.initial_beta = candidate_beta
.map(|beta| {
let zero = Array1::<f64>::zeros(beta.len());
project_monotone_feasible_beta(&prepared.runtime, &zero, &beta, name)
})
.transpose()?;
let mut spec = block.intospec(name)?;
spec.gauge_priority = match name {
"link_dev" => 60,
"score_warp_dev" => 80,
_ => 70,
};
Ok(spec)
}
pub(crate) fn push_deviation_aux_blockspecs(
blocks: &mut Vec<ParameterBlockSpec>,
rho: &Array1<f64>,
cursor: &mut usize,
score_warp_prepared: Option<&DeviationPrepared>,
link_dev_prepared: Option<&DeviationPrepared>,
score_warp_beta_hint: Option<Array1<f64>>,
link_dev_beta_hint: Option<Array1<f64>>,
) -> Result<(), String> {
if let Some(prepared) = score_warp_prepared {
let rho_h = rho
.slice(s![*cursor..*cursor + prepared.block.penalties.len()])
.to_owned();
*cursor += prepared.block.penalties.len();
blocks.push(build_deviation_aux_blockspec(
"score_warp_dev",
prepared,
rho_h,
score_warp_beta_hint,
)?);
}
if let Some(prepared) = link_dev_prepared {
let rho_w = rho
.slice(s![*cursor..*cursor + prepared.block.penalties.len()])
.to_owned();
blocks.push(build_deviation_aux_blockspec(
"link_dev",
prepared,
rho_w,
link_dev_beta_hint,
)?);
}
Ok(())
}
fn inner_fit(
family: &BernoulliMarginalSlopeFamily,
blocks: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<UnifiedFitResult, String> {
let mut options = options.clone();
options.use_outer_hessian = false;
options.outer_tol = options.outer_tol.max(2.0e-5);
fit_custom_family(family, blocks, &options).map_err(|e| e.to_string())
}
pub fn fit_bernoulli_marginal_slope_terms(
data: ArrayView2<'_, f64>,
spec: BernoulliMarginalSlopeTermSpec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
policy: &crate::resource::ResourcePolicy,
) -> Result<BernoulliMarginalSlopeFitResult, String> {
let mut spec = spec;
let data_view = data;
validate_spec(data_view, &spec)?;
let mut effective_kappa_options = kappa_options.clone();
let kappa_locked_marginal = crate::smooth::all_spatial_terms_kappa_fixed(&spec.marginalspec);
let kappa_locked_logslope = crate::smooth::all_spatial_terms_kappa_fixed(&spec.logslopespec);
if effective_kappa_options.enabled && kappa_locked_marginal && kappa_locked_logslope {
log::info!(
"[BMS spatial] disabling κ/ψ optimization: every spatial term has an \
explicit length_scale and no anisotropy; user-supplied kernel scale is fixed"
);
effective_kappa_options.enabled = false;
}
let flex_spatial_pilot_path = (spec.score_warp.is_some() || spec.link_dev.is_some())
&& spec.y.len() >= BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD
&& effective_kappa_options.enabled;
if flex_spatial_pilot_path {
let marginal_terms = spatial_length_scale_term_indices(&spec.marginalspec);
let logslope_terms = spatial_length_scale_term_indices(&spec.logslopespec);
let marginal_updates = apply_spatial_anisotropy_pilot_initializer(
data_view,
&mut spec.marginalspec,
&marginal_terms,
effective_kappa_options.pilot_subsample_threshold,
&effective_kappa_options,
);
let logslope_updates = apply_spatial_anisotropy_pilot_initializer(
data_view,
&mut spec.logslopespec,
&logslope_terms,
effective_kappa_options.pilot_subsample_threshold,
&effective_kappa_options,
);
effective_kappa_options.enabled = false;
log::info!(
"[BMS spatial] n={} flex=true pilot_geometry_updates={} iterative_spatial_outer=false reason=large-flex-spatial-pilot",
spec.y.len(),
marginal_updates + logslope_updates,
);
}
let (z_standardized, z_normalization) = standardize_latent_z_with_policy(
&spec.z,
&spec.weights,
"bernoulli-marginal-slope",
&spec.latent_z_policy,
)?;
spec.z = z_standardized;
let sigma_learnable = matches!(
&spec.frailty,
FrailtySpec::GaussianShift { sigma_fixed: None }
);
let initial_sigma = match &spec.frailty {
FrailtySpec::GaussianShift {
sigma_fixed: Some(s),
} => Some(*s),
FrailtySpec::GaussianShift { sigma_fixed: None } => Some(0.5),
FrailtySpec::None => None,
FrailtySpec::HazardMultiplier { .. } => {
return Err(
"internal: validate_spec should have rejected unsupported marginal-slope frailty"
.to_string(),
);
}
};
let probit_scale = probit_frailty_scale(initial_sigma);
let (_raw_joint_designs, mut joint_specs) = build_term_collection_designs_and_freeze_joint(
data_view,
&[spec.marginalspec.clone(), spec.logslopespec.clone()],
)
.map_err(|e| e.to_string())?;
let marginalspec_boot = joint_specs.remove(0);
let logslopespec_boot = joint_specs.remove(0);
let (mut joint_designs, _) = build_term_collection_designs_and_freeze_joint(
data_view,
&[marginalspec_boot.clone(), logslopespec_boot.clone()],
)
.map_err(|e| format!("failed to rebuild frozen probe BMS joint designs: {e}"))?;
let marginal_design = joint_designs.remove(0);
let logslope_design = joint_designs.remove(0);
let (latent_measure, latent_z_calibration) =
build_latent_measure_with_geometry(&spec.z, &spec.weights, &spec.latent_z_policy)?;
if latent_measure.is_empirical() && sigma_learnable {
return Err("empirical latent-measure marginal-slope calibration requires fixed GaussianShift sigma; learnable sigma derivatives must be fit under the standard-normal latent measure"
.to_string());
}
let y = Arc::new(spec.y.clone());
let weights = Arc::new(spec.weights.clone());
let z = match &latent_z_calibration {
LatentMeasureCalibration::None => Arc::new(spec.z.clone()),
LatentMeasureCalibration::RankInverseNormal(cal) => {
Arc::new(cal.apply_to_training(&spec.z)?)
}
};
let z_train = z.as_ref();
let pilot_baseline = pooled_probit_baseline(&spec.y, z_train, &spec.weights)?;
let baseline = (
bernoulli_marginal_slope_eta_from_probability(
&spec.base_link,
normal_cdf(pilot_baseline.0),
"bernoulli marginal-slope baseline link inversion",
)?,
pilot_baseline.1 / probit_scale,
);
let rigid_pilot_eta = rigid_pooled_probit_pilot_eta(
&spec.base_link,
z_train,
&spec.marginal_offset,
&spec.logslope_offset,
baseline.0,
baseline.1,
probit_scale,
)?;
let cross_block_pilot_w_score_warp =
pilot_irls_hessian_row_metric_at_eta(&rigid_pilot_eta, &spec.weights);
let influence_columns = if let Some(jac) = spec
.score_influence_jacobian
.as_ref()
.filter(|j| j.ncols() > 0)
{
let marginal_dense_for_proj = marginal_design
.design
.try_to_dense_arc("bernoulli marginal-slope influence-block marginal projection")?;
let marginal_dense = marginal_dense_for_proj.as_ref();
if jac.nrows() != marginal_dense.nrows() {
return Err(format!(
"influence block: Jacobian has {} rows, marginal design has {}",
jac.nrows(),
marginal_dense.nrows()
));
}
let rigid_logslope_at_rows = &spec.logslope_offset + baseline.1;
let residualized =
crate::families::marginal_slope_orthogonal::residualized_influence_block(
jac,
z_train,
&rigid_logslope_at_rows,
probit_scale,
marginal_dense.view(),
&cross_block_pilot_w_score_warp,
)?;
Some(residualized)
} else {
None
};
let mut cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning> = Vec::new();
let score_warp_prepared = if let Some(cfg) = spec.score_warp.as_ref() {
use super::deviation_runtime::ParametricAnchorBlock;
let mut prepared = build_score_warp_deviation_block_from_seed(z_train, cfg)?;
let outcome = install_compiled_flex_block_into_runtime(
&mut prepared,
z_train,
cfg,
&[
(&marginal_design.design, ParametricAnchorBlock::Marginal),
(&logslope_design.design, ParametricAnchorBlock::Logslope),
],
&[],
&cross_block_pilot_w_score_warp,
)?;
match outcome {
FlexCompileOutcome::Reparameterised => Some(prepared),
FlexCompileOutcome::FullyAliased { reason } => {
cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
candidate_label: "score_warp",
anchor_summary: "marginal+logslope".to_string(),
reason,
});
Some(prepared)
}
}
} else {
None
};
let link_dev_prepared = if let Some(cfg) = spec.link_dev.as_ref() {
let eta_pilot = pilot_eta_for_link_dev_orthogonalisation(
&spec.base_link,
&spec.y,
z_train,
&spec.weights,
&marginal_design.design,
&spec.marginal_offset,
&spec.logslope_offset,
baseline.0,
baseline.1,
probit_scale,
)?;
let link_dev_seed = padded_deviation_seed(&eta_pilot, 1.0, 0.5);
let mut prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_dev_seed,
&eta_pilot,
cfg,
)?;
let score_warp_anchor_design = score_warp_prepared
.as_ref()
.map(|sw| sw.runtime.design_at_training_with_residual(z_train))
.transpose()?;
use super::deviation_runtime::ParametricAnchorBlock;
let parametric_anchors: [(&DesignMatrix, ParametricAnchorBlock); 2] = [
(&marginal_design.design, ParametricAnchorBlock::Marginal),
(&logslope_design.design, ParametricAnchorBlock::Logslope),
];
let flex_anchor_slot: Option<&Array2<f64>> = score_warp_anchor_design.as_ref();
let flex_anchors: Vec<&Array2<f64>> = flex_anchor_slot.into_iter().collect();
let cross_block_pilot_w_link_dev =
pilot_irls_hessian_row_metric_at_eta(&eta_pilot, &spec.weights);
let outcome = install_compiled_flex_block_into_runtime(
&mut prepared,
&eta_pilot,
cfg,
¶metric_anchors,
&flex_anchors,
&cross_block_pilot_w_link_dev,
)?;
match outcome {
FlexCompileOutcome::Reparameterised => Some(prepared),
FlexCompileOutcome::FullyAliased { reason } => {
cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
candidate_label: "link_deviation",
anchor_summary: "marginal+logslope+score_warp".to_string(),
reason,
});
Some(prepared)
}
}
} else {
None
};
let extra_rho0 = {
let mut out = Vec::new();
if let Some(ref prepared) = score_warp_prepared {
out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
}
if let Some(ref prepared) = link_dev_prepared {
out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
}
out
};
let logslope_reduced_reparam: Option<ReducedLogslopeReparam> = build_reduced_logslope_reparam(
&marginal_design,
&logslope_design,
z.as_ref(),
&cross_block_pilot_w_score_warp,
&spec.marginal_offset,
&spec.logslope_offset,
baseline.0,
baseline.1,
probit_scale,
)?;
let reduce_logslope_design =
|logslope_design: &TermCollectionDesign| -> Result<TermCollectionDesign, String> {
match logslope_reduced_reparam.as_ref() {
Some(reparam) => reparameterize_logslope_design_reduced(logslope_design, reparam),
None => Ok(logslope_design.clone()),
}
};
let marginal_penalty_count = marginal_design.penalties.len();
let setup = joint_setup(
data_view,
&marginalspec_boot,
&logslopespec_boot,
marginal_penalty_count,
logslope_design.penalties.len(),
&extra_rho0,
&effective_kappa_options,
);
let setup = if sigma_learnable {
setup.with_auxiliary(
Array1::from_vec(vec![initial_sigma.expect("learnable sigma seed").ln()]),
Array1::from_vec(vec![0.01_f64.ln()]),
Array1::from_vec(vec![5.0_f64.ln()]),
)
} else {
setup
};
let final_sigma_cell = std::cell::Cell::new(initial_sigma);
let exact_warm_start = RefCell::new(None::<CustomFamilyWarmStart>);
let runaway_error = RefCell::new(None::<String>);
let pending_beta_seed = RefCell::new(None::<Array1<f64>>);
let hints = RefCell::new(ThetaHints::default());
let score_warp_runtime = score_warp_prepared.as_ref().map(|p| p.runtime.clone());
let link_dev_runtime = link_dev_prepared.as_ref().map(|p| p.runtime.clone());
let build_blocks = |rho: &Array1<f64>,
marginal_design: &TermCollectionDesign,
logslope_design: &TermCollectionDesign|
-> Result<Vec<ParameterBlockSpec>, String> {
let hints = hints.borrow();
let mut cursor = 0usize;
let logslope_design_reduced = reduce_logslope_design(logslope_design)?;
let logslope_design = &logslope_design_reduced;
let rho_marginal = rho
.slice(s![cursor..cursor + marginal_design.penalties.len()])
.to_owned();
cursor += marginal_design.penalties.len();
let rho_logslope = rho
.slice(s![cursor..cursor + logslope_design.penalties.len()])
.to_owned();
cursor += logslope_design.penalties.len();
let p_m = marginal_design.design.ncols()
+ influence_columns.as_ref().map(|z| z.ncols()).unwrap_or(0);
let mut blocks = vec![
build_marginal_blockspec_bms(
marginal_design,
baseline.0,
&spec.marginal_offset,
rho_marginal,
hints.marginal_beta.clone(),
logslope_design,
&spec.logslope_offset,
baseline.1,
p_m,
influence_columns.as_ref(),
INFLUENCE_ABSORBER_FIXED_LOG_LAMBDA,
)?,
build_logslope_blockspec_bms(
logslope_design,
baseline.1,
&spec.logslope_offset,
rho_logslope,
hints.logslope_beta.clone(),
marginal_design,
&spec.marginal_offset,
baseline.0,
Arc::clone(&z),
p_m,
influence_columns.as_ref(),
)?,
];
push_deviation_aux_blockspecs(
&mut blocks,
rho,
&mut cursor,
score_warp_prepared.as_ref(),
link_dev_prepared.as_ref(),
hints.score_warp_beta.clone(),
hints.link_dev_beta.clone(),
)?;
Ok(blocks)
};
let intercept_warm_starts = new_intercept_warm_start_cache(y.len());
let cell_moment_lru = new_cell_moment_lru_cache(policy);
let cell_moment_cache_stats = new_cell_moment_cache_stats();
let make_family = |marginal_design: &TermCollectionDesign,
logslope_design: &TermCollectionDesign,
sigma: Option<f64>|
-> BernoulliMarginalSlopeFamily {
let kernel_marginal_design = match influence_columns.as_ref() {
Some(z_infl) => {
let raw = marginal_design
.design
.try_to_dense_arc("make_family::widened-marginal")
.expect("dense marginal design for influence widening");
let widened = widen_marginal_dense_with_influence(&raw, Some(z_infl))
.expect("widen marginal design with influence columns");
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from((*widened).clone()))
}
None => marginal_design.design.clone(),
};
let kernel_logslope_design = reduce_logslope_design(logslope_design)
.expect("reduce logslope design for family construction")
.design;
BernoulliMarginalSlopeFamily {
y: Arc::clone(&y),
weights: Arc::clone(&weights),
z: Arc::clone(&z),
latent_measure: latent_measure.clone(),
gaussian_frailty_sd: sigma,
base_link: spec.base_link.clone(),
marginal_design: kernel_marginal_design,
logslope_design: kernel_logslope_design,
score_warp: score_warp_runtime.clone(),
link_dev: link_dev_runtime.clone(),
policy: policy.clone(),
cell_moment_lru: Arc::clone(&cell_moment_lru),
cell_moment_cache_stats: Arc::clone(&cell_moment_cache_stats),
intercept_warm_starts: Some(Arc::clone(&intercept_warm_starts)),
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
}
};
let marginal_terms = spatial_length_scale_term_indices(&marginalspec_boot);
let logslope_terms = spatial_length_scale_term_indices(&logslopespec_boot);
let marginal_has_spatial = !marginal_terms.is_empty();
let logslope_has_spatial = !logslope_terms.is_empty();
let analytic_joint_derivatives_available =
marginal_has_spatial || logslope_has_spatial || setup.log_kappa_dim() == 0;
if setup.log_kappa_dim() > 0 && !analytic_joint_derivatives_available {
return Err("exact bernoulli marginal-slope spatial optimization requires analytic joint psi derivatives"
.to_string());
}
let initial_rho = setup.theta0().slice(s![..setup.rho_dim()]).to_owned();
let initial_blocks = build_blocks(&initial_rho, &marginal_design, &logslope_design)?;
let initial_family = make_family(&marginal_design, &logslope_design, initial_sigma);
let (joint_gradient, joint_hessian) =
custom_family_outer_derivatives(&initial_family, &initial_blocks, options);
let analytic_joint_gradient_available = analytic_joint_derivatives_available
&& matches!(
joint_gradient,
crate::solver::outer_strategy::Derivative::Analytic
);
let analytic_joint_hessian_available =
analytic_joint_derivatives_available && joint_hessian.is_analytic();
let kappa_options_ref: &SpatialLengthScaleOptimizationOptions = &effective_kappa_options;
let sigma_from_theta = |theta: &Array1<f64>| -> Option<f64> {
if sigma_learnable {
Some(theta[setup.rho_dim() + setup.log_kappa_dim()].exp())
} else {
initial_sigma
}
};
let derivative_block_cache = RefCell::new(
None::<(
Array1<f64>,
Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
)>,
);
let theta_matches = |left: &Array1<f64>, right: &Array1<f64>| -> bool {
left.len() == right.len()
&& left
.iter()
.zip(right.iter())
.all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12 * (1.0 + lhs.abs().max(rhs.abs())))
};
let get_derivative_blocks = |theta: &Array1<f64>,
specs: &[TermCollectionSpec],
designs: &[TermCollectionDesign]|
-> Result<
Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
String,
> {
if let Some((cached_theta, cached_blocks)) = derivative_block_cache.borrow().as_ref()
&& theta_matches(cached_theta, theta)
{
return Ok(Arc::clone(cached_blocks));
}
let built = |specs: &[TermCollectionSpec],
designs: &[TermCollectionDesign]|
-> Result<
Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
String,
> {
let marginal_psi_derivs = if marginal_has_spatial {
build_block_spatial_psi_derivatives(data_view, &specs[0], &designs[0])?.ok_or_else(
|| {
"bernoulli marginal-slope: marginal block has spatial terms \
but spatial psi derivatives are unavailable"
.to_string()
},
)?
} else {
Vec::new()
};
let logslope_psi_derivs = if logslope_has_spatial {
build_block_spatial_psi_derivatives(data_view, &specs[1], &designs[1])?.ok_or_else(
|| {
"bernoulli marginal-slope: logslope block has spatial terms \
but spatial psi derivatives are unavailable"
.to_string()
},
)?
} else {
Vec::new()
};
let mut derivative_blocks = vec![marginal_psi_derivs, logslope_psi_derivs];
if score_warp_runtime.is_some() {
derivative_blocks.push(Vec::new());
}
if link_dev_runtime.is_some() {
derivative_blocks.push(Vec::new());
}
if sigma_learnable {
derivative_blocks
.last_mut()
.expect("bernoulli derivative block list is non-empty")
.push(crate::custom_family::CustomFamilyBlockPsiDerivative::new(
None,
Array2::zeros((0, 0)),
Array2::zeros((0, 0)),
None,
None,
None,
None,
));
}
Ok(derivative_blocks)
}(specs, designs)?;
let built = Arc::new(built);
derivative_block_cache.replace(Some((theta.clone(), Arc::clone(&built))));
Ok(built)
};
let outer_policy = {
let psi_dim = setup.theta0().len() - setup.rho_dim();
initial_family.outer_derivative_policy(&initial_blocks, psi_dim, options)
};
let exact_spatial_outer_tol = kappa_options_ref.rel_tol.max(1e-6);
let solved = optimize_spatial_length_scale_exact_joint(
data_view,
&[marginalspec_boot.clone(), logslopespec_boot.clone()],
&[marginal_terms.clone(), logslope_terms.clone()],
kappa_options_ref,
&setup,
crate::seeding::SeedRiskProfile::GeneralizedLinear,
analytic_joint_gradient_available,
analytic_joint_hessian_available,
true,
None,
outer_policy,
|theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
if let Some(err) = runaway_error.borrow().as_ref().cloned() {
return Err(err);
}
assert_eq!(
specs.len(),
designs.len(),
"spatial joint optimizer must supply one spec per design",
);
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let fit = inner_fit(&family, &blocks, options)?;
if let Some(block) = fit.block_states.first()
&& let Some(err) = bernoulli_marginal_slope_runaway_error_from_argmax(
marginal_parametric_argmax_from_beta(&block.beta, &designs[0], &specs[0]),
marginal_full_argmax_from_beta(&block.beta, &designs[0]),
"the final inner solve produced a separation-scale coefficient",
"final fit",
)
{
runaway_error.replace(Some(err.clone()));
return Err(err);
}
let mut hints_mut = hints.borrow_mut();
let mut bidx = 0usize;
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.marginal_beta = Some(block.beta.clone());
}
bidx += 1;
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.logslope_beta = Some(block.beta.clone());
}
bidx += 1;
if score_warp_prepared.is_some() {
if let Some(block) = fit.block_states.get(bidx) {
hints_mut.score_warp_beta = Some(block.beta.clone());
}
bidx += 1;
}
if link_dev_prepared.is_some()
&& let Some(block) = fit.block_states.get(bidx)
{
hints_mut.link_dev_beta = Some(block.beta.clone());
}
Ok(fit)
},
|theta,
specs: &[TermCollectionSpec],
designs: &[TermCollectionDesign],
eval_mode,
row_set: &crate::families::row_kernel::RowSet| {
if let Some(err) = runaway_error.borrow().as_ref().cloned() {
return Err(err);
}
use crate::solver::estimate::reml::unified::EvalMode;
let row_set_rows = match row_set {
crate::families::row_kernel::RowSet::All => spec.y.len(),
crate::families::row_kernel::RowSet::Subsample { rows, .. } => rows.len(),
};
log::debug!(
"[BMS exact outer eval] mode={:?} row_set_rows={}",
eval_mode,
row_set_rows
);
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
Ok(ws) => {
exact_warm_start.replace(Some(ws));
}
Err(e) => {
log::warn!(
"[BMS] outer ρ-cache β-warm-start rejected: {e}; falling back to cold β"
);
}
}
}
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
let effective_mode = match eval_mode {
EvalMode::ValueGradientHessian if !analytic_joint_hessian_available => {
EvalMode::ValueAndGradient
}
other => other,
};
let eval = evaluate_custom_family_joint_hyper_shared(
&family,
&blocks,
&joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol),
&rho,
derivative_blocks,
exact_warm_start.borrow().as_ref(),
effective_mode,
)?;
if let Some(err) = bernoulli_marginal_slope_runaway_error(
&eval.warm_start,
&designs[0],
&specs[0],
eval.inner_converged,
"exact outer evaluation",
) {
runaway_error.replace(Some(err.clone()));
return Err(err);
}
exact_warm_start.replace(Some(eval.warm_start.clone()));
if !eval.inner_converged {
return Err(
"exact bernoulli marginal-slope inner solve did not converge".to_string(),
);
}
if matches!(eval_mode, EvalMode::ValueGradientHessian)
&& analytic_joint_hessian_available
&& !eval.outer_hessian.is_analytic()
{
return Err("exact bernoulli marginal-slope joint [rho, psi] objective did not return an outer Hessian"
.to_string());
}
Ok((eval.objective, eval.gradient, eval.outer_hessian))
},
|theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
if let Some(err) = runaway_error.borrow().as_ref().cloned() {
return Err(err);
}
let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
Ok(ws) => {
exact_warm_start.replace(Some(ws));
}
Err(e) => {
log::warn!(
"[BMS] outer ρ-cache β-warm-start rejected (efs): {e}; falling back to cold β"
);
}
}
}
let sigma = sigma_from_theta(theta);
final_sigma_cell.set(sigma);
let family = make_family(&designs[0], &designs[1], sigma);
let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
let eval = evaluate_custom_family_joint_hyper_efs_shared(
&family,
&blocks,
&joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol),
&rho,
derivative_blocks,
exact_warm_start.borrow().as_ref(),
)?;
if let Some(err) = bernoulli_marginal_slope_runaway_error(
&eval.warm_start,
&designs[0],
&specs[0],
eval.inner_converged,
"EFS outer evaluation",
) {
runaway_error.replace(Some(err.clone()));
return Err(err);
}
exact_warm_start.replace(Some(eval.warm_start.clone()));
if !eval.inner_converged {
return Err(
"exact bernoulli marginal-slope EFS inner solve did not converge".to_string(),
);
}
Ok(eval.efs_eval)
},
crate::families::marginal_slope_shared::make_beta_seed_validator(&pending_beta_seed),
)?;
let mut resolved_specs = solved.resolved_specs;
let mut designs = solved.designs;
let mut solved_fit = solved.fit;
if let Some(reparam) = logslope_reduced_reparam.as_ref() {
let r = reparam.reduced_cols();
if let Some(block) = solved_fit.blocks.get_mut(1)
&& block.beta.len() == r
{
block.beta = reparam.recover_original_logslope_beta(&block.beta)?;
}
if let Some(state) = solved_fit.block_states.get_mut(1)
&& state.beta.len() == r
{
state.beta = reparam.recover_original_logslope_beta(&state.beta)?;
}
}
let latent_z_rank_int_calibration = match latent_z_calibration {
LatentMeasureCalibration::None => None,
LatentMeasureCalibration::RankInverseNormal(cal) => Some(cal),
};
Ok(BernoulliMarginalSlopeFitResult {
fit: solved_fit,
marginalspec_resolved: resolved_specs.remove(0),
logslopespec_resolved: resolved_specs.remove(0),
marginal_design: designs.remove(0),
logslope_design: designs.remove(0),
baseline_marginal: baseline.0,
baseline_logslope: baseline.1,
z_normalization,
latent_measure,
score_warp_runtime,
link_dev_runtime,
gaussian_frailty_sd: final_sigma_cell.get(),
cross_block_warnings,
latent_z_rank_int_calibration,
})
}