use super::family::{
append_deviation_function_penalty, require_probit_marginal_slope_link,
resolve_deviation_operator_orders,
};
use super::*;
struct BmsFlexBlockContext {
pub(super) anchor_dense_blocks: Vec<Array2<f64>>,
pub(super) anchor_components: Vec<super::deviation_runtime::AnchorComponentTag>,
pub(super) n_train: Array2<f64>,
pub(super) operators:
Vec<std::sync::Arc<dyn crate::families::identifiability_compiler::RowJacobianOperator>>,
pub(super) ordering: Vec<crate::families::identifiability_compiler::BlockOrder>,
pub(super) row_hess:
crate::families::bernoulli_marginal_slope_identifiability::BernoulliRowHessian,
pub(super) candidate_design_dense: Array2<f64>,
pub(super) n: usize,
pub(super) p_candidate: usize,
pub(super) d_total: usize,
}
fn build_bms_flex_block_context(
candidate: &DeviationPrepared,
candidate_arg_at_training_rows: &Array1<f64>,
parametric_anchors: &[(
&DesignMatrix,
super::deviation_runtime::ParametricAnchorBlock,
)],
flex_anchors: &[&Array2<f64>],
training_row_weights: &Array1<f64>,
) -> Result<Option<BmsFlexBlockContext>, String> {
use super::deviation_runtime::AnchorComponentTag;
use crate::families::bernoulli_marginal_slope_identifiability::{
BernoulliDenseDesignOperator, BernoulliRowHessian,
};
use crate::families::identifiability_compiler::{BlockOrder, RowJacobianOperator};
let candidate_design = candidate.runtime.design(candidate_arg_at_training_rows)?;
let n = candidate_design.nrows();
let p_candidate = candidate_design.ncols();
if training_row_weights.len() != n {
return Err(format!(
"cross-block identifiability: training_row_weights length {} does not match candidate row count {}",
training_row_weights.len(),
n,
));
}
for (i, &w) in training_row_weights.iter().enumerate() {
if !w.is_finite() || w < 0.0 {
return Err(format!(
"cross-block identifiability: training_row_weights[{i}] = {w} is not finite/non-negative",
));
}
}
let mut anchor_dense_blocks: Vec<Array2<f64>> = Vec::new();
let mut anchor_components: Vec<AnchorComponentTag> = Vec::new();
let mut total_anchor_cols = 0usize;
for (d, block_tag) in parametric_anchors {
if d.nrows() != n {
return Err(format!(
"cross-block identifiability: parametric anchor has {} rows, candidate has {}",
d.nrows(),
n,
));
}
let p_a = d.ncols();
if p_a == 0 {
continue;
}
let dense = d
.try_to_dense_arc("cross-block parametric anchor")?
.as_ref()
.clone();
anchor_dense_blocks.push(dense);
anchor_components.push(AnchorComponentTag::Parametric {
block: *block_tag,
ncols: p_a,
});
total_anchor_cols += p_a;
}
for a in flex_anchors {
if a.nrows() != n {
return Err(format!(
"cross-block identifiability: flex anchor has {} rows, candidate has {}",
a.nrows(),
n,
));
}
let p_a = a.ncols();
if p_a == 0 {
continue;
}
anchor_dense_blocks.push((*a).clone());
anchor_components.push(AnchorComponentTag::FlexEvaluation { ncols: p_a });
total_anchor_cols += p_a;
}
if total_anchor_cols == 0 {
return Ok(None);
}
let d_total = total_anchor_cols;
let mut n_train = Array2::<f64>::zeros((n, d_total));
{
let mut col_offset = 0usize;
for block in &anchor_dense_blocks {
let bc = block.ncols();
n_train
.slice_mut(s![.., col_offset..col_offset + bc])
.assign(block);
col_offset += bc;
}
}
let mut operators: Vec<std::sync::Arc<dyn RowJacobianOperator>> =
Vec::with_capacity(anchor_dense_blocks.len() + 1);
let mut ordering: Vec<BlockOrder> = Vec::with_capacity(anchor_dense_blocks.len() + 1);
for dense in &anchor_dense_blocks {
operators.push(std::sync::Arc::new(BernoulliDenseDesignOperator::new(
dense.clone(),
)));
ordering.push(BlockOrder::Marginal);
}
operators.push(std::sync::Arc::new(BernoulliDenseDesignOperator::new(
candidate_design.clone(),
)));
ordering.push(BlockOrder::LinkDev);
let row_hess = BernoulliRowHessian::from_row_weights(training_row_weights.clone());
Ok(Some(BmsFlexBlockContext {
anchor_dense_blocks,
anchor_components,
n_train,
operators,
ordering,
row_hess,
candidate_design_dense: candidate_design,
n,
p_candidate,
d_total,
}))
}
#[derive(Debug)]
pub enum FlexCompileOutcome {
Reparameterised,
FullyAliased { reason: String },
}
#[derive(Clone, Debug)]
pub struct CrossBlockIdentifiabilityWarning {
pub candidate_label: &'static str,
pub anchor_summary: String,
pub reason: String,
}
pub(crate) fn install_compiled_flex_block_into_runtime(
candidate: &mut DeviationPrepared,
candidate_arg_at_training_rows: &Array1<f64>,
candidate_cfg: &DeviationBlockConfig,
parametric_anchors: &[(
&DesignMatrix,
super::deviation_runtime::ParametricAnchorBlock,
)],
flex_anchors: &[&Array2<f64>],
training_row_weights: &Array1<f64>,
) -> Result<FlexCompileOutcome, String> {
use crate::families::identifiability_compiler::compile;
use crate::solver::identifiability_audit::audit_identifiability_channel_aware;
let p_check = candidate
.runtime
.design(candidate_arg_at_training_rows)?
.ncols();
if p_check == 0 {
return Ok(FlexCompileOutcome::Reparameterised);
}
let ctx = match build_bms_flex_block_context(
candidate,
candidate_arg_at_training_rows,
parametric_anchors,
flex_anchors,
training_row_weights,
)? {
None => {
return Ok(FlexCompileOutcome::Reparameterised);
}
Some(c) => c,
};
let BmsFlexBlockContext {
anchor_dense_blocks,
anchor_components,
n_train,
operators,
ordering,
row_hess,
candidate_design_dense,
n,
p_candidate,
d_total,
} = ctx;
let audit = audit_identifiability_channel_aware(
&{
let mut specs = Vec::with_capacity(anchor_dense_blocks.len() + 1);
for (idx, dense) in anchor_dense_blocks.iter().enumerate() {
specs.push(crate::custom_family::ParameterBlockSpec {
name: format!("anchor_{idx}"),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
dense.clone(),
)),
offset: Array1::<f64>::zeros(n),
penalties: Vec::new(),
nullspace_dims: Vec::new(),
initial_log_lambdas: Array1::<f64>::zeros(0),
initial_beta: None,
gauge_priority: 200,
jacobian_callback: None,
stacked_design: None,
stacked_offset: None,
});
}
specs.push(crate::custom_family::ParameterBlockSpec {
name: "candidate_flex".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
candidate_design_dense.clone(),
)),
offset: Array1::<f64>::zeros(n),
penalties: Vec::new(),
nullspace_dims: Vec::new(),
initial_log_lambdas: Array1::<f64>::zeros(0),
initial_beta: None,
gauge_priority: 100,
jacobian_callback: None,
stacked_design: None,
stacked_offset: None,
});
specs
},
&operators,
&row_hess,
)
.map_err(|e| format!("cross-block identifiability audit failed: {e}"))?;
if audit.fatal {
let candidate_block = audit.blocks.last();
let effective = candidate_block.map(|b| b.effective_dim).unwrap_or(0);
if effective == 0 {
let reason = format!(
"candidate flex basis ({p_candidate} cols) has zero directions remaining after \
W-metric residualisation against the anchor union ({d_total} anchor cols) at the \
{n} training rows. The channel-aware audit collapses every direction in \
span(C) — every direction in span(C) is reproducible by the anchor union up to \
numerical tolerance. Drop the flex block or remove the anchor term that reproduces \
its argument; knot count is NOT the relevant lever for this failure mode.",
);
return Ok(FlexCompileOutcome::FullyAliased { reason });
}
}
let compiled = compile(&operators, &row_hess, &ordering).map_err(|e| {
format!(
"cross-block identifiability: compile failed (n={n}, d_total={d_total}, p_c={p_candidate}): {e}",
)
})?;
let candidate_compiled = compiled
.blocks
.last()
.ok_or_else(|| "cross-block identifiability: compile returned no blocks".to_string())?;
let k_kept = candidate_compiled.t_lw.ncols();
if k_kept == 0 {
let reason = format!(
"candidate flex basis ({p_candidate} cols) has zero directions remaining after \
W-metric residualisation against the anchor union ({d_total} anchor cols) at the \
{n} training rows. The compiler's joint pre-fit audit collapses every direction in \
span(C) — every direction in span(C) is reproducible by the anchor union up to \
numerical tolerance. Drop the flex block or remove the anchor term that reproduces \
its argument; knot count is NOT the relevant lever for this failure mode.",
);
return Ok(FlexCompileOutcome::FullyAliased { reason });
}
{
let m = candidate_compiled
.anchor_correction
.as_ref()
.ok_or_else(|| {
"cross-block identifiability: compile returned no anchor_correction for the \
candidate block (expected for trailing block with non-empty anchor union)"
.to_string()
})?;
if m.nrows() != d_total || m.ncols() != k_kept {
return Err(format!(
"cross-block identifiability: anchor_correction shape {}×{} does not match \
expected d_total={d_total} × k_kept={k_kept}",
m.nrows(),
m.ncols(),
));
}
}
candidate.runtime.install_compiled_flex_block(
candidate_compiled,
anchor_components,
n_train,
)?;
let new_design = candidate
.runtime
.design_at_training_with_residual(candidate_arg_at_training_rows)?;
let new_p = new_design.ncols();
assert_eq!(new_p, k_kept);
assert_eq!(new_design.nrows(), n);
candidate.block.design =
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(new_design));
candidate.block.penalties.clear();
candidate.block.nullspace_dims.clear();
let penalty_orders = resolve_deviation_operator_orders(candidate_cfg)?;
for order in penalty_orders {
append_deviation_function_penalty(&mut candidate.block, &candidate.runtime, order)?;
}
if candidate_cfg.double_penalty {
append_deviation_function_penalty(&mut candidate.block, &candidate.runtime, 0)?;
}
candidate.block.initial_beta = Some(Array1::zeros(new_p));
log::info!(
"[BMS cross-block identifiability] flex block reparameterised via compiler: \
kept {kept}/{p_candidate} directions (anchor union cols={d_total}, training rows={n}, \
joint_rank={joint_rank}, dropped_by_audit={dropped})",
kept = new_p,
p_candidate = p_candidate,
d_total = d_total,
n = n,
joint_rank = compiled.joint_rank,
dropped = compiled.dropped.len(),
);
Ok(FlexCompileOutcome::Reparameterised)
}
pub(crate) fn project_monotone_feasible_beta(
runtime: &DeviationRuntime,
current: &Array1<f64>,
proposed: &Array1<f64>,
label: &str,
) -> Result<Array1<f64>, String> {
if current.len() != runtime.basis_dim() {
return Err(format!(
"{label} monotone projection current length mismatch: current={}, expected={}",
current.len(),
runtime.basis_dim()
));
}
if proposed.len() != runtime.basis_dim() {
return Err(format!(
"{label} monotone projection length mismatch: proposed={}, expected={}",
proposed.len(),
runtime.basis_dim()
));
}
for (idx, value) in current.iter().enumerate() {
if !value.is_finite() {
return Err(format!("{label} current coefficient {idx} is non-finite"));
}
}
for (idx, value) in proposed.iter().enumerate() {
if !value.is_finite() {
return Err(format!("{label} coefficient {idx} is non-finite"));
}
}
runtime.monotonicity_feasible(current, &format!("{label} current beta"))?;
if runtime
.monotonicity_feasible(proposed, &format!("{label} proposed beta"))
.is_ok()
{
return Ok(proposed.clone());
}
let constraints = runtime.structural_monotonicity_constraints();
let alpha = max_linear_constraint_segment_alpha(current, proposed, &constraints, label)?;
let direction = proposed - current;
let candidate = current + &direction.mapv(|value| value * alpha);
validate_monotone_structural_feasible(runtime, &candidate, &format!("{label} projected beta"))?;
Ok(candidate)
}
pub(crate) fn validate_monotone_structural_feasible(
runtime: &DeviationRuntime,
beta: &Array1<f64>,
label: &str,
) -> Result<(), String> {
let constraints = runtime.structural_monotonicity_constraints();
if beta.len() != constraints.a.ncols() {
return Err(format!(
"{label} structural monotonicity length mismatch: beta={}, expected={}",
beta.len(),
constraints.a.ncols()
));
}
if beta.iter().any(|value| !value.is_finite()) {
let bad = beta
.iter()
.enumerate()
.find(|(_, value)| !value.is_finite())
.map(|(idx, value)| format!("{label} coefficient {idx} is non-finite ({value})"))
.unwrap_or_else(|| format!("{label} coefficient is non-finite"));
return Err(bad);
}
let slack = constraints.a.dot(beta) - &constraints.b;
let mut min_slack = f64::INFINITY;
let mut min_row = 0usize;
for (row, &value) in slack.iter().enumerate() {
if value < min_slack {
min_slack = value;
min_row = row;
}
}
if min_slack < -1e-10 {
return Err(format!(
"{label} violates structural monotonicity row {min_row}: slack={min_slack:.3e}; \
deviation monotonicity must be enforced by analytic linear constraints, not post-update projection"
));
}
runtime.monotonicity_feasible(beta, label)
}
fn max_linear_constraint_segment_alpha(
current: &Array1<f64>,
proposed: &Array1<f64>,
constraints: &LinearInequalityConstraints,
label: &str,
) -> Result<f64, String> {
if current.len() != proposed.len() || current.len() != constraints.a.ncols() {
return Err(format!(
"{label} linear-constraint segment dimension mismatch: current={}, proposed={}, constraints={}",
current.len(),
proposed.len(),
constraints.a.ncols()
));
}
if constraints.a.nrows() != constraints.b.len() {
return Err(format!(
"{label} linear-constraint segment row mismatch: A rows={}, b len={}",
constraints.a.nrows(),
constraints.b.len()
));
}
let direction = proposed - current;
let mut alpha = 1.0_f64;
for row in 0..constraints.a.nrows() {
let a_row = constraints.a.row(row);
let slack = a_row.dot(current) - constraints.b[row];
if slack < -1e-10 {
return Err(format!(
"{label} current beta violates structural monotonicity row {row}: slack={slack:.3e}"
));
}
let drift = a_row.dot(&direction);
if drift < 0.0 {
alpha = alpha.min((slack / -drift).clamp(0.0, 1.0));
}
}
Ok(alpha.clamp(0.0, 1.0))
}
pub(super) fn validate_spec(
data: ArrayView2<'_, f64>,
spec: &BernoulliMarginalSlopeTermSpec,
) -> Result<(), String> {
let n = data.nrows();
if spec.y.len() != n
|| spec.weights.len() != n
|| spec.z.len() != n
|| spec.marginal_offset.len() != n
|| spec.logslope_offset.len() != n
{
return Err(format!(
"bernoulli-marginal-slope row mismatch: data={}, y={}, weights={}, z={}, marginal_offset={}, logslope_offset={}",
n,
spec.y.len(),
spec.weights.len(),
spec.z.len(),
spec.marginal_offset.len(),
spec.logslope_offset.len()
));
}
if spec
.y
.iter()
.any(|&yi| !yi.is_finite() || ((yi - 0.0).abs() > 1e-9 && (yi - 1.0).abs() > 1e-9))
{
return Err("bernoulli-marginal-slope requires binary y in {0,1}".to_string());
}
if spec.weights.iter().any(|&w| !w.is_finite() || w < 0.0) {
return Err("bernoulli-marginal-slope requires finite non-negative weights".to_string());
}
if spec.z.iter().any(|&zi| !zi.is_finite()) {
return Err("bernoulli-marginal-slope requires finite z values".to_string());
}
if spec.marginal_offset.iter().any(|&value| !value.is_finite()) {
return Err("bernoulli-marginal-slope requires finite marginal offsets".to_string());
}
if spec.logslope_offset.iter().any(|&value| !value.is_finite()) {
return Err("bernoulli-marginal-slope requires finite logslope offsets".to_string());
}
require_probit_marginal_slope_link(&spec.base_link, "bernoulli-marginal-slope")?;
spec.frailty.validate_for_marginal_slope()?;
match &spec.frailty {
FrailtySpec::None => {}
FrailtySpec::GaussianShift { sigma_fixed } => {
if let Some(sigma) = sigma_fixed
&& (!sigma.is_finite() || *sigma < 0.0)
{
return Err(format!(
"bernoulli-marginal-slope requires GaussianShift sigma >= 0, got {sigma}"
));
}
}
FrailtySpec::HazardMultiplier { .. } => {
return Err(
"bernoulli-marginal-slope does not support FrailtySpec::HazardMultiplier"
.to_string(),
);
}
}
Ok(())
}