use faer::Side;
use gam_linalg::faer_ndarray::FaerEigh;
use gam_problem::diagnostics::KktRefusalDiagnosis;
use ndarray::{Array1, Array2, Array3};
use super::identifiability::SurvivalRowHessian;
const K_SURVIVAL: usize = 4;
pub(crate) const KKT_PHANTOM_TRUST_RADIUS: f64 = 1.0e-3;
pub(crate) const KKT_RANK_REL_TOL: f64 = 1.0e-10;
pub(crate) const KKT_NEAR_NULL_REL_TOL: f64 = 1.0e-8;
#[derive(Clone, Debug)]
pub(crate) struct NearNullDirection {
pub(crate) eigenvalue: f64,
pub(crate) gradient_residual: f64,
pub(crate) newton_step: f64,
pub(crate) gate_threshold: f64,
pub(crate) phantom_projectable: bool,
pub(crate) eigenvector: Array1<f64>,
}
#[derive(Clone, Debug)]
pub(crate) struct SurvivalKktRefusalReport {
pub(crate) eigenvalues_ascending: Vec<f64>,
pub(crate) lambda_min: f64,
pub(crate) lambda_max: f64,
pub(crate) condition_number: f64,
pub(crate) nullity_at_tol: usize,
pub(crate) rank_tol: f64,
pub(crate) step: f64,
pub(crate) near_null: Vec<NearNullDirection>,
pub(crate) diagnosis: KktRefusalDiagnosis,
}
impl SurvivalKktRefusalReport {
pub(crate) fn phantom_projectable_count(&self) -> usize {
self.near_null
.iter()
.filter(|d| d.phantom_projectable)
.count()
}
pub(crate) fn real_nonstationary_count(&self) -> usize {
self.near_null
.iter()
.filter(|d| !d.phantom_projectable)
.count()
}
pub(crate) fn all_near_null_are_phantom(&self) -> bool {
!self.near_null.is_empty() && self.real_nonstationary_count() == 0
}
pub(crate) fn summary(&self) -> String {
let p = self.eigenvalues_ascending.len();
let carrying = self.near_null.first();
let (car_lambda, car_resid, car_newton, car_gate, car_vinf, car_phantom) = match carrying {
Some(d) => (
d.eigenvalue,
d.gradient_residual,
d.newton_step,
d.gate_threshold,
d.eigenvector.iter().fold(0.0_f64, |m, &v| m.max(v.abs())),
d.phantom_projectable,
),
None => (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN, false),
};
format!(
"p={p} lambda_min={:.4e} lambda_max={:.4e} cond={:.4e} \
nullity@tol={} rank_tol={:.4e} step={:.4e} near_null={} \
phantom={} real={} diagnosis={} | carrying: lambda={:.4e} \
resid={:.4e} newton_step={:.4e} gate={:.4e} v_inf={:.4e} phantom={}",
self.lambda_min,
self.lambda_max,
self.condition_number,
self.nullity_at_tol,
self.rank_tol,
self.step,
self.near_null.len(),
self.phantom_projectable_count(),
self.real_nonstationary_count(),
self.diagnosis.as_str(),
car_lambda,
car_resid,
car_newton,
car_gate,
car_vinf,
car_phantom,
)
}
}
#[derive(Clone, Copy)]
pub(crate) struct SurvivalEffectiveDesigns<'a> {
pub(crate) dq0: &'a Array2<f64>,
pub(crate) dq1: &'a Array2<f64>,
pub(crate) dqd1: &'a Array2<f64>,
pub(crate) m_dq: &'a Array2<f64>,
pub(crate) m_dqd1: &'a Array2<f64>,
pub(crate) g_dg: &'a Array2<f64>,
}
#[derive(Clone, Copy)]
pub(crate) struct SurvivalPilotRows<'a> {
pub(crate) q0: &'a Array1<f64>,
pub(crate) q1: &'a Array1<f64>,
pub(crate) qd1: &'a Array1<f64>,
pub(crate) g: &'a Array1<f64>,
pub(crate) z: &'a Array1<f64>,
pub(crate) weights: &'a Array1<f64>,
pub(crate) event: &'a Array1<f64>,
}
#[derive(Clone, Copy)]
pub(crate) struct SurvivalLinkParams {
pub(crate) derivative_guard: f64,
pub(crate) probit_scale: f64,
}
pub(crate) fn assemble_joint_penalized_hessian_and_score(
designs: SurvivalEffectiveDesigns<'_>,
row_hess: &Array3<f64>,
row_grad: &Array2<f64>,
s_total: &Array2<f64>,
) -> Result<(Array2<f64>, Array1<f64>), String> {
let SurvivalEffectiveDesigns {
dq0,
dq1,
dqd1,
m_dq,
m_dqd1,
g_dg,
} = designs;
let n = dq0.nrows();
let p_time = dq0.ncols();
let p_marg = m_dq.ncols();
let p_log = g_dg.ncols();
let p_total = p_time + p_marg + p_log;
for (name, rows, cols, want_rows, want_cols) in [
("dq1", dq1.nrows(), dq1.ncols(), n, p_time),
("dqd1", dqd1.nrows(), dqd1.ncols(), n, p_time),
("m_dqd1", m_dqd1.nrows(), m_dqd1.ncols(), n, p_marg),
] {
if rows != want_rows || cols != want_cols {
return Err(format!(
"kkt_refusal assembly: {name} is {rows}x{cols}, expected {want_rows}x{want_cols}"
));
}
}
if row_hess.shape() != [n, K_SURVIVAL, K_SURVIVAL] {
return Err(format!(
"kkt_refusal assembly: row_hess is {:?}, expected [{n}, {K_SURVIVAL}, {K_SURVIVAL}]",
row_hess.shape()
));
}
if row_grad.shape() != [n, K_SURVIVAL] {
return Err(format!(
"kkt_refusal assembly: row_grad is {:?}, expected [{n}, {K_SURVIVAL}]",
row_grad.shape()
));
}
if s_total.shape() != [p_total, p_total] {
return Err(format!(
"kkt_refusal assembly: s_total is {:?}, expected [{p_total}, {p_total}]",
s_total.shape()
));
}
let mut m = s_total.clone();
let mut g = Array1::<f64>::zeros(p_total);
let mut j_i = Array2::<f64>::zeros((K_SURVIVAL, p_total));
let marg_off = p_time;
let log_off = p_time + p_marg;
for i in 0..n {
j_i.fill(0.0);
for c in 0..p_time {
j_i[[0, c]] = dq0[[i, c]];
j_i[[1, c]] = dq1[[i, c]];
j_i[[2, c]] = dqd1[[i, c]];
}
for c in 0..p_marg {
let gc = marg_off + c;
j_i[[0, gc]] = m_dq[[i, c]];
j_i[[1, gc]] = m_dq[[i, c]];
j_i[[2, gc]] = m_dqd1[[i, c]];
}
for c in 0..p_log {
j_i[[3, log_off + c]] = g_dg[[i, c]];
}
let mut hj = Array2::<f64>::zeros((K_SURVIVAL, p_total));
for a in 0..K_SURVIVAL {
for col in 0..p_total {
let mut acc = 0.0;
for b in 0..K_SURVIVAL {
acc += row_hess[[i, a, b]] * j_i[[b, col]];
}
hj[[a, col]] = acc;
}
}
for col in 0..p_total {
for row in 0..p_total {
let mut acc = 0.0;
for a in 0..K_SURVIVAL {
acc += j_i[[a, row]] * hj[[a, col]];
}
m[[row, col]] += acc;
}
let mut gacc = 0.0;
for a in 0..K_SURVIVAL {
gacc += j_i[[a, col]] * row_grad[[i, a]];
}
g[col] += gacc;
}
}
for row in 0..p_total {
for col in (row + 1)..p_total {
let avg = 0.5 * (m[[row, col]] + m[[col, row]]);
m[[row, col]] = avg;
m[[col, row]] = avg;
}
}
Ok((m, g))
}
pub(crate) fn build_refusal_report_from_hessian(
m: &Array2<f64>,
g: &Array1<f64>,
step: f64,
) -> Result<SurvivalKktRefusalReport, String> {
let p = m.nrows();
if m.ncols() != p {
return Err(format!("kkt_refusal: M is {}x{}, must be square", p, m.ncols()));
}
if g.len() != p {
return Err(format!("kkt_refusal: g len {} != M dim {p}", g.len()));
}
if !step.is_finite() || step <= 0.0 {
return Err(format!("kkt_refusal: step must be finite and positive, got {step}"));
}
let (evals, evecs) = m
.eigh(Side::Lower)
.map_err(|e| format!("kkt_refusal: eigendecomposition of M failed: {e:?}"))?;
let mut order: Vec<usize> = (0..p).collect();
order.sort_by(|&a, &b| evals[a].partial_cmp(&evals[b]).unwrap_or(std::cmp::Ordering::Equal));
let eigenvalues_ascending: Vec<f64> = order.iter().map(|&i| evals[i]).collect();
let lambda_min = eigenvalues_ascending.first().copied().unwrap_or(0.0);
let lambda_max = eigenvalues_ascending.last().copied().unwrap_or(0.0);
let condition_number = if lambda_min > 0.0 {
lambda_max / lambda_min
} else {
f64::INFINITY
};
let rank_tol = KKT_RANK_REL_TOL * lambda_max.max(0.0);
let nullity_at_tol = eigenvalues_ascending
.iter()
.filter(|&&v| v <= rank_tol)
.count();
let gate_floor_curvature = rank_tol;
let near_null_threshold = KKT_NEAR_NULL_REL_TOL * lambda_max.max(0.0);
let mut near_null: Vec<NearNullDirection> = Vec::new();
for &idx in &order {
let eigenvalue = evals[idx];
if eigenvalue > near_null_threshold {
break;
}
let v = evecs.column(idx);
let mut r = 0.0;
for k in 0..p {
r += v[k] * g[k];
}
let gradient_residual = r.abs();
let gate_curvature = eigenvalue.max(gate_floor_curvature);
let gate_threshold = gate_curvature * step;
let newton_step = gradient_residual / gate_curvature;
let phantom_projectable = gradient_residual <= gate_threshold;
near_null.push(NearNullDirection {
eigenvalue,
gradient_residual,
newton_step,
gate_threshold,
phantom_projectable,
eigenvector: v.to_owned(),
});
}
let diagnosis = if nullity_at_tol > 0 {
KktRefusalDiagnosis::RankDeficientHPen
} else {
KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH
};
Ok(SurvivalKktRefusalReport {
eigenvalues_ascending,
lambda_min,
lambda_max,
condition_number,
nullity_at_tol,
rank_tol,
step,
near_null,
diagnosis,
})
}
pub(crate) fn survival_kkt_refusal_report_from_designs(
designs: SurvivalEffectiveDesigns<'_>,
row_hess: &Array3<f64>,
row_grad: &Array2<f64>,
s_total: &Array2<f64>,
step: f64,
) -> Result<SurvivalKktRefusalReport, String> {
let (m, g) =
assemble_joint_penalized_hessian_and_score(designs, row_hess, row_grad, s_total)?;
build_refusal_report_from_hessian(&m, &g, step)
}
pub(crate) fn survival_row_gradient_from_pilot_primary_state(
rows: SurvivalPilotRows<'_>,
link: SurvivalLinkParams,
) -> Result<Array2<f64>, String> {
let SurvivalPilotRows {
q0,
q1,
qd1,
g,
z,
weights,
event,
} = rows;
let SurvivalLinkParams {
derivative_guard,
probit_scale,
} = link;
let n = q0.len();
for (name, len) in [
("q1", q1.len()),
("qd1", qd1.len()),
("g", g.len()),
("z", z.len()),
("weights", weights.len()),
("event", event.len()),
] {
if len != n {
return Err(format!(
"survival_row_gradient: length mismatch q0={n}, {name}={len}"
));
}
}
let mut out = Array2::<f64>::zeros((n, K_SURVIVAL));
for i in 0..n {
let (_, grad, _) = super::row_primary_for_compiler(
q0[i],
q1[i],
qd1[i],
g[i],
z[i],
weights[i],
event[i],
derivative_guard,
probit_scale,
)?;
for a in 0..K_SURVIVAL {
out[[i, a]] = grad[a];
}
}
Ok(out)
}
pub(crate) fn assemble_unit_block_penalty(
p_time: usize,
p_marg: usize,
p_log: usize,
time_block: &Array2<f64>,
marg_block: &Array2<f64>,
log_block: &Array2<f64>,
) -> Result<Array2<f64>, String> {
let p_total = p_time + p_marg + p_log;
let mut s = Array2::<f64>::zeros((p_total, p_total));
for (name, blk, off, width) in [
("time", time_block, 0usize, p_time),
("marginal", marg_block, p_time, p_marg),
("logslope", log_block, p_time + p_marg, p_log),
] {
if blk.nrows() != width || blk.ncols() != width {
return Err(format!(
"assemble_unit_block_penalty: {name} block is {}x{}, expected {width}x{width}",
blk.nrows(),
blk.ncols(),
));
}
for r in 0..width {
for c in 0..width {
s[[off + r, off + c]] += blk[[r, c]];
}
}
}
Ok(s)
}
pub(crate) fn dense_block_penalty_from_dense_list(
pens: &[Array2<f64>],
width: usize,
) -> Result<Array2<f64>, String> {
let mut s = Array2::<f64>::zeros((width, width));
for (k, p) in pens.iter().enumerate() {
if p.nrows() != width || p.ncols() != width {
return Err(format!(
"dense_block_penalty_from_dense_list: penalty {k} is {}x{}, expected {width}x{width}",
p.nrows(),
p.ncols(),
));
}
s += p;
}
Ok(s)
}
pub(crate) fn dense_block_penalty_from_blockwise(
pens: &[super::BlockwisePenalty],
width: usize,
) -> Result<Array2<f64>, String> {
let mut s = Array2::<f64>::zeros((width, width));
for (k, p) in pens.iter().enumerate() {
let r = p.col_range.clone();
if r.end > width {
return Err(format!(
"dense_block_penalty_from_blockwise: penalty {k} col_range {}..{} exceeds width {width}",
r.start, r.end,
));
}
let bl = r.len();
if p.local.nrows() != bl || p.local.ncols() != bl {
return Err(format!(
"dense_block_penalty_from_blockwise: penalty {k} local is {}x{} but col_range width is {bl}",
p.local.nrows(),
p.local.ncols(),
));
}
for i in 0..bl {
for j in 0..bl {
s[[r.start + i, r.start + j]] += p.local[[i, j]];
}
}
}
Ok(s)
}
pub(crate) fn survival_kkt_refusal_report_at_pilot(
designs: SurvivalEffectiveDesigns<'_>,
row_hess: &SurvivalRowHessian,
rows: SurvivalPilotRows<'_>,
link: SurvivalLinkParams,
s_total: &Array2<f64>,
step: f64,
) -> Result<SurvivalKktRefusalReport, String> {
use gam_identifiability::families::compiler::RowHessian;
let row_grad = survival_row_gradient_from_pilot_primary_state(rows, link)?;
let h_tensor = row_hess.evaluate_full();
survival_kkt_refusal_report_from_designs(designs, &h_tensor, &row_grad, s_total, step)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array3;
fn constant_row_hess(n: usize, h: &Array2<f64>) -> Array3<f64> {
let mut out = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
out[[i, a, b]] = h[[a, b]];
}
}
}
out
}
fn coupled_q1_g_hessian() -> Array2<f64> {
let mut h = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
h[[1, 1]] = 1.0; h[[3, 3]] = 1.0; h[[1, 3]] = 1.0; h[[3, 1]] = 1.0;
h
}
#[test]
fn confounded_flat_undriven_direction_is_a_projectable_phantom() {
let n = 64;
let p_time = 0;
let p_marg = 1;
let p_log = 1;
let mut basis = Array2::<f64>::zeros((n, 1));
for i in 0..n {
basis[[i, 0]] = (i as f64 / n as f64) - 0.5;
}
let dq0 = Array2::<f64>::zeros((n, p_time));
let dq1 = Array2::<f64>::zeros((n, p_time));
let dqd1 = Array2::<f64>::zeros((n, p_time));
let m_dq = basis.clone();
let m_dqd1 = Array2::<f64>::zeros((n, p_marg));
let g_dg = basis.clone();
let row_hess = constant_row_hess(n, &coupled_q1_g_hessian());
let mut row_grad = Array2::<f64>::zeros((n, K_SURVIVAL));
for i in 0..n {
row_grad[[i, 1]] = basis[[i, 0]]; row_grad[[i, 3]] = basis[[i, 0]]; }
let s_total = Array2::<f64>::zeros((p_marg + p_log, p_marg + p_log));
let report = survival_kkt_refusal_report_from_designs(
SurvivalEffectiveDesigns {
dq0: &dq0,
dq1: &dq1,
dqd1: &dqd1,
m_dq: &m_dq,
m_dqd1: &m_dqd1,
g_dg: &g_dg,
},
&row_hess,
&row_grad,
&s_total,
KKT_PHANTOM_TRUST_RADIUS,
)
.expect("report builds");
assert!(
report.lambda_min <= 1e-9 * report.lambda_max.max(1.0),
"expected a near-null λ_min, got {:.3e} (λ_max={:.3e})",
report.lambda_min,
report.lambda_max
);
assert!(
report.condition_number >= 1e8,
"expected huge condition number, got {:.3e}",
report.condition_number
);
assert!(
!report.near_null.is_empty(),
"expected at least one near-null direction"
);
assert!(
report.all_near_null_are_phantom(),
"flat + undriven confound must be a projectable phantom; report: {}",
report.summary()
);
}
#[test]
fn flat_but_driven_direction_is_real_nonstationarity_not_projected() {
let n = 64;
let p_marg = 1;
let p_log = 1;
let mut basis = Array2::<f64>::zeros((n, 1));
for i in 0..n {
basis[[i, 0]] = (i as f64 / n as f64) - 0.5;
}
let dq0 = Array2::<f64>::zeros((n, 0));
let dq1 = Array2::<f64>::zeros((n, 0));
let dqd1 = Array2::<f64>::zeros((n, 0));
let m_dq = basis.clone();
let m_dqd1 = Array2::<f64>::zeros((n, p_marg));
let g_dg = basis.clone();
let row_hess = constant_row_hess(n, &coupled_q1_g_hessian());
let mut row_grad = Array2::<f64>::zeros((n, K_SURVIVAL));
for i in 0..n {
row_grad[[i, 1]] = 50.0 * basis[[i, 0]]; row_grad[[i, 3]] = -50.0 * basis[[i, 0]]; }
let s_total = Array2::<f64>::zeros((p_marg + p_log, p_marg + p_log));
let report = survival_kkt_refusal_report_from_designs(
SurvivalEffectiveDesigns {
dq0: &dq0,
dq1: &dq1,
dqd1: &dqd1,
m_dq: &m_dq,
m_dqd1: &m_dqd1,
g_dg: &g_dg,
},
&row_hess,
&row_grad,
&s_total,
KKT_PHANTOM_TRUST_RADIUS,
)
.expect("report builds");
assert!(
!report.near_null.is_empty(),
"expected a near-null direction"
);
assert!(
report.real_nonstationary_count() >= 1,
"driven flat direction must be flagged real non-stationarity; report: {}",
report.summary()
);
assert!(
!report.all_near_null_are_phantom(),
"must refuse to project a driven flat direction; report: {}",
report.summary()
);
}
#[test]
fn well_conditioned_problem_has_no_near_null_direction() {
let n = 64;
let p_marg = 1;
let p_log = 1;
let mut marg = Array2::<f64>::zeros((n, 1));
let mut logb = Array2::<f64>::zeros((n, 1));
for i in 0..n {
let t = i as f64 / n as f64;
marg[[i, 0]] = t - 0.5;
logb[[i, 0]] = (2.0 * std::f64::consts::PI * t).sin();
}
let dq0 = Array2::<f64>::zeros((n, 0));
let dq1 = Array2::<f64>::zeros((n, 0));
let dqd1 = Array2::<f64>::zeros((n, 0));
let m_dqd1 = Array2::<f64>::zeros((n, p_marg));
let row_hess = constant_row_hess(n, &coupled_q1_g_hessian());
let row_grad = Array2::<f64>::zeros((n, K_SURVIVAL));
let s_total = Array2::<f64>::zeros((p_marg + p_log, p_marg + p_log));
let report = survival_kkt_refusal_report_from_designs(
SurvivalEffectiveDesigns {
dq0: &dq0,
dq1: &dq1,
dqd1: &dqd1,
m_dq: &marg,
m_dqd1: &m_dqd1,
g_dg: &logb,
},
&row_hess,
&row_grad,
&s_total,
KKT_PHANTOM_TRUST_RADIUS,
)
.expect("report builds");
assert!(
report.condition_number < 1e6,
"well-conditioned design must have modest condition number, got {:.3e}",
report.condition_number
);
assert!(
report.near_null.is_empty(),
"well-conditioned design must have no near-null direction; report: {}",
report.summary()
);
assert!(!report.all_near_null_are_phantom());
}
#[test]
fn penalty_null_but_data_curved_direction_is_not_near_null() {
let n = 64;
let p_marg = 1;
let p_log = 1;
let mut marg = Array2::<f64>::zeros((n, 1));
let mut logb = Array2::<f64>::zeros((n, 1));
for i in 0..n {
let t = i as f64 / n as f64;
marg[[i, 0]] = t - 0.5;
logb[[i, 0]] = (3.0 * t).cos();
}
let dq0 = Array2::<f64>::zeros((n, 0));
let dq1 = Array2::<f64>::zeros((n, 0));
let dqd1 = Array2::<f64>::zeros((n, 0));
let m_dqd1 = Array2::<f64>::zeros((n, p_marg));
let row_hess = constant_row_hess(n, &coupled_q1_g_hessian());
let row_grad = Array2::<f64>::zeros((n, K_SURVIVAL));
let s_total = Array2::<f64>::zeros((p_marg + p_log, p_marg + p_log));
let report = survival_kkt_refusal_report_from_designs(
SurvivalEffectiveDesigns {
dq0: &dq0,
dq1: &dq1,
dqd1: &dqd1,
m_dq: &marg,
m_dqd1: &m_dqd1,
g_dg: &logb,
},
&row_hess,
&row_grad,
&s_total,
KKT_PHANTOM_TRUST_RADIUS,
)
.expect("report builds");
assert!(
report.near_null.is_empty(),
"data-identified directions must not be near-null; report: {}",
report.summary()
);
}
}