use std::sync::Arc;
use ndarray::{Array1, Array2, Array3};
use gam_identifiability::families::compiler::{
BlockOrder, RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with,
};
use gam_problem::gauge::assemble_block_triangular_t;
use faer::Side;
use gam_linalg::faer_ndarray::FaerEigh;
use gam_linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
use gam_problem::{FamilyChannelHessian, PenaltyMatrix};
const K_SURVIVAL: usize = 4;
const BETA_NONTRIVIAL_ABS_THRESHOLD: f64 = 1e-12;
pub struct SurvivalRowHessian {
h: Array3<f64>,
}
impl SurvivalRowHessian {
pub fn from_pilot_primary_state(
q0: &Array1<f64>,
q1: &Array1<f64>,
qd1: &Array1<f64>,
g: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
event: &Array1<f64>,
derivative_guard: f64,
probit_scale: f64,
) -> Result<Self, String> {
let n = q0.len();
if [
q1.len(),
qd1.len(),
g.len(),
z.len(),
weights.len(),
event.len(),
]
.iter()
.any(|&l| l != n)
{
return Err(format!(
"SurvivalRowHessian: length mismatch \
q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
q1.len(),
qd1.len(),
g.len(),
z.len(),
weights.len(),
event.len()
));
}
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
let (_, _grad, hess) =
crate::survival::marginal_slope::row_primary_for_compiler(
q0[i],
q1[i],
qd1[i],
g[i],
z[i],
weights[i],
event[i],
derivative_guard,
probit_scale,
)?;
let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_i[[a, b]] = hess[a][b];
}
}
let clamped = psd_clamp_4x4(&h_i);
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_full[[i, a, b]] = clamped[[a, b]];
}
}
}
Ok(Self { h: h_full })
}
pub fn from_full(h: Array3<f64>) -> Self {
assert_eq!(h.shape()[1], K_SURVIVAL);
assert_eq!(h.shape()[2], K_SURVIVAL);
Self { h }
}
}
impl RowHessian for SurvivalRowHessian {
fn k(&self) -> usize {
K_SURVIVAL
}
fn nrows(&self) -> usize {
self.h.shape()[0]
}
fn fill_row(&self, row: usize, out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
}
}
}
fn evaluate_full(&self) -> Array3<f64> {
self.h.clone()
}
}
impl FamilyChannelHessian for SurvivalRowHessian {
fn n_outputs(&self) -> usize {
K_SURVIVAL
}
fn n_subjects(&self) -> usize {
self.h.shape()[0]
}
fn fill_subject(&self, i: usize, out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
}
}
}
fn evaluate_full(&self) -> ndarray::Array3<f64> {
self.h.clone()
}
fn channel_hessian_at(
&self,
beta: &[f64],
family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
) -> Result<Arc<dyn FamilyChannelHessian>, String> {
use crate::survival::marginal_slope::SurvivalMarginalSlopeFamilyScalars;
let scalars_opt =
family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
let beta_nontrivial = beta
.iter()
.any(|&b| b.abs() > BETA_NONTRIVIAL_ABS_THRESHOLD);
match scalars_opt {
None if beta_nontrivial => {
Err(
"SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
via FamilyLinearizationState::family_scalars to evaluate W(β) \
correctly (same contract as T26 Jacobian callbacks)."
.to_string(),
)
}
None => {
Ok(Arc::new(gam_problem::TensorChannelHessian {
h: self.h.clone(),
}))
}
Some(sc) => {
let n = self.h.shape()[0];
if sc.q0_i.len() != n
|| sc.q1_i.len() != n
|| sc.qd1_i.len() != n
|| sc.g_i.len() != n
|| sc.z_i.len() != n
{
return Err(format!(
"SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
(expected n={n}, got q0={} q1={} qd1={} g={} z={})",
sc.q0_i.len(),
sc.q1_i.len(),
sc.qd1_i.len(),
sc.g_i.len(),
sc.z_i.len(),
));
}
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
let q0 = sc.q0_i[i];
let q1 = sc.q1_i[i];
let qd1 = sc.qd1_i[i];
let g = sc.g_i[i];
let z = sc.z_i[i];
match crate::survival::marginal_slope::row_primary_for_compiler(
q0, q1, qd1, g, z, 1.0, 1.0, crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
sc.s, ) {
Ok((_nll, _grad, hess)) => {
let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_i[[a, b]] = hess[a][b];
}
}
let clamped = psd_clamp_4x4(&h_i);
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_full[[i, a, b]] = clamped[[a, b]];
}
}
}
Err(_) => {
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_full[[i, a, b]] = self.h[[i, a, b]];
}
}
}
}
}
Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
}
}
}
}
fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
let k = m.nrows();
let (evals, evecs) = match m.eigh(Side::Lower) {
Ok(pair) => pair,
Err(_) => {
let mut out = Array2::<f64>::zeros((k, k));
for i in 0..k {
out[[i, i]] = m[[i, i]].max(0.0);
}
return out;
}
};
let mut out = Array2::<f64>::zeros((k, k));
for i in 0..k {
for j in 0..k {
let mut acc = 0.0;
for l in 0..k {
acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
}
out[[i, j]] = acc;
}
}
out
}
pub struct TimeBlockOperator {
dq0: Array2<f64>,
dq1: Array2<f64>,
dqd1: Array2<f64>,
}
impl TimeBlockOperator {
pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
assert_eq!(dq0.dim(), dq1.dim());
assert_eq!(dq0.dim(), dqd1.dim());
Self { dq0, dq1, dqd1 }
}
}
impl RowJacobianOperator for TimeBlockOperator {
fn k(&self) -> usize {
K_SURVIVAL
}
fn ncols(&self) -> usize {
self.dq0.ncols()
}
fn nrows(&self) -> usize {
self.dq0.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL);
assert_eq!(delta_beta.len(), self.dq0.ncols());
let mut acc = [0.0_f64; K_SURVIVAL];
for (j, &b) in delta_beta.iter().enumerate() {
acc[0] += self.dq0[[row, j]] * b;
acc[1] += self.dq1[[row, j]] * b;
acc[2] += self.dqd1[[row, j]] * b;
}
out.copy_from_slice(&acc);
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.dq0.nrows();
let p = self.dq0.ncols();
let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
for i in 0..n {
for j in 0..p {
out[[i, j, 0]] = self.dq0[[i, j]];
out[[i, j, 1]] = self.dq1[[i, j]];
out[[i, j, 2]] = self.dqd1[[i, j]];
}
}
out
}
fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
let n = self.dq0.nrows();
let p = self.dq0.ncols();
scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
0 => self.dq0[[i, a]],
1 => self.dq1[[i, a]],
2 => self.dqd1[[i, a]],
_ => 0.0,
})
}
}
pub struct QChannelBlockOperator {
dq: Array2<f64>,
dqd1: Array2<f64>,
}
impl QChannelBlockOperator {
pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
assert_eq!(dq.dim(), dqd1.dim());
Self { dq, dqd1 }
}
}
impl RowJacobianOperator for QChannelBlockOperator {
fn k(&self) -> usize {
K_SURVIVAL
}
fn ncols(&self) -> usize {
self.dq.ncols()
}
fn nrows(&self) -> usize {
self.dq.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL);
assert_eq!(delta_beta.len(), self.dq.ncols());
let mut dq_acc = 0.0;
let mut dqd_acc = 0.0;
for (j, &b) in delta_beta.iter().enumerate() {
dq_acc += self.dq[[row, j]] * b;
dqd_acc += self.dqd1[[row, j]] * b;
}
out[0] = dq_acc;
out[1] = dq_acc;
out[2] = dqd_acc;
out[3] = 0.0;
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.dq.nrows();
let p = self.dq.ncols();
let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
for i in 0..n {
for j in 0..p {
let v = self.dq[[i, j]];
out[[i, j, 0]] = v;
out[[i, j, 1]] = v;
out[[i, j, 2]] = self.dqd1[[i, j]];
}
}
out
}
fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
let n = self.dq.nrows();
let p = self.dq.ncols();
scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
0 | 1 => self.dq[[i, a]],
2 => self.dqd1[[i, a]],
_ => 0.0,
})
}
}
pub struct LogslopeBlockOperator {
dg: Array2<f64>,
}
impl LogslopeBlockOperator {
pub fn new(dg: Array2<f64>) -> Self {
Self { dg }
}
}
impl RowJacobianOperator for LogslopeBlockOperator {
fn k(&self) -> usize {
K_SURVIVAL
}
fn ncols(&self) -> usize {
self.dg.ncols()
}
fn nrows(&self) -> usize {
self.dg.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL);
assert_eq!(delta_beta.len(), self.dg.ncols());
let mut acc = 0.0;
for (j, &b) in delta_beta.iter().enumerate() {
acc += self.dg[[row, j]] * b;
}
out[0] = 0.0;
out[1] = 0.0;
out[2] = 0.0;
out[3] = acc;
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.dg.nrows();
let p = self.dg.ncols();
let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
for i in 0..n {
for j in 0..p {
out[[i, j, 3]] = self.dg[[i, j]];
}
}
out
}
fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
let n = self.dg.nrows();
let p = self.dg.ncols();
scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| {
if c == 3 { self.dg[[i, a]] } else { 0.0 }
})
}
}
pub struct SurvivalCompilerInputs {
pub operators: Vec<Arc<dyn RowJacobianOperator>>,
pub ordering: Vec<BlockOrder>,
}
pub struct SurvivalParametricCompiled {
pub v_time: Array2<f64>,
pub v_marginal: Array2<f64>,
pub v_logslope: Array2<f64>,
pub drops_by_block: (usize, usize, usize),
}
fn wrap_design_with_transform(
raw: DesignMatrix,
v: &Array2<f64>,
context: &str,
) -> Result<DesignMatrix, String> {
if raw.ncols() != v.nrows() {
return Err(format!(
"{context}: raw design has {} cols but V has {} rows (V is {}×{})",
raw.ncols(),
v.nrows(),
v.nrows(),
v.ncols(),
));
}
let inner_dense = match raw {
DesignMatrix::Dense(d) => d,
DesignMatrix::Sparse(_) => {
let dense = raw
.try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
.map_err(|reason| format!("{context}: densify failed: {reason}"))?;
DenseDesignMatrix::from(dense)
}
};
let op = CoefficientTransformOperator::new(inner_dense, v.clone())
.map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
}
pub struct SurvivalParametricCompiledPerTerm {
pub v_time_per_term: Vec<Array2<f64>>,
pub v_marginal_per_term: Vec<Array2<f64>>,
pub v_logslope_per_term: Vec<Array2<f64>>,
pub r_lw_per_term: Vec<Option<Array2<f64>>>,
pub drops_by_block: (usize, usize, usize),
}
pub fn compile_survival_parametric_designs_per_term(
time_dq0: Array2<f64>,
time_dq1: Array2<f64>,
time_dqd1: Array2<f64>,
time_partition: &[std::ops::Range<usize>],
marginal_dq: Array2<f64>,
marginal_dqd1: Array2<f64>,
marginal_partition: &[std::ops::Range<usize>],
logslope_dg: Array2<f64>,
logslope_partition: &[std::ops::Range<usize>],
row_hess: &dyn RowHessian,
) -> Result<SurvivalParametricCompiledPerTerm, String> {
use gam_identifiability::families::compiler::compile;
let p_time = time_dq0.ncols();
let p_marg = marginal_dq.ncols();
let p_log = logslope_dg.ncols();
validate_partition(time_partition, p_time, "time")?;
validate_partition(marginal_partition, p_marg, "marginal")?;
validate_partition(logslope_partition, p_log, "logslope")?;
let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
let mut ordering: Vec<BlockOrder> = Vec::new();
for range in time_partition {
let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
ordering.push(BlockOrder::Time);
}
for range in marginal_partition {
let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
let dqd1 = marginal_dqd1
.slice(ndarray::s![.., range.clone()])
.to_owned();
operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
ordering.push(BlockOrder::Marginal);
}
for range in logslope_partition {
let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
ordering.push(BlockOrder::Logslope);
}
let compiled = compile(&operators, row_hess, &ordering).map_err(|e| {
format!("identifiability::families::compiler::compile (per-term) failed: {e}")
})?;
let blocks = compiled.blocks;
let n_time = time_partition.len();
let n_marg = marginal_partition.len();
let n_log = logslope_partition.len();
if blocks.len() != n_time + n_marg + n_log {
return Err(format!(
"per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
n_time + n_marg + n_log,
n_time,
n_marg,
n_log,
blocks.len(),
));
}
let mut iter = blocks.into_iter();
let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
for _ in 0..n_time {
let blk = iter.next().unwrap();
v_time_per_term.push(blk.t_lw);
r_time_per_term.push(blk.r_lw);
}
let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
for _ in 0..n_marg {
let blk = iter.next().unwrap();
v_marginal_per_term.push(blk.t_lw);
r_marginal_per_term.push(blk.r_lw);
}
let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
for _ in 0..n_log {
let blk = iter.next().unwrap();
v_logslope_per_term.push(blk.t_lw);
r_logslope_per_term.push(blk.r_lw);
}
let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
r_lw_per_term.extend(r_time_per_term);
r_lw_per_term.extend(r_marginal_per_term);
r_lw_per_term.extend(r_logslope_per_term);
let drops_time: usize = time_partition
.iter()
.zip(v_time_per_term.iter())
.map(|(r, v)| r.len().saturating_sub(v.ncols()))
.sum();
let drops_marg: usize = marginal_partition
.iter()
.zip(v_marginal_per_term.iter())
.map(|(r, v)| r.len().saturating_sub(v.ncols()))
.sum();
let drops_log: usize = logslope_partition
.iter()
.zip(v_logslope_per_term.iter())
.map(|(r, v)| r.len().saturating_sub(v.ncols()))
.sum();
Ok(SurvivalParametricCompiledPerTerm {
v_time_per_term,
v_marginal_per_term,
v_logslope_per_term,
r_lw_per_term,
drops_by_block: (drops_time, drops_marg, drops_log),
})
}
fn validate_partition(
partition: &[std::ops::Range<usize>],
p_block: usize,
label: &str,
) -> Result<(), String> {
if partition.is_empty() {
if p_block == 0 {
return Ok(());
}
return Err(format!(
"{label} partition empty but block has p={p_block} columns"
));
}
if partition[0].start != 0 {
return Err(format!(
"{label} partition must start at 0, got start={}",
partition[0].start
));
}
if partition.last().unwrap().end != p_block {
return Err(format!(
"{label} partition must cover [0, {p_block}); last range ends at {}",
partition.last().unwrap().end
));
}
for w in partition.windows(2) {
if w[0].end != w[1].start {
return Err(format!(
"{label} partition has gap/overlap between [{}..{}) and [{}..{})",
w[0].start, w[0].end, w[1].start, w[1].end
));
}
if w[0].is_empty() {
return Err(format!(
"{label} partition has empty range [{}..{})",
w[0].start, w[0].end
));
}
}
if partition.last().unwrap().is_empty() {
return Err(format!("{label} partition's final range is empty",));
}
Ok(())
}
pub fn extract_term_partition_from_penalty_ranges(
p_block: usize,
penalty_ranges: &[std::ops::Range<usize>],
) -> Vec<std::ops::Range<usize>> {
use std::collections::BTreeSet;
let mut starts: BTreeSet<usize> = BTreeSet::new();
starts.insert(0);
starts.insert(p_block);
for r in penalty_ranges {
starts.insert(r.start.min(p_block));
starts.insert(r.end.min(p_block));
}
let v: Vec<usize> = starts.into_iter().collect();
v.windows(2)
.filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
.collect()
}
pub fn pull_back_blockwise_penalty_through_block_v(
pen: &gam_terms::smooth::BlockwisePenalty,
v_block: &Array2<f64>,
) -> Result<PenaltyMatrix, String> {
let raw_p = v_block.nrows();
let compiled_p = v_block.ncols();
let block_p = pen.col_range.len();
let embed_start = pen.col_range.start;
let embed_end = pen.col_range.end;
if embed_end > raw_p {
return Err(format!(
"pull_back_blockwise_penalty_through_block_v: penalty col_range {embed_start}..{embed_end} \
exceeds block raw width {raw_p}"
));
}
if pen.local.nrows() != block_p || pen.local.ncols() != block_p {
return Err(format!(
"pull_back_blockwise_penalty_through_block_v: penalty local is {}x{} but col_range \
width is {block_p}",
pen.local.nrows(),
pen.local.ncols(),
));
}
let mut embedded = Array2::<f64>::zeros((raw_p, raw_p));
if block_p > 0 {
let mut dst =
embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
for i in 0..block_p {
for j in 0..block_p {
dst[[i, j]] = pen.local[[i, j]];
}
}
}
let temp = embedded.dot(v_block);
let pulled = v_block.t().dot(&temp);
let mut sym = Array2::<f64>::zeros((compiled_p, compiled_p));
for i in 0..compiled_p {
for j in 0..compiled_p {
sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
}
}
Ok(PenaltyMatrix::Dense(sym))
}
pub fn compiled_map_from_per_term(
compiled: &SurvivalParametricCompiledPerTerm,
) -> gam_identifiability::families::compiler::CompiledMap {
let mut v_all: Vec<Array2<f64>> = Vec::new();
v_all.extend(compiled.v_time_per_term.iter().cloned());
v_all.extend(compiled.v_marginal_per_term.iter().cloned());
v_all.extend(compiled.v_logslope_per_term.iter().cloned());
let t_full = assemble_block_triangular_t(&v_all, &compiled.r_lw_per_term);
let raw_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.nrows()).sum() };
let kept_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.ncols()).sum() };
let raw_time = raw_w(&compiled.v_time_per_term);
let raw_marg = raw_w(&compiled.v_marginal_per_term);
let raw_log = raw_w(&compiled.v_logslope_per_term);
let kept_time = kept_w(&compiled.v_time_per_term);
let kept_marg = kept_w(&compiled.v_marginal_per_term);
let kept_log = kept_w(&compiled.v_logslope_per_term);
let raw_block_ranges = vec![
0..raw_time,
raw_time..(raw_time + raw_marg),
(raw_time + raw_marg)..(raw_time + raw_marg + raw_log),
];
let compiled_block_ranges = vec![
0..kept_time,
kept_time..(kept_time + kept_marg),
(kept_time + kept_marg)..(kept_time + kept_marg + kept_log),
];
gam_identifiability::families::compiler::CompiledMap {
raw_from_compiled: t_full,
compiled_block_ranges,
raw_block_ranges,
}
}
pub fn apply_compiled_map_to_designs(
map: &gam_identifiability::families::compiler::CompiledMap,
time_design_entry: DesignMatrix,
time_design_exit: DesignMatrix,
time_design_derivative_exit: DesignMatrix,
marginal_design: DesignMatrix,
logslope_design: DesignMatrix,
time_penalties: &[gam_terms::smooth::BlockwisePenalty],
marginal_penalties: &[gam_terms::smooth::BlockwisePenalty],
logslope_penalties: &[gam_terms::smooth::BlockwisePenalty],
) -> Result<CompiledSurvivalDesignsVMExact, String> {
if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
return Err(format!(
"apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
got {} raw / {} compiled",
map.raw_block_ranges.len(),
map.compiled_block_ranges.len(),
));
}
let time_raw = map.raw_block_ranges[0].clone();
let marg_raw = map.raw_block_ranges[1].clone();
let log_raw = map.raw_block_ranges[2].clone();
let time_compiled = map.compiled_block_ranges[0].clone();
let marg_compiled = map.compiled_block_ranges[1].clone();
let log_compiled = map.compiled_block_ranges[2].clone();
let t = &map.raw_from_compiled;
let raw_total = t.nrows();
let compiled_total = t.ncols();
let expected_raw_total = log_raw.end;
if raw_total != expected_raw_total {
return Err(format!(
"apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
{expected_raw_total}"
));
}
let expected_compiled_total = log_compiled.end;
if compiled_total != expected_compiled_total {
return Err(format!(
"apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
sum to {expected_compiled_total}"
));
}
let v_time = t
.slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
.to_owned();
let v_marg = t
.slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
.to_owned();
let v_log = t
.slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
.to_owned();
let time_entry_out =
wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
let time_exit_out =
wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
let time_deriv_out = wrap_design_with_transform(
time_design_derivative_exit,
&v_time,
"compiled-map: time derivative_exit",
)?;
let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
let pull_set = |pens: &[gam_terms::smooth::BlockwisePenalty],
v_block: &Array2<f64>,
channel: &str|
-> Result<Vec<PenaltyMatrix>, String> {
pens.iter()
.map(|p| {
pull_back_blockwise_penalty_through_block_v(p, v_block).map_err(|e| {
format!("apply_compiled_map_to_designs: {channel} penalty pullback: {e}")
})
})
.collect()
};
let time_penalties = pull_set(time_penalties, &v_time, "time")?;
let marginal_penalties = pull_set(marginal_penalties, &v_marg, "marginal")?;
let logslope_penalties = pull_set(logslope_penalties, &v_log, "logslope")?;
validate_block_penalty_shapes("time", time_exit_out.ncols(), &time_penalties)?;
validate_block_penalty_shapes("marginal", marg_out.ncols(), &marginal_penalties)?;
validate_block_penalty_shapes("logslope", log_out.ncols(), &logslope_penalties)?;
Ok(CompiledSurvivalDesignsVMExact {
time_design_entry: time_entry_out,
time_design_exit: time_exit_out,
time_design_derivative_exit: time_deriv_out,
marginal_design: marg_out,
logslope_design: log_out,
time_penalties,
marginal_penalties,
logslope_penalties,
})
}
fn validate_block_penalty_shapes(
block: &str,
width: usize,
penalties: &[PenaltyMatrix],
) -> Result<(), String> {
for (idx, penalty) in penalties.iter().enumerate() {
let shape = penalty.shape();
if shape != (width, width) {
return Err(format!(
"apply_compiled_map_to_designs: {block} penalty {idx} must be {width}x{width}, got {}x{}",
shape.0, shape.1
));
}
}
Ok(())
}
pub fn compile_survival_parametric_designs(
time_dq0: Array2<f64>,
time_dq1: Array2<f64>,
time_dqd1: Array2<f64>,
marginal_dq: Array2<f64>,
marginal_dqd1: Array2<f64>,
logslope_dg: Array2<f64>,
row_hess: &dyn RowHessian,
) -> Result<SurvivalParametricCompiled, String> {
use gam_identifiability::families::compiler::compile;
let p_time_raw = time_dq0.ncols();
let p_marg_raw = marginal_dq.ncols();
let p_log_raw = logslope_dg.ncols();
let inputs = build_survival_compiler_inputs(
time_dq0,
time_dq1,
time_dqd1,
marginal_dq,
marginal_dqd1,
logslope_dg,
None,
None,
);
if inputs.operators.len() != 3 {
return Err(format!(
"compile_survival_parametric_designs: expected exactly 3 parametric operators \
(time, marginal, logslope); got {}",
inputs.operators.len(),
));
}
let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
.map_err(|e| format!("identifiability::families::compiler::compile failed: {e}"))?;
if compiled.blocks.len() != 3 {
return Err(format!(
"compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
compiled.blocks.len(),
));
}
let v_time = compiled.blocks[0].t_lw.clone();
let v_marginal = compiled.blocks[1].t_lw.clone();
let v_logslope = compiled.blocks[2].t_lw.clone();
let drops_by_block = (
p_time_raw.saturating_sub(v_time.ncols()),
p_marg_raw.saturating_sub(v_marginal.ncols()),
p_log_raw.saturating_sub(v_logslope.ncols()),
);
Ok(SurvivalParametricCompiled {
v_time,
v_marginal,
v_logslope,
drops_by_block,
})
}
pub fn build_survival_compiler_inputs(
time_dq0: Array2<f64>,
time_dq1: Array2<f64>,
time_dqd1: Array2<f64>,
marginal_dq: Array2<f64>,
marginal_dqd1: Array2<f64>,
logslope_dg: Array2<f64>,
score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
) -> SurvivalCompilerInputs {
let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
operators.push(Arc::new(TimeBlockOperator::new(
time_dq0, time_dq1, time_dqd1,
)));
ordering.push(BlockOrder::Time);
operators.push(Arc::new(QChannelBlockOperator::new(
marginal_dq,
marginal_dqd1,
)));
ordering.push(BlockOrder::Marginal);
operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
ordering.push(BlockOrder::Logslope);
if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
ordering.push(BlockOrder::ScoreWarp);
}
if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
ordering.push(BlockOrder::LinkDev);
}
SurvivalCompilerInputs {
operators,
ordering,
}
}
pub struct CompiledSurvivalDesignsVMExact {
pub time_design_entry: DesignMatrix,
pub time_design_exit: DesignMatrix,
pub time_design_derivative_exit: DesignMatrix,
pub marginal_design: DesignMatrix,
pub logslope_design: DesignMatrix,
pub time_penalties: Vec<PenaltyMatrix>,
pub marginal_penalties: Vec<PenaltyMatrix>,
pub logslope_penalties: Vec<PenaltyMatrix>,
}
#[cfg(test)]
mod tests {
use super::*;
use gam_problem::Gauge;
#[test]
fn psd_clamp_zeros_negative_eigenvalues() {
let mut m = Array2::<f64>::zeros((4, 4));
m[[0, 0]] = 2.0;
m[[1, 1]] = -1.0;
m[[2, 2]] = 0.5;
m[[3, 3]] = -0.25;
let clamped = psd_clamp_4x4(&m);
assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
assert!(clamped[[1, 1]].abs() < 1e-12);
assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
assert!(clamped[[3, 3]].abs() < 1e-12);
}
#[test]
fn time_block_operator_evaluate_full_shape() {
let n = 6;
let p = 3;
let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
let full = op.evaluate_full();
assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
for i in 0..n {
for j in 0..p {
assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
assert_eq!(full[[i, j, 3]], 0.0);
}
}
}
#[test]
fn q_channel_block_apply_row_shares_q0_q1() {
let n = 5;
let p = 2;
let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
let mut out = [0.0_f64; K_SURVIVAL];
let delta = [1.0_f64, -0.5];
op.apply_row(3, &delta, &mut out);
let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
assert!((out[0] - want_q).abs() < 1e-12);
assert!((out[1] - want_q).abs() < 1e-12);
assert!((out[2] - want_qd).abs() < 1e-12);
assert_eq!(out[3], 0.0);
}
#[test]
fn logslope_block_writes_only_g_channel() {
let n = 4;
let p = 2;
let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
let op = LogslopeBlockOperator::new(dg.clone());
let mut out = [0.0_f64; K_SURVIVAL];
let delta = [2.0_f64, -1.0];
op.apply_row(1, &delta, &mut out);
assert_eq!(out[0], 0.0);
assert_eq!(out[1], 0.0);
assert_eq!(out[2], 0.0);
let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
assert!((out[3] - want).abs() < 1e-12);
}
#[test]
fn extract_term_partition_simple_cases() {
let full = 0..5usize;
let part = extract_term_partition_from_penalty_ranges(5, &[]);
assert_eq!(part.as_slice(), std::slice::from_ref(&full));
let part = extract_term_partition_from_penalty_ranges(5, std::slice::from_ref(&full));
assert_eq!(part.as_slice(), std::slice::from_ref(&full));
let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
assert_eq!(part, vec![0..3, 3..6, 6..10]);
let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
assert_eq!(part, vec![0..3, 3..6]);
let part = extract_term_partition_from_penalty_ranges(0, &[]);
assert!(part.is_empty());
}
#[test]
fn assemble_block_triangular_t_identity_when_v_eye_and_r_none() {
let v_a = Array2::<f64>::eye(2);
let v_b = Array2::<f64>::eye(2);
let t = assemble_block_triangular_t(&[v_a, v_b], &[None, None]);
assert_eq!(t.dim(), (4, 4));
let eye4 = Array2::<f64>::eye(4);
for i in 0..4 {
for j in 0..4 {
assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
}
}
}
#[test]
fn assemble_block_triangular_t_with_drops_and_nonzero_r() {
let mut v_a = Array2::<f64>::zeros((3, 2));
v_a[[0, 0]] = 1.0;
v_a[[1, 0]] = 0.5;
v_a[[2, 1]] = 1.0;
let v_b = Array2::<f64>::eye(2);
let r_ab =
Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
let t =
assemble_block_triangular_t(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
assert_eq!(t.dim(), (5, 4));
for i in 0..3 {
for j in 0..2 {
assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
}
}
for i in 0..2 {
for j in 0..2 {
assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
}
}
for i in 0..3 {
for j in 0..2 {
assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
}
}
for i in 0..2 {
for j in 0..2 {
assert_eq!(t[[3 + i, j]], 0.0);
}
}
}
#[test]
fn validate_partition_rejects_bad_partitions() {
let bad_start = 1..5usize;
let short_cover = 0..3usize;
let full_cover = 0..5usize;
assert!(validate_partition(std::slice::from_ref(&bad_start), 5, "test").is_err());
assert!(validate_partition(std::slice::from_ref(&short_cover), 5, "test").is_err());
assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
assert!(validate_partition(&[], 0, "test").is_ok());
assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
assert!(validate_partition(std::slice::from_ref(&full_cover), 5, "test").is_ok());
}
#[test]
fn compiled_map_penalty_pullback_is_per_block_width_with_nonzero_residual() {
use gam_identifiability::families::compiler::CompiledMap;
use gam_terms::smooth::BlockwisePenalty;
let n = 10;
let v_time =
Array2::<f64>::from_shape_fn(
(3, 3),
|(i, j)| {
if i == j { 1.0 } else { 0.1 * ((i + j) as f64) }
},
);
let v_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
0.5 + 0.3 * (i as f64) - 0.2 * (j as f64)
});
let v_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 1.2 } else { 0.4 });
let r_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 0.7 - 0.1 * ((i + j) as f64));
let r_log =
Array2::<f64>::from_shape_fn((6, 2), |(i, j)| 0.3 + 0.05 * ((i * 2 + j) as f64));
let t = assemble_block_triangular_t(
&[v_time.clone(), v_marg.clone(), v_log.clone()],
&[None, Some(r_marg.clone()), Some(r_log.clone())],
);
assert_eq!(t.dim(), (8, 7), "joint raw 8 × joint compiled 7");
let map = CompiledMap {
raw_from_compiled: t.clone(),
compiled_block_ranges: vec![0..3, 3..5, 5..7],
raw_block_ranges: vec![0..3, 3..6, 6..8],
};
let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
Array2::<f64>::from_shape_fn((n, 3), |(i, j)| 1.0 + (i as f64) * 0.1 + (j as f64)),
));
let raw_time_exit = raw_time_entry.clone();
let raw_time_deriv = raw_time_entry.clone();
let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
(n, 3),
|(i, j)| 0.2 * (i as f64) - 0.3 * (j as f64),
)));
let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
(n, 2),
|(i, j)| 0.5 + (i as f64) * (j as f64 + 1.0),
)));
let s_time =
Array2::<f64>::from_shape_fn(
(3, 3),
|(i, j)| if i == j { (i + 2) as f64 } else { 0.3 },
);
let s_marg =
Array2::<f64>::from_shape_fn(
(3, 3),
|(i, j)| if i == j { 1.5 + i as f64 } else { 0.2 },
);
let s_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 2.0 } else { 0.5 });
let time_pens = vec![BlockwisePenalty::new(0..3, s_time.clone())];
let marg_pens = vec![BlockwisePenalty::new(0..3, s_marg.clone())];
let log_pens = vec![BlockwisePenalty::new(0..2, s_log.clone())];
let out = apply_compiled_map_to_designs(
&map,
raw_time_entry,
raw_time_exit,
raw_time_deriv,
raw_marg,
raw_log,
&time_pens,
&marg_pens,
&log_pens,
)
.expect("apply_compiled_map_to_designs must succeed");
assert_eq!(out.time_design_entry.ncols(), 3);
assert_eq!(out.marginal_design.ncols(), 2);
assert_eq!(out.logslope_design.ncols(), 2);
for s in &out.time_penalties {
assert_eq!(
s.as_dense_cow().dim(),
(3, 3),
"time penalty must be per-block 3×3, not joint-width"
);
}
for s in &out.marginal_penalties {
assert_eq!(
s.as_dense_cow().dim(),
(2, 2),
"marginal penalty must match reduced compiled width 2, not joint 7"
);
}
for s in &out.logslope_penalties {
assert_eq!(s.as_dense_cow().dim(), (2, 2));
}
let p_time_dense = out.time_penalties[0].as_dense_cow().into_owned();
let theta_time = Array1::<f64>::from_shape_fn(3, |k| 0.4 + 0.7 * (k as f64));
let gamma_time = v_time.dot(&theta_time);
let lhs = theta_time.dot(&p_time_dense.dot(&theta_time));
let rhs = gamma_time.dot(&s_time.dot(&gamma_time));
assert!(
(lhs - rhs).abs() < 1e-10,
"time-block per-block pullback must be exact: lhs={lhs}, rhs={rhs}"
);
let p_marg_dense = out.marginal_penalties[0].as_dense_cow().into_owned();
let want_marg = v_marg.t().dot(&s_marg.dot(&v_marg));
for i in 0..2 {
for j in 0..2 {
assert!(
(p_marg_dense[[i, j]] - want_marg[[i, j]]).abs() < 1e-12,
"marginal penalty must be V_margᵀ S_marg V_marg at ({i},{j})"
);
}
}
}
#[test]
fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
let n = 24;
let p_time = 3;
let p_marginal = 3;
let p_logslope = 2;
let x: Vec<f64> = (0..n)
.map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
.collect();
let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
for i in 0..n {
time_dq0[[i, 0]] = 1.0;
time_dq0[[i, 1]] = x[i];
time_dq0[[i, 2]] = x[i] * x[i];
time_dq1[[i, 0]] = 1.0;
time_dq1[[i, 1]] = x[i];
time_dq1[[i, 2]] = x[i] * x[i];
time_dqd1[[i, 0]] = 0.0;
time_dqd1[[i, 1]] = 1.0;
time_dqd1[[i, 2]] = 2.0 * x[i];
marg_dq[[i, 0]] = 1.0; marg_dq[[i, 1]] = x[i] * x[i] * x[i];
marg_dq[[i, 2]] = x[i].sin();
log_dg[[i, 0]] = (2.0 * x[i]).cos();
log_dg[[i, 1]] = x[i].tanh();
}
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
for k in 0..K_SURVIVAL {
h_full[[i, k, k]] = 1.0;
}
}
let row_hess = SurvivalRowHessian::from_full(h_full);
let out = compile_survival_parametric_designs(
time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
)
.expect("Phase-4b parametric compile must succeed on single-direction alias");
assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
assert_eq!(
out.v_marginal.ncols(),
p_marginal - 1,
"marginal loses exactly the shared-constant direction"
);
assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
assert_eq!(
out.drops_by_block,
(0, 1, 0),
"attribution: zero from time/logslope, one from marginal",
);
}
#[test]
fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
use gam_identifiability::families::compiler::compile;
let n = 32;
let p_time = 3;
let p_marginal = 3;
let p_logslope = 2;
let x: Vec<f64> = (0..n)
.map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
.collect();
let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
for i in 0..n {
time_dq0[[i, 0]] = 1.0;
time_dq0[[i, 1]] = x[i];
time_dq0[[i, 2]] = x[i] * x[i];
time_dq1[[i, 0]] = 1.0;
time_dq1[[i, 1]] = x[i];
time_dq1[[i, 2]] = x[i] * x[i];
time_dqd1[[i, 0]] = 0.0;
time_dqd1[[i, 1]] = 1.0;
time_dqd1[[i, 2]] = 2.0 * x[i];
}
let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
for i in 0..n {
marg_dq[[i, 0]] = 1.0;
marg_dq[[i, 1]] = x[i] * x[i] * x[i];
marg_dq[[i, 2]] = x[i].sin();
}
let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
for i in 0..n {
log_dg[[i, 0]] = (2.0 * x[i]).cos();
log_dg[[i, 1]] = x[i].tanh();
}
let inputs = build_survival_compiler_inputs(
time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
);
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
for k in 0..K_SURVIVAL {
h_full[[i, k, k]] = 1.0;
}
}
let row_hess = SurvivalRowHessian::from_full(h_full);
let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
.expect("survival 3-block compile must succeed; aliasing is single-direction");
assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
let v_time = &compiled.blocks[0].t_lw;
assert_eq!(
v_time.ncols(),
p_time,
"time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
v_time.dim(),
);
let v_marg = &compiled.blocks[1].t_lw;
assert_eq!(
v_marg.ncols(),
p_marginal - 1,
"marginal block must lose exactly the shared-constant direction; \
V_marginal cols = {}, expected {}",
v_marg.ncols(),
p_marginal - 1,
);
let v_log = &compiled.blocks[2].t_lw;
assert_eq!(
v_log.ncols(),
p_logslope,
"logslope block (no shared direction) must retain all {p_logslope} columns",
);
let raw_total = p_time + p_marginal + p_logslope;
let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
assert_eq!(
kept_total,
raw_total - 1,
"joint kept = raw_total − aliased; got {kept_total}, expected {}",
raw_total - 1,
);
assert_eq!(
compiled.joint_rank, kept_total,
"CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
);
let v_per_term: Vec<Array2<f64>> = compiled.blocks.iter().map(|b| b.t_lw.clone()).collect();
let r_per_term: Vec<Option<Array2<f64>>> = vec![None; v_per_term.len()];
let gauge = Gauge::from_v_and_r(&v_per_term, &r_per_term);
let mut expected_reduced = vec![0usize];
let mut expected_raw = vec![0usize];
for b in &compiled.blocks {
let prev_reduced = *expected_reduced.last().unwrap();
expected_reduced.push(prev_reduced + b.t_lw.ncols());
let prev_raw = *expected_raw.last().unwrap();
expected_raw.push(prev_raw + b.t_lw.nrows());
}
assert_eq!(
*gauge.block_starts_reduced.last().unwrap(),
compiled.joint_rank,
"SMGS lift reduced dimension must equal the compiled joint_rank",
);
assert_eq!(
gauge.block_starts_reduced, expected_reduced,
"SMGS lift reduced block boundaries must match the compiled kept widths",
);
assert_eq!(
gauge.block_starts_raw, expected_raw,
"SMGS lift raw block boundaries must match the compiled per-block raw widths",
);
for (bi, block) in compiled.blocks.iter().enumerate() {
for j in 0..block.t_lw.ncols() {
let col = block.t_lw.column(j);
assert!(
col.iter().all(|v| v.is_finite()),
"block {bi} kept direction {j} has a non-finite entry",
);
let norm = col.dot(&col).sqrt();
assert!(
norm > 1e-10,
"block {bi} kept direction {j} is degenerate (norm {norm:.3e})",
);
}
}
}
#[test]
fn smgs_lift_via_t_identity_passes_through() {
let v0 = Array2::<f64>::eye(3);
let v1 = Array2::<f64>::eye(2);
let v_per_term = vec![v0, v1];
let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
let lift = Gauge::from_v_and_r(&v_per_term, &r_per_term);
assert_eq!(lift.t_full.dim(), (5, 5));
assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
for i in 0..5 {
for j in 0..5 {
let want = if i == j { 1.0 } else { 0.0 };
assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
}
}
let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
let lifted = lift.lift_block_betas(&[theta_0.clone(), theta_1.clone()]);
assert_eq!(lifted.len(), 2);
for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
assert!((a - b).abs() < 1e-14);
}
for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
assert!((a - b).abs() < 1e-14);
}
}
#[test]
fn smgs_lift_via_t_two_block_with_residualisation() {
let v_a = Array2::<f64>::eye(3);
let mut v_b = Array2::<f64>::zeros((3, 2));
v_b[[0, 0]] = 1.0;
v_b[[2, 1]] = 1.0;
let mut r_b = Array2::<f64>::zeros((3, 2));
r_b[[0, 0]] = 0.4;
r_b[[0, 1]] = -0.1;
r_b[[1, 0]] = 0.7;
r_b[[1, 1]] = 1.3;
r_b[[2, 0]] = -0.2;
r_b[[2, 1]] = 0.5;
let lift = Gauge::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
assert_eq!(lift.t_full.dim(), (6, 5));
assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
let theta_b = Array1::from(vec![0.5_f64, -0.25]);
let lifted = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
let r_theta_b = r_b.dot(&theta_b);
let expected_a = &theta_a - &r_theta_b;
assert_eq!(lifted[0].len(), 3);
for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
}
assert_eq!(lifted[1].len(), 3);
assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
assert!(lifted[1][1].abs() < 1e-12);
assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
}
#[test]
fn smgs_lift_covariance_identity_and_rank1_consistency() {
let lift_id = Gauge::from_v_and_r(
&[Array2::<f64>::eye(2), Array2::<f64>::eye(2)],
&[None, None],
);
let mut cov = Array2::<f64>::zeros((4, 4));
for i in 0..4 {
for j in 0..4 {
cov[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
}
}
let lifted_id = lift_id.lift_covariance(&cov);
assert_eq!(lifted_id.dim(), (4, 4));
for i in 0..4 {
for j in 0..4 {
assert!(
(lifted_id[[i, j]] - cov[[i, j]]).abs() < 1e-12,
"identity-T covariance lift must be a no-op at [{i},{j}]",
);
}
}
let v_a = Array2::<f64>::eye(3);
let mut v_b = Array2::<f64>::zeros((3, 2));
v_b[[0, 0]] = 1.0;
v_b[[2, 1]] = 1.0;
let mut r_b = Array2::<f64>::zeros((3, 2));
r_b[[0, 0]] = 0.4;
r_b[[0, 1]] = -0.1;
r_b[[1, 0]] = 0.7;
r_b[[1, 1]] = 1.3;
r_b[[2, 0]] = -0.2;
r_b[[2, 1]] = 0.5;
let lift = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_b)]);
let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
let theta_b = Array1::from(vec![0.5_f64, -0.25]);
let theta_full = Array1::from(vec![
theta_a[0], theta_a[1], theta_a[2], theta_b[0], theta_b[1],
]);
let mut cov_rank1 = Array2::<f64>::zeros((5, 5));
for i in 0..5 {
for j in 0..5 {
cov_rank1[[i, j]] = theta_full[i] * theta_full[j];
}
}
let lifted_cov = lift.lift_covariance(&cov_rank1);
let lifted_blocks = lift.lift_block_betas(&[theta_a, theta_b]);
let beta_raw = Array1::from(
lifted_blocks
.iter()
.flat_map(|b| b.iter().copied())
.collect::<Vec<f64>>(),
);
assert_eq!(lifted_cov.dim(), (6, 6));
assert_eq!(beta_raw.len(), 6);
for i in 0..6 {
for j in 0..6 {
let want = beta_raw[i] * beta_raw[j];
assert!(
(lifted_cov[[i, j]] - want).abs() < 1e-10,
"rank-1 covariance pushforward must equal (Tθ)(Tθ)ᵀ at [{i},{j}]: got {}, want {want}",
lifted_cov[[i, j]],
);
}
}
for i in 0..6 {
for j in 0..6 {
assert!((lifted_cov[[i, j]] - lifted_cov[[j, i]]).abs() < 1e-14);
}
}
}
#[test]
fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
let mut v_a = Array2::<f64>::zeros((3, 2));
v_a[[0, 0]] = 0.6;
v_a[[1, 0]] = -0.8;
v_a[[1, 1]] = 0.3;
v_a[[2, 1]] = 0.9;
let mut v_b = Array2::<f64>::zeros((4, 3));
v_b[[0, 0]] = 1.0;
v_b[[1, 1]] = -0.4;
v_b[[2, 0]] = 0.2;
v_b[[2, 2]] = 0.7;
v_b[[3, 2]] = -1.1;
let v_per_term = vec![v_a.clone(), v_b.clone()];
let lift = Gauge::from_v_and_r(&v_per_term, &[None, None]);
let theta_a = Array1::from(vec![0.3_f64, -1.4]);
let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
let via_t = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
let ref_a = v_a.dot(&theta_a);
let ref_b = v_b.dot(&theta_b);
assert_eq!(via_t[0].len(), ref_a.len());
for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
assert!((g - w).abs() < 1e-12);
}
assert_eq!(via_t[1].len(), ref_b.len());
for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
assert!((g - w).abs() < 1e-12);
}
}
#[test]
fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
let n = 6usize;
let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
let time_dq1 = Array2::<f64>::zeros((n, 1));
let time_dqd1 = Array2::<f64>::zeros((n, 1));
let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
let marg_dqd1 = Array2::<f64>::zeros((n, 1));
let log_dg = Array2::<f64>::zeros((n, 0));
let mut time_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
time_partition.push(0..1);
let mut marg_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
marg_partition.push(0..1);
let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
for k in 0..K_SURVIVAL {
h_ident[[i, k, k]] = 1.0;
}
}
let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
let compiled_ident = compile_survival_parametric_designs_per_term(
time_dq0.clone(),
time_dq1.clone(),
time_dqd1.clone(),
&time_partition,
marg_dq.clone(),
marg_dqd1.clone(),
&marg_partition,
log_dg.clone(),
&log_partition,
&row_hess_ident,
)
.expect("identity-H compile must succeed");
let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
h_q0_only[[i, 0, 0]] = 1.0;
}
let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
let compiled_q0 = compile_survival_parametric_designs_per_term(
time_dq0,
time_dq1,
time_dqd1,
&time_partition,
marg_dq,
marg_dqd1,
&marg_partition,
log_dg,
&log_partition,
&row_hess_q0,
)
.expect("q0-only-H compile must succeed");
assert_ne!(
compiled_ident.drops_by_block, compiled_q0.drops_by_block,
"structural-H and data-adaptive-H compiles must produce different \
drops_by_block on the constructed pilot-curvature-trap design; \
identity={:?} q0-only={:?}",
compiled_ident.drops_by_block, compiled_q0.drops_by_block,
);
assert_eq!(
compiled_ident.drops_by_block.1, 0,
"identity-H marg drops expected 0, got {:?}",
compiled_ident.drops_by_block,
);
assert_eq!(
compiled_q0.drops_by_block.1, 1,
"q0-only-H marg drops expected 1, got {:?}",
compiled_q0.drops_by_block,
);
}
#[test]
fn compiled_map_from_per_term_partitions_and_lift_round_trip() {
let v_time = Array2::<f64>::eye(2);
let mut v_marg = Array2::<f64>::zeros((2, 1));
v_marg[[0, 0]] = 1.0;
v_marg[[1, 0]] = 0.5;
let v_log = Array2::<f64>::eye(1);
let r_marg = Array2::<f64>::from_shape_fn((2, 1), |(i, _)| 0.25 + i as f64);
let r_log = Array2::<f64>::from_shape_fn((4, 1), |(i, _)| 0.1 * (i as f64 + 1.0));
let per_term = SurvivalParametricCompiledPerTerm {
v_time_per_term: vec![v_time.clone()],
v_marginal_per_term: vec![v_marg.clone()],
v_logslope_per_term: vec![v_log.clone()],
r_lw_per_term: vec![None, Some(r_marg.clone()), Some(r_log.clone())],
drops_by_block: (0, 1, 0),
};
let map = compiled_map_from_per_term(&per_term);
assert_eq!(map.raw_block_ranges, vec![0..2, 2..4, 4..5]);
assert_eq!(map.compiled_block_ranges, vec![0..2, 2..3, 3..4]);
assert_eq!(map.raw_from_compiled.dim(), (5, 4));
let v_time_slice = map
.raw_from_compiled
.slice(ndarray::s![0..2, 0..2])
.to_owned();
let v_marg_slice = map
.raw_from_compiled
.slice(ndarray::s![2..4, 2..3])
.to_owned();
let v_log_slice = map
.raw_from_compiled
.slice(ndarray::s![4..5, 3..4])
.to_owned();
for i in 0..2 {
for j in 0..2 {
assert!((v_time_slice[[i, j]] - v_time[[i, j]]).abs() < 1e-14);
}
assert!((v_marg_slice[[i, 0]] - v_marg[[i, 0]]).abs() < 1e-14);
}
assert!((v_log_slice[[0, 0]] - v_log[[0, 0]]).abs() < 1e-14);
let ordering = [
gam_identifiability::families::compiler::BlockOrder::Time,
gam_identifiability::families::compiler::BlockOrder::Marginal,
gam_identifiability::families::compiler::BlockOrder::Logslope,
];
let lift_from_map = Gauge::from_compiled_map(&map, &ordering);
let v_all = vec![v_time, v_marg, v_log];
let lift_direct = Gauge::from_v_and_r(&v_all, &[None, Some(r_marg), Some(r_log)]);
assert_eq!(lift_from_map.t_full.dim(), lift_direct.t_full.dim());
for i in 0..lift_from_map.t_full.nrows() {
for j in 0..lift_from_map.t_full.ncols() {
assert!(
(lift_from_map.t_full[[i, j]] - lift_direct.t_full[[i, j]]).abs() < 1e-14,
"T mismatch at ({i},{j}): map={} direct={}",
lift_from_map.t_full[[i, j]],
lift_direct.t_full[[i, j]],
);
}
}
}
}