use super::*;
#[derive(Clone)]
pub struct SurvivalMarginalSlopeTermSpec {
pub age_entry: Array1<f64>,
pub age_exit: Array1<f64>,
pub event_target: Array1<f64>,
pub weights: Array1<f64>,
pub z: Array2<f64>,
pub base_link: InverseLink,
pub marginalspec: TermCollectionSpec,
pub marginal_offset: Array1<f64>,
pub frailty: FrailtySpec,
pub derivative_guard: f64,
pub time_block: TimeBlockInput,
pub timewiggle_block: Option<TimeWiggleBlockInput>,
pub logslopespec: TermCollectionSpec,
pub logslopespecs: Option<Vec<TermCollectionSpec>>,
pub logslope_offset: Array1<f64>,
pub score_warp: Option<DeviationBlockConfig>,
pub link_dev: Option<DeviationBlockConfig>,
pub score_influence_jacobian: Option<Array2<f64>>,
pub latent_z_policy: LatentZPolicy,
}
pub const DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD: f64 = 1e-6;
pub(crate) const SURVIVAL_INTERCEPT_ABS_RESIDUAL_TOL: f64 = 1e-12;
pub(crate) const SURVIVAL_INTERCEPT_REL_TAIL_RESIDUAL_TOL: f64 = 1e-8;
pub(crate) const SURVIVAL_INTERCEPT_LOG_TAIL_THRESHOLD: f64 = 1e-8;
#[inline]
pub(crate) fn survival_derivative_guard_tolerance(qd1: f64, derivative_guard: f64) -> f64 {
let magnitude = 1.0 + qd1.abs().max(derivative_guard.abs());
let solver_band = 4.0 * gam_solve::pirls::ACTIVE_SET_PRIMAL_FEASIBILITY_TOL * magnitude;
let eps_floor = 256.0 * f64::EPSILON * magnitude;
solver_band.max(eps_floor)
}
#[inline]
pub(crate) fn survival_derivative_guard_violated(qd1: f64, derivative_guard: f64) -> bool {
!qd1.is_finite()
| !derivative_guard.is_finite()
| (qd1 + survival_derivative_guard_tolerance(qd1, derivative_guard) < derivative_guard)
}
pub struct SurvivalMarginalSlopeFitResult {
pub fit: UnifiedFitResult,
pub marginalspec_resolved: TermCollectionSpec,
pub logslopespec_resolved: TermCollectionSpec,
pub marginal_design: TermCollectionDesign,
pub gaussian_frailty_sd: Option<f64>,
pub logslope_design: TermCollectionDesign,
pub baseline_slope: f64,
pub baseline_offset_residuals: OffsetChannelResiduals,
pub baseline_offset_curvatures: OffsetChannelCurvatures,
pub z_normalization: LatentZNormalization,
pub time_block_penalties_len: usize,
pub score_warp_runtime: Option<DeviationRuntime>,
pub link_dev_runtime: Option<DeviationRuntime>,
pub influence_absorber_width: Option<usize>,
}
pub(crate) fn validate_spec(spec: &SurvivalMarginalSlopeTermSpec) -> Result<(), String> {
let n = spec.age_entry.len();
log::info!(
"[survival-marginal-slope] fit start n={} marginal_terms={} logslope_terms={}",
n,
spec.marginalspec.linear_terms.len()
+ spec.marginalspec.random_effect_terms.len()
+ spec.marginalspec.smooth_terms.len(),
spec.logslopespec.linear_terms.len()
+ spec.logslopespec.random_effect_terms.len()
+ spec.logslopespec.smooth_terms.len(),
);
if spec.age_exit.len() != n
|| spec.event_target.len() != n
|| spec.weights.len() != n
|| spec.z.nrows() != n
|| spec.z.ncols() == 0
|| spec.marginal_offset.len() != n
|| spec.logslope_offset.len() != n
{
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival-marginal-slope row mismatch: entry={}, exit={}, event={}, weights={}, z={}x{}, marginal_offset={}, logslope_offset={}",
n,
spec.age_exit.len(),
spec.event_target.len(),
spec.weights.len(),
spec.z.nrows(),
spec.z.ncols(),
spec.marginal_offset.len(),
spec.logslope_offset.len()
),
}
.into());
}
if spec.weights.iter().any(|&w| !w.is_finite() || w < 0.0) {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "survival-marginal-slope requires finite non-negative weights".to_string(),
}
.into());
}
if let Some(jac) = spec.score_influence_jacobian.as_ref() {
if jac.nrows() != n {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival-marginal-slope score_influence_jacobian has {} rows, expected {n}",
jac.nrows()
),
}
.into());
}
if jac.iter().any(|&v| !v.is_finite()) {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "survival-marginal-slope score_influence_jacobian must be finite"
.to_string(),
}
.into());
}
}
if spec.z.iter().any(|&zi| !zi.is_finite()) {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "survival-marginal-slope requires finite z values".to_string(),
}
.into());
}
if spec.marginal_offset.iter().any(|&value| !value.is_finite()) {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "survival-marginal-slope requires finite marginal offsets".to_string(),
}
.into());
}
if spec.logslope_offset.iter().any(|&value| !value.is_finite()) {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "survival-marginal-slope requires finite logslope offsets".to_string(),
}
.into());
}
spec.frailty.validate_for_marginal_slope()?;
match &spec.frailty {
FrailtySpec::None => {}
FrailtySpec::GaussianShift { sigma_fixed } => {
let Some(sigma) = sigma_fixed else {
return Err(SurvivalMarginalSlopeError::UnsupportedConfiguration {
reason:
"survival-marginal-slope requires GaussianShift sigma_fixed or FrailtySpec::None; learnable GaussianShift sigma is not implemented for the exact marginal-slope outer solver"
.to_string(),
}
.into());
};
if !sigma.is_finite() || *sigma < 0.0 {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: format!(
"survival-marginal-slope requires GaussianShift sigma >= 0, got {sigma}"
),
}
.into());
}
}
FrailtySpec::HazardMultiplier { .. } => {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "survival-marginal-slope does not support FrailtySpec::HazardMultiplier"
.to_string(),
}
.into());
}
}
if spec.event_target.iter().any(|&d| d != 0.0 && d != 1.0) {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "survival-marginal-slope requires binary event indicators (0.0 or 1.0)"
.to_string(),
}
.into());
}
if !spec.event_target.is_empty() && spec.event_target.iter().all(|&d| d == 0.0) {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "survival-marginal-slope requires at least one event (event==1); the supplied design is entirely censored (all event==0), which has no finite marginal-slope fit"
.to_string(),
}
.into());
}
if !spec.derivative_guard.is_finite() || spec.derivative_guard <= 0.0 {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: format!(
"survival-marginal-slope requires derivative_guard > 0, got {}",
spec.derivative_guard
),
}
.into());
}
for i in 0..n {
if spec.age_exit[i] < spec.age_entry[i] {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: format!(
"survival-marginal-slope row {i}: exit time ({}) < entry time ({})",
spec.age_exit[i], spec.age_entry[i]
),
}
.into());
}
}
let n_entry = spec.time_block.design_entry.nrows();
let n_exit = spec.time_block.design_exit.nrows();
let n_deriv = spec.time_block.design_derivative_exit.nrows();
if n_entry != n || n_exit != n || n_deriv != n {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival-marginal-slope time block design row mismatch: \
data={n}, design_entry={n_entry}, design_exit={n_exit}, design_derivative_exit={n_deriv}"
),
}
.into());
}
let p_entry = spec.time_block.design_entry.ncols();
let p_exit = spec.time_block.design_exit.ncols();
let p_deriv = spec.time_block.design_derivative_exit.ncols();
if p_exit != p_entry || p_deriv != p_entry {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival-marginal-slope time block design column mismatch: entry={p_entry}, exit={p_exit}, deriv={p_deriv}"
),
}
.into());
}
if !spec.time_block.time_monotonicity.requires_row_constraints()
&& !spec.time_block.time_monotonicity.is_coordinate_cone()
{
return Err(SurvivalMarginalSlopeError::UnsupportedConfiguration {
reason: format!(
"survival-marginal-slope requires a row-constraint or coordinate-cone time block; got {:?}",
spec.time_block.time_monotonicity
),
}
.into());
}
if spec.time_block.time_monotonicity.is_coordinate_cone() {
for (row, &offset) in spec.time_block.derivative_offset_exit.iter().enumerate() {
if !offset.is_finite() {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival-marginal-slope coordinate-cone time block has non-finite derivative offset at row {row}: {offset}"
),
}
.into());
}
if offset < spec.derivative_guard - 1e-12 {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival-marginal-slope coordinate-cone time block requires derivative offset >= guard at row {row}: offset={offset:.3e}, guard={:.3e}",
spec.derivative_guard
),
}
.into());
}
}
let derivative_design = spec
.time_block
.design_derivative_exit
.try_to_dense_by_chunks("survival marginal-slope coordinate-cone derivative audit")
.map_err(|reason| SurvivalMarginalSlopeError::IncompatibleDimensions { reason })?;
for ((row, col), &value) in derivative_design.indexed_iter() {
if !value.is_finite() {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival-marginal-slope coordinate-cone time block has non-finite derivative design entry at row {row}, col {col}: {value}"
),
}
.into());
}
if value < -1e-12 {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival-marginal-slope coordinate-cone time block requires nonnegative derivative design entries; row {row}, col {col} = {value:.3e}"
),
}
.into());
}
}
}
if let Some(beta0) = &spec.time_block.initial_beta {
match spec.time_block.time_monotonicity {
monotonicity if monotonicity.is_coordinate_cone() => {
if spec.time_block.design_derivative_exit.ncols() != beta0.len() {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival-marginal-slope time_block initial_beta length mismatch under coordinate-cone monotonicity: got {}, expected {}",
beta0.len(),
spec.time_block.design_derivative_exit.ncols()
),
}
.into());
}
for (j, &g) in beta0.iter().enumerate() {
if !g.is_finite() {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival-marginal-slope time_block initial_beta is non-finite at coordinate {j} under coordinate-cone monotonicity: got {g}"
),
}
.into());
}
if g < -1e-12 {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival-marginal-slope time_block initial_beta violates β ≥ 0 at coordinate {j} under coordinate-cone monotonicity: got {g:.3e}"
),
}
.into());
}
}
}
_ => {
let derivative_constraints = time_derivative_guard_constraints(
&spec.time_block.design_derivative_exit,
&spec.time_block.derivative_offset_exit,
spec.derivative_guard,
)?;
if let Some(constraints) = derivative_constraints.as_ref() {
if beta0.len() != constraints.a.ncols() {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival-marginal-slope time_block initial_beta length mismatch: got {}, expected {}",
beta0.len(),
constraints.a.ncols()
),
}
.into());
}
for row in 0..constraints.a.nrows() {
let slack = constraints.a.row(row).dot(beta0) - constraints.b[row];
if slack < -1e-10 {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival-marginal-slope time_block initial_beta violates derivative guard constraint at row {row}: slack={slack:.3e}"
),
}
.into());
}
}
}
}
}
}
if let Some(timewiggle) = spec.timewiggle_block.as_ref() {
if timewiggle.degree != 3 {
return Err(SurvivalMarginalSlopeError::UnsupportedConfiguration {
reason: format!(
"survival-marginal-slope timewiggle requires cubic degree=3, got {}",
timewiggle.degree
),
}
.into());
}
let derived_ncols = time_wiggle_basis_ncols(&timewiggle.knots, timewiggle.degree)?;
if derived_ncols == 0 {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason:
"survival-marginal-slope timewiggle requires at least one wiggle coefficient"
.to_string(),
}
.into());
}
if timewiggle.ncols != derived_ncols {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival-marginal-slope timewiggle metadata width mismatch: metadata={}, basis={derived_ncols}",
timewiggle.ncols
),
}
.into());
}
if spec.time_block.design_exit.ncols() < derived_ncols {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival-marginal-slope timewiggle requests {} tail columns but time block only has {} columns",
derived_ncols,
spec.time_block.design_exit.ncols()
),
}
.into());
}
}
Ok(())
}