use super::*;
use crate::families::jet_scalar::{JetScalar, Order2, OneSeed, TwoSeed};
#[derive(Clone, Copy, Debug)]
pub(crate) struct SurvivalExactRowKernel {
pub(crate) w: f64,
pub(crate) d: f64,
pub(crate) log_s0: f64,
pub(crate) r0: f64,
pub(crate) dr0: f64,
pub(crate) ddr0: f64,
pub(crate) dddr0: f64,
pub(crate) log_s1: f64,
pub(crate) r1: f64,
pub(crate) dr1: f64,
pub(crate) ddr1: f64,
pub(crate) dddr1: f64,
pub(crate) logphi1: f64,
pub(crate) dlogphi1: f64,
pub(crate) d2logphi1: f64,
pub(crate) d3logphi1: f64,
pub(crate) d4logphi1: f64,
pub(crate) log_g: f64,
pub(crate) d_log_g: f64,
pub(crate) d2_log_g: f64,
pub(crate) d3_log_g: f64,
pub(crate) d4_log_g: f64,
}
#[inline]
pub(crate) fn event_mix(d: f64, event_val: f64, censored_val: f64) -> f64 {
if d == 1.0 {
event_val
} else if d == 0.0 {
censored_val
} else {
d * event_val + (1.0 - d) * censored_val
}
}
impl SurvivalExactRowKernel {
#[inline]
pub(crate) fn log_likelihood(self) -> f64 {
self.w * (event_mix(self.d, self.logphi1 + self.log_g, self.log_s1) - self.log_s0)
}
#[inline]
pub(crate) fn nll_index_tower(self) -> crate::families::jet_tower::Tower4<3> {
use crate::families::jet_tower::Tower4;
let u0 = Tower4::<3>::variable(0.0, 0);
let u1 = Tower4::<3>::variable(0.0, 1);
let g = Tower4::<3>::variable(0.0, 2);
let mut nll = u0
.compose_unary([self.log_s0, -self.r0, -self.dr0, -self.ddr0, -self.dddr0])
.scale(self.w);
let censored_weight = self.w * (1.0 - self.d);
if censored_weight != 0.0 {
nll = nll
+ u1.compose_unary([self.log_s1, -self.r1, -self.dr1, -self.ddr1, -self.dddr1])
.scale(-censored_weight);
}
let event_weight = self.w * self.d;
if event_weight != 0.0 {
nll = nll
+ u1.compose_unary([
self.logphi1,
self.dlogphi1,
self.d2logphi1,
self.d3logphi1,
self.d4logphi1,
])
.scale(-event_weight)
+ g.compose_unary([
self.log_g,
self.d_log_g,
self.d2_log_g,
self.d3_log_g,
self.d4_log_g,
])
.scale(-event_weight);
}
nll
}
}
pub(crate) struct SurvivalJointQuantities {
pub(crate) d1_q: Array1<f64>,
pub(crate) d2_q: Array1<f64>,
pub(crate) d3_q: Array1<f64>,
pub(crate) d1_q0: Array1<f64>,
pub(crate) d2_q0: Array1<f64>,
pub(crate) d3_q0: Array1<f64>,
pub(crate) d1_q1: Array1<f64>,
pub(crate) d2_q1: Array1<f64>,
pub(crate) d3_q1: Array1<f64>,
pub(crate) d1_qdot1: Array1<f64>,
pub(crate) d2_qdot1: Array1<f64>,
pub(crate) h_time_h0: Array1<f64>,
pub(crate) h_time_h1: Array1<f64>,
pub(crate) h_time_d: Array1<f64>,
pub(crate) d_h_h0: Array1<f64>,
pub(crate) d_h_h1: Array1<f64>,
pub(crate) d_h_d: Array1<f64>,
pub(crate) dq_t: Array1<f64>,
pub(crate) dq_ls: Array1<f64>,
pub(crate) d2q_tls: Array1<f64>,
pub(crate) d2q_ls: Array1<f64>,
pub(crate) d3q_tls_ls: Array1<f64>,
pub(crate) d3q_ls: Array1<f64>,
pub(crate) dq_t_entry: Option<Array1<f64>>,
pub(crate) dq_ls_entry: Option<Array1<f64>>,
pub(crate) d2q_tls_entry: Option<Array1<f64>>,
pub(crate) d2q_ls_entry: Option<Array1<f64>>,
pub(crate) d3q_tls_ls_entry: Option<Array1<f64>>,
pub(crate) d3q_ls_entry: Option<Array1<f64>>,
pub(crate) dqdot_t: Array1<f64>,
pub(crate) dqdot_ls: Array1<f64>,
pub(crate) dqdot_td: Array1<f64>,
pub(crate) dqdot_lsd: Array1<f64>,
pub(crate) d2qdot_tt: Array1<f64>,
pub(crate) d2qdot_tls: Array1<f64>,
pub(crate) d2qdot_ttd: Array1<f64>,
pub(crate) d2qdot_tlsd: Array1<f64>,
pub(crate) d2qdot_ls: Array1<f64>,
pub(crate) d2qdot_lstd: Array1<f64>,
pub(crate) d2qdot_lslsd: Array1<f64>,
}
pub(crate) struct TimeChannelNllCurvatures {
pub(crate) h0: Array1<f64>,
pub(crate) h1: Array1<f64>,
pub(crate) d: Array1<f64>,
}
impl SurvivalJointQuantities {
pub(crate) fn time_channel_nll_curvatures(&self) -> TimeChannelNllCurvatures {
TimeChannelNllCurvatures {
h0: -&self.h_time_h0,
h1: -&self.h_time_h1,
d: -&self.h_time_d,
}
}
}
pub(crate) struct SurvivalJointPsiDirection {
pub(crate) x_t_exit_psi: Option<Array2<f64>>,
pub(crate) x_t_entry_psi: Option<Array2<f64>>,
pub(crate) x_ls_exit_psi: Option<Array2<f64>>,
pub(crate) x_ls_entry_psi: Option<Array2<f64>>,
pub(crate) z_t_exit_psi: Array1<f64>,
pub(crate) z_t_entry_psi: Array1<f64>,
pub(crate) z_ls_exit_psi: Array1<f64>,
pub(crate) z_ls_entry_psi: Array1<f64>,
pub(crate) x_t_exit_action: Option<CustomFamilyPsiDesignAction>,
pub(crate) x_t_entry_action: Option<CustomFamilyPsiDesignAction>,
pub(crate) x_ls_exit_action: Option<CustomFamilyPsiDesignAction>,
pub(crate) x_ls_entry_action: Option<CustomFamilyPsiDesignAction>,
}
pub(crate) fn split_survival_psi_design(
x_psi: &Array2<f64>,
n: usize,
time_varying: bool,
label: &str,
) -> Result<(Array2<f64>, Array2<f64>), String> {
if time_varying {
if x_psi.nrows() != 2 * n && x_psi.nrows() != 3 * n {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"{label} stacked psi design row mismatch: got {}, expected {} or {}",
x_psi.nrows(),
2 * n,
3 * n,
),
}
.into());
}
Ok((
x_psi.slice(s![0..n, ..]).to_owned(),
x_psi.slice(s![n..2 * n, ..]).to_owned(),
))
} else {
if x_psi.nrows() != n {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"{label} psi design row mismatch: got {}, expected {}",
x_psi.nrows(),
n
),
}
.into());
}
Ok((x_psi.clone(), x_psi.clone()))
}
}
pub(crate) const SLS_ROW_K: usize = 9;
pub(crate) struct SurvivalLsRowKernel<'a> {
pub(crate) family: &'a SurvivalLocationScaleFamily,
pub(crate) q: &'a SurvivalJointQuantities,
pub(crate) dynamic: &'a SurvivalDynamicGeometry,
pub(crate) deriv_log_scale: f64,
pub(crate) offsets: Vec<usize>,
}
impl SurvivalLsRowKernel<'_> {
#[inline]
pub(crate) fn entry_design<'b>(
opt: &'b Option<DesignMatrix>,
fallback: &'b DesignMatrix,
) -> &'b DesignMatrix {
opt.as_ref().unwrap_or(fallback)
}
pub(crate) fn channel_block(&self, ch: usize) -> Option<usize> {
match ch {
0 | 1 | 2 => Some(Self::THRESHOLD_BLOCK_TIME),
3 | 4 | 5 => Some(Self::THRESHOLD_BLOCK_THR),
6 | 7 | 8 => Some(Self::THRESHOLD_BLOCK_LS),
_ => None,
}
}
pub(crate) const THRESHOLD_BLOCK_TIME: usize = 0;
pub(crate) const THRESHOLD_BLOCK_THR: usize = 1;
pub(crate) const THRESHOLD_BLOCK_LS: usize = 2;
pub(crate) fn channel_row(&self, ch: usize, row: usize) -> Option<Array1<f64>> {
let fam = self.family;
match ch {
0 => Some(self.dynamic.time_jac_entry.row(row).to_owned()),
1 => Some(self.dynamic.time_jac_exit.row(row).to_owned()),
2 => Some(self.dynamic.time_jac_deriv.row(row).to_owned()),
3 => Some(design_dense_row(&fam.x_threshold, row)),
4 => Some(design_dense_row(
Self::entry_design(&fam.x_threshold_entry, &fam.x_threshold),
row,
)),
5 => fam
.x_threshold_deriv
.as_ref()
.map(|d| design_dense_row(d, row)),
6 => Some(design_dense_row(&fam.x_log_sigma, row)),
7 => Some(design_dense_row(
Self::entry_design(&fam.x_log_sigma_entry, &fam.x_log_sigma),
row,
)),
8 => fam
.x_log_sigma_deriv
.as_ref()
.map(|d| design_dense_row(d, row)),
_ => None,
}
}
pub(crate) fn row_primary_values(&self, row: usize) -> [f64; SLS_ROW_K] {
let inv_sigma_exit = self.dynamic.inv_sigma_exit[row];
let eta_t_exit = -self.dynamic.q_exit[row] / inv_sigma_exit;
let eta_ls_deriv = self.q.dqdot_t[row] / inv_sigma_exit;
let eta_t_deriv = eta_t_exit * eta_ls_deriv - self.dynamic.qdot_exit[row] / inv_sigma_exit;
[
self.dynamic.h_entry[row],
self.dynamic.h_exit[row],
self.dynamic.hdot_exit[row],
eta_t_exit,
-self.dynamic.q_entry[row] / self.dynamic.inv_sigma_entry[row],
eta_t_deriv,
self.dynamic.eta_ls_exit[row],
self.dynamic.eta_ls_entry[row],
eta_ls_deriv,
]
}
fn row_nll_inputs(
&self,
row: usize,
) -> Result<([f64; SLS_ROW_K], SurvivalExactRowKernel), String> {
let p = self.row_primary_values(row);
let state = self.family.row_predictor_state(
self.dynamic.h_entry[row],
self.dynamic.h_exit[row],
self.dynamic.hdot_exit[row],
self.dynamic.q_entry[row],
self.dynamic.q_exit[row],
self.dynamic.qdot_exit[row],
);
let kernel = self
.family
.exact_row_kernel_rescaled(row, state, self.deriv_log_scale)?
.ok_or_else(|| format!("survival location-scale row {row} has no exact kernel"))?;
Ok((p, kernel))
}
}
pub(crate) fn sls_row_nll<S: JetScalar<SLS_ROW_K>>(
vars: &[S; SLS_ROW_K],
kernel: &SurvivalExactRowKernel,
) -> Result<S, String> {
let inv_sigma_entry = vars[7].neg().exp();
let u0 = vars[0].sub(&vars[4].mul(&inv_sigma_entry));
let inv_sigma_exit = vars[6].neg().exp();
let u1 = vars[1].sub(&vars[3].mul(&inv_sigma_exit));
let g = vars[2].add(&inv_sigma_exit.mul(&vars[3].mul(&vars[8]).sub(&vars[5])));
let mut nll = u0
.compose_unary([
kernel.log_s0,
-kernel.r0,
-kernel.dr0,
-kernel.ddr0,
-kernel.dddr0,
])
.scale(kernel.w);
let censored_weight = kernel.w * (1.0 - kernel.d);
if censored_weight != 0.0 {
nll = nll.add(
&u1.compose_unary([
kernel.log_s1,
-kernel.r1,
-kernel.dr1,
-kernel.ddr1,
-kernel.dddr1,
])
.scale(-censored_weight),
);
}
let event_weight = kernel.w * kernel.d;
if event_weight != 0.0 {
nll = nll
.add(
&u1.compose_unary([
kernel.logphi1,
kernel.dlogphi1,
kernel.d2logphi1,
kernel.d3logphi1,
kernel.d4logphi1,
])
.scale(-event_weight),
)
.add(
&g.compose_unary([
kernel.log_g,
kernel.d_log_g,
kernel.d2_log_g,
kernel.d3_log_g,
kernel.d4_log_g,
])
.scale(-event_weight),
);
}
Ok(nll)
}
pub(crate) fn design_dense_row(d: &DesignMatrix, row: usize) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(d.ncols());
d.axpy_row_into(row, 1.0, &mut out.view_mut())
.expect("design_dense_row: ncols-sized buffer matches design width");
out
}
#[inline]
pub(crate) fn axpy_dense_row_into(jac: &Array2<f64>, row: usize, alpha: f64, out: &mut [f64]) {
if alpha == 0.0 {
return;
}
let jr = jac.row(row);
for (o, &j) in out.iter_mut().zip(jr.iter()) {
*o += alpha * j;
}
}
pub(crate) fn row_set_from_survival_mask(
row_mask: Option<&Array1<f64>>,
n: usize,
) -> crate::families::row_kernel::RowSet {
let Some(mask) = row_mask else {
return crate::families::row_kernel::RowSet::All;
};
let rows = mask
.iter()
.enumerate()
.filter_map(|(index, &weight)| {
(weight != 0.0).then_some(crate::outer_subsample::WeightedOuterRow {
index,
weight,
stratum: 0,
})
})
.collect::<Vec<_>>();
crate::families::row_kernel::RowSet::Subsample {
rows: Arc::new(rows),
n_full: n,
}
}
impl crate::families::row_kernel::RowKernel<SLS_ROW_K> for SurvivalLsRowKernel<'_> {
fn n_rows(&self) -> usize {
self.family.n
}
fn n_coefficients(&self) -> usize {
*self.offsets.last().expect("offsets has block bounds")
}
fn row_kernel(
&self,
row: usize,
) -> Result<(f64, [f64; SLS_ROW_K], [[f64; SLS_ROW_K]; SLS_ROW_K]), String> {
let (p, kernel) = self.row_nll_inputs(row)?;
let vars: [Order2<SLS_ROW_K>; SLS_ROW_K] =
std::array::from_fn(|a| Order2::variable(p[a], a));
let out = sls_row_nll(&vars, &kernel)?;
Ok((out.value(), out.g(), out.h()))
}
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; SLS_ROW_K] {
let d_beta = ndarray::ArrayView1::from(d_beta);
let d_time = d_beta.slice(s![self.offsets[0]..self.offsets[1]]);
let d_thr = d_beta.slice(s![self.offsets[1]..self.offsets[2]]);
let d_ls = d_beta.slice(s![self.offsets[2]..self.offsets[3]]);
let fam = self.family;
let t_entry = Self::entry_design(&fam.x_threshold_entry, &fam.x_threshold);
let ls_entry = Self::entry_design(&fam.x_log_sigma_entry, &fam.x_log_sigma);
let ch5 = fam
.x_threshold_deriv
.as_ref()
.map_or(0.0, |d| d.dot_row_view(row, d_thr));
let ch8 = fam
.x_log_sigma_deriv
.as_ref()
.map_or(0.0, |d| d.dot_row_view(row, d_ls));
[
self.dynamic.time_jac_entry.row(row).dot(&d_time),
self.dynamic.time_jac_exit.row(row).dot(&d_time),
self.dynamic.time_jac_deriv.row(row).dot(&d_time),
fam.x_threshold.dot_row_view(row, d_thr),
t_entry.dot_row_view(row, d_thr),
ch5,
fam.x_log_sigma.dot_row_view(row, d_ls),
ls_entry.dot_row_view(row, d_ls),
ch8,
]
}
fn jacobian_transpose_action(&self, row: usize, v: &[f64; SLS_ROW_K], out: &mut [f64]) {
let fam = self.family;
{
let time = &mut out[self.offsets[0]..self.offsets[1]];
axpy_dense_row_into(&self.dynamic.time_jac_entry, row, v[0], time);
axpy_dense_row_into(&self.dynamic.time_jac_exit, row, v[1], time);
axpy_dense_row_into(&self.dynamic.time_jac_deriv, row, v[2], time);
}
{
let mut thr = ndarray::ArrayViewMut1::from(&mut out[self.offsets[1]..self.offsets[2]]);
fam.x_threshold
.axpy_row_into(row, v[3], &mut thr)
.expect("threshold exit axpy");
Self::entry_design(&fam.x_threshold_entry, &fam.x_threshold)
.axpy_row_into(row, v[4], &mut thr)
.expect("threshold entry axpy");
if let Some(d) = fam.x_threshold_deriv.as_ref() {
d.axpy_row_into(row, v[5], &mut thr)
.expect("threshold deriv axpy");
}
}
{
let mut ls = ndarray::ArrayViewMut1::from(&mut out[self.offsets[2]..self.offsets[3]]);
fam.x_log_sigma
.axpy_row_into(row, v[6], &mut ls)
.expect("log_sigma exit axpy");
Self::entry_design(&fam.x_log_sigma_entry, &fam.x_log_sigma)
.axpy_row_into(row, v[7], &mut ls)
.expect("log_sigma entry axpy");
if let Some(d) = fam.x_log_sigma_deriv.as_ref() {
d.axpy_row_into(row, v[8], &mut ls)
.expect("log_sigma deriv axpy");
}
}
}
fn add_pullback_hessian(
&self,
row: usize,
h: &[[f64; SLS_ROW_K]; SLS_ROW_K],
target: &mut Array2<f64>,
) {
let rows: Vec<Option<(usize, Array1<f64>)>> = (0..SLS_ROW_K)
.map(|ch| self.channel_block(ch).zip(self.channel_row(ch, row)))
.collect();
for a in 0..SLS_ROW_K {
let Some((ba, ra)) = rows[a].as_ref() else {
continue;
};
let off_a = self.offsets[*ba];
for b in 0..SLS_ROW_K {
let hab = h[a][b];
if hab == 0.0 {
continue;
}
let Some((bb, rb)) = rows[b].as_ref() else {
continue;
};
let off_b = self.offsets[*bb];
for (ia, &va) in ra.iter().enumerate() {
if va == 0.0 {
continue;
}
let w = hab * va;
let mut trow = target.row_mut(off_a + ia);
for (ib, &vb) in rb.iter().enumerate() {
trow[off_b + ib] += w * vb;
}
}
}
}
}
fn add_diagonal_quadratic(
&self,
row: usize,
h: &[[f64; SLS_ROW_K]; SLS_ROW_K],
diag: &mut [f64],
) {
let rows: Vec<Option<(usize, Array1<f64>)>> = (0..SLS_ROW_K)
.map(|ch| self.channel_block(ch).zip(self.channel_row(ch, row)))
.collect();
for a in 0..SLS_ROW_K {
let Some((ba, ra)) = rows[a].as_ref() else {
continue;
};
for b in 0..SLS_ROW_K {
let hab = h[a][b];
if hab == 0.0 {
continue;
}
let Some((bb, rb)) = rows[b].as_ref() else {
continue;
};
if ba != bb {
continue;
}
let off = self.offsets[*ba];
for (k, (&va, &vb)) in ra.iter().zip(rb.iter()).enumerate() {
diag[off + k] += hab * va * vb;
}
}
}
}
fn row_third_contracted(
&self,
row: usize,
dir: &[f64; SLS_ROW_K],
) -> Result<[[f64; SLS_ROW_K]; SLS_ROW_K], String> {
let (p, kernel) = self.row_nll_inputs(row)?;
let vars: [OneSeed<SLS_ROW_K>; SLS_ROW_K] =
std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
Ok(sls_row_nll(&vars, &kernel)?.contracted_third())
}
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; SLS_ROW_K],
dir_v: &[f64; SLS_ROW_K],
) -> Result<[[f64; SLS_ROW_K]; SLS_ROW_K], String> {
let (p, kernel) = self.row_nll_inputs(row)?;
let vars: [TwoSeed<SLS_ROW_K>; SLS_ROW_K] =
std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
Ok(sls_row_nll(&vars, &kernel)?.contracted_fourth())
}
}
impl SurvivalLocationScaleFamily {
pub(crate) const BLOCK_TIME: usize = 0;
pub(crate) const BLOCK_THRESHOLD: usize = 1;
pub(crate) const BLOCK_LOG_SIGMA: usize = 2;
pub(crate) const BLOCK_LINK_WIGGLE: usize = 3;
pub(crate) const EVALUATE_PARALLEL_ROW_THRESHOLD: usize = 1024;
#[inline]
pub(crate) fn row_kernel_joint_hessian_supported(&self) -> bool {
false
}
#[inline]
pub(crate) fn row_kernel_directional_supported(&self) -> bool {
self.x_link_wiggle.is_none()
}
pub(crate) fn survival_ls_row_kernel<'a>(
&'a self,
q: &'a SurvivalJointQuantities,
dynamic: &'a SurvivalDynamicGeometry,
) -> SurvivalLsRowKernel<'a> {
self.survival_ls_row_kernel_rescaled(q, dynamic, 0.0)
}
pub(crate) fn survival_ls_row_kernel_rescaled<'a>(
&'a self,
q: &'a SurvivalJointQuantities,
dynamic: &'a SurvivalDynamicGeometry,
deriv_log_scale: f64,
) -> SurvivalLsRowKernel<'a> {
SurvivalLsRowKernel {
family: self,
q,
dynamic,
deriv_log_scale,
offsets: self.joint_block_offsets(),
}
}
#[inline]
pub(crate) fn time_wiggle_range(&self) -> std::ops::Range<usize> {
let p_total = self.x_time_exit.ncols();
let p_w = self.time_wiggle_ncols.min(p_total);
p_total - p_w..p_total
}
#[inline]
pub(crate) fn time_derivative_lower_bound(&self) -> f64 {
assert!(
self.derivative_guard.is_finite() && self.derivative_guard > 0.0,
"survival location-scale derivative guard must be finite and positive: derivative_guard={}",
self.derivative_guard
);
self.derivative_guard
}
pub(crate) fn max_feasible_time_step(
&self,
beta: &Array1<f64>,
delta: &Array1<f64>,
) -> Result<Option<f64>, String> {
let Some(constraints) = self.time_linear_constraints.as_ref() else {
return Ok(None);
};
crate::families::marginal_slope_shared::feasible_step_fraction(
constraints,
beta,
delta,
|beta_len, delta_len, expected| {
SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"survival location-scale time-step constraint dimension mismatch: beta={beta_len}, delta={delta_len}, constraints={expected}"
) }.into()
},
|row, slack| {
SurvivalLocationScaleError::ConstraintViolation { reason: format!(
"survival location-scale current time block violates linear constraint at row {row}: slack={slack:.3e}"
) }.into()
},
)
.map(Some)
}
pub(crate) fn max_feasible_link_wiggle_step(
&self,
beta: &Array1<f64>,
delta: &Array1<f64>,
) -> Result<Option<f64>, String> {
if beta.len() != delta.len() {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"survival location-scale linkwiggle-step dimension mismatch: beta={}, delta={}",
beta.len(),
delta.len()
),
}
.into());
}
let mut alpha = 1.0f64;
for j in 0..beta.len() {
let slack = beta[j];
if slack < -CONSTRAINT_NONNEGATIVITY_REL_TOL {
return Err(SurvivalLocationScaleError::ConstraintViolation { reason: format!(
"survival location-scale current linkwiggle block violates nonnegativity at coefficient {j}: beta={slack:.3e}"
) }.into());
}
let drift = delta[j];
if drift < 0.0 {
alpha = alpha.min((slack / -drift).clamp(0.0, 1.0));
}
}
if alpha >= 1.0 {
Ok(Some(1.0))
} else {
Ok(Some((0.995 * alpha).clamp(0.0, 1.0)))
}
}
#[inline]
pub(crate) fn expected_blocks(&self) -> usize {
if self.x_link_wiggle.is_some() { 4 } else { 3 }
}
#[inline]
pub(crate) fn joint_block_dims(&self) -> Vec<usize> {
let mut dims = vec![
self.x_time_entry.ncols(),
self.x_threshold.ncols(),
self.x_log_sigma.ncols(),
];
if let Some(xw) = self.x_link_wiggle.as_ref() {
dims.push(xw.ncols());
}
dims
}
pub(crate) fn validate_joint_specs(
&self,
specs: &[ParameterBlockSpec],
context: &str,
) -> Result<(), String> {
let dims = self.joint_block_dims();
if specs.len() != dims.len() {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"{context} expects {} specs, got {}",
dims.len(),
specs.len()
),
}
.into());
}
for (block_idx, (spec, expected_width)) in specs.iter().zip(dims.iter()).enumerate() {
let width = spec.design.ncols();
if width != *expected_width {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"{context} spec {block_idx} width mismatch: got {width}, expected {expected_width}"
),
}
.into());
}
}
Ok(())
}
#[inline]
pub(crate) fn joint_block_offsets(&self) -> Vec<usize> {
let dims = self.joint_block_dims();
let mut offsets = Vec::with_capacity(dims.len() + 1);
offsets.push(0);
let mut acc = 0usize;
for dim in dims {
acc += dim;
offsets.push(acc);
}
offsets
}
pub(crate) fn wiggle_geometry(
&self,
q0: ndarray::ArrayView1<'_, f64>,
beta_w: ndarray::ArrayView1<'_, f64>,
) -> Result<Option<SurvivalWiggleGeometry>, String> {
let (Some(knots), Some(degree)) = (self.wiggle_knots.as_ref(), self.wiggle_degree) else {
return Ok(None);
};
let basis = survival_wiggle_basis_with_options(q0, knots, degree, BasisOptions::value())?;
let basis_d1 = survival_wiggle_basis_with_options(
q0,
knots,
degree,
BasisOptions::first_derivative(),
)?;
let basis_d2 = survival_wiggle_basis_with_options(
q0,
knots,
degree,
BasisOptions::second_derivative(),
)?;
let basis_d3 = survival_wiggle_third_basis(q0, knots, degree)?;
if basis.ncols() != beta_w.len()
|| basis_d1.ncols() != beta_w.len()
|| basis_d2.ncols() != beta_w.len()
|| basis_d3.ncols() != beta_w.len()
{
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"survival linkwiggle basis/beta mismatch: B={} B'={} B''={} B'''={} betaw={}",
basis.ncols(),
basis_d1.ncols(),
basis_d2.ncols(),
basis_d3.ncols(),
beta_w.len()
),
}
.into());
}
let dq_dq0 = fast_av(&basis_d1, &beta_w) + 1.0;
let d2q_dq02 = fast_av(&basis_d2, &beta_w);
let d3q_dq03 = fast_av(&basis_d3, &beta_w);
Ok(Some(SurvivalWiggleGeometry {
basis,
basis_d1,
basis_d2,
dq_dq0,
d2q_dq02,
d3q_dq03,
}))
}
pub(crate) fn time_wiggle_geometry(
&self,
h0: ndarray::ArrayView1<'_, f64>,
beta_w: ndarray::ArrayView1<'_, f64>,
) -> Result<Option<SurvivalWiggleGeometry>, String> {
let (Some(knots), Some(degree)) =
(self.time_wiggle_knots.as_ref(), self.time_wiggle_degree)
else {
return Ok(None);
};
let basis = monotone_wiggle_basis_with_derivative_order(h0, knots, degree, 0)?;
let basis_d1 = monotone_wiggle_basis_with_derivative_order(h0, knots, degree, 1)?;
let basis_d2 = monotone_wiggle_basis_with_derivative_order(h0, knots, degree, 2)?;
let basis_d3 = monotone_wiggle_basis_with_derivative_order(h0, knots, degree, 3)?;
if basis.ncols() != beta_w.len()
|| basis_d1.ncols() != beta_w.len()
|| basis_d2.ncols() != beta_w.len()
|| basis_d3.ncols() != beta_w.len()
{
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"survival timewiggle basis/beta mismatch: B={} B'={} B''={} B'''={} betaw={}",
basis.ncols(),
basis_d1.ncols(),
basis_d2.ncols(),
basis_d3.ncols(),
beta_w.len()
),
}
.into());
}
let dq = fast_av(&basis_d1, &beta_w) + 1.0;
let d2 = fast_av(&basis_d2, &beta_w);
let d3 = fast_av(&basis_d3, &beta_w);
Ok(Some(SurvivalWiggleGeometry {
basis,
basis_d1,
basis_d2,
dq_dq0: dq,
d2q_dq02: d2,
d3q_dq03: d3,
}))
}
pub(crate) fn validate_joint_states<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<
(
ndarray::ArrayView1<'a, f64>,
ndarray::ArrayView1<'a, f64>,
ndarray::ArrayView1<'a, f64>,
ndarray::ArrayView1<'a, f64>,
ndarray::ArrayView1<'a, f64>,
ndarray::ArrayView1<'a, f64>,
ndarray::ArrayView1<'a, f64>,
Option<ndarray::ArrayView1<'a, f64>>,
Option<ndarray::ArrayView1<'a, f64>>,
Option<&'a Array1<f64>>,
),
String,
> {
crate::families::block_layout::block_count::validate_block_count::<
SurvivalLocationScaleError,
>(
"SurvivalLocationScaleFamily",
self.expected_blocks(),
block_states.len(),
)?;
let n = self.n;
let eta_time = &block_states[Self::BLOCK_TIME].eta;
let eta_t_raw = &block_states[Self::BLOCK_THRESHOLD].eta;
let eta_ls_raw = &block_states[Self::BLOCK_LOG_SIGMA].eta;
let etaw = self
.x_link_wiggle
.as_ref()
.map(|_| &block_states[Self::BLOCK_LINK_WIGGLE].eta);
if eta_time.len() != 3 * n {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"survival location-scale time eta length mismatch: got {}, expected {}",
eta_time.len(),
3 * n
),
}
.into());
}
let (eta_t_exit, eta_t_entry, eta_t_deriv_exit) = if self.x_threshold_entry.is_some() {
if eta_t_raw.len() != 3 * n {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"time-varying threshold eta length mismatch: got {}, expected {}",
eta_t_raw.len(),
3 * n
),
}
.into());
}
(
eta_t_raw.slice(s![0..n]),
eta_t_raw.slice(s![n..2 * n]),
Some(eta_t_raw.slice(s![2 * n..3 * n])),
)
} else {
if eta_t_raw.len() != n {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"threshold eta length mismatch: got {}, expected {n}",
eta_t_raw.len()
),
}
.into());
}
(eta_t_raw.slice(s![0..n]), eta_t_raw.slice(s![0..n]), None)
};
let (eta_ls_exit, eta_ls_entry, eta_ls_deriv_exit) = if self.x_log_sigma_entry.is_some() {
if eta_ls_raw.len() != 3 * n {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"time-varying log-sigma eta length mismatch: got {}, expected {}",
eta_ls_raw.len(),
3 * n
),
}
.into());
}
(
eta_ls_raw.slice(s![0..n]),
eta_ls_raw.slice(s![n..2 * n]),
Some(eta_ls_raw.slice(s![2 * n..3 * n])),
)
} else {
if eta_ls_raw.len() != n {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"log-sigma eta length mismatch: got {}, expected {n}",
eta_ls_raw.len()
),
}
.into());
}
(eta_ls_raw.slice(s![0..n]), eta_ls_raw.slice(s![0..n]), None)
};
if let Some(w) = etaw
&& w.len() != n
{
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"survival location-scale wiggle eta length mismatch: got {}, expected {n}",
w.len()
),
}
.into());
}
Ok((
eta_time.slice(s![0..n]),
eta_time.slice(s![n..2 * n]),
eta_time.slice(s![2 * n..3 * n]),
eta_t_exit,
eta_ls_exit,
eta_t_entry,
eta_ls_entry,
eta_t_deriv_exit,
eta_ls_deriv_exit,
etaw,
))
}
pub(crate) fn collect_joint_quantities(
&self,
block_states: &[ParameterBlockState],
) -> Result<SurvivalJointQuantities, String> {
self.collect_joint_quantities_rescaled(block_states, 0.0)
}
pub(crate) fn collect_joint_quantities_rescaled(
&self,
block_states: &[ParameterBlockState],
deriv_log_scale: f64,
) -> Result<SurvivalJointQuantities, String> {
let n = self.n;
let dynamic = self.build_dynamic_geometry(block_states)?;
let mut d1_q = Array1::<f64>::zeros(n);
let mut d2_q = Array1::<f64>::zeros(n);
let mut d3_q = Array1::<f64>::zeros(n);
let mut d1_q0 = Array1::<f64>::zeros(n);
let mut d2_q0 = Array1::<f64>::zeros(n);
let mut d3_q0 = Array1::<f64>::zeros(n);
let mut d1_q1 = Array1::<f64>::zeros(n);
let mut d2_q1 = Array1::<f64>::zeros(n);
let mut d3_q1 = Array1::<f64>::zeros(n);
let mut d1_qdot1 = Array1::<f64>::zeros(n);
let mut d2_qdot1 = Array1::<f64>::zeros(n);
let mut h_time_h0 = Array1::<f64>::zeros(n);
let mut h_time_h1 = Array1::<f64>::zeros(n);
let mut h_time_d = Array1::<f64>::zeros(n);
let mut d_h_h0 = Array1::<f64>::zeros(n);
let mut d_h_h1 = Array1::<f64>::zeros(n);
let mut d_h_d = Array1::<f64>::zeros(n);
#[derive(Clone, Copy)]
struct SendPtr(*mut f64);
unsafe impl Send for SendPtr {}
unsafe impl Sync for SendPtr {}
impl SendPtr {
#[inline(always)]
unsafe fn write(self, i: usize, v: f64) {
unsafe { *self.0.add(i) = v };
}
}
let p_d1_q = SendPtr(d1_q.as_mut_ptr());
let p_d2_q = SendPtr(d2_q.as_mut_ptr());
let p_d3_q = SendPtr(d3_q.as_mut_ptr());
let p_d1_q0 = SendPtr(d1_q0.as_mut_ptr());
let p_d2_q0 = SendPtr(d2_q0.as_mut_ptr());
let p_d3_q0 = SendPtr(d3_q0.as_mut_ptr());
let p_d1_q1 = SendPtr(d1_q1.as_mut_ptr());
let p_d2_q1 = SendPtr(d2_q1.as_mut_ptr());
let p_d3_q1 = SendPtr(d3_q1.as_mut_ptr());
let p_d1_qdot1 = SendPtr(d1_qdot1.as_mut_ptr());
let p_d2_qdot1 = SendPtr(d2_qdot1.as_mut_ptr());
let p_h_time_h0 = SendPtr(h_time_h0.as_mut_ptr());
let p_h_time_h1 = SendPtr(h_time_h1.as_mut_ptr());
let p_h_time_d = SendPtr(h_time_d.as_mut_ptr());
let p_d_h_h0 = SendPtr(d_h_h0.as_mut_ptr());
let p_d_h_h1 = SendPtr(d_h_h1.as_mut_ptr());
let p_d_h_d = SendPtr(d_h_d.as_mut_ptr());
let dyn_ref = &dynamic;
(0..n)
.into_par_iter()
.try_for_each(move |i| -> Result<(), String> {
let state = self.row_predictor_state(
dyn_ref.h_entry[i],
dyn_ref.h_exit[i],
dyn_ref.hdot_exit[i],
dyn_ref.q_entry[i],
dyn_ref.q_exit[i],
dyn_ref.qdot_exit[i],
);
let Some(row) = self.row_derivatives_rescaled(i, state, deriv_log_scale)? else {
return Ok(());
};
unsafe {
p_d1_q.write(i, row.d1_q);
p_d2_q.write(i, row.d2_q);
p_d3_q.write(i, row.d3_q);
p_d1_q0.write(i, row.d1_q0);
p_d2_q0.write(i, row.d2_q0);
p_d3_q0.write(i, row.d3_q0);
p_d1_q1.write(i, row.d1_q1);
p_d2_q1.write(i, row.d2_q1);
p_d3_q1.write(i, row.d3_q1);
p_d1_qdot1.write(i, row.d1_qdot1);
p_d2_qdot1.write(i, row.d2_qdot1);
p_h_time_h0.write(i, row.h_time_h0);
p_h_time_h1.write(i, row.h_time_h1);
p_h_time_d.write(i, row.h_time_d);
p_d_h_h0.write(i, row.d_h_h0);
p_d_h_h1.write(i, row.d_h_h1);
p_d_h_d.write(i, row.d_h_d);
}
Ok(())
})?;
Ok(SurvivalJointQuantities {
d1_q,
d2_q,
d3_q,
d1_q0,
d2_q0,
d3_q0,
d1_q1,
d2_q1,
d3_q1,
d1_qdot1,
d2_qdot1,
h_time_h0,
h_time_h1,
h_time_d,
d_h_h0,
d_h_h1,
d_h_d,
dq_t: dynamic.dq_t_exit,
dq_ls: dynamic.dq_ls_exit,
d2q_tls: dynamic.d2q_tls_exit,
d2q_ls: dynamic.d2q_ls_exit,
d3q_tls_ls: dynamic.d3q_tls_ls_exit,
d3q_ls: dynamic.d3q_ls_exit,
dq_t_entry: Some(dynamic.dq_t_entry),
dq_ls_entry: Some(dynamic.dq_ls_entry),
d2q_tls_entry: Some(dynamic.d2q_tls_entry),
d2q_ls_entry: Some(dynamic.d2q_ls_entry),
d3q_tls_ls_entry: Some(dynamic.d3q_tls_ls_entry),
d3q_ls_entry: Some(dynamic.d3q_ls_entry),
dqdot_t: dynamic.dqdot_t,
dqdot_ls: dynamic.dqdot_ls,
dqdot_td: dynamic.dqdot_td,
dqdot_lsd: dynamic.dqdot_lsd,
d2qdot_tt: dynamic.d2qdot_tt,
d2qdot_tls: dynamic.d2qdot_tls,
d2qdot_ttd: dynamic.d2qdot_ttd,
d2qdot_tlsd: dynamic.d2qdot_tlsd,
d2qdot_ls: dynamic.d2qdot_ls,
d2qdot_lstd: dynamic.d2qdot_lstd,
d2qdot_lslsd: dynamic.d2qdot_lslsd,
})
}
pub(crate) fn offset_channel_geometry(
&self,
block_states: &[ParameterBlockState],
) -> Result<(OffsetChannelResiduals, OffsetChannelCurvatures), String> {
let n = self.n;
if block_states.is_empty() {
log::warn!(
"SurvivalLocationScaleFamily::offset_channel_geometry: \
block_states is empty (degraded fit, likely ARC \
deterministic-replay stall); returning zero residuals + \
curvatures (n={n})"
);
return Ok((
OffsetChannelResiduals {
exit: Array1::<f64>::zeros(n),
entry: Array1::<f64>::zeros(n),
derivative: Array1::<f64>::zeros(n),
right: Array1::<f64>::zeros(n),
},
OffsetChannelCurvatures {
rows: vec![[[0.0_f64; 3]; 3]; n],
},
));
}
let dynamic = self.build_dynamic_geometry(block_states)?;
let mut entry = Array1::<f64>::zeros(n);
let mut exit = Array1::<f64>::zeros(n);
let mut derivative = Array1::<f64>::zeros(n);
let mut curvatures = vec![[[0.0_f64; 3]; 3]; n];
let rows = (0..n)
.into_par_iter()
.map(
|i| -> Result<(usize, f64, f64, f64, [[f64; 3]; 3]), String> {
let state = self.row_predictor_state(
dynamic.h_entry[i],
dynamic.h_exit[i],
dynamic.hdot_exit[i],
dynamic.q_entry[i],
dynamic.q_exit[i],
dynamic.qdot_exit[i],
);
let Some(row) = self.row_derivatives(i, state)? else {
return Ok((i, 0.0, 0.0, 0.0, [[0.0; 3]; 3]));
};
let [r_entry, r_exit, r_deriv] = row.time_channel_nll_gradient();
let curv_diag = row.time_channel_nll_curvature_diag();
let mut curv = [[0.0_f64; 3]; 3];
curv[0][0] = curv_diag[0];
curv[1][1] = curv_diag[1];
curv[2][2] = curv_diag[2];
Ok((i, r_entry, r_exit, r_deriv, curv))
},
)
.collect::<Result<Vec<_>, String>>()?;
for (i, r_entry, r_exit, r_deriv, curv) in rows {
entry[i] = r_entry;
exit[i] = r_exit;
derivative[i] = r_deriv;
curvatures[i] = curv;
}
Ok((
OffsetChannelResiduals {
exit,
entry,
derivative,
right: Array1::<f64>::zeros(n),
},
OffsetChannelCurvatures { rows: curvatures },
))
}
pub(crate) fn link_param_data_fit_gradient(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array1<f64>>, String> {
use crate::solver::mixture_link::{InverseLinkKernel, LinkParamPartials};
let n = self.n;
if block_states.is_empty() {
return Ok(None);
}
let probe = self
.inverse_link
.param_partials(0.0)
.map_err(|e| format!("inverse-link param partials probe failed: {e}"))?;
let n_theta = match &probe {
None => return Ok(None),
Some(LinkParamPartials::Sas(_)) => 2,
Some(LinkParamPartials::Mixture(m)) => m.djet_drho.len(),
};
if n_theta == 0 {
return Ok(None);
}
let dynamic = self.build_dynamic_geometry(block_states)?;
let dlog_survival_dtheta = |u: f64| -> Result<Vec<f64>, String> {
let partials = self
.inverse_link
.param_partials(u)
.map_err(|e| format!("inverse-link survival param partials failed: {e}"))?
.ok_or_else(|| "inverse-link reported no param partials".to_string())?;
let jet = self
.inverse_link
.jet(u)
.map_err(|e| format!("inverse-link jet failed at u={u}: {e}"))?;
let s = (1.0 - jet.mu).clamp(f64::MIN_POSITIVE, 1.0);
let map = |dmu: f64| -dmu / s;
Ok(match partials {
LinkParamPartials::Sas(p) => {
vec![map(p.djet_depsilon.mu), map(p.djet_dlog_delta.mu)]
}
LinkParamPartials::Mixture(p) => p.djet_drho.iter().map(|j| map(j.mu)).collect(),
})
};
let dlog_pdf_dtheta = |u: f64| -> Result<Vec<f64>, String> {
let partials = self
.inverse_link
.param_partials(u)
.map_err(|e| format!("inverse-link pdf param partials failed: {e}"))?
.ok_or_else(|| "inverse-link reported no param partials".to_string())?;
let jet = self
.inverse_link
.jet(u)
.map_err(|e| format!("inverse-link jet failed at u={u}: {e}"))?;
let f = jet.d1;
if !(f.is_finite() && f > 0.0) {
return Err(format!(
"inverse-link pdf (d1) must be finite positive for θ-gradient, got {f} at u={u}"
));
}
let map = |dd1: f64| dd1 / f;
Ok(match partials {
LinkParamPartials::Sas(p) => {
vec![map(p.djet_depsilon.d1), map(p.djet_dlog_delta.d1)]
}
LinkParamPartials::Mixture(p) => p.djet_drho.iter().map(|j| map(j.d1)).collect(),
})
};
let mut grad = Array1::<f64>::zeros(n_theta);
for i in 0..n {
let w = self.w[i];
if w <= 0.0 {
continue;
}
let d = self.validated_event_target(i)?;
let u0 = dynamic.h_entry[i] + dynamic.q_entry[i];
let u1 = dynamic.h_exit[i] + dynamic.q_exit[i];
let dls_u0 = dlog_survival_dtheta(u0)?;
for k in 0..n_theta {
grad[k] += w * dls_u0[k];
}
if d <= 0.0 {
let dls_u1 = dlog_survival_dtheta(u1)?;
for k in 0..n_theta {
grad[k] -= w * dls_u1[k];
}
} else if d >= 1.0 {
let dlp_u1 = dlog_pdf_dtheta(u1)?;
for k in 0..n_theta {
grad[k] -= w * dlp_u1[k];
}
} else {
let dls_u1 = dlog_survival_dtheta(u1)?;
let dlp_u1 = dlog_pdf_dtheta(u1)?;
for k in 0..n_theta {
grad[k] -= w * (d * dlp_u1[k] + (1.0 - d) * dls_u1[k]);
}
}
}
Ok(Some(grad))
}
pub(crate) fn exact_newton_joint_psi_direction(
&self,
block_states: &[ParameterBlockState],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
psi_index: usize,
) -> Result<Option<SurvivalJointPsiDirection>, String> {
if block_states.len() != self.expected_blocks()
|| derivative_blocks.len() != self.expected_blocks()
{
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily joint psi direction expects {} blocks and derivative lists, got {} and {}",
self.expected_blocks(),
block_states.len(),
derivative_blocks.len()
) }.into());
}
let n = self.n;
let pt = self.x_threshold.ncols();
let pls = self.x_log_sigma.ncols();
let beta_t = &block_states[Self::BLOCK_THRESHOLD].beta;
let beta_ls = &block_states[Self::BLOCK_LOG_SIGMA].beta;
let t_time_varying = self.x_threshold_entry.is_some();
let ls_time_varying = self.x_log_sigma_entry.is_some();
let mut global = 0usize;
for (block_idx, block_derivs) in derivative_blocks.iter().enumerate() {
for deriv in block_derivs {
if global == psi_index {
let mut x_t_exit_psi = None;
let mut x_t_entry_psi = None;
let mut x_ls_exit_psi = None;
let mut x_ls_entry_psi = None;
let mut x_t_exit_action = None;
let mut x_t_entry_action = None;
let mut x_ls_exit_action = None;
let mut x_ls_entry_action = None;
let mut z_t_exit_psi = Array1::<f64>::zeros(n);
let mut z_t_entry_psi = Array1::<f64>::zeros(n);
let mut z_ls_exit_psi = Array1::<f64>::zeros(n);
let mut z_ls_entry_psi = Array1::<f64>::zeros(n);
match block_idx {
Self::BLOCK_THRESHOLD => {
let total_rows = if t_time_varying { 3 * n } else { n };
match resolve_custom_family_x_psi_map(
deriv,
total_rows,
pt,
0..total_rows,
"SurvivalLocationScaleFamily threshold",
&self.policy,
)? {
PsiDesignMap::First { action } => {
if t_time_varying {
let exit_action = action.slice_rows(0..n)?;
let entry_action = action.slice_rows(n..2 * n)?;
z_t_exit_psi = exit_action.forward_mul(beta_t.view());
z_t_entry_psi = entry_action.forward_mul(beta_t.view());
x_t_exit_action = Some(exit_action);
x_t_entry_action = Some(entry_action);
} else {
z_t_exit_psi = action.forward_mul(beta_t.view());
z_t_entry_psi = z_t_exit_psi.clone();
x_t_exit_action = Some(action.clone());
x_t_entry_action = Some(action);
}
}
PsiDesignMap::Dense { matrix } => {
let (exit, entry) = split_survival_psi_design(
&matrix,
n,
t_time_varying,
"SurvivalLocationScaleFamily threshold",
)?;
z_t_exit_psi = fast_av(&exit, beta_t);
z_t_entry_psi = fast_av(&entry, beta_t);
x_t_exit_psi = Some(exit);
x_t_entry_psi = Some(entry);
}
PsiDesignMap::Zero { .. } => {}
PsiDesignMap::Second { .. } => {
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: "SurvivalLocationScaleFamily threshold: unexpected Second variant from _psi_map"
.to_string(), }.into());
}
}
}
Self::BLOCK_LOG_SIGMA => {
let total_rows = if ls_time_varying { 3 * n } else { n };
match resolve_custom_family_x_psi_map(
deriv,
total_rows,
pls,
0..total_rows,
"SurvivalLocationScaleFamily log-sigma",
&self.policy,
)? {
PsiDesignMap::First { action } => {
if ls_time_varying {
let exit_action = action.slice_rows(0..n)?;
let entry_action = action.slice_rows(n..2 * n)?;
z_ls_exit_psi = exit_action.forward_mul(beta_ls.view());
z_ls_entry_psi = entry_action.forward_mul(beta_ls.view());
x_ls_exit_action = Some(exit_action);
x_ls_entry_action = Some(entry_action);
} else {
z_ls_exit_psi = action.forward_mul(beta_ls.view());
z_ls_entry_psi = z_ls_exit_psi.clone();
x_ls_exit_action = Some(action.clone());
x_ls_entry_action = Some(action);
}
}
PsiDesignMap::Dense { matrix } => {
let (exit, entry) = split_survival_psi_design(
&matrix,
n,
ls_time_varying,
"SurvivalLocationScaleFamily log-sigma",
)?;
z_ls_exit_psi = fast_av(&exit, beta_ls);
z_ls_entry_psi = fast_av(&entry, beta_ls);
x_ls_exit_psi = Some(exit);
x_ls_entry_psi = Some(entry);
}
PsiDesignMap::Zero { .. } => {}
PsiDesignMap::Second { .. } => {
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: "SurvivalLocationScaleFamily log-sigma: unexpected Second variant from _psi_map"
.to_string(), }.into());
}
}
}
_ => return Ok(None),
}
return Ok(Some(SurvivalJointPsiDirection {
x_t_exit_psi,
x_t_entry_psi,
x_ls_exit_psi,
x_ls_entry_psi,
z_t_exit_psi,
z_t_entry_psi,
z_ls_exit_psi,
z_ls_entry_psi,
x_t_exit_action,
x_t_entry_action,
x_ls_exit_action,
x_ls_entry_action,
}));
}
global += 1;
}
}
Ok(None)
}
pub(crate) fn survival_ratio_first_derivative(f: f64, fp: f64, s: f64) -> (f64, f64) {
let r = f / s;
let dr = (r * r) + fp / s;
(r, dr)
}
pub(crate) fn survival_ratiosecond_derivative(
r: f64,
dr: f64,
f: f64,
fp: f64,
fpp: f64,
s: f64,
) -> f64 {
(2.0 * r * dr) + (fpp / s + fp * f / (s * s))
}
pub(crate) fn survival_ratio_third_derivative(
r: f64,
dr: f64,
ddr: f64,
f: f64,
fp: f64,
fpp: f64,
fppp: f64,
s: f64,
) -> f64 {
let s2 = s * s;
let s3 = s2 * s;
2.0 * dr * dr
+ 2.0 * r * ddr
+ fppp / s
+ 2.0 * fpp * f / s2
+ fp * fp / s2
+ 2.0 * fp * f * f / s3
}
pub(crate) fn exact_log_pdf_derivatives_rescaled(
inverse_link: &InverseLink,
eta: f64,
deriv_log_scale: f64,
) -> Result<(f64, f64, f64, f64, f64), String> {
match inverse_link {
InverseLink::Standard(StandardLink::Probit) => Ok((
-0.5 * eta * eta - 0.5 * (2.0 * std::f64::consts::PI).ln(),
-eta,
-1.0,
0.0,
0.0,
)),
InverseLink::Standard(StandardLink::Logit) => {
let mu = crate::solver::mixture_link::component_inverse_link_jet(
crate::types::LinkComponent::Logit,
eta,
)
.mu;
let w = mu * (1.0 - mu);
Ok((
-softplus(eta) - softplus(-eta),
1.0 - 2.0 * mu,
-2.0 * w,
-2.0 * w * (1.0 - 2.0 * mu),
-2.0 * w * (1.0 - 6.0 * w),
))
}
InverseLink::Standard(StandardLink::CLogLog) => {
let t_val = eta.exp(); let t_deriv = (eta - deriv_log_scale).exp(); let deriv_scale = (-deriv_log_scale).exp();
Ok((
eta - t_val,
deriv_scale - t_deriv,
-t_deriv,
-t_deriv,
-t_deriv,
))
}
InverseLink::Standard(StandardLink::Identity) => Ok((0.0, 0.0, 0.0, 0.0, 0.0)),
_ => {
let jet = inverse_link_jet_for_inverse_link(inverse_link, eta)
.map_err(|e| format!("inverse link evaluation failed at eta={eta}: {e}"))?;
let f = jet.d1;
if !(f.is_finite() && f > 0.0) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: format!(
"inverse-link pdf must be finite and positive, got {f} at eta={eta}"
),
}
.into());
}
let fp = jet.d2;
let fpp = jet.d3;
let fppp = inverse_link_pdfthird_derivative_for_inverse_link(inverse_link, eta)
.map_err(|e| {
format!("inverse link third-derivative evaluation failed at eta={eta}: {e}")
})?;
let fpppp = inverse_link_pdffourth_derivative(inverse_link, eta)?;
let d1 = fp / f;
let d2 = fpp / f - d1 * d1;
let d3 = fppp / f - 3.0 * fp * fpp / (f * f) + 2.0 * fp.powi(3) / f.powi(3);
let d4 = fpppp / f - 4.0 * fp * fppp / f.powi(2) - 3.0 * fpp * fpp / f.powi(2)
+ 12.0 * fp.powi(2) * fpp / f.powi(3)
- 6.0 * fp.powi(4) / f.powi(4);
Ok((f.ln(), d1, d2, d3, d4))
}
}
}
pub(crate) fn exact_survival_neglog_derivatives_fourth_rescaled(
inverse_link: &InverseLink,
eta: f64,
) -> Result<(f64, f64, f64, f64, f64), String> {
match inverse_link {
InverseLink::Standard(StandardLink::Probit) => {
let (log_s, r, dr, ddr, dddr) = probit_log_survival_and_ratio_derivatives(eta);
Ok((log_s, r, dr, ddr, dddr))
}
InverseLink::Standard(StandardLink::Logit) => {
let mu = crate::solver::mixture_link::component_inverse_link_jet(
crate::types::LinkComponent::Logit,
eta,
)
.mu;
let w = mu * (1.0 - mu);
Ok((
-softplus(eta),
mu,
w,
w * (1.0 - 2.0 * mu),
w * (1.0 - 6.0 * w),
))
}
InverseLink::Standard(StandardLink::CLogLog) => {
let t = eta.exp();
Ok((-t, t, t, t, t))
}
InverseLink::Standard(StandardLink::Identity) => {
let s = 1.0 - eta;
if !(s.is_finite() && s > 0.0) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: format!("identity-link survival invalid at eta={eta}: S={s}"),
}
.into());
}
let inv = s.recip();
Ok((s.ln(), inv, inv * inv, 2.0 * inv.powi(3), 6.0 * inv.powi(4)))
}
_ => {
let jet = inverse_link_jet_for_inverse_link(inverse_link, eta)
.map_err(|e| format!("inverse link evaluation failed at eta={eta}: {e}"))?;
let s = inverse_link_survival_probvalue(inverse_link, eta);
if !(s.is_finite() && s > 0.0 && s <= 1.0) {
return Err(SurvivalLocationScaleError::NumericalFailure { reason: format!(
"inverse-link survival probability must lie in (0,1], got {s} at eta={eta}"
) }.into());
}
let fppp = inverse_link_pdfthird_derivative_for_inverse_link(inverse_link, eta)
.map_err(|e| {
format!("inverse link third-derivative evaluation failed at eta={eta}: {e}")
})?;
let (r, dr) = Self::survival_ratio_first_derivative(jet.d1, jet.d2, s);
let ddr = Self::survival_ratiosecond_derivative(r, dr, jet.d1, jet.d2, jet.d3, s);
let dddr = Self::survival_ratio_third_derivative(
r, dr, ddr, jet.d1, jet.d2, jet.d3, fppp, s,
);
Ok((s.ln(), r, dr, ddr, dddr))
}
}
}
#[inline]
pub(crate) fn clglog_exit_pair(
u1: f64,
deriv_log_scale: f64,
) -> ((f64, f64, f64, f64, f64), (f64, f64, f64, f64, f64)) {
let t_val = u1.exp();
let t_deriv = (u1 - deriv_log_scale).exp();
let deriv_scale = (-deriv_log_scale).exp();
let surv = (-t_val, t_val, t_val, t_val, t_val);
let logpdf = (
u1 - t_val,
deriv_scale - t_deriv,
-t_deriv,
-t_deriv,
-t_deriv,
);
(surv, logpdf)
}
pub(crate) fn logwith_derivatives_positive(x: f64) -> (f64, f64, f64, f64, f64) {
assert!(
x.is_finite() && x > 0.0,
"log derivative kernel requires finite positive x: x={x}"
);
let inv = 1.0 / x;
(
x.ln(),
inv,
-inv * inv,
2.0 * inv * inv * inv,
-6.0 * inv * inv * inv * inv,
)
}
pub(crate) fn row_predictor_state(
&self,
h0: f64,
h1: f64,
d_raw: f64,
q0: f64,
q1: f64,
qdot1: f64,
) -> SurvivalPredictorState {
let g_diff = compensated_difference(d_raw, -qdot1);
SurvivalPredictorState {
h0,
h1,
g: g_diff.value,
q0,
q1,
g_roundoff_slack: g_diff.roundoff_slack,
g_operand_scale: g_diff.operand_scale,
}
}
#[inline]
pub(crate) fn validated_event_target(&self, row: usize) -> Result<f64, String> {
let d = self.y[row];
if !(d.is_finite() && (0.0..=1.0).contains(&d)) {
return Err(SurvivalLocationScaleError::ConstraintViolation {
reason: format!(
"survival location-scale event target must lie in [0,1] at row {row}, got {d}"
),
}
.into());
}
Ok(d)
}
pub(crate) fn exact_row_kernel(
&self,
row: usize,
state: SurvivalPredictorState,
) -> Result<Option<SurvivalExactRowKernel>, String> {
self.exact_row_kernel_rescaled(row, state, 0.0)
}
pub(crate) fn exact_row_kernel_rescaled(
&self,
row: usize,
state: SurvivalPredictorState,
deriv_log_scale: f64,
) -> Result<Option<SurvivalExactRowKernel>, String> {
let w = self.w[row];
if w <= 0.0 {
return Ok(None);
}
let d = self.validated_event_target(row)?;
let u0 = state.h0 + state.q0;
let u1 = state.h1 + state.q1;
let (log_s0, r0, dr0, ddr0, dddr0) =
Self::exact_survival_neglog_derivatives_fourth_rescaled(&self.inverse_link, u0)
.map_err(|e| {
format!("inverse-link survival evaluation failed at row {row} entry: {e}")
})?;
let ((log_s1, r1, dr1, ddr1, dddr1), (logphi1, dlogphi1, d2logphi1, d3logphi1, d4logphi1)) =
if matches!(
&self.inverse_link,
InverseLink::Standard(StandardLink::CLogLog)
) {
Self::clglog_exit_pair(u1, deriv_log_scale)
} else {
let surv =
Self::exact_survival_neglog_derivatives_fourth_rescaled(&self.inverse_link, u1)
.map_err(|e| {
format!(
"inverse-link survival evaluation failed at row {row} exit: {e}"
)
})?;
let pdf = Self::exact_log_pdf_derivatives_rescaled(
&self.inverse_link,
u1,
deriv_log_scale,
)
.map_err(|e| {
format!("inverse-link log-pdf evaluation failed at row {row} exit: {e}")
})?;
(surv, pdf)
};
if !(r0.is_finite()
&& dr0.is_finite()
&& ddr0.is_finite()
&& dddr0.is_finite()
&& r1.is_finite()
&& dr1.is_finite()
&& ddr1.is_finite()
&& dddr1.is_finite()
&& dlogphi1.is_finite()
&& d2logphi1.is_finite()
&& d3logphi1.is_finite()
&& d4logphi1.is_finite())
{
log::debug!(
"skipping row {row}: survival derivatives non-finite \
(u0={u0:.2e}, u1={u1:.2e})"
);
return Ok(None);
}
let guard = self.time_derivative_lower_bound();
let mut g = state.g;
if g.is_nan() {
return Err(SurvivalLocationScaleError::NumericalFailure { reason: format!(
"survival location-scale time derivative is non-finite at row {row}: d_eta/dt={g}"
) }.into());
}
if g == f64::INFINITY {
g = f64::MAX;
} else if g == f64::NEG_INFINITY {
g = f64::MIN;
}
let legacy_slack = MONOTONICITY_GUARD_SLACK_REL
* (1.0
+ state
.h0
.abs()
.max(state.h1.abs())
.max(state.q0.abs())
.max(state.q1.abs()));
let roundoff_slack = state.g_roundoff_slack.max(legacy_slack);
if g < guard && g >= guard - roundoff_slack {
g = guard;
}
if g > 0.0 && g < guard {
g = guard;
}
let cancellation_floor = guard + roundoff_slack;
if g <= 0.0 && g >= -cancellation_floor {
g = guard;
}
if g <= 0.0 {
return Err(SurvivalLocationScaleError::ConstraintViolation {
reason: format!(
"survival location-scale monotonicity violated at row {row}: \
d_eta/dt={g:.3e} <= 0 (lower_bound={guard:.3e}) \
(operand_scale={:.3e}, roundoff_slack={roundoff_slack:.3e})",
state.g_operand_scale
),
}
.into());
}
let (log_g, d_log_g, d2_log_g, d3_log_g, d4_log_g) = Self::logwith_derivatives_positive(g);
Ok(Some(SurvivalExactRowKernel {
w,
d,
log_s0,
r0,
dr0,
ddr0,
dddr0,
log_s1,
r1,
dr1,
ddr1,
dddr1,
logphi1,
dlogphi1,
d2logphi1,
d3logphi1,
d4logphi1,
log_g,
d_log_g,
d2_log_g,
d3_log_g,
d4_log_g,
}))
}
pub(crate) fn row_derivatives(
&self,
row: usize,
state: SurvivalPredictorState,
) -> Result<Option<SurvivalRowDerivatives>, String> {
self.row_derivatives_rescaled(row, state, 0.0)
}
pub(crate) fn row_derivatives_rescaled(
&self,
row: usize,
state: SurvivalPredictorState,
deriv_log_scale: f64,
) -> Result<Option<SurvivalRowDerivatives>, String> {
let Some(kernel) = self.exact_row_kernel_rescaled(row, state, deriv_log_scale)? else {
return Ok(None);
};
let tower = kernel.nll_index_tower();
let d1_q0 = -tower.g[0];
let d2_q0 = -tower.h[0][0];
let d3_q0 = -tower.t3[0][0][0];
let d1_q1 = -tower.g[1];
let d2_q1 = -tower.h[1][1];
let d3_q1 = -tower.t3[1][1][1];
let d1_qdot1 = -tower.g[2];
let d2_qdot1 = -tower.h[2][2];
let d1_q = d1_q0 + d1_q1;
let d2_q = d2_q0 + d2_q1;
let d3_q = d3_q0 + d3_q1;
Ok(Some(SurvivalRowDerivatives {
ll: kernel.log_likelihood(),
d1_q,
d2_q,
d3_q,
d1_q0,
d2_q0,
d3_q0,
d1_q1,
d2_q1,
d3_q1,
d1_qdot1,
d2_qdot1,
grad_time_eta_h0: d1_q0,
grad_time_eta_h1: d1_q1,
grad_time_eta_d: d1_qdot1,
h_time_h0: d2_q0,
h_time_h1: d2_q1,
h_time_d: d2_qdot1,
d_h_h0: d3_q0,
d_h_h1: d3_q1,
d_h_d: -tower.t3[2][2][2],
}))
}
}
#[inline]
pub(crate) fn q_chain_derivs_scalar(eta_t: f64, eta_ls: f64) -> (f64, f64, f64, f64, f64, f64) {
let inv_sigma = exp_sigma_inverse_from_eta_scalar(eta_ls);
let q = -safe_product(eta_t, inv_sigma);
(-inv_sigma, -q, inv_sigma, q, -inv_sigma, -q)
}