use std::sync::Arc;
use ndarray::{Array1, Array2, Array3};
use crate::families::custom_family::{FamilyChannelHessian, PenaltyMatrix};
use crate::families::identifiability_compiler::{
AnchorRowEvaluator, BlockOrder, RowHessian, RowJacobianOperator,
};
use crate::linalg::faer_ndarray::{FaerEigh, fast_ab};
use crate::linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
use faer::Side;
const K_SURVIVAL: usize = 4;
pub struct SurvivalRowHessian {
h: Array3<f64>,
}
impl SurvivalRowHessian {
pub fn from_pilot_primary_state(
q0: &Array1<f64>,
q1: &Array1<f64>,
qd1: &Array1<f64>,
g: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
event: &Array1<f64>,
derivative_guard: f64,
probit_scale: f64,
) -> Result<Self, String> {
let n = q0.len();
if [
q1.len(),
qd1.len(),
g.len(),
z.len(),
weights.len(),
event.len(),
]
.iter()
.any(|&l| l != n)
{
return Err(format!(
"SurvivalRowHessian: length mismatch \
q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
q1.len(),
qd1.len(),
g.len(),
z.len(),
weights.len(),
event.len()
));
}
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
let (_, _grad, hess) =
crate::families::survival_marginal_slope::row_primary_for_compiler(
q0[i],
q1[i],
qd1[i],
g[i],
z[i],
weights[i],
event[i],
derivative_guard,
probit_scale,
)?;
let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_i[[a, b]] = hess[a][b];
}
}
let clamped = psd_clamp_4x4(&h_i);
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_full[[i, a, b]] = clamped[[a, b]];
}
}
}
Ok(Self { h: h_full })
}
pub fn from_full(h: Array3<f64>) -> Self {
assert_eq!(h.shape()[1], K_SURVIVAL);
assert_eq!(h.shape()[2], K_SURVIVAL);
Self { h }
}
}
impl RowHessian for SurvivalRowHessian {
fn k(&self) -> usize {
K_SURVIVAL
}
fn nrows(&self) -> usize {
self.h.shape()[0]
}
fn fill_row(&self, row: usize, out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
}
}
}
fn evaluate_full(&self) -> Array3<f64> {
self.h.clone()
}
}
impl FamilyChannelHessian for SurvivalRowHessian {
fn n_outputs(&self) -> usize {
K_SURVIVAL
}
fn n_subjects(&self) -> usize {
self.h.shape()[0]
}
fn fill_subject(&self, i: usize, out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
}
}
}
fn evaluate_full(&self) -> ndarray::Array3<f64> {
self.h.clone()
}
fn channel_hessian_at(
&self,
beta: &[f64],
family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
) -> Result<Arc<dyn FamilyChannelHessian>, String> {
use crate::families::survival_marginal_slope::SurvivalMarginalSlopeFamilyScalars;
let scalars_opt =
family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
let beta_nontrivial = beta.iter().any(|&b| b.abs() > 1e-12);
match scalars_opt {
None if beta_nontrivial => {
Err(
"SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
via FamilyLinearizationState::family_scalars to evaluate W(β) \
correctly (same contract as T26 Jacobian callbacks)."
.to_string(),
)
}
None => {
Ok(Arc::new(
crate::families::custom_family::TensorChannelHessian { h: self.h.clone() },
))
}
Some(sc) => {
let n = self.h.shape()[0];
if sc.q0_i.len() != n
|| sc.q1_i.len() != n
|| sc.qd1_i.len() != n
|| sc.g_i.len() != n
|| sc.z_i.len() != n
{
return Err(format!(
"SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
(expected n={n}, got q0={} q1={} qd1={} g={} z={})",
sc.q0_i.len(),
sc.q1_i.len(),
sc.qd1_i.len(),
sc.g_i.len(),
sc.z_i.len(),
));
}
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
let q0 = sc.q0_i[i];
let q1 = sc.q1_i[i];
let qd1 = sc.qd1_i[i];
let g = sc.g_i[i];
let z = sc.z_i[i];
match crate::families::survival_marginal_slope::row_primary_for_compiler(
q0, q1, qd1, g, z, 1.0, 1.0, 1e-6, sc.s, ) {
Ok((_nll, _grad, hess)) => {
let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_i[[a, b]] = hess[a][b];
}
}
let clamped = psd_clamp_4x4(&h_i);
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_full[[i, a, b]] = clamped[[a, b]];
}
}
}
Err(_) => {
for a in 0..K_SURVIVAL {
for b in 0..K_SURVIVAL {
h_full[[i, a, b]] = self.h[[i, a, b]];
}
}
}
}
}
Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
}
}
}
}
fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
let k = m.nrows();
let (evals, evecs) = match m.eigh(Side::Lower) {
Ok(pair) => pair,
Err(_) => {
let mut out = Array2::<f64>::zeros((k, k));
for i in 0..k {
out[[i, i]] = m[[i, i]].max(0.0);
}
return out;
}
};
let mut out = Array2::<f64>::zeros((k, k));
for i in 0..k {
for j in 0..k {
let mut acc = 0.0;
for l in 0..k {
acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
}
out[[i, j]] = acc;
}
}
out
}
pub fn survival_row_nll_grad_hess(
q0: f64,
q1: f64,
qd1: f64,
g: f64,
z: f64,
w: f64,
d: f64,
derivative_guard: f64,
probit_scale: f64,
) -> Result<(f64, [f64; 4], [[f64; 4]; 4]), String> {
crate::families::survival_marginal_slope::row_primary_for_compiler(
q0,
q1,
qd1,
g,
z,
w,
d,
derivative_guard,
probit_scale,
)
}
pub struct TimeBlockOperator {
dq0: Array2<f64>,
dq1: Array2<f64>,
dqd1: Array2<f64>,
}
impl TimeBlockOperator {
pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
assert_eq!(dq0.dim(), dq1.dim());
assert_eq!(dq0.dim(), dqd1.dim());
Self { dq0, dq1, dqd1 }
}
}
impl RowJacobianOperator for TimeBlockOperator {
fn k(&self) -> usize {
K_SURVIVAL
}
fn ncols(&self) -> usize {
self.dq0.ncols()
}
fn nrows(&self) -> usize {
self.dq0.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL);
assert_eq!(delta_beta.len(), self.dq0.ncols());
let mut acc = [0.0_f64; K_SURVIVAL];
for (j, &b) in delta_beta.iter().enumerate() {
acc[0] += self.dq0[[row, j]] * b;
acc[1] += self.dq1[[row, j]] * b;
acc[2] += self.dqd1[[row, j]] * b;
}
out.copy_from_slice(&acc);
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.dq0.nrows();
let p = self.dq0.ncols();
let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
for i in 0..n {
for j in 0..p {
out[[i, j, 0]] = self.dq0[[i, j]];
out[[i, j, 1]] = self.dq1[[i, j]];
out[[i, j, 2]] = self.dqd1[[i, j]];
}
}
out
}
}
pub struct QChannelBlockOperator {
dq: Array2<f64>,
dqd1: Array2<f64>,
}
impl QChannelBlockOperator {
pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
assert_eq!(dq.dim(), dqd1.dim());
Self { dq, dqd1 }
}
}
impl RowJacobianOperator for QChannelBlockOperator {
fn k(&self) -> usize {
K_SURVIVAL
}
fn ncols(&self) -> usize {
self.dq.ncols()
}
fn nrows(&self) -> usize {
self.dq.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL);
assert_eq!(delta_beta.len(), self.dq.ncols());
let mut dq_acc = 0.0;
let mut dqd_acc = 0.0;
for (j, &b) in delta_beta.iter().enumerate() {
dq_acc += self.dq[[row, j]] * b;
dqd_acc += self.dqd1[[row, j]] * b;
}
out[0] = dq_acc;
out[1] = dq_acc;
out[2] = dqd_acc;
out[3] = 0.0;
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.dq.nrows();
let p = self.dq.ncols();
let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
for i in 0..n {
for j in 0..p {
let v = self.dq[[i, j]];
out[[i, j, 0]] = v;
out[[i, j, 1]] = v;
out[[i, j, 2]] = self.dqd1[[i, j]];
}
}
out
}
}
pub struct LogslopeBlockOperator {
dg: Array2<f64>,
}
impl LogslopeBlockOperator {
pub fn new(dg: Array2<f64>) -> Self {
Self { dg }
}
}
impl RowJacobianOperator for LogslopeBlockOperator {
fn k(&self) -> usize {
K_SURVIVAL
}
fn ncols(&self) -> usize {
self.dg.ncols()
}
fn nrows(&self) -> usize {
self.dg.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
assert_eq!(out.len(), K_SURVIVAL);
assert_eq!(delta_beta.len(), self.dg.ncols());
let mut acc = 0.0;
for (j, &b) in delta_beta.iter().enumerate() {
acc += self.dg[[row, j]] * b;
}
out[0] = 0.0;
out[1] = 0.0;
out[2] = 0.0;
out[3] = acc;
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.dg.nrows();
let p = self.dg.ncols();
let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
for i in 0..n {
for j in 0..p {
out[[i, j, 3]] = self.dg[[i, j]];
}
}
out
}
}
pub struct ParametricAnchorEvaluator {
design: Array2<f64>,
}
impl ParametricAnchorEvaluator {
pub fn new(design: Array2<f64>) -> Self {
Self { design }
}
}
impl AnchorRowEvaluator for ParametricAnchorEvaluator {
fn anchor_rows(&self, predict_arg: &Array1<f64>) -> Result<Array2<f64>, String> {
if predict_arg.len() != self.design.nrows() {
return Err(format!(
"ParametricAnchorEvaluator: predict_arg length {} must match \
materialised design rows {}",
predict_arg.len(),
self.design.nrows()
));
}
Ok(self.design.clone())
}
fn ncols(&self) -> usize {
self.design.ncols()
}
}
pub struct CompiledFlexAnchorEvaluator {
raw_basis: Arc<dyn Fn(&Array1<f64>) -> Result<Array2<f64>, String> + Send + Sync>,
t_lw: Array2<f64>,
anchor_correction: Option<Array2<f64>>,
parent: Option<Arc<dyn AnchorRowEvaluator>>,
}
impl CompiledFlexAnchorEvaluator {
pub fn new(
raw_basis: Arc<dyn Fn(&Array1<f64>) -> Result<Array2<f64>, String> + Send + Sync>,
t_lw: Array2<f64>,
anchor_correction: Option<Array2<f64>>,
parent: Option<Arc<dyn AnchorRowEvaluator>>,
) -> Self {
Self {
raw_basis,
t_lw,
anchor_correction,
parent,
}
}
}
impl AnchorRowEvaluator for CompiledFlexAnchorEvaluator {
fn anchor_rows(&self, predict_arg: &Array1<f64>) -> Result<Array2<f64>, String> {
let raw = (self.raw_basis)(predict_arg)?;
let rotated = fast_ab(&raw, &self.t_lw);
match (&self.anchor_correction, &self.parent) {
(Some(m), Some(parent)) => {
let anchor = parent.anchor_rows(predict_arg)?;
let correction = fast_ab(&anchor, m);
Ok(&rotated - &correction)
}
(None, _) | (_, None) => Ok(rotated),
}
}
fn ncols(&self) -> usize {
self.t_lw.ncols()
}
}
pub struct SurvivalCompilerInputs {
pub operators: Vec<Arc<dyn RowJacobianOperator>>,
pub ordering: Vec<BlockOrder>,
}
pub struct SurvivalParametricCompiled {
pub v_time: Array2<f64>,
pub v_marginal: Array2<f64>,
pub v_logslope: Array2<f64>,
pub drops_by_block: (usize, usize, usize),
}
pub struct CompiledSurvivalDesigns {
pub time_design_entry: DesignMatrix,
pub time_design_exit: DesignMatrix,
pub time_design_derivative_exit: DesignMatrix,
pub marginal_design: DesignMatrix,
pub logslope_design: DesignMatrix,
pub time_penalties: Vec<PenaltyMatrix>,
pub marginal_penalties: Vec<PenaltyMatrix>,
pub logslope_penalties: Vec<PenaltyMatrix>,
}
pub fn apply_survival_parametric_compile_to_designs(
compiled: &SurvivalParametricCompiled,
time_design_entry: DesignMatrix,
time_design_exit: DesignMatrix,
time_design_derivative_exit: DesignMatrix,
marginal_design: DesignMatrix,
logslope_design: DesignMatrix,
time_penalties: &[PenaltyMatrix],
marginal_penalties: &[PenaltyMatrix],
logslope_penalties: &[PenaltyMatrix],
) -> Result<CompiledSurvivalDesigns, String> {
Ok(CompiledSurvivalDesigns {
time_design_entry: wrap_design_with_transform(
time_design_entry,
&compiled.v_time,
"survival time block design_entry",
)?,
time_design_exit: wrap_design_with_transform(
time_design_exit,
&compiled.v_time,
"survival time block design_exit",
)?,
time_design_derivative_exit: wrap_design_with_transform(
time_design_derivative_exit,
&compiled.v_time,
"survival time block design_derivative_exit",
)?,
marginal_design: wrap_design_with_transform(
marginal_design,
&compiled.v_marginal,
"survival marginal block design",
)?,
logslope_design: wrap_design_with_transform(
logslope_design,
&compiled.v_logslope,
"survival logslope block design",
)?,
time_penalties: pull_back_penalties(time_penalties, &compiled.v_time),
marginal_penalties: pull_back_penalties(marginal_penalties, &compiled.v_marginal),
logslope_penalties: pull_back_penalties(logslope_penalties, &compiled.v_logslope),
})
}
fn wrap_design_with_transform(
raw: DesignMatrix,
v: &Array2<f64>,
context: &str,
) -> Result<DesignMatrix, String> {
if raw.ncols() != v.nrows() {
return Err(format!(
"{context}: raw design has {} cols but V has {} rows (V is {}×{})",
raw.ncols(),
v.nrows(),
v.nrows(),
v.ncols(),
));
}
let inner_dense = match raw {
DesignMatrix::Dense(d) => d,
DesignMatrix::Sparse(_) => {
let dense = raw
.try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
.map_err(|reason| format!("{context}: densify failed: {reason}"))?;
DenseDesignMatrix::from(dense)
}
};
let op = CoefficientTransformOperator::new(inner_dense, v.clone())
.map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
}
fn pull_back_penalties(penalties: &[PenaltyMatrix], v: &Array2<f64>) -> Vec<PenaltyMatrix> {
penalties
.iter()
.map(|p| {
let label = p.precision_label().map(|s| s.to_string());
let s_dense = p.as_dense_cow();
let s_view = s_dense.view();
let s_v = fast_ab(&s_view.to_owned(), v);
let vt_s_v = fast_ab(&v.t().to_owned(), &s_v);
let mut sym = Array2::<f64>::zeros(vt_s_v.dim());
for i in 0..sym.nrows() {
for j in 0..sym.ncols() {
sym[[i, j]] = 0.5 * (vt_s_v[[i, j]] + vt_s_v[[j, i]]);
}
}
let base = PenaltyMatrix::Dense(sym);
match label {
Some(lbl) => base.with_precision_label(lbl),
None => base,
}
})
.collect()
}
pub struct SurvivalParametricCompiledPerTerm {
pub v_time_per_term: Vec<Array2<f64>>,
pub v_marginal_per_term: Vec<Array2<f64>>,
pub v_logslope_per_term: Vec<Array2<f64>>,
pub r_lw_per_term: Vec<Option<Array2<f64>>>,
pub drops_by_block: (usize, usize, usize),
}
impl SurvivalParametricCompiledPerTerm {
pub fn v_time_block_diag(&self) -> Array2<f64> {
block_diag_from(&self.v_time_per_term)
}
pub fn v_marginal_block_diag(&self) -> Array2<f64> {
block_diag_from(&self.v_marginal_per_term)
}
pub fn v_logslope_block_diag(&self) -> Array2<f64> {
block_diag_from(&self.v_logslope_per_term)
}
}
fn block_diag_from(v_per_term: &[Array2<f64>]) -> Array2<f64> {
let total_rows: usize = v_per_term.iter().map(|v| v.nrows()).sum();
let total_cols: usize = v_per_term.iter().map(|v| v.ncols()).sum();
let mut out = Array2::<f64>::zeros((total_rows, total_cols));
let mut row_off = 0usize;
let mut col_off = 0usize;
for v in v_per_term {
let r = v.nrows();
let c = v.ncols();
if r > 0 && c > 0 {
out.slice_mut(ndarray::s![row_off..row_off + r, col_off..col_off + c])
.assign(v);
}
row_off += r;
col_off += c;
}
out
}
pub fn compile_survival_parametric_designs_per_term(
time_dq0: Array2<f64>,
time_dq1: Array2<f64>,
time_dqd1: Array2<f64>,
time_partition: &[std::ops::Range<usize>],
marginal_dq: Array2<f64>,
marginal_dqd1: Array2<f64>,
marginal_partition: &[std::ops::Range<usize>],
logslope_dg: Array2<f64>,
logslope_partition: &[std::ops::Range<usize>],
row_hess: &dyn RowHessian,
) -> Result<SurvivalParametricCompiledPerTerm, String> {
use crate::families::identifiability_compiler::compile;
let p_time = time_dq0.ncols();
let p_marg = marginal_dq.ncols();
let p_log = logslope_dg.ncols();
validate_partition(time_partition, p_time, "time")?;
validate_partition(marginal_partition, p_marg, "marginal")?;
validate_partition(logslope_partition, p_log, "logslope")?;
let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
let mut ordering: Vec<BlockOrder> = Vec::new();
for range in time_partition {
let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
ordering.push(BlockOrder::Time);
}
for range in marginal_partition {
let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
let dqd1 = marginal_dqd1
.slice(ndarray::s![.., range.clone()])
.to_owned();
operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
ordering.push(BlockOrder::Marginal);
}
for range in logslope_partition {
let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
ordering.push(BlockOrder::Logslope);
}
let compiled = compile(&operators, row_hess, &ordering)
.map_err(|e| format!("identifiability_compiler::compile (per-term) failed: {e}"))?;
let blocks = compiled.blocks;
let n_time = time_partition.len();
let n_marg = marginal_partition.len();
let n_log = logslope_partition.len();
if blocks.len() != n_time + n_marg + n_log {
return Err(format!(
"per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
n_time + n_marg + n_log,
n_time,
n_marg,
n_log,
blocks.len(),
));
}
let mut iter = blocks.into_iter();
let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
for _ in 0..n_time {
let blk = iter.next().unwrap();
v_time_per_term.push(blk.t_lw);
r_time_per_term.push(blk.r_lw);
}
let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
for _ in 0..n_marg {
let blk = iter.next().unwrap();
v_marginal_per_term.push(blk.t_lw);
r_marginal_per_term.push(blk.r_lw);
}
let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
for _ in 0..n_log {
let blk = iter.next().unwrap();
v_logslope_per_term.push(blk.t_lw);
r_logslope_per_term.push(blk.r_lw);
}
let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
r_lw_per_term.extend(r_time_per_term);
r_lw_per_term.extend(r_marginal_per_term);
r_lw_per_term.extend(r_logslope_per_term);
let drops_time: usize = time_partition
.iter()
.zip(v_time_per_term.iter())
.map(|(r, v)| r.len().saturating_sub(v.ncols()))
.sum();
let drops_marg: usize = marginal_partition
.iter()
.zip(v_marginal_per_term.iter())
.map(|(r, v)| r.len().saturating_sub(v.ncols()))
.sum();
let drops_log: usize = logslope_partition
.iter()
.zip(v_logslope_per_term.iter())
.map(|(r, v)| r.len().saturating_sub(v.ncols()))
.sum();
Ok(SurvivalParametricCompiledPerTerm {
v_time_per_term,
v_marginal_per_term,
v_logslope_per_term,
r_lw_per_term,
drops_by_block: (drops_time, drops_marg, drops_log),
})
}
fn validate_partition(
partition: &[std::ops::Range<usize>],
p_block: usize,
label: &str,
) -> Result<(), String> {
if partition.is_empty() {
if p_block == 0 {
return Ok(());
}
return Err(format!(
"{label} partition empty but block has p={p_block} columns"
));
}
if partition[0].start != 0 {
return Err(format!(
"{label} partition must start at 0, got start={}",
partition[0].start
));
}
if partition.last().unwrap().end != p_block {
return Err(format!(
"{label} partition must cover [0, {p_block}); last range ends at {}",
partition.last().unwrap().end
));
}
for w in partition.windows(2) {
if w[0].end != w[1].start {
return Err(format!(
"{label} partition has gap/overlap between [{}..{}) and [{}..{})",
w[0].start, w[0].end, w[1].start, w[1].end
));
}
if w[0].is_empty() {
return Err(format!(
"{label} partition has empty range [{}..{})",
w[0].start, w[0].end
));
}
}
if partition.last().unwrap().is_empty() {
return Err(format!("{label} partition's final range is empty",));
}
Ok(())
}
pub fn extract_term_partition_from_penalty_ranges(
p_block: usize,
penalty_ranges: &[std::ops::Range<usize>],
) -> Vec<std::ops::Range<usize>> {
use std::collections::BTreeSet;
let mut starts: BTreeSet<usize> = BTreeSet::new();
starts.insert(0);
starts.insert(p_block);
for r in penalty_ranges {
starts.insert(r.start.min(p_block));
starts.insert(r.end.min(p_block));
}
let v: Vec<usize> = starts.into_iter().collect();
v.windows(2)
.filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
.collect()
}
pub fn pull_back_blockwise_penalty_per_term(
pen: &crate::terms::smooth::BlockwisePenalty,
raw_partition: &[std::ops::Range<usize>],
v_per_term: &[Array2<f64>],
) -> Result<crate::terms::smooth::BlockwisePenalty, String> {
use crate::terms::smooth::BlockwisePenalty;
if raw_partition.len() != v_per_term.len() {
return Err(format!(
"pull_back_blockwise_penalty_per_term: partition len {} != v_per_term len {}",
raw_partition.len(),
v_per_term.len()
));
}
let mut term_idx = None;
for (idx, range) in raw_partition.iter().enumerate() {
if pen.col_range.start >= range.start && pen.col_range.end <= range.end {
term_idx = Some(idx);
break;
}
}
let term_idx = term_idx.ok_or_else(|| {
format!(
"pull_back_blockwise_penalty_per_term: penalty col_range {}..{} does not fit \
within any term partition (partition entries: {:?})",
pen.col_range.start, pen.col_range.end, raw_partition
)
})?;
let v_term = &v_per_term[term_idx];
let term_range = &raw_partition[term_idx];
let local_off_start = pen.col_range.start - term_range.start;
let term_p = term_range.len();
let mut embedded = Array2::<f64>::zeros((term_p, term_p));
for i in 0..pen.col_range.len() {
for j in 0..pen.col_range.len() {
embedded[[local_off_start + i, local_off_start + j]] = pen.local[[i, j]];
}
}
let temp = embedded.dot(v_term);
let pulled = v_term.t().dot(&temp);
let r = pulled.nrows();
let mut sym = Array2::<f64>::zeros((r, r));
for i in 0..r {
for j in 0..r {
sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
}
}
let mut compiled_start = 0usize;
for v in v_per_term.iter().take(term_idx) {
compiled_start += v.ncols();
}
let compiled_end = compiled_start + v_term.ncols();
Ok(BlockwisePenalty::new(compiled_start..compiled_end, sym))
}
pub fn build_full_t_matrix(
v_per_term: &[Array2<f64>],
r_per_term: &[Option<Array2<f64>>],
) -> Array2<f64> {
assert_eq!(
v_per_term.len(),
r_per_term.len(),
"build_full_t_matrix: v_per_term len {} != r_per_term len {}",
v_per_term.len(),
r_per_term.len(),
);
let raw_widths: Vec<usize> = v_per_term.iter().map(|v| v.nrows()).collect();
let kept_widths: Vec<usize> = v_per_term.iter().map(|v| v.ncols()).collect();
let row_offsets: Vec<usize> = {
let mut o = Vec::with_capacity(raw_widths.len() + 1);
o.push(0);
for w in &raw_widths {
o.push(o.last().copied().unwrap_or(0) + w);
}
o
};
let col_offsets: Vec<usize> = {
let mut o = Vec::with_capacity(kept_widths.len() + 1);
o.push(0);
for w in &kept_widths {
o.push(o.last().copied().unwrap_or(0) + w);
}
o
};
let total_rows = row_offsets.last().copied().unwrap_or(0);
let total_cols = col_offsets.last().copied().unwrap_or(0);
let mut t = Array2::<f64>::zeros((total_rows, total_cols));
for (b, v) in v_per_term.iter().enumerate() {
let r = v.nrows();
let c = v.ncols();
if r > 0 && c > 0 {
t.slice_mut(ndarray::s![
row_offsets[b]..row_offsets[b] + r,
col_offsets[b]..col_offsets[b] + c
])
.assign(v);
}
}
for b in 1..v_per_term.len() {
let Some(r_stack) = r_per_term[b].as_ref() else {
continue;
};
let kept_b = kept_widths[b];
assert_eq!(
r_stack.ncols(),
kept_b,
"build_full_t_matrix: r_per_term[{b}] has {} cols, expected {}",
r_stack.ncols(),
kept_b,
);
let expected_rows: usize = raw_widths.iter().take(b).sum();
assert_eq!(
r_stack.nrows(),
expected_rows,
"build_full_t_matrix: r_per_term[{b}] has {} rows, expected {} (sum of raw_widths[0..{}])",
r_stack.nrows(),
expected_rows,
b,
);
let mut local_row = 0usize;
for a in 0..b {
let r_a = raw_widths[a];
if r_a == 0 || kept_b == 0 {
local_row += r_a;
continue;
}
let block = r_stack.slice(ndarray::s![local_row..local_row + r_a, ..]);
let mut dst = t.slice_mut(ndarray::s![
row_offsets[a]..row_offsets[a] + r_a,
col_offsets[b]..col_offsets[b] + kept_b
]);
for i in 0..r_a {
for j in 0..kept_b {
dst[[i, j]] = -block[[i, j]];
}
}
local_row += r_a;
}
}
t
}
pub struct GlobalTBlock {
pub v: Array2<f64>,
pub r_against_earlier: Option<Array2<f64>>,
}
pub struct GlobalTSpec {
pub blocks: Vec<Vec<GlobalTBlock>>,
}
pub fn build_global_t_matrix(spec: &GlobalTSpec) -> Array2<f64> {
let total_terms: usize = spec.blocks.iter().map(|b| b.len()).sum();
let mut v_flat: Vec<Array2<f64>> = Vec::with_capacity(total_terms);
let mut r_flat: Vec<Option<Array2<f64>>> = Vec::with_capacity(total_terms);
for block in &spec.blocks {
for term in block {
v_flat.push(term.v.clone());
r_flat.push(term.r_against_earlier.clone());
}
}
build_full_t_matrix(&v_flat, &r_flat)
}
pub fn pull_back_penalty_through_t(
pen: &crate::terms::smooth::BlockwisePenalty,
anchor_offset: usize,
t: &Array2<f64>,
) -> PenaltyMatrix {
let raw_total = t.nrows();
let compiled_total = t.ncols();
let block_p = pen.col_range.len();
let embed_start = anchor_offset + pen.col_range.start;
let embed_end = embed_start + block_p;
assert!(
embed_end <= raw_total,
"pull_back_penalty_through_t: embed range {}..{} exceeds raw total {}",
embed_start,
embed_end,
raw_total,
);
let mut embedded = Array2::<f64>::zeros((raw_total, raw_total));
if block_p > 0 {
let mut dst =
embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
for i in 0..block_p {
for j in 0..block_p {
dst[[i, j]] = pen.local[[i, j]];
}
}
}
let temp = embedded.dot(t);
let pulled = t.t().dot(&temp);
let mut sym = Array2::<f64>::zeros((compiled_total, compiled_total));
for i in 0..compiled_total {
for j in 0..compiled_total {
sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
}
}
PenaltyMatrix::Dense(sym)
}
pub struct CompiledSurvivalDesignsPerTerm {
pub time_design_entry: DesignMatrix,
pub time_design_exit: DesignMatrix,
pub time_design_derivative_exit: DesignMatrix,
pub marginal_design: DesignMatrix,
pub logslope_design: DesignMatrix,
pub time_penalties: Vec<crate::terms::smooth::BlockwisePenalty>,
pub marginal_penalties: Vec<crate::terms::smooth::BlockwisePenalty>,
pub logslope_penalties: Vec<crate::terms::smooth::BlockwisePenalty>,
pub v_time: Array2<f64>,
pub v_marginal: Array2<f64>,
pub v_logslope: Array2<f64>,
}
#[allow(clippy::too_many_arguments)]
pub fn apply_per_term_survival_parametric_compile_to_designs(
compiled: &SurvivalParametricCompiledPerTerm,
time_partition: &[std::ops::Range<usize>],
marginal_partition: &[std::ops::Range<usize>],
logslope_partition: &[std::ops::Range<usize>],
time_design_entry: DesignMatrix,
time_design_exit: DesignMatrix,
time_design_derivative_exit: DesignMatrix,
marginal_design: DesignMatrix,
logslope_design: DesignMatrix,
time_penalties: &[crate::terms::smooth::BlockwisePenalty],
marginal_penalties: &[crate::terms::smooth::BlockwisePenalty],
logslope_penalties: &[crate::terms::smooth::BlockwisePenalty],
) -> Result<CompiledSurvivalDesignsPerTerm, String> {
let v_time = compiled.v_time_block_diag();
let v_marginal = compiled.v_marginal_block_diag();
let v_logslope = compiled.v_logslope_block_diag();
let pull_set = |pens: &[crate::terms::smooth::BlockwisePenalty],
partition: &[std::ops::Range<usize>],
v_per_term: &[Array2<f64>]|
-> Result<Vec<crate::terms::smooth::BlockwisePenalty>, String> {
pens.iter()
.map(|p| pull_back_blockwise_penalty_per_term(p, partition, v_per_term))
.collect()
};
Ok(CompiledSurvivalDesignsPerTerm {
time_design_entry: wrap_design_with_transform(
time_design_entry,
&v_time,
"smgs per-term apply: time entry",
)?,
time_design_exit: wrap_design_with_transform(
time_design_exit,
&v_time,
"smgs per-term apply: time exit",
)?,
time_design_derivative_exit: wrap_design_with_transform(
time_design_derivative_exit,
&v_time,
"smgs per-term apply: time derivative_exit",
)?,
marginal_design: wrap_design_with_transform(
marginal_design,
&v_marginal,
"smgs per-term apply: marginal",
)?,
logslope_design: wrap_design_with_transform(
logslope_design,
&v_logslope,
"smgs per-term apply: logslope",
)?,
time_penalties: pull_set(time_penalties, time_partition, &compiled.v_time_per_term)?,
marginal_penalties: pull_set(
marginal_penalties,
marginal_partition,
&compiled.v_marginal_per_term,
)?,
logslope_penalties: pull_set(
logslope_penalties,
logslope_partition,
&compiled.v_logslope_per_term,
)?,
v_time,
v_marginal,
v_logslope,
})
}
#[derive(Debug, Clone)]
pub struct SmgsLiftPerBlockV {
pub v_per_block: Vec<Array2<f64>>,
}
impl SmgsLiftPerBlockV {
pub fn lift_block_betas(&self, block_betas: &mut [Array1<f64>]) {
for (i, beta) in block_betas.iter_mut().enumerate() {
if let Some(v) = self.v_per_block.get(i) {
if v.ncols() == beta.len() && v.nrows() != v.ncols() {
let raw = v.dot(&*beta);
*beta = raw;
}
else if v.ncols() == beta.len() && v.nrows() == v.ncols() {
let raw = v.dot(&*beta);
*beta = raw;
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct SmgsLiftViaT {
pub t_full: Array2<f64>,
pub block_starts_compiled: Vec<usize>,
pub block_starts_raw: Vec<usize>,
}
impl SmgsLiftViaT {
pub fn from_v_and_r(v_per_term: &[Array2<f64>], r_per_term: &[Option<Array2<f64>>]) -> Self {
let n_blocks = v_per_term.len();
let mut block_starts_compiled = Vec::with_capacity(n_blocks + 1);
let mut block_starts_raw = Vec::with_capacity(n_blocks + 1);
block_starts_compiled.push(0);
block_starts_raw.push(0);
for v in v_per_term {
let prev_c = *block_starts_compiled.last().unwrap();
let prev_r = *block_starts_raw.last().unwrap();
block_starts_compiled.push(prev_c + v.ncols());
block_starts_raw.push(prev_r + v.nrows());
}
let t_full = build_full_t_matrix(v_per_term, r_per_term);
Self {
t_full,
block_starts_compiled,
block_starts_raw,
}
}
pub fn from_t(t_full: Array2<f64>, raw_widths: &[usize], compiled_widths: &[usize]) -> Self {
assert_eq!(
raw_widths.len(),
compiled_widths.len(),
"SmgsLiftViaT::from_t: raw_widths len {} != compiled_widths len {}",
raw_widths.len(),
compiled_widths.len(),
);
let total_raw: usize = raw_widths.iter().sum();
let total_compiled: usize = compiled_widths.iter().sum();
assert_eq!(
t_full.dim(),
(total_raw, total_compiled),
"SmgsLiftViaT::from_t: T has shape {:?}, expected ({total_raw}, {total_compiled})",
t_full.dim(),
);
let mut block_starts_raw = Vec::with_capacity(raw_widths.len() + 1);
block_starts_raw.push(0);
for w in raw_widths {
block_starts_raw.push(block_starts_raw.last().copied().unwrap() + w);
}
let mut block_starts_compiled = Vec::with_capacity(compiled_widths.len() + 1);
block_starts_compiled.push(0);
for w in compiled_widths {
block_starts_compiled.push(block_starts_compiled.last().copied().unwrap() + w);
}
Self {
t_full,
block_starts_compiled,
block_starts_raw,
}
}
pub fn lift_block_betas_via_t(&self, compiled_block_betas: &[Array1<f64>]) -> Vec<Array1<f64>> {
let n_blocks = self.block_starts_compiled.len().saturating_sub(1);
assert_eq!(
compiled_block_betas.len(),
n_blocks,
"SmgsLiftViaT::lift_block_betas_via_t: got {} compiled block betas, expected {}",
compiled_block_betas.len(),
n_blocks,
);
for (b, beta) in compiled_block_betas.iter().enumerate() {
let expected = self.block_starts_compiled[b + 1] - self.block_starts_compiled[b];
assert_eq!(
beta.len(),
expected,
"SmgsLiftViaT::lift_block_betas_via_t: block {b} has β of len {}, expected compiled width {}",
beta.len(),
expected,
);
}
let total_compiled = *self.block_starts_compiled.last().unwrap_or(&0);
let mut theta_full = Array1::<f64>::zeros(total_compiled);
for (b, beta) in compiled_block_betas.iter().enumerate() {
let c0 = self.block_starts_compiled[b];
let c1 = self.block_starts_compiled[b + 1];
theta_full.slice_mut(ndarray::s![c0..c1]).assign(beta);
}
let beta_full = self.t_full.dot(&theta_full);
let mut out = Vec::with_capacity(n_blocks);
for b in 0..n_blocks {
let r0 = self.block_starts_raw[b];
let r1 = self.block_starts_raw[b + 1];
out.push(beta_full.slice(ndarray::s![r0..r1]).to_owned());
}
out
}
pub fn from_compiled_map(
map: &crate::families::identifiability_compiler::CompiledMap,
ordering: &[crate::families::identifiability_compiler::BlockOrder],
) -> Self {
assert_eq!(
map.raw_block_ranges.len(),
map.compiled_block_ranges.len(),
"SmgsLiftViaT::from_compiled_map: CompiledMap raw_block_ranges len {} != \
compiled_block_ranges len {}",
map.raw_block_ranges.len(),
map.compiled_block_ranges.len(),
);
assert_eq!(
map.raw_block_ranges.len(),
ordering.len(),
"SmgsLiftViaT::from_compiled_map: ordering len {} != block count {}",
ordering.len(),
map.raw_block_ranges.len(),
);
let mut block_starts_raw = Vec::with_capacity(map.raw_block_ranges.len() + 1);
block_starts_raw.push(0);
for r in &map.raw_block_ranges {
block_starts_raw.push(r.end);
}
let mut block_starts_compiled = Vec::with_capacity(map.compiled_block_ranges.len() + 1);
block_starts_compiled.push(0);
for r in &map.compiled_block_ranges {
block_starts_compiled.push(r.end);
}
Self {
t_full: map.raw_from_compiled.clone(),
block_starts_compiled,
block_starts_raw,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn apply_compiled_map_to_designs(
map: &crate::families::identifiability_compiler::CompiledMap,
time_design_entry: DesignMatrix,
time_design_exit: DesignMatrix,
time_design_derivative_exit: DesignMatrix,
marginal_design: DesignMatrix,
logslope_design: DesignMatrix,
time_penalties: &[crate::terms::smooth::BlockwisePenalty],
marginal_penalties: &[crate::terms::smooth::BlockwisePenalty],
logslope_penalties: &[crate::terms::smooth::BlockwisePenalty],
) -> Result<CompiledSurvivalDesignsVMExact, String> {
if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
return Err(format!(
"apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
got {} raw / {} compiled",
map.raw_block_ranges.len(),
map.compiled_block_ranges.len(),
));
}
let time_raw = map.raw_block_ranges[0].clone();
let marg_raw = map.raw_block_ranges[1].clone();
let log_raw = map.raw_block_ranges[2].clone();
let time_compiled = map.compiled_block_ranges[0].clone();
let marg_compiled = map.compiled_block_ranges[1].clone();
let log_compiled = map.compiled_block_ranges[2].clone();
let t = &map.raw_from_compiled;
let raw_total = t.nrows();
let compiled_total = t.ncols();
let expected_raw_total = log_raw.end;
if raw_total != expected_raw_total {
return Err(format!(
"apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
{expected_raw_total}"
));
}
let expected_compiled_total = log_compiled.end;
if compiled_total != expected_compiled_total {
return Err(format!(
"apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
sum to {expected_compiled_total}"
));
}
let v_time = t
.slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
.to_owned();
let v_marg = t
.slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
.to_owned();
let v_log = t
.slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
.to_owned();
let time_entry_out =
wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
let time_exit_out =
wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
let time_deriv_out = wrap_design_with_transform(
time_design_derivative_exit,
&v_time,
"compiled-map: time derivative_exit",
)?;
let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
let time_offset = time_raw.start;
let marg_offset = marg_raw.start;
let log_offset = log_raw.start;
let pull_set = |pens: &[crate::terms::smooth::BlockwisePenalty],
anchor_offset: usize|
-> Vec<PenaltyMatrix> {
pens.iter()
.map(|p| pull_back_penalty_through_t(p, anchor_offset, t))
.collect()
};
Ok(CompiledSurvivalDesignsVMExact {
time_design_entry: time_entry_out,
time_design_exit: time_exit_out,
time_design_derivative_exit: time_deriv_out,
marginal_design: marg_out,
logslope_design: log_out,
time_penalties: pull_set(time_penalties, time_offset),
marginal_penalties: pull_set(marginal_penalties, marg_offset),
logslope_penalties: pull_set(logslope_penalties, log_offset),
t_full: t.clone(),
})
}
pub fn compile_survival_parametric_designs(
time_dq0: Array2<f64>,
time_dq1: Array2<f64>,
time_dqd1: Array2<f64>,
marginal_dq: Array2<f64>,
marginal_dqd1: Array2<f64>,
logslope_dg: Array2<f64>,
row_hess: &dyn RowHessian,
) -> Result<SurvivalParametricCompiled, String> {
use crate::families::identifiability_compiler::compile;
let p_time_raw = time_dq0.ncols();
let p_marg_raw = marginal_dq.ncols();
let p_log_raw = logslope_dg.ncols();
let inputs = build_survival_compiler_inputs(
time_dq0,
time_dq1,
time_dqd1,
marginal_dq,
marginal_dqd1,
logslope_dg,
None,
None,
);
if inputs.operators.len() != 3 {
return Err(format!(
"compile_survival_parametric_designs: expected exactly 3 parametric operators \
(time, marginal, logslope); got {}",
inputs.operators.len(),
));
}
let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
.map_err(|e| format!("identifiability_compiler::compile failed: {e}"))?;
if compiled.blocks.len() != 3 {
return Err(format!(
"compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
compiled.blocks.len(),
));
}
let v_time = compiled.blocks[0].t_lw.clone();
let v_marginal = compiled.blocks[1].t_lw.clone();
let v_logslope = compiled.blocks[2].t_lw.clone();
let drops_by_block = (
p_time_raw.saturating_sub(v_time.ncols()),
p_marg_raw.saturating_sub(v_marginal.ncols()),
p_log_raw.saturating_sub(v_logslope.ncols()),
);
Ok(SurvivalParametricCompiled {
v_time,
v_marginal,
v_logslope,
drops_by_block,
})
}
pub fn build_survival_compiler_inputs(
time_dq0: Array2<f64>,
time_dq1: Array2<f64>,
time_dqd1: Array2<f64>,
marginal_dq: Array2<f64>,
marginal_dqd1: Array2<f64>,
logslope_dg: Array2<f64>,
score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
) -> SurvivalCompilerInputs {
let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
operators.push(Arc::new(TimeBlockOperator::new(
time_dq0, time_dq1, time_dqd1,
)));
ordering.push(BlockOrder::Time);
operators.push(Arc::new(QChannelBlockOperator::new(
marginal_dq,
marginal_dqd1,
)));
ordering.push(BlockOrder::Marginal);
operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
ordering.push(BlockOrder::Logslope);
if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
ordering.push(BlockOrder::ScoreWarp);
}
if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
ordering.push(BlockOrder::LinkDev);
}
SurvivalCompilerInputs {
operators,
ordering,
}
}
pub struct CompiledSurvivalDesignsVMExact {
pub time_design_entry: DesignMatrix,
pub time_design_exit: DesignMatrix,
pub time_design_derivative_exit: DesignMatrix,
pub marginal_design: DesignMatrix,
pub logslope_design: DesignMatrix,
pub time_penalties: Vec<PenaltyMatrix>,
pub marginal_penalties: Vec<PenaltyMatrix>,
pub logslope_penalties: Vec<PenaltyMatrix>,
pub t_full: Array2<f64>,
}
pub fn project_raw_beta_to_compiled(
t_full: &Array2<f64>,
k_struct: &Array2<f64>,
beta_raw: &Array1<f64>,
) -> Result<Array1<f64>, String> {
let p_raw = t_full.nrows();
let p_comp = t_full.ncols();
if k_struct.nrows() != p_raw || k_struct.ncols() != p_raw {
return Err(format!(
"project_raw_beta_to_compiled: K^S shape {}x{} mismatches T rows {}",
k_struct.nrows(),
k_struct.ncols(),
p_raw,
));
}
if beta_raw.len() != p_raw {
return Err(format!(
"project_raw_beta_to_compiled: beta_raw length {} mismatches T rows {}",
beta_raw.len(),
p_raw,
));
}
if p_comp == 0 {
return Ok(Array1::<f64>::zeros(0));
}
if p_raw == 0 {
return Ok(Array1::<f64>::zeros(p_comp));
}
let ks_t = fast_ab(k_struct, t_full);
let m = crate::linalg::faer_ndarray::fast_atb(t_full, &ks_t);
let mut m_sym = Array2::<f64>::zeros((p_comp, p_comp));
for i in 0..p_comp {
for j in 0..p_comp {
m_sym[[i, j]] = 0.5 * (m[[i, j]] + m[[j, i]]);
}
}
let ks_beta = crate::linalg::faer_ndarray::fast_av(k_struct, beta_raw);
let rhs = crate::linalg::faer_ndarray::fast_atv(t_full, &ks_beta);
let (evals, evecs) = m_sym
.eigh(Side::Lower)
.map_err(|e| format!("project_raw_beta_to_compiled: eigh failed: {e:?}"))?;
let max_abs = evals.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
let tol = (p_comp as f64) * f64::EPSILON * max_abs.max(1.0);
let ut_rhs = crate::linalg::faer_ndarray::fast_atv(&evecs, &rhs);
let mut scaled = Array1::<f64>::zeros(p_comp);
for i in 0..p_comp {
if evals[i].abs() > tol {
scaled[i] = ut_rhs[i] / evals[i];
}
}
let theta = crate::linalg::faer_ndarray::fast_av(&evecs, &scaled);
Ok(theta)
}
#[derive(Debug, Clone)]
pub struct RawPrimaryBlockDesign {
pub name: String,
pub q0: DesignMatrix,
pub q1: DesignMatrix,
pub qd1: DesignMatrix,
pub g: DesignMatrix,
}
#[derive(Debug, Clone)]
pub struct CompiledPrimaryBlockDesign {
pub name: String,
pub q0: DesignMatrix,
pub q1: DesignMatrix,
pub qd1: DesignMatrix,
pub g: DesignMatrix,
}
pub fn materialise_compiled_primary_blocks(
n_rows: usize,
raw_channel_blocks: &[RawPrimaryBlockDesign],
compiled_block_ranges: &[std::ops::Range<usize>],
raw_block_ranges: &[std::ops::Range<usize>],
t_full: &Array2<f64>,
) -> Result<Vec<CompiledPrimaryBlockDesign>, String> {
let n_blocks = raw_channel_blocks.len();
if compiled_block_ranges.len() != n_blocks {
return Err(format!(
"materialise_compiled_primary_blocks: compiled_block_ranges len {} != raw_channel_blocks len {}",
compiled_block_ranges.len(),
n_blocks,
));
}
if raw_block_ranges.len() != n_blocks {
return Err(format!(
"materialise_compiled_primary_blocks: raw_block_ranges len {} != raw_channel_blocks len {}",
raw_block_ranges.len(),
n_blocks,
));
}
let raw_joint_width: usize = raw_block_ranges.iter().map(|r| r.len()).sum();
let compiled_joint_width: usize = compiled_block_ranges.iter().map(|r| r.len()).sum();
if t_full.nrows() != raw_joint_width {
return Err(format!(
"materialise_compiled_primary_blocks: t_full has {} rows but raw joint width is {raw_joint_width}",
t_full.nrows(),
));
}
if t_full.ncols() != compiled_joint_width {
return Err(format!(
"materialise_compiled_primary_blocks: t_full has {} cols but compiled joint width is {compiled_joint_width}",
t_full.ncols(),
));
}
for (a, r) in raw_block_ranges.iter().enumerate() {
if r.end > raw_joint_width {
return Err(format!(
"materialise_compiled_primary_blocks: raw_block_ranges[{a}] = {r:?} exceeds raw_joint_width {raw_joint_width}",
));
}
}
for (b, r) in compiled_block_ranges.iter().enumerate() {
if r.end > compiled_joint_width {
return Err(format!(
"materialise_compiled_primary_blocks: compiled_block_ranges[{b}] = {r:?} exceeds compiled_joint_width {compiled_joint_width}",
));
}
}
let densify = |dm: &DesignMatrix,
raw_w: usize,
a: usize,
channel: &str|
-> Result<Array2<f64>, String> {
let dense = match dm {
DesignMatrix::Dense(d) => d.to_dense(),
DesignMatrix::Sparse(_) => dm
.try_to_dense_by_chunks(&format!(
"materialise_compiled_primary_blocks: densify raw block {a} channel {channel}"
))
.map_err(|e| {
format!(
"materialise_compiled_primary_blocks: densify raw block {a} channel {channel}: {e}"
)
})?,
};
if dense.nrows() != n_rows {
return Err(format!(
"materialise_compiled_primary_blocks: raw block {a} channel {channel} has {} rows, expected {n_rows}",
dense.nrows(),
));
}
if dense.ncols() != raw_w {
return Err(format!(
"materialise_compiled_primary_blocks: raw block {a} channel {channel} has {} cols, expected raw width {raw_w}",
dense.ncols(),
));
}
Ok(dense)
};
let mut raw_q0: Vec<Array2<f64>> = Vec::with_capacity(n_blocks);
let mut raw_q1: Vec<Array2<f64>> = Vec::with_capacity(n_blocks);
let mut raw_qd1: Vec<Array2<f64>> = Vec::with_capacity(n_blocks);
let mut raw_g: Vec<Array2<f64>> = Vec::with_capacity(n_blocks);
for (a, blk) in raw_channel_blocks.iter().enumerate() {
let raw_w = raw_block_ranges[a].len();
raw_q0.push(densify(&blk.q0, raw_w, a, "q0")?);
raw_q1.push(densify(&blk.q1, raw_w, a, "q1")?);
raw_qd1.push(densify(&blk.qd1, raw_w, a, "qd1")?);
raw_g.push(densify(&blk.g, raw_w, a, "g")?);
}
let mut out: Vec<CompiledPrimaryBlockDesign> = Vec::with_capacity(n_blocks);
for (b, comp_range) in compiled_block_ranges.iter().enumerate() {
let p_b = comp_range.len();
let mut acc_q0 = Array2::<f64>::zeros((n_rows, p_b));
let mut acc_q1 = Array2::<f64>::zeros((n_rows, p_b));
let mut acc_qd1 = Array2::<f64>::zeros((n_rows, p_b));
let mut acc_g = Array2::<f64>::zeros((n_rows, p_b));
for (a, raw_range) in raw_block_ranges.iter().enumerate() {
if raw_range.is_empty() || p_b == 0 {
continue;
}
let t_ab = t_full
.slice(ndarray::s![raw_range.clone(), comp_range.clone()])
.to_owned();
acc_q0 = acc_q0 + fast_ab(&raw_q0[a], &t_ab);
acc_q1 = acc_q1 + fast_ab(&raw_q1[a], &t_ab);
acc_qd1 = acc_qd1 + fast_ab(&raw_qd1[a], &t_ab);
acc_g = acc_g + fast_ab(&raw_g[a], &t_ab);
}
out.push(CompiledPrimaryBlockDesign {
name: raw_channel_blocks[b].name.clone(),
q0: DesignMatrix::Dense(DenseDesignMatrix::from(acc_q0)),
q1: DesignMatrix::Dense(DenseDesignMatrix::from(acc_q1)),
qd1: DesignMatrix::Dense(DenseDesignMatrix::from(acc_qd1)),
g: DesignMatrix::Dense(DenseDesignMatrix::from(acc_g)),
});
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn psd_clamp_zeros_negative_eigenvalues() {
let mut m = Array2::<f64>::zeros((4, 4));
m[[0, 0]] = 2.0;
m[[1, 1]] = -1.0;
m[[2, 2]] = 0.5;
m[[3, 3]] = -0.25;
let clamped = psd_clamp_4x4(&m);
assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
assert!(clamped[[1, 1]].abs() < 1e-12);
assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
assert!(clamped[[3, 3]].abs() < 1e-12);
}
#[test]
fn time_block_operator_evaluate_full_shape() {
let n = 6;
let p = 3;
let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
let full = op.evaluate_full();
assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
for i in 0..n {
for j in 0..p {
assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
assert_eq!(full[[i, j, 3]], 0.0);
}
}
}
#[test]
fn q_channel_block_apply_row_shares_q0_q1() {
let n = 5;
let p = 2;
let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
let mut out = [0.0_f64; K_SURVIVAL];
let delta = [1.0_f64, -0.5];
op.apply_row(3, &delta, &mut out);
let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
assert!((out[0] - want_q).abs() < 1e-12);
assert!((out[1] - want_q).abs() < 1e-12);
assert!((out[2] - want_qd).abs() < 1e-12);
assert_eq!(out[3], 0.0);
}
#[test]
fn logslope_block_writes_only_g_channel() {
let n = 4;
let p = 2;
let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
let op = LogslopeBlockOperator::new(dg.clone());
let mut out = [0.0_f64; K_SURVIVAL];
let delta = [2.0_f64, -1.0];
op.apply_row(1, &delta, &mut out);
assert_eq!(out[0], 0.0);
assert_eq!(out[1], 0.0);
assert_eq!(out[2], 0.0);
let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
assert!((out[3] - want).abs() < 1e-12);
}
#[test]
#[allow(clippy::single_range_in_vec_init)]
fn extract_term_partition_simple_cases() {
let part = extract_term_partition_from_penalty_ranges(5, &[]);
assert_eq!(part, vec![0..5]);
let part = extract_term_partition_from_penalty_ranges(5, &[0..5]);
assert_eq!(part, vec![0..5]);
let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
assert_eq!(part, vec![0..3, 3..6, 6..10]);
let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
assert_eq!(part, vec![0..3, 3..6]);
let part = extract_term_partition_from_penalty_ranges(0, &[]);
assert!(part.is_empty());
}
#[test]
fn block_diag_from_assembles_correctly() {
let v1 = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| (i * 2 + j + 1) as f64);
let v2 = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| (10 + i * 2 + j) as f64);
let bd = block_diag_from(&[v1.clone(), v2.clone()]);
assert_eq!(bd.dim(), (5, 4));
for i in 0..3 {
for j in 0..2 {
assert_eq!(bd[[i, j]], v1[[i, j]]);
assert_eq!(bd[[i, 2 + j]], 0.0);
}
}
for i in 0..2 {
for j in 0..2 {
assert_eq!(bd[[3 + i, 2 + j]], v2[[i, j]]);
assert_eq!(bd[[3 + i, j]], 0.0);
}
}
}
#[test]
#[allow(clippy::single_range_in_vec_init)]
fn pull_back_blockwise_penalty_per_term_full_term_identity_v() {
use crate::terms::smooth::BlockwisePenalty;
let v_term = Array2::<f64>::eye(3);
let v_per_term = vec![v_term];
let partition = vec![0..3];
let local = Array2::<f64>::from_shape_fn((3, 3), |(i, j)| (i + j) as f64);
let pen = BlockwisePenalty::new(0..3, local.clone());
let out = pull_back_blockwise_penalty_per_term(&pen, &partition, &v_per_term)
.expect("identity-V pullback must succeed");
assert_eq!(out.col_range, 0..3);
for i in 0..3 {
for j in 0..3 {
assert!((out.local[[i, j]] - 0.5 * (local[[i, j]] + local[[j, i]])).abs() < 1e-12);
}
}
}
#[test]
#[allow(clippy::single_range_in_vec_init)]
fn pull_back_blockwise_penalty_per_term_drops_one_column() {
use crate::terms::smooth::BlockwisePenalty;
let mut v_term = Array2::<f64>::zeros((3, 2));
v_term[[1, 0]] = 1.0;
v_term[[2, 1]] = 1.0;
let partition = vec![0..3];
let local = Array2::<f64>::from_shape_fn((3, 3), |(i, j)| (i + j + 1) as f64);
let pen = BlockwisePenalty::new(0..3, local.clone());
let out = pull_back_blockwise_penalty_per_term(&pen, &partition, &[v_term])
.expect("selection-V pullback must succeed");
assert_eq!(out.col_range, 0..2);
let sym = |i: usize, j: usize| 0.5 * (local[[i, j]] + local[[j, i]]);
assert!((out.local[[0, 0]] - sym(1, 1)).abs() < 1e-12);
assert!((out.local[[0, 1]] - sym(1, 2)).abs() < 1e-12);
assert!((out.local[[1, 0]] - sym(2, 1)).abs() < 1e-12);
assert!((out.local[[1, 1]] - sym(2, 2)).abs() < 1e-12);
}
#[test]
fn pull_back_blockwise_penalty_per_term_routes_to_correct_term() {
use crate::terms::smooth::BlockwisePenalty;
let v0 = Array2::<f64>::eye(2);
let mut v1 = Array2::<f64>::zeros((3, 2));
v1[[0, 0]] = 1.0;
v1[[2, 1]] = 1.0;
let partition = vec![0..2, 2..5];
let v_per_term = vec![v0, v1];
let local1 = Array2::<f64>::from_shape_fn((3, 3), |(i, j)| (10 + i + j) as f64);
let pen1 = BlockwisePenalty::new(2..5, local1.clone());
let out1 = pull_back_blockwise_penalty_per_term(&pen1, &partition, &v_per_term)
.expect("term1 pullback must succeed");
assert_eq!(out1.col_range, 2..4);
let sym = |i: usize, j: usize| 0.5 * (local1[[i, j]] + local1[[j, i]]);
assert!((out1.local[[0, 0]] - sym(0, 0)).abs() < 1e-12);
assert!((out1.local[[0, 1]] - sym(0, 2)).abs() < 1e-12);
assert!((out1.local[[1, 0]] - sym(2, 0)).abs() < 1e-12);
assert!((out1.local[[1, 1]] - sym(2, 2)).abs() < 1e-12);
}
#[test]
fn build_full_t_matrix_identity_when_v_eye_and_r_none() {
let v_a = Array2::<f64>::eye(2);
let v_b = Array2::<f64>::eye(2);
let t = build_full_t_matrix(&[v_a, v_b], &[None, None]);
assert_eq!(t.dim(), (4, 4));
let eye4 = Array2::<f64>::eye(4);
for i in 0..4 {
for j in 0..4 {
assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
}
}
}
#[test]
fn build_full_t_matrix_with_drops_and_nonzero_r() {
let mut v_a = Array2::<f64>::zeros((3, 2));
v_a[[0, 0]] = 1.0;
v_a[[1, 0]] = 0.5;
v_a[[2, 1]] = 1.0;
let v_b = Array2::<f64>::eye(2);
let r_ab =
Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
let t = build_full_t_matrix(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
assert_eq!(t.dim(), (5, 4));
for i in 0..3 {
for j in 0..2 {
assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
}
}
for i in 0..2 {
for j in 0..2 {
assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
}
}
for i in 0..3 {
for j in 0..2 {
assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
}
}
for i in 0..2 {
for j in 0..2 {
assert_eq!(t[[3 + i, j]], 0.0);
}
}
}
#[test]
fn build_global_t_matrix_identity_when_all_v_eye_and_all_r_none() {
let mk_block = || GlobalTBlock {
v: Array2::<f64>::eye(2),
r_against_earlier: None,
};
let spec = GlobalTSpec {
blocks: vec![
vec![mk_block(), mk_block()],
vec![mk_block(), mk_block()],
vec![mk_block(), mk_block()],
],
};
let t = build_global_t_matrix(&spec);
let total: usize = 2 * 6;
assert_eq!(t.dim(), (total, total));
let eye = Array2::<f64>::eye(total);
for i in 0..total {
for j in 0..total {
assert!((t[[i, j]] - eye[[i, j]]).abs() < 1e-14);
}
}
}
#[test]
fn build_global_t_matrix_block_upper_triangular_with_cross_block_r() {
let r_block = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| 1.0 + (i + 2 * j) as f64);
let spec = GlobalTSpec {
blocks: vec![
vec![GlobalTBlock {
v: Array2::<f64>::eye(2),
r_against_earlier: None,
}],
vec![GlobalTBlock {
v: Array2::<f64>::eye(2),
r_against_earlier: Some(r_block.clone()),
}],
],
};
let t = build_global_t_matrix(&spec);
assert_eq!(t.dim(), (4, 4));
for i in 0..2 {
for j in 0..2 {
let want_diag_a = if i == j { 1.0 } else { 0.0 };
let want_diag_b = if i == j { 1.0 } else { 0.0 };
assert!((t[[i, j]] - want_diag_a).abs() < 1e-14);
assert!((t[[2 + i, 2 + j]] - want_diag_b).abs() < 1e-14);
assert!((t[[i, 2 + j]] + r_block[[i, j]]).abs() < 1e-14);
assert_eq!(t[[2 + i, j]], 0.0);
}
}
}
#[test]
fn pull_back_penalty_through_t_identity_returns_zero_embedded_raw() {
use crate::terms::smooth::BlockwisePenalty;
let v_a = Array2::<f64>::eye(2);
let v_b = Array2::<f64>::eye(2);
let t = build_full_t_matrix(&[v_a, v_b], &[None, None]);
let local = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| (i + 2 * j + 1) as f64);
let pen = BlockwisePenalty::new(0..2, local.clone());
let anchor_offset = 2usize;
let out = pull_back_penalty_through_t(&pen, anchor_offset, &t);
let PenaltyMatrix::Dense(dense) = out else {
panic!("expected PenaltyMatrix::Dense");
};
assert_eq!(dense.dim(), (4, 4));
let sym_local = |i: usize, j: usize| 0.5 * (local[[i, j]] + local[[j, i]]);
for i in 0..4 {
for j in 0..4 {
let want = if i >= 2 && j >= 2 {
sym_local(i - 2, j - 2)
} else {
0.0
};
assert!(
(dense[[i, j]] - want).abs() < 1e-14,
"mismatch at ({i},{j}): got {}, want {}",
dense[[i, j]],
want,
);
}
}
}
#[test]
fn pull_back_penalty_through_t_nontrivial_t_has_off_block_coupling() {
use crate::terms::smooth::BlockwisePenalty;
let v_a = Array2::<f64>::eye(2);
let v_b = Array2::<f64>::eye(2);
let r_ab = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| 1.0 + (i + j) as f64);
let t = build_full_t_matrix(&[v_a, v_b], &[None, Some(r_ab.clone())]);
let mut local = Array2::<f64>::zeros((2, 2));
local[[0, 0]] = 2.0;
local[[1, 1]] = 3.0;
let pen = BlockwisePenalty::new(0..2, local.clone());
let out = pull_back_penalty_through_t(&pen, 0, &t);
let PenaltyMatrix::Dense(dense) = out else {
panic!("expected PenaltyMatrix::Dense");
};
assert_eq!(dense.dim(), (4, 4));
let mut any_nonzero = false;
for i in 0..2 {
for j in 0..2 {
if dense[[i, 2 + j]].abs() > 1e-10 {
any_nonzero = true;
}
}
}
assert!(any_nonzero, "expected nonzero off-block coupling");
let want = local.dot(&r_ab).map(|v| -v);
for i in 0..2 {
for j in 0..2 {
assert!(
(dense[[i, 2 + j]] - want[[i, j]]).abs() < 1e-12,
"off-block (a,b) mismatch at ({i},{j}): got {}, want {}",
dense[[i, 2 + j]],
want[[i, j]],
);
assert!((dense[[2 + j, i]] - want[[i, j]]).abs() < 1e-12);
}
}
}
#[test]
fn pull_back_penalty_through_t_round_trip_quadratic_form() {
use crate::terms::smooth::BlockwisePenalty;
let v_a = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
((i + 1) as f64).sin() + 0.3 * (j as f64)
});
let v_b = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| 1.0 + 0.1 * ((i * 2 + j) as f64));
let r_ab = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
0.5 - 0.2 * (i as f64) + 0.1 * (j as f64)
});
let t = build_full_t_matrix(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
let raw_local =
Array2::<f64>::from_shape_fn(
(2, 2),
|(i, j)| {
if i == j { 2.5 + i as f64 } else { 0.4 }
},
);
let pen = BlockwisePenalty::new(0..2, raw_local.clone());
let anchor = 3usize;
let out = pull_back_penalty_through_t(&pen, anchor, &t);
let PenaltyMatrix::Dense(s_compiled) = out else {
panic!("expected Dense");
};
let raw_total = t.nrows();
let mut s_raw_emb = Array2::<f64>::zeros((raw_total, raw_total));
for i in 0..2 {
for j in 0..2 {
s_raw_emb[[anchor + i, anchor + j]] = raw_local[[i, j]];
}
}
let mut s_raw_sym = Array2::<f64>::zeros((raw_total, raw_total));
for i in 0..raw_total {
for j in 0..raw_total {
s_raw_sym[[i, j]] = 0.5 * (s_raw_emb[[i, j]] + s_raw_emb[[j, i]]);
}
}
let theta = Array1::<f64>::from_shape_fn(t.ncols(), |k| ((k as f64) * 0.7 - 0.3).cos());
let gamma = t.dot(&theta);
let lhs = theta.dot(&s_compiled.dot(&theta));
let rhs = gamma.dot(&s_raw_sym.dot(&gamma));
assert!(
(lhs - rhs).abs() < 1e-10,
"round-trip mismatch: lhs={lhs}, rhs={rhs}",
);
}
#[test]
#[allow(clippy::single_range_in_vec_init)]
fn validate_partition_rejects_bad_partitions() {
assert!(validate_partition(&[1..5], 5, "test").is_err());
assert!(validate_partition(&[0..3], 5, "test").is_err());
assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
assert!(validate_partition(&[], 0, "test").is_ok());
assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
assert!(validate_partition(&[0..5], 5, "test").is_ok());
}
#[test]
fn apply_compile_produces_width_consistent_designs_and_penalties() {
use crate::families::custom_family::PenaltyMatrix;
use crate::linalg::matrix::DenseDesignMatrix;
let n = 16;
let p_time = 3;
let p_marginal = 3;
let p_logslope = 2;
let x: Vec<f64> = (0..n)
.map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
.collect();
let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
for i in 0..n {
time_dq0[[i, 0]] = 1.0;
time_dq0[[i, 1]] = x[i];
time_dq0[[i, 2]] = x[i] * x[i];
time_dq1[[i, 0]] = 1.0;
time_dq1[[i, 1]] = x[i];
time_dq1[[i, 2]] = x[i] * x[i];
time_dqd1[[i, 0]] = 0.0;
time_dqd1[[i, 1]] = 1.0;
time_dqd1[[i, 2]] = 2.0 * x[i];
marg_dq[[i, 0]] = 1.0;
marg_dq[[i, 1]] = x[i] * x[i] * x[i];
marg_dq[[i, 2]] = x[i].sin();
log_dg[[i, 0]] = (2.0 * x[i]).cos();
log_dg[[i, 1]] = x[i].tanh();
}
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
for k in 0..K_SURVIVAL {
h_full[[i, k, k]] = 1.0;
}
}
let row_hess = SurvivalRowHessian::from_full(h_full);
let compiled = compile_survival_parametric_designs(
time_dq0.clone(),
time_dq1.clone(),
time_dqd1.clone(),
marg_dq.clone(),
marg_dqd1.clone(),
log_dg.clone(),
&row_hess,
)
.expect("compile must succeed");
let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(time_dq0.clone()));
let raw_time_exit = DesignMatrix::Dense(DenseDesignMatrix::from(time_dq1.clone()));
let raw_time_deriv = DesignMatrix::Dense(DenseDesignMatrix::from(time_dqd1.clone()));
let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(marg_dq.clone()));
let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(log_dg.clone()));
let time_pens = vec![PenaltyMatrix::Dense(Array2::<f64>::from_shape_fn(
(p_time, p_time),
|(i, j)| if i == j { (i + 1) as f64 } else { 0.0 },
))];
let marg_pens = vec![PenaltyMatrix::Dense(Array2::<f64>::from_shape_fn(
(p_marginal, p_marginal),
|(i, j)| if i == j { (i + 1) as f64 } else { 0.0 },
))];
let log_pens = vec![PenaltyMatrix::Dense(Array2::<f64>::from_shape_fn(
(p_logslope, p_logslope),
|(i, j)| if i == j { (i + 1) as f64 } else { 0.0 },
))];
let out = apply_survival_parametric_compile_to_designs(
&compiled,
raw_time_entry,
raw_time_exit,
raw_time_deriv,
raw_marg,
raw_log,
&time_pens,
&marg_pens,
&log_pens,
)
.expect("apply must succeed");
assert_eq!(out.time_design_entry.ncols(), compiled.v_time.ncols());
assert_eq!(out.time_design_exit.ncols(), compiled.v_time.ncols());
assert_eq!(
out.time_design_derivative_exit.ncols(),
compiled.v_time.ncols()
);
assert_eq!(out.marginal_design.ncols(), compiled.v_marginal.ncols());
assert_eq!(out.logslope_design.ncols(), compiled.v_logslope.ncols());
for s in &out.time_penalties {
let dense = s.as_dense_cow();
assert_eq!(
dense.dim(),
(compiled.v_time.ncols(), compiled.v_time.ncols())
);
}
for s in &out.marginal_penalties {
let dense = s.as_dense_cow();
assert_eq!(
dense.dim(),
(compiled.v_marginal.ncols(), compiled.v_marginal.ncols())
);
}
for s in &out.logslope_penalties {
let dense = s.as_dense_cow();
assert_eq!(
dense.dim(),
(compiled.v_logslope.ncols(), compiled.v_logslope.ncols())
);
}
assert_eq!(out.time_design_entry.nrows(), n);
assert_eq!(out.marginal_design.nrows(), n);
assert_eq!(out.logslope_design.nrows(), n);
}
#[test]
fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
let n = 24;
let p_time = 3;
let p_marginal = 3;
let p_logslope = 2;
let x: Vec<f64> = (0..n)
.map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
.collect();
let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
for i in 0..n {
time_dq0[[i, 0]] = 1.0;
time_dq0[[i, 1]] = x[i];
time_dq0[[i, 2]] = x[i] * x[i];
time_dq1[[i, 0]] = 1.0;
time_dq1[[i, 1]] = x[i];
time_dq1[[i, 2]] = x[i] * x[i];
time_dqd1[[i, 0]] = 0.0;
time_dqd1[[i, 1]] = 1.0;
time_dqd1[[i, 2]] = 2.0 * x[i];
marg_dq[[i, 0]] = 1.0; marg_dq[[i, 1]] = x[i] * x[i] * x[i];
marg_dq[[i, 2]] = x[i].sin();
log_dg[[i, 0]] = (2.0 * x[i]).cos();
log_dg[[i, 1]] = x[i].tanh();
}
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
for k in 0..K_SURVIVAL {
h_full[[i, k, k]] = 1.0;
}
}
let row_hess = SurvivalRowHessian::from_full(h_full);
let out = compile_survival_parametric_designs(
time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
)
.expect("Phase-4b parametric compile must succeed on single-direction alias");
assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
assert_eq!(
out.v_marginal.ncols(),
p_marginal - 1,
"marginal loses exactly the shared-constant direction"
);
assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
assert_eq!(
out.drops_by_block,
(0, 1, 0),
"attribution: zero from time/logslope, one from marginal",
);
}
#[test]
fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
use crate::families::identifiability_compiler::compile;
let n = 32;
let p_time = 3;
let p_marginal = 3;
let p_logslope = 2;
let x: Vec<f64> = (0..n)
.map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
.collect();
let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
for i in 0..n {
time_dq0[[i, 0]] = 1.0;
time_dq0[[i, 1]] = x[i];
time_dq0[[i, 2]] = x[i] * x[i];
time_dq1[[i, 0]] = 1.0;
time_dq1[[i, 1]] = x[i];
time_dq1[[i, 2]] = x[i] * x[i];
time_dqd1[[i, 0]] = 0.0;
time_dqd1[[i, 1]] = 1.0;
time_dqd1[[i, 2]] = 2.0 * x[i];
}
let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
for i in 0..n {
marg_dq[[i, 0]] = 1.0;
marg_dq[[i, 1]] = x[i] * x[i] * x[i];
marg_dq[[i, 2]] = x[i].sin();
}
let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
for i in 0..n {
log_dg[[i, 0]] = (2.0 * x[i]).cos();
log_dg[[i, 1]] = x[i].tanh();
}
let inputs = build_survival_compiler_inputs(
time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
);
let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
for k in 0..K_SURVIVAL {
h_full[[i, k, k]] = 1.0;
}
}
let row_hess = SurvivalRowHessian::from_full(h_full);
let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
.expect("survival 3-block compile must succeed; aliasing is single-direction");
assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
let v_time = &compiled.blocks[0].t_lw;
assert_eq!(
v_time.ncols(),
p_time,
"time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
v_time.dim(),
);
let v_marg = &compiled.blocks[1].t_lw;
assert_eq!(
v_marg.ncols(),
p_marginal - 1,
"marginal block must lose exactly the shared-constant direction; \
V_marginal cols = {}, expected {}",
v_marg.ncols(),
p_marginal - 1,
);
let v_log = &compiled.blocks[2].t_lw;
assert_eq!(
v_log.ncols(),
p_logslope,
"logslope block (no shared direction) must retain all {p_logslope} columns",
);
let raw_total = p_time + p_marginal + p_logslope;
let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
assert_eq!(
kept_total,
raw_total - 1,
"joint kept = raw_total − aliased; got {kept_total}, expected {}",
raw_total - 1,
);
assert_eq!(
compiled.joint_rank, kept_total,
"CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
);
}
#[test]
fn smgs_lift_via_t_identity_passes_through() {
let v0 = Array2::<f64>::eye(3);
let v1 = Array2::<f64>::eye(2);
let v_per_term = vec![v0, v1];
let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
let lift = SmgsLiftViaT::from_v_and_r(&v_per_term, &r_per_term);
assert_eq!(lift.t_full.dim(), (5, 5));
assert_eq!(lift.block_starts_compiled, vec![0, 3, 5]);
assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
for i in 0..5 {
for j in 0..5 {
let want = if i == j { 1.0 } else { 0.0 };
assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
}
}
let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
let lifted = lift.lift_block_betas_via_t(&[theta_0.clone(), theta_1.clone()]);
assert_eq!(lifted.len(), 2);
for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
assert!((a - b).abs() < 1e-14);
}
for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
assert!((a - b).abs() < 1e-14);
}
}
#[test]
fn smgs_lift_via_t_two_block_with_residualisation() {
let v_a = Array2::<f64>::eye(3);
let mut v_b = Array2::<f64>::zeros((3, 2));
v_b[[0, 0]] = 1.0;
v_b[[2, 1]] = 1.0;
let mut r_b = Array2::<f64>::zeros((3, 2));
r_b[[0, 0]] = 0.4;
r_b[[0, 1]] = -0.1;
r_b[[1, 0]] = 0.7;
r_b[[1, 1]] = 1.3;
r_b[[2, 0]] = -0.2;
r_b[[2, 1]] = 0.5;
let lift =
SmgsLiftViaT::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
assert_eq!(lift.t_full.dim(), (6, 5));
assert_eq!(lift.block_starts_compiled, vec![0, 3, 5]);
assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
let theta_b = Array1::from(vec![0.5_f64, -0.25]);
let lifted = lift.lift_block_betas_via_t(&[theta_a.clone(), theta_b.clone()]);
let r_theta_b = r_b.dot(&theta_b);
let expected_a = &theta_a - &r_theta_b;
assert_eq!(lifted[0].len(), 3);
for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
}
assert_eq!(lifted[1].len(), 3);
assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
assert!(lifted[1][1].abs() < 1e-12);
assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
}
#[test]
fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
let mut v_a = Array2::<f64>::zeros((3, 2));
v_a[[0, 0]] = 0.6;
v_a[[1, 0]] = -0.8;
v_a[[1, 1]] = 0.3;
v_a[[2, 1]] = 0.9;
let mut v_b = Array2::<f64>::zeros((4, 3));
v_b[[0, 0]] = 1.0;
v_b[[1, 1]] = -0.4;
v_b[[2, 0]] = 0.2;
v_b[[2, 2]] = 0.7;
v_b[[3, 2]] = -1.1;
let v_per_term = vec![v_a.clone(), v_b.clone()];
let lift = SmgsLiftViaT::from_v_and_r(&v_per_term, &[None, None]);
let theta_a = Array1::from(vec![0.3_f64, -1.4]);
let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
let via_t = lift.lift_block_betas_via_t(&[theta_a.clone(), theta_b.clone()]);
let ref_a = v_a.dot(&theta_a);
let ref_b = v_b.dot(&theta_b);
assert_eq!(via_t[0].len(), ref_a.len());
for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
assert!((g - w).abs() < 1e-12);
}
assert_eq!(via_t[1].len(), ref_b.len());
for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
assert!((g - w).abs() < 1e-12);
}
let per_block = SmgsLiftPerBlockV {
v_per_block: v_per_term,
};
let mut block_betas = vec![theta_a, theta_b];
per_block.lift_block_betas(&mut block_betas);
for (got, want) in via_t[0].iter().zip(block_betas[0].iter()) {
assert!((got - want).abs() < 1e-12);
}
for (got, want) in via_t[1].iter().zip(block_betas[1].iter()) {
assert!((got - want).abs() < 1e-12);
}
}
#[test]
fn project_raw_beta_identity_t_and_k_is_identity() {
let p = 5;
let t_full = Array2::<f64>::eye(p);
let k_struct = Array2::<f64>::eye(p);
let beta_raw = Array1::from_vec(vec![0.3_f64, -1.1, 2.5, 0.0, 4.2]);
let theta = project_raw_beta_to_compiled(&t_full, &k_struct, &beta_raw)
.expect("identity projection should succeed");
assert_eq!(theta.len(), p);
for i in 0..p {
assert!(
(theta[i] - beta_raw[i]).abs() < 1e-12,
"identity projection mismatch at {i}: got {} expected {}",
theta[i],
beta_raw[i]
);
}
}
#[test]
fn project_raw_beta_column_selector_recovers_kept_indices() {
let p_raw = 5;
let kept = [0usize, 2, 4];
let p_comp = kept.len();
let mut t_full = Array2::<f64>::zeros((p_raw, p_comp));
for (j, &i) in kept.iter().enumerate() {
t_full[[i, j]] = 1.0;
}
let mut k_struct = Array2::<f64>::zeros((p_raw, p_raw));
for i in 0..p_raw {
k_struct[[i, i]] = (i as f64 + 1.0) * 0.7;
}
let beta_raw = Array1::from_vec(vec![1.5_f64, -0.4, 2.0, 9.9, -3.3]);
let theta = project_raw_beta_to_compiled(&t_full, &k_struct, &beta_raw)
.expect("selector projection should succeed");
for (j, &i) in kept.iter().enumerate() {
assert!(
(theta[j] - beta_raw[i]).abs() < 1e-10,
"selector projection at compiled idx {j} (raw {i}): got {} expected {}",
theta[j],
beta_raw[i]
);
}
}
#[test]
fn project_raw_beta_roundtrip_on_structural_positive_subspace() {
let p_raw = 6;
let p_comp = 3;
let raw_cols = Array2::from_shape_fn((p_raw, p_comp), |(i, j)| {
((i as f64) * 0.31 + (j as f64) * 1.7 + 0.5).sin() + (j as f64)
});
let mut t_full = raw_cols.clone();
for j in 0..p_comp {
for k in 0..j {
let mut dot = 0.0;
for i in 0..p_raw {
dot += t_full[[i, j]] * t_full[[i, k]];
}
for i in 0..p_raw {
t_full[[i, j]] -= dot * t_full[[i, k]];
}
}
let mut nrm = 0.0;
for i in 0..p_raw {
nrm += t_full[[i, j]].powi(2);
}
nrm = nrm.sqrt();
assert!(nrm > 1e-10, "column {j} collapsed in QR");
for i in 0..p_raw {
t_full[[i, j]] /= nrm;
}
}
let d = Array1::from_vec(vec![2.0_f64, 0.7, 1.3]);
let mut k_struct = Array2::<f64>::zeros((p_raw, p_raw));
for i in 0..p_raw {
for j in 0..p_raw {
let mut s = 0.0;
for r in 0..p_comp {
s += t_full[[i, r]] * d[r] * t_full[[j, r]];
}
k_struct[[i, j]] = s;
}
k_struct[[i, i]] += 1e-6;
}
let theta_true = Array1::from_vec(vec![0.7_f64, -1.4, 2.1]);
let beta_raw = crate::linalg::faer_ndarray::fast_av(&t_full, &theta_true);
let theta = project_raw_beta_to_compiled(&t_full, &k_struct, &beta_raw)
.expect("round-trip projection should succeed");
for i in 0..p_comp {
assert!(
(theta[i] - theta_true[i]).abs() < 1e-8,
"round-trip mismatch at {i}: got {} expected {}",
theta[i],
theta_true[i]
);
}
}
fn make_raw_block(
name: &str,
n_rows: usize,
raw_w: usize,
q0: Option<Array2<f64>>,
q1: Option<Array2<f64>>,
qd1: Option<Array2<f64>>,
g: Option<Array2<f64>>,
) -> RawPrimaryBlockDesign {
let zero = || Array2::<f64>::zeros((n_rows, raw_w));
let wrap =
|a: Array2<f64>| -> DesignMatrix { DesignMatrix::Dense(DenseDesignMatrix::from(a)) };
RawPrimaryBlockDesign {
name: name.to_string(),
q0: wrap(q0.unwrap_or_else(zero)),
q1: wrap(q1.unwrap_or_else(zero)),
qd1: wrap(qd1.unwrap_or_else(zero)),
g: wrap(g.unwrap_or_else(zero)),
}
}
fn dense_of(dm: &DesignMatrix) -> Array2<f64> {
match dm {
DesignMatrix::Dense(d) => d.to_dense(),
DesignMatrix::Sparse(_) => panic!("expected Dense"),
}
}
#[test]
fn materialise_compiled_primary_blocks_block_diagonal_t_no_cross_block() {
let n = 5;
let raw_ranges: Vec<std::ops::Range<usize>> = vec![0..2, 2..4, 4..6];
let compiled_ranges: Vec<std::ops::Range<usize>> = vec![0..2, 2..4, 4..6];
let x_time_q0 = Array2::from_shape_fn((n, 2), |(i, j)| (i + j) as f64 + 1.0);
let x_time_q1 = Array2::from_shape_fn((n, 2), |(i, j)| 2.0 * (i as f64) + j as f64);
let x_time_qd1 = Array2::from_shape_fn((n, 2), |(i, j)| 0.5 * (i as f64) - j as f64);
let x_marg_q0 = Array2::from_shape_fn((n, 2), |(i, j)| 0.3 + (i * 3 + j) as f64);
let x_marg_q1 = x_marg_q0.clone();
let x_log_g = Array2::from_shape_fn((n, 2), |(i, j)| 7.0 - (i as f64) + 0.1 * (j as f64));
let raw_blocks = vec![
make_raw_block(
"time",
n,
2,
Some(x_time_q0.clone()),
Some(x_time_q1.clone()),
Some(x_time_qd1.clone()),
None,
),
make_raw_block(
"marginal",
n,
2,
Some(x_marg_q0.clone()),
Some(x_marg_q1.clone()),
None,
None,
),
make_raw_block("logslope", n, 2, None, None, None, Some(x_log_g.clone())),
];
let mut t_full = Array2::<f64>::zeros((6, 6));
for i in 0..6 {
t_full[[i, i]] = 1.0;
}
let out = materialise_compiled_primary_blocks(
n,
&raw_blocks,
&compiled_ranges,
&raw_ranges,
&t_full,
)
.expect("materialise should succeed on block-diagonal T");
assert_eq!(out.len(), 3);
let comp_marg_q0 = dense_of(&out[1].q0);
assert_eq!(comp_marg_q0.shape(), &[n, 2]);
for i in 0..n {
for j in 0..2 {
assert!(
(comp_marg_q0[[i, j]] - x_marg_q0[[i, j]]).abs() < 1e-12,
"block-diag T: compiled marg q0 mismatch at ({i},{j}): got {} want {}",
comp_marg_q0[[i, j]],
x_marg_q0[[i, j]],
);
}
}
let comp_time_q0 = dense_of(&out[0].q0);
for i in 0..n {
for j in 0..2 {
assert!((comp_time_q0[[i, j]] - x_time_q0[[i, j]]).abs() < 1e-12);
}
}
let comp_log_q0 = dense_of(&out[2].q0);
for i in 0..n {
for j in 0..2 {
assert!(comp_log_q0[[i, j]].abs() < 1e-14);
}
}
}
#[test]
fn materialise_compiled_primary_blocks_with_anchor_subtraction_pulls_q0_into_logslope() {
let n = 4;
let raw_ranges: Vec<std::ops::Range<usize>> = vec![0..2, 2..4, 4..6];
let compiled_ranges: Vec<std::ops::Range<usize>> = vec![0..2, 2..4, 4..6];
let x_time_q0 = Array2::from_shape_fn((n, 2), |(i, j)| (i + 2 * j + 1) as f64);
let x_marg_q0 = Array2::from_shape_fn((n, 2), |(i, j)| 0.5 + (i as f64) - (j as f64));
let x_log_g = Array2::from_shape_fn((n, 2), |(i, j)| 1.0 + 0.25 * (i + j) as f64);
let raw_blocks = vec![
make_raw_block("time", n, 2, Some(x_time_q0.clone()), None, None, None),
make_raw_block("marginal", n, 2, Some(x_marg_q0.clone()), None, None, None),
make_raw_block("logslope", n, 2, None, None, None, Some(x_log_g.clone())),
];
let r_time_log = Array2::from_shape_vec((2, 2), vec![0.7, -0.3, 0.1, 0.9]).unwrap();
let mut t_full = Array2::<f64>::zeros((6, 6));
for i in 0..6 {
t_full[[i, i]] = 1.0;
}
for i in 0..2 {
for j in 0..2 {
t_full[[i, 4 + j]] = -r_time_log[[i, j]];
}
}
let out = materialise_compiled_primary_blocks(
n,
&raw_blocks,
&compiled_ranges,
&raw_ranges,
&t_full,
)
.expect("materialise should succeed with anchor subtraction");
let comp_log_q0 = dense_of(&out[2].q0);
assert_eq!(comp_log_q0.shape(), &[n, 2]);
let neg_r = -&r_time_log;
let expected = fast_ab(&x_time_q0, &neg_r);
let mut max_abs = 0.0_f64;
for i in 0..n {
for j in 0..2 {
max_abs = max_abs.max(expected[[i, j]].abs());
}
}
assert!(
max_abs > 1e-6,
"test construction error: expected compiled logslope q0 to be nonzero",
);
for i in 0..n {
for j in 0..2 {
assert!(
(comp_log_q0[[i, j]] - expected[[i, j]]).abs() < 1e-12,
"compiled logslope q0 mismatch at ({i},{j}): got {} want {}",
comp_log_q0[[i, j]],
expected[[i, j]],
);
}
}
let comp_log_g = dense_of(&out[2].g);
for i in 0..n {
for j in 0..2 {
assert!((comp_log_g[[i, j]] - x_log_g[[i, j]]).abs() < 1e-12);
}
}
}
#[test]
fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
let n = 6usize;
let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
let time_dq1 = Array2::<f64>::zeros((n, 1));
let time_dqd1 = Array2::<f64>::zeros((n, 1));
let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
let marg_dqd1 = Array2::<f64>::zeros((n, 1));
let log_dg = Array2::<f64>::zeros((n, 0));
let time_partition = vec![0..1usize];
let marg_partition = vec![0..1usize];
let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
for k in 0..K_SURVIVAL {
h_ident[[i, k, k]] = 1.0;
}
}
let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
let compiled_ident = compile_survival_parametric_designs_per_term(
time_dq0.clone(),
time_dq1.clone(),
time_dqd1.clone(),
&time_partition,
marg_dq.clone(),
marg_dqd1.clone(),
&marg_partition,
log_dg.clone(),
&log_partition,
&row_hess_ident,
)
.expect("identity-H compile must succeed");
let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
for i in 0..n {
h_q0_only[[i, 0, 0]] = 1.0;
}
let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
let compiled_q0 = compile_survival_parametric_designs_per_term(
time_dq0,
time_dq1,
time_dqd1,
&time_partition,
marg_dq,
marg_dqd1,
&marg_partition,
log_dg,
&log_partition,
&row_hess_q0,
)
.expect("q0-only-H compile must succeed");
assert_ne!(
compiled_ident.drops_by_block, compiled_q0.drops_by_block,
"structural-H and data-adaptive-H compiles must produce different \
drops_by_block on the constructed pilot-curvature-trap design; \
identity={:?} q0-only={:?}",
compiled_ident.drops_by_block, compiled_q0.drops_by_block,
);
assert_eq!(
compiled_ident.drops_by_block.1, 0,
"identity-H marg drops expected 0, got {:?}",
compiled_ident.drops_by_block,
);
assert_eq!(
compiled_q0.drops_by_block.1, 1,
"q0-only-H marg drops expected 1, got {:?}",
compiled_q0.drops_by_block,
);
}
}