use super::*;
use gam_solve::gauge::Gauge;
pub(crate) fn structural_time_coefficient_constraints(
design_derivative_exit: &DesignMatrix,
derivative_offset_exit: &Array1<f64>,
derivative_guard: f64,
) -> Result<Option<LinearInequalityConstraints>, String> {
time_derivative_guard_constraints(
design_derivative_exit,
derivative_offset_exit,
derivative_guard,
)
}
pub(crate) fn time_derivative_guard_constraints(
design_derivative_exit: &DesignMatrix,
derivative_offset_exit: &Array1<f64>,
derivative_guard: f64,
) -> Result<Option<LinearInequalityConstraints>, String> {
build_time_derivative_guard_constraints(
design_derivative_exit,
derivative_offset_exit,
derivative_guard,
LOCATION_SCALE_GUARD_POLICY,
)
.map_err(map_guard_constraint_failure)
}
pub(crate) fn map_guard_constraint_failure(failure: GuardConstraintFailure) -> String {
match failure {
GuardConstraintFailure::RowOffsetMismatch { rows, offsets } => {
SurvivalLocationScaleError::InvalidConfiguration {
reason: format!(
"time derivative guard constraints require matching rows/offsets: rows={rows}, offsets={offsets}"
),
}
.into()
}
GuardConstraintFailure::GuardOutOfRange { guard, range } => {
SurvivalLocationScaleError::ConstraintViolation {
reason: format!(
"time derivative guard must be finite and {range}, got {guard}"
),
}
.into()
}
GuardConstraintFailure::NonFiniteOffset { row, offset } => {
SurvivalLocationScaleError::ConstraintViolation {
reason: format!(
"time derivative guard constraints require finite derivative offsets; found offset[{row}]={offset}"
),
}
.into()
}
GuardConstraintFailure::NonFiniteDesign { row, col } => {
SurvivalLocationScaleError::ConstraintViolation {
reason: format!(
"time derivative guard constraints require finite derivative design entries; found row {row}, column {col}"
),
}
.into()
}
GuardConstraintFailure::InfeasibleRow {
row,
offset,
guard,
no_time_coefficients,
} => {
let detail = if no_time_coefficients {
"with no time coefficients"
} else {
"zero derivative design row"
};
let reason = if no_time_coefficients {
format!(
"time derivative guard is infeasible at row {row}: offset={offset:.3e} < guard={guard:.3e} {detail}"
)
} else {
format!(
"time derivative guard is infeasible at row {row}: {detail} with offset={offset:.3e} < guard={guard:.3e}"
)
};
SurvivalLocationScaleError::ConstraintViolation { reason }.into()
}
}
}
pub(crate) fn structural_time_initial_beta_guess(
design_derivative_exit: &Array2<f64>,
derivative_offset_exit: &Array1<f64>,
age_exit: &Array1<f64>,
derivative_guard: f64,
coefficient_lower_bounds: Option<&Array1<f64>>,
) -> Option<Array1<f64>> {
let n = design_derivative_exit.nrows();
let p = design_derivative_exit.ncols();
if p == 0 || n == 0 || derivative_offset_exit.len() != n || age_exit.len() != n {
return None;
}
let mut target = Array1::<f64>::zeros(n);
for i in 0..n {
let desired = 1.0 / age_exit[i].max(STRUCTURAL_GUESS_AGE_FLOOR);
target[i] = (desired - derivative_offset_exit[i]).max(0.0);
}
let xtx = gam_linalg::faer_ndarray::fast_ata(design_derivative_exit);
let xty = fast_atv(design_derivative_exit, &target);
let eps =
STRUCTURAL_GUESS_RIDGE_REL * (0..p).map(|i| xtx[[i, i]]).fold(0.0_f64, f64::max).max(1.0);
let mut lhs = xtx;
for i in 0..p {
lhs[[i, i]] += eps;
}
use gam_linalg::faer_ndarray::FaerCholesky;
let chol = lhs.cholesky(faer::Side::Lower).ok()?;
let mut beta_init = chol.solvevec(&xty);
if let Some(lower_bounds) = coefficient_lower_bounds
&& let Some(constraints) = lower_bound_constraints(lower_bounds)
{
beta_init = project_onto_linear_constraints(p, &constraints, Some(&beta_init)).ok()?;
}
let d_raw_init = fast_av(design_derivative_exit, &beta_init) + derivative_offset_exit;
if d_raw_init
.iter()
.all(|v| v.is_finite() && *v >= derivative_guard)
{
Some(beta_init)
} else {
None
}
}
#[derive(Clone, Debug)]
pub(crate) struct TimeIdentifiabilityTransform {
pub(crate) gauge: Gauge,
}
#[derive(Clone, Debug)]
pub(crate) struct TimeBlockPrepared {
pub(crate) design_entry: Array2<f64>,
pub(crate) design_exit: Array2<f64>,
pub(crate) design_derivative_exit: Array2<f64>,
pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
pub(crate) linear_constraints: Option<LinearInequalityConstraints>,
pub(crate) penalties: Vec<Array2<f64>>,
pub(crate) nullspace_dims: Vec<usize>,
pub(crate) initial_beta: Option<Array1<f64>>,
pub(crate) transform: TimeIdentifiabilityTransform,
pub(crate) offset_entry: Array1<f64>,
pub(crate) offset_exit: Array1<f64>,
pub(crate) derivative_offset_exit: Array1<f64>,
pub(crate) pinned_free_row_constant: bool,
pub(crate) location_log_time_offset: bool,
}
pub(crate) fn lower_bound_constraints(
lower_bounds: &Array1<f64>,
) -> Option<LinearInequalityConstraints> {
LinearInequalityConstraints::from_per_coordinate_lower_bounds(lower_bounds)
}
pub(crate) fn append_linear_constraints(
first: Option<LinearInequalityConstraints>,
second: Option<LinearInequalityConstraints>,
) -> Result<Option<LinearInequalityConstraints>, String> {
match (first, second) {
(None, None) => Ok(None),
(Some(constraints), None) | (None, Some(constraints)) => Ok(Some(constraints)),
(Some(lhs), Some(rhs)) => {
if lhs.a.ncols() != rhs.a.ncols() {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"time linear constraint width mismatch: left={}, right={}",
lhs.a.ncols(),
rhs.a.ncols()
),
}
.into());
}
let rows = lhs.a.nrows() + rhs.a.nrows();
let cols = lhs.a.ncols();
let mut a = Array2::<f64>::zeros((rows, cols));
let mut b = Array1::<f64>::zeros(rows);
a.slice_mut(s![..lhs.a.nrows(), ..]).assign(&lhs.a);
a.slice_mut(s![lhs.a.nrows().., ..]).assign(&rhs.a);
b.slice_mut(s![..lhs.b.len()]).assign(&lhs.b);
b.slice_mut(s![lhs.b.len()..]).assign(&rhs.b);
LinearInequalityConstraints::new(a, b).map(Some)
}
}
}
pub(crate) fn structural_time_coefficient_lower_bounds(
design_derivative_exit: &DesignMatrix,
derivative_offset_exit: &Array1<f64>,
lower_bound: f64,
) -> Result<Option<Array1<f64>>, String> {
if design_derivative_exit.nrows() != derivative_offset_exit.len() {
return Err(SurvivalLocationScaleError::InvalidConfiguration { reason: format!(
"structural time coefficient bounds require matching rows/offsets: rows={}, offsets={}",
design_derivative_exit.nrows(),
derivative_offset_exit.len()
) }.into());
}
if design_derivative_exit.ncols() == 0 {
return Ok(None);
}
if !lower_bound.is_finite() || lower_bound <= 0.0 {
return Err(SurvivalLocationScaleError::ConstraintViolation {
reason: format!(
"structural time coefficient lower bound must be finite and > 0, got {lower_bound}"
),
}
.into());
}
const DERIVATIVE_TOL: f64 = 1e-12;
const FEASIBILITY_TOL: f64 = 1e-12;
const SUBTOL_NONZERO_FLOOR: f64 = 1e-30;
const DIAGNOSTIC_COLUMN_PREVIEW: usize = 8;
let p = design_derivative_exit.ncols();
let nrows = design_derivative_exit.nrows();
let mut lower_bounds = Array1::from_elem(p, f64::NEG_INFINITY);
let mut has_structural_support = false;
for (row, &offset) in derivative_offset_exit.iter().enumerate() {
if !offset.is_finite() {
return Err(SurvivalLocationScaleError::ConstraintViolation { reason: format!(
"structural time coefficient bounds require finite derivative offsets; found offset[{row}]={offset}"
) }.into());
}
if lower_bound - offset > FEASIBILITY_TOL {
return Err(SurvivalLocationScaleError::ConstraintViolation { reason: format!(
"structural time coefficient bounds require derivative offsets to encode the derivative guard at row {row}: offset={offset:.3e} < guard={lower_bound:.3e}"
) }.into());
}
}
let mut col_maxes: Vec<(usize, f64)> = Vec::with_capacity(p.min(DIAGNOSTIC_COLUMN_PREVIEW));
let mut total_subtol_nonzeros = 0_usize;
for col in 0..p {
let column = design_derivative_exit.extract_column(col);
if column.len() != nrows {
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"structural time coefficient bounds: extract_column returned {} entries for column {col}, expected {nrows}",
column.len()
) }.into());
}
let mut has_positive_support = false;
let mut col_max = 0.0_f64;
for (row, &value) in column.iter().enumerate() {
if !value.is_finite() {
return Err(SurvivalLocationScaleError::ConstraintViolation { reason: format!(
"structural time coefficient bounds require finite derivative design entries; found row {row}, column {col}"
) }.into());
}
if value < -DERIVATIVE_TOL {
return Err(SurvivalLocationScaleError::ConstraintViolation { reason: format!(
"structural time coefficient bounds require a non-negative derivative basis at row {row}, column {col}; found {value:.3e}"
) }.into());
}
if value > DERIVATIVE_TOL {
has_positive_support = true;
}
let abs_value = value.abs();
if abs_value > col_max {
col_max = abs_value;
}
if abs_value > SUBTOL_NONZERO_FLOOR && abs_value <= DERIVATIVE_TOL {
total_subtol_nonzeros += 1;
}
}
if has_positive_support {
lower_bounds[col] = 0.0;
has_structural_support = true;
}
if col < DIAGNOSTIC_COLUMN_PREVIEW {
col_maxes.push((col, col_max));
}
}
if !has_structural_support {
if total_subtol_nonzeros > 0 {
log::warn!(
"structural time coefficient bounds: no derivative-active column on this candidate's exit-time design ({} rows × {} cols, sub-tolerance nonzero entries ({:.0e} < |v| ≤ {:.0e}): {}, first-{} col max(|.|): {:?}); skipping the structural lower-bound ridge — fit may converge to a non-monotone-in-time hazard",
nrows,
p,
SUBTOL_NONZERO_FLOOR,
DERIVATIVE_TOL,
total_subtol_nonzeros,
DIAGNOSTIC_COLUMN_PREVIEW,
col_maxes,
);
}
return Ok(None);
}
Ok(Some(lower_bounds))
}
pub(crate) fn structural_time_coefficient_lower_bounds_with_monotone_time_wiggle(
design_derivative_exit: &DesignMatrix,
derivative_offset_exit: &Array1<f64>,
lower_bound: f64,
monotone_time_wiggle_ncols: usize,
) -> Result<Option<Array1<f64>>, String> {
let mut lower_bounds = structural_time_coefficient_lower_bounds(
design_derivative_exit,
derivative_offset_exit,
lower_bound,
)?;
let Some(bounds) = lower_bounds.as_mut() else {
return Ok(None);
};
if monotone_time_wiggle_ncols == 0 {
return Ok(lower_bounds);
}
if monotone_time_wiggle_ncols > bounds.len() {
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"structural time coefficient bounds cannot reserve {monotone_time_wiggle_ncols} monotone wiggle columns from {} coefficients",
bounds.len()
) }.into());
}
let tail_start = bounds.len() - monotone_time_wiggle_ncols;
for col in tail_start..bounds.len() {
if !bounds[col].is_finite() || bounds[col] < 0.0 {
bounds[col] = 0.0;
}
}
Ok(lower_bounds)
}
pub fn project_onto_linear_constraints(
dim: usize,
constraints: &LinearInequalityConstraints,
beta0: Option<&Array1<f64>>,
) -> Result<Array1<f64>, String> {
if let Some(b0) = beta0
&& b0.len() != dim
{
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"project_onto_linear_constraints: beta0 length {} does not match dim {dim}",
b0.len()
),
}
.into());
}
if constraints.a.nrows() != constraints.b.len() {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"project_onto_linear_constraints: constraint A has {} rows but b has length {}",
constraints.a.nrows(),
constraints.b.len()
),
}
.into());
}
let beta0_vec = beta0.cloned().unwrap_or_else(|| Array1::zeros(dim));
if constraints.a.ncols() != dim || constraints.a.nrows() == 0 {
return Ok(beta0_vec);
}
let n_rows = constraints.a.nrows();
const DOWNSTREAM_FEASIBILITY_GATE_TOL: f64 = MONOTONE_CONE_FEASIBILITY_GATE_TOL;
let worst_raw_violation = |b: &Array1<f64>| -> (f64, usize) {
let mut worst = 0.0_f64;
let mut worst_row = 0usize;
for i in 0..n_rows {
let slack = constraints.a.row(i).dot(b) - constraints.b[i];
let viol = (-slack).max(0.0);
if viol > worst {
worst = viol;
worst_row = i;
}
}
(worst, worst_row)
};
if let Some(interior) = gam_solve::active_set::project_point_strictly_into_feasible_cone(
&beta0_vec,
constraints,
) && worst_raw_violation(&interior).0 <= DOWNSTREAM_FEASIBILITY_GATE_TOL
{
return Ok(interior);
}
let identity = Array2::<f64>::eye(dim);
if let Ok((boundary, _active)) =
gam_solve::active_set::solve_quadratic_with_linear_constraints(
&identity,
&beta0_vec,
&beta0_vec,
constraints,
None,
)
&& worst_raw_violation(&boundary).0 <= DOWNSTREAM_FEASIBILITY_GATE_TOL
{
return Ok(boundary);
}
let mut seen: std::collections::HashSet<Box<[u64]>> =
std::collections::HashSet::with_capacity(n_rows);
let mut unique_rows: Vec<usize> = Vec::with_capacity(n_rows);
for i in 0..n_rows {
let row_i = constraints.a.row(i);
if row_i.dot(&row_i) <= DYKSTRA_ROW_DEGENERACY_FLOOR {
continue;
}
let mut key: Vec<u64> = Vec::with_capacity(dim + 1);
key.extend(row_i.iter().map(|v| v.to_bits()));
key.push(constraints.b[i].to_bits());
if seen.insert(key.into_boxed_slice()) {
unique_rows.push(i);
}
}
let mut beta = beta0_vec;
let mut corrections = Array2::<f64>::zeros((unique_rows.len(), dim));
let max_sweeps = DYKSTRA_PROJECTION_MAX_SWEEPS;
for _ in 0..max_sweeps {
let mut max_violation = 0.0_f64;
for (slot, &i) in unique_rows.iter().enumerate() {
let row = constraints.a.row(i);
let row_norm_sq = row.dot(&row);
if row_norm_sq <= DYKSTRA_ROW_DEGENERACY_FLOOR {
continue;
}
let y = &beta + &corrections.row(slot);
let slack = row.dot(&y) - constraints.b[i];
max_violation = max_violation.max((-slack).max(0.0));
if slack >= 0.0 {
corrections.row_mut(slot).assign(&(&y - &beta));
continue;
}
let step = (constraints.b[i] - row.dot(&y)) / row_norm_sq;
let projected = &y + &(row.to_owned() * step);
corrections.row_mut(slot).assign(&(&y - &projected));
beta.assign(&projected);
}
if max_violation <= DYKSTRA_PROJECTION_TOL {
break;
}
}
let (worst, worst_row) = worst_raw_violation(&beta);
if worst > DOWNSTREAM_FEASIBILITY_GATE_TOL {
return Err(SurvivalLocationScaleError::ConstraintViolation {
reason: format!(
"project_onto_linear_constraints could not certify a feasible projection of the \
seed onto the monotone time-derivative cone: worst raw violation {worst:.3e} at \
row {worst_row} ({} unique of {n_rows} guard rows). Both exact active-set \
projections (strict-interior and boundary) refused and the Dykstra safety net \
did not reach the downstream gate tol={DOWNSTREAM_FEASIBILITY_GATE_TOL:.1e}. This \
is a genuine feasibility failure of the constraint system, surfaced rather than \
silently returning an infeasible seed.",
unique_rows.len(),
),
}
.into());
}
Ok(beta)
}
pub(crate) fn validate_linear_constraints(
label: &str,
beta: &Array1<f64>,
constraints: &LinearInequalityConstraints,
) -> Result<(), String> {
if beta.len() != constraints.a.ncols() {
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"survival location-scale {label} constraint dimension mismatch: beta={}, constraints={}",
beta.len(),
constraints.a.ncols()
) }.into());
}
if constraints.a.nrows() != constraints.b.len() {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"survival location-scale {label} constraint row mismatch: A rows={}, b len={}",
constraints.a.nrows(),
constraints.b.len()
),
}
.into());
}
let mut worst_row = None;
let mut worst_slack = 0.0_f64;
let mut worst_tol = 0.0_f64;
for row in 0..constraints.a.nrows() {
let a_row = constraints.a.row(row);
let slack = a_row.dot(beta) - constraints.b[row];
let scale = a_row
.iter()
.zip(beta.iter())
.map(|(a, b)| (a * b).abs())
.sum::<f64>()
.max(constraints.b[row].abs())
.max(1.0);
let tol =
(CONSTRAINT_NONNEGATIVITY_REL_TOL * scale).max(MONOTONE_CONE_FEASIBILITY_GATE_TOL);
if slack < -tol && (worst_row.is_none() || slack < worst_slack) {
worst_row = Some(row);
worst_slack = slack;
worst_tol = tol;
}
}
if let Some(row) = worst_row {
return Err(SurvivalLocationScaleError::ConstraintViolation { reason: format!(
"survival location-scale {label} violates represented linear constraint at row {row}: slack={worst_slack:.3e}, tol={worst_tol:.3e}"
) }.into());
}
Ok(())
}
pub(crate) fn time_parametric_null_space_basis(
penalties: &[Array2<f64>],
p: usize,
) -> Option<Array2<f64>> {
if p == 0 || penalties.is_empty() {
return None;
}
let mut total = Array2::<f64>::zeros((p, p));
for s_mat in penalties {
if s_mat.nrows() != p || s_mat.ncols() != p {
return None;
}
total += s_mat;
}
let (evals, evecs) = total.eigh(faer::Side::Lower).ok()?;
let max_ev = evals
.iter()
.copied()
.fold(0.0_f64, |a, b| a.max(b.abs()))
.max(1.0);
let threshold = 100.0 * (p as f64) * f64::EPSILON * max_ev;
let null_cols: Vec<usize> = evals
.iter()
.enumerate()
.filter(|&(_, &e)| e <= threshold)
.map(|(idx, _)| idx)
.collect();
if null_cols.is_empty() || null_cols.len() >= p {
return None;
}
Some(evecs.select(ndarray::Axis(1), &null_cols))
}
pub(crate) fn time_block_collapses_to_logt_baseline(
constant_scale: bool,
protected_timewiggle_cols: usize,
design_exit: &Array2<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> bool {
if !constant_scale || protected_timewiggle_cols != 0 {
return false;
}
let p = design_exit.ncols();
if p == 0 {
return false;
}
let identity = Array2::<f64>::eye(p);
reduced_warp_logt_baseline_usable(&identity, design_exit, log_time_exit)
}
pub(crate) fn time_block_reduces_to_parametric(
time_penalties: &[Array2<f64>],
time_ncols: usize,
constant_scale: bool,
protected_timewiggle_cols: usize,
design_exit: &Array2<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> bool {
constant_scale
&& protected_timewiggle_cols == 0
&& (time_parametric_null_space_basis(time_penalties, time_ncols).is_some()
|| time_block_collapses_to_logt_baseline(
constant_scale,
protected_timewiggle_cols,
design_exit,
log_time_exit,
))
}
pub(crate) fn survival_time_rho_count(
time_penalties: &[Array2<f64>],
time_ncols: usize,
constant_scale: bool,
protected_timewiggle_cols: usize,
design_exit: &Array2<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> usize {
if time_block_reduces_to_parametric(
time_penalties,
time_ncols,
constant_scale,
protected_timewiggle_cols,
design_exit,
log_time_exit,
) {
0
} else {
time_penalties.len()
}
}
pub(crate) fn survival_reduced_parametric_aft_regime(
time_penalties: &[Array2<f64>],
time_ncols: usize,
constant_scale: bool,
protected_timewiggle_cols: usize,
threshold_nullspace_dims: &[usize],
threshold_npenalties: usize,
log_sigma_nullspace_dims: &[usize],
log_sigma_npenalties: usize,
has_linkwiggle: bool,
design_exit: &Array2<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> bool {
if has_linkwiggle || protected_timewiggle_cols > 0 {
return false;
}
if survival_time_rho_count(
time_penalties,
time_ncols,
constant_scale,
protected_timewiggle_cols,
design_exit,
log_time_exit,
) != 0
{
return false;
}
block_penalties_all_parametric_ridges(threshold_nullspace_dims, threshold_npenalties)
&& block_penalties_all_parametric_ridges(log_sigma_nullspace_dims, log_sigma_npenalties)
}
pub(crate) fn block_penalties_all_parametric_ridges(
nullspace_dims: &[usize],
npenalties: usize,
) -> bool {
if npenalties == 0 {
return true;
}
nullspace_dims.len() == npenalties && nullspace_dims.iter().all(|&d| d == 0)
}
pub(crate) fn unit_log_time_slope(
design_exit: &Array2<f64>,
direction: &Array1<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> Option<f64> {
let n = design_exit.nrows();
if n == 0 || log_time_exit.len() != n {
return None;
}
let y = design_exit.dot(direction);
let log_mean = log_time_exit.sum() / n as f64;
let mut sxx = 0.0_f64;
let mut sxy = 0.0_f64;
for i in 0..n {
let xc = log_time_exit[i] - log_mean;
sxx += xc * xc;
sxy += xc * y[i];
}
let y_scale = y.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs())).max(1.0);
if !sxx.is_finite() || sxx <= f64::EPSILON {
return None;
}
let slope = sxy / sxx;
if !slope.is_finite() || slope.abs() <= f64::EPSILON * y_scale {
return None;
}
Some(slope)
}
pub(crate) fn rank1_reduced_time_warp_applies(
z: &Array2<f64>,
design_exit: &Array2<f64>,
log_time_entry: ndarray::ArrayView1<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> bool {
if z.ncols() != 1 {
return false;
}
let n = design_exit.nrows();
if log_time_entry.len() != n || log_time_exit.len() != n {
return false;
}
if log_time_entry.iter().any(|v| !v.is_finite()) || log_time_exit.iter().any(|v| !v.is_finite())
{
return false;
}
let z_dir = z.column(0).to_owned();
unit_log_time_slope(design_exit, &z_dir, log_time_exit).is_some()
}
pub(crate) struct PinnedTimeWarp {
pub(crate) z_c: Array2<f64>,
pub(crate) z_t: Array1<f64>,
}
pub(crate) fn pin_reduced_time_warp_slope(
z: &Array2<f64>,
design_exit: &Array2<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> Option<PinnedTimeWarp> {
let p = z.nrows();
let n = design_exit.nrows();
if z.ncols() != 2 || n == 0 || log_time_exit.len() != n {
return None;
}
let g = design_exit.dot(z);
let m00 = g.column(0).dot(&g.column(0));
let m01 = g.column(0).dot(&g.column(1));
let m11 = g.column(1).dot(&g.column(1));
let ones = Array1::<f64>::ones(n);
let r0 = g.column(0).dot(&ones);
let r1 = g.column(1).dot(&ones);
let det = m00 * m11 - m01 * m01;
let gram_scale = m00.max(m11).max(1.0);
if !det.is_finite() || det.abs() <= f64::EPSILON * gram_scale * gram_scale {
return None;
}
let a0 = (m11 * r0 - m01 * r1) / det;
let a1 = (m00 * r1 - m01 * r0) / det;
let z_c_raw = z.dot(&Array1::from(vec![a0, a1]));
let z_c_norm = z_c_raw.dot(&z_c_raw).sqrt();
if !z_c_norm.is_finite() || z_c_norm <= f64::EPSILON * (p as f64).sqrt() {
return None;
}
let z_c_vec = &z_c_raw / z_c_norm;
let a_perp = Array1::from(vec![-a1, a0]);
let z_t_raw = z.dot(&a_perp);
let slope = unit_log_time_slope(design_exit, &z_t_raw, log_time_exit)?;
let z_t = &z_t_raw / slope;
let z_c = z_c_vec.insert_axis(ndarray::Axis(1));
Some(PinnedTimeWarp { z_c, z_t })
}
pub(crate) fn location_logt_offset_time_block(
design_entry: &Array2<f64>,
design_exit: &Array2<f64>,
design_derivative_exit: &Array2<f64>,
p: usize,
) -> TimeBlockPrepared {
TimeBlockPrepared {
design_entry: Array2::<f64>::zeros((design_entry.nrows(), 0)),
design_exit: Array2::<f64>::zeros((design_exit.nrows(), 0)),
design_derivative_exit: Array2::<f64>::zeros((design_derivative_exit.nrows(), 0)),
coefficient_lower_bounds: None,
linear_constraints: None,
penalties: Vec::new(),
nullspace_dims: Vec::new(),
initial_beta: Some(Array1::<f64>::zeros(0)),
transform: TimeIdentifiabilityTransform {
gauge: Gauge::from_block_transforms(&[Array2::<f64>::zeros((p, 0))]),
},
offset_entry: Array1::zeros(design_entry.nrows()),
offset_exit: Array1::zeros(design_exit.nrows()),
derivative_offset_exit: Array1::zeros(design_derivative_exit.nrows()),
pinned_free_row_constant: false,
location_log_time_offset: true,
}
}
pub(crate) fn reduced_warp_logt_baseline_usable(
z: &Array2<f64>,
design_exit: &Array2<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> bool {
use gam_linalg::faer_ndarray::FaerCholesky;
let n = design_exit.nrows();
let r = z.ncols();
if n == 0 || r == 0 || log_time_exit.len() != n {
return false;
}
let g = design_exit.dot(z); let gtg = g.t().dot(&g); let gtl = g.t().dot(&log_time_exit); let scale = gtg
.diag()
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()))
.max(1.0);
let mut ridged = gtg;
for i in 0..r {
ridged[[i, i]] += 1e-10 * scale;
}
let Ok(chol) = ridged.cholesky(faer::Side::Lower) else {
return false;
};
let c = chol.solvevec(>l);
if c.iter().any(|v| !v.is_finite()) {
return false;
}
let direction = z.dot(&c);
unit_log_time_slope(design_exit, &direction, log_time_exit).is_some()
}
pub(crate) fn prepare_identified_time_block(
input: &TimeBlockInput,
derivative_guard: f64,
monotone_time_wiggle_ncols: usize,
reduce_to_parametric: bool,
log_time_entry: ndarray::ArrayView1<f64>,
log_time_exit: ndarray::ArrayView1<f64>,
) -> Result<TimeBlockPrepared, String> {
let p = input.design_exit.ncols();
if !input.time_monotonicity.is_coordinate_cone() {
return Err(SurvivalLocationScaleError::InvalidConfiguration { reason: format!(
"time_block requires a coordinate-cone monotonicity strategy by construction; got {:?}",
input.time_monotonicity
) }.into());
}
let design_entry = input.design_entry.to_dense();
let design_exit = input.design_exit.to_dense();
let design_derivative_exit = input.design_derivative_exit.to_dense();
if reduce_to_parametric && let Some(z) = time_parametric_null_space_basis(&input.penalties, p) {
let r = z.ncols();
if r == 1
&& z.nrows() == p
&& rank1_reduced_time_warp_applies(&z, &design_exit, log_time_entry, log_time_exit)
{
return Ok(location_logt_offset_time_block(
&design_entry,
&design_exit,
&design_derivative_exit,
p,
));
}
if r == 2
&& z.nrows() == p
&& let Some(pinned) = pin_reduced_time_warp_slope(&z, &design_exit, log_time_exit)
{
let PinnedTimeWarp { z_c, z_t } = pinned;
let gauge = Gauge::from_block_transform_with_shift(z_c.clone(), z_t);
let (reduced_entry, offset_entry) =
gauge.restrict_design_and_offset(&design_entry, &input.offset_entry);
let (reduced_exit, offset_exit) =
gauge.restrict_design_and_offset(&design_exit, &input.offset_exit);
let (reduced_derivative_exit, derivative_offset_exit) = gauge
.restrict_design_and_offset(&design_derivative_exit, &input.derivative_offset_exit);
let reduced_derivative_design =
DesignMatrix::Dense(DenseDesignMatrix::from(reduced_derivative_exit.clone()));
let linear_constraints = time_derivative_guard_constraints(
&reduced_derivative_design,
&derivative_offset_exit,
derivative_guard,
)?;
let initial_beta = match (linear_constraints.as_ref(), input.initial_beta.as_ref()) {
(Some(constraints), Some(beta0)) => Some(project_onto_linear_constraints(
1,
constraints,
Some(&z_c.t().dot(beta0)),
)?),
(_, Some(beta0)) => Some(z_c.t().dot(beta0)),
_ => None,
};
return Ok(TimeBlockPrepared {
design_entry: reduced_entry,
design_exit: reduced_exit,
design_derivative_exit: reduced_derivative_exit,
coefficient_lower_bounds: None,
linear_constraints,
penalties: Vec::new(),
nullspace_dims: Vec::new(),
initial_beta,
transform: TimeIdentifiabilityTransform { gauge },
offset_entry,
offset_exit,
derivative_offset_exit,
pinned_free_row_constant: true,
location_log_time_offset: false,
});
}
if reduced_warp_logt_baseline_usable(&z, &design_exit, log_time_exit) {
return Ok(location_logt_offset_time_block(
&design_entry,
&design_exit,
&design_derivative_exit,
p,
));
}
let reduced_entry = design_entry.dot(&z);
let reduced_exit = design_exit.dot(&z);
let reduced_derivative_exit = design_derivative_exit.dot(&z);
let reduced_penalties: Vec<Array2<f64>> = Vec::new();
let reduced_nullspace_dims: Vec<usize> = Vec::new();
let reduced_derivative_design =
DesignMatrix::Dense(DenseDesignMatrix::from(reduced_derivative_exit.clone()));
let linear_constraints = time_derivative_guard_constraints(
&reduced_derivative_design,
&input.derivative_offset_exit,
derivative_guard,
)?;
let initial_beta = match (linear_constraints.as_ref(), input.initial_beta.as_ref()) {
(Some(constraints), Some(beta0)) => Some(project_onto_linear_constraints(
r,
constraints,
Some(&z.t().dot(beta0)),
)?),
(_, Some(beta0)) => Some(z.t().dot(beta0)),
_ => None,
};
return Ok(TimeBlockPrepared {
design_entry: reduced_entry,
design_exit: reduced_exit,
design_derivative_exit: reduced_derivative_exit,
coefficient_lower_bounds: None,
linear_constraints,
penalties: reduced_penalties,
nullspace_dims: reduced_nullspace_dims,
initial_beta,
transform: TimeIdentifiabilityTransform {
gauge: Gauge::from_block_transforms(&[z]),
},
offset_entry: input.offset_entry.clone(),
offset_exit: input.offset_exit.clone(),
derivative_offset_exit: input.derivative_offset_exit.clone(),
pinned_free_row_constant: false,
location_log_time_offset: false,
});
}
if reduce_to_parametric
&& time_block_collapses_to_logt_baseline(true, 0, &design_exit, log_time_exit)
{
return Ok(location_logt_offset_time_block(
&design_entry,
&design_exit,
&design_derivative_exit,
p,
));
}
let penalties = input.penalties.clone();
let coefficient_lower_bounds = structural_time_coefficient_lower_bounds_with_monotone_time_wiggle(
&input.design_derivative_exit,
&input.derivative_offset_exit,
derivative_guard,
monotone_time_wiggle_ncols,
)?
.ok_or_else(|| {
"structural time block requires derivative offsets to encode the derivative guard and a non-negative derivative basis"
.to_string()
})?;
let coefficient_constraints = lower_bound_constraints(&coefficient_lower_bounds);
let derivative_constraints = time_derivative_guard_constraints(
&input.design_derivative_exit,
&input.derivative_offset_exit,
derivative_guard,
)?;
let linear_constraints =
append_linear_constraints(coefficient_constraints.clone(), derivative_constraints)?;
let initial_beta = match (linear_constraints.as_ref(), input.initial_beta.as_ref()) {
(Some(constraints), Some(beta0)) => {
let mut clipped = beta0.clone();
for (value, &lower) in clipped.iter_mut().zip(coefficient_lower_bounds.iter()) {
if lower.is_finite() && *value < lower {
*value = lower;
}
}
if validate_linear_constraints("time initial beta", &clipped, constraints).is_ok() {
Some(clipped)
} else {
Some(project_onto_linear_constraints(
p,
constraints,
Some(beta0),
)?)
}
}
(_, Some(beta0)) => Some(beta0.clone()),
_ => None,
};
Ok(TimeBlockPrepared {
design_entry,
design_exit,
design_derivative_exit,
coefficient_lower_bounds: Some(coefficient_lower_bounds),
linear_constraints,
penalties,
nullspace_dims: input.nullspace_dims.clone(),
initial_beta,
transform: TimeIdentifiabilityTransform {
gauge: Gauge::identity(&[p]),
},
offset_entry: input.offset_entry.clone(),
offset_exit: input.offset_exit.clone(),
derivative_offset_exit: input.derivative_offset_exit.clone(),
pinned_free_row_constant: false,
location_log_time_offset: false,
})
}