use super::*;
#[inline]
pub(crate) fn rigid_observed_logslope(g: f64, probit_scale: f64) -> f64 {
probit_scale * g
}
#[inline]
pub(crate) fn rigid_observed_scale(g: f64, probit_scale: f64) -> f64 {
let observed_g = rigid_observed_logslope(g, probit_scale);
(1.0 + observed_g * observed_g).sqrt()
}
#[inline]
pub(crate) fn rigid_observed_eta(q: f64, g: f64, z: f64, probit_scale: f64) -> f64 {
q * rigid_observed_scale(g, probit_scale) + rigid_observed_logslope(g, probit_scale) * z
}
pub(crate) fn survival_pilot_irls_row_metric_at_eta(
eta_pilot: &Array1<f64>,
sample_weights: &Array1<f64>,
event: &Array1<f64>,
) -> Result<Array1<f64>, String> {
let n = eta_pilot.len();
if sample_weights.len() != n || event.len() != n {
return Err(format!(
"survival cross-block W metric: length mismatch eta={}, weights={}, event={}",
n,
sample_weights.len(),
event.len(),
));
}
let mut w = Array1::<f64>::zeros(n);
for i in 0..n {
let eta = eta_pilot[i];
let d = event[i];
let weight = sample_weights[i];
let (_, k2, _, _) =
signed_probit_neglog_derivatives_up_to_fourth(-eta, weight * (1.0 - d))?;
let phi_part = weight * d;
w[i] = k2 + phi_part;
}
Ok(w)
}
pub(crate) fn survival_rigid_pilot_eta(
n: usize,
z_primary: &Array1<f64>,
offset_exit: &Array1<f64>,
marginal_offset: &Array1<f64>,
logslope_offset: &Array1<f64>,
baseline_slope: f64,
probit_scale: f64,
) -> Array1<f64> {
Array1::from_iter((0..n).map(|row| {
let q_exit = offset_exit[row] + marginal_offset[row];
let slope = baseline_slope + logslope_offset[row];
rigid_observed_eta(q_exit, slope, z_primary[row], probit_scale)
}))
}
pub(crate) fn survival_nonrigid_pilot_eta(
n: usize,
location_anchor_design: &DesignMatrix,
logslope_design: &DesignMatrix,
z_primary: &Array1<f64>,
offset_exit: &Array1<f64>,
marginal_offset: &Array1<f64>,
logslope_offset: &Array1<f64>,
baseline_slope: f64,
sample_weights: &Array1<f64>,
event: &Array1<f64>,
probit_scale: f64,
) -> Result<(Array1<f64>, Array1<f64>), String> {
if location_anchor_design.nrows() != n
|| logslope_design.nrows() != n
|| z_primary.len() != n
|| offset_exit.len() != n
|| marginal_offset.len() != n
|| logslope_offset.len() != n
|| sample_weights.len() != n
|| event.len() != n
{
return Err(format!(
"survival_nonrigid_pilot_eta: row-count mismatch (n={n}, location={}, logslope={}, \
z={}, offset_exit={}, marginal_offset={}, logslope_offset={}, weights={}, event={})",
location_anchor_design.nrows(),
logslope_design.nrows(),
z_primary.len(),
offset_exit.len(),
marginal_offset.len(),
logslope_offset.len(),
sample_weights.len(),
event.len(),
));
}
let p_loc = location_anchor_design.ncols();
let p_g = logslope_design.ncols();
let p_joint = p_loc + p_g;
if p_joint == 0 {
return Ok((
survival_rigid_pilot_eta(
n,
z_primary,
offset_exit,
marginal_offset,
logslope_offset,
baseline_slope,
probit_scale,
),
Array1::<f64>::zeros(p_g),
));
}
let mut q_exit = Array1::<f64>::zeros(n);
let mut slope = Array1::<f64>::zeros(n);
let mut eta1 = Array1::<f64>::zeros(n);
for i in 0..n {
q_exit[i] = offset_exit[i] + marginal_offset[i];
slope[i] = baseline_slope + logslope_offset[i];
eta1[i] = rigid_observed_eta(q_exit[i], slope[i], z_primary[i], probit_scale);
}
let mut chain_q = Array1::<f64>::zeros(n);
let mut chain_g = Array1::<f64>::zeros(n);
let mut grad_eta1 = Array1::<f64>::zeros(n);
let mut hess_eta1 = Array1::<f64>::zeros(n);
for i in 0..n {
let g_i = slope[i];
let z_i = z_primary[i];
let h_fd: f64 = 1.0e-7;
chain_q[i] = (rigid_observed_eta(q_exit[i] + h_fd, g_i, z_i, probit_scale)
- rigid_observed_eta(q_exit[i] - h_fd, g_i, z_i, probit_scale))
/ (2.0 * h_fd);
chain_g[i] = (rigid_observed_eta(q_exit[i], g_i + h_fd, z_i, probit_scale)
- rigid_observed_eta(q_exit[i], g_i - h_fd, z_i, probit_scale))
/ (2.0 * h_fd);
let (k1, k2, _, _) = signed_probit_neglog_derivatives_up_to_fourth(
-eta1[i],
sample_weights[i] * (1.0 - event[i]),
)
.map_err(|e| format!("survival_nonrigid_pilot_eta: row {i}: {e}"))?;
let event_w = sample_weights[i] * event[i];
grad_eta1[i] = -k1 + event_w * eta1[i];
hess_eta1[i] = k2 + event_w;
if !hess_eta1[i].is_finite() || hess_eta1[i] < 0.0 {
hess_eta1[i] = hess_eta1[i].max(0.0);
}
if !grad_eta1[i].is_finite() {
grad_eta1[i] = 0.0;
}
}
let mut gram = Array2::<f64>::zeros((p_joint, p_joint));
let mut rhs = Array1::<f64>::zeros(p_joint);
const PILOT_ROW_CHUNK: usize = 4096;
let mut x_chunk = Array2::<f64>::zeros((PILOT_ROW_CHUNK.min(n), p_joint));
let mut chunk_start = 0usize;
while chunk_start < n {
let chunk_end = (chunk_start + PILOT_ROW_CHUNK).min(n);
let rows = chunk_end - chunk_start;
let loc_rows = location_anchor_design
.try_row_chunk(chunk_start..chunk_end)
.map_err(|e| format!("survival_nonrigid_pilot_eta: location anchor rows: {e}"))?;
let g_rows = logslope_design
.try_row_chunk(chunk_start..chunk_end)
.map_err(|e| format!("survival_nonrigid_pilot_eta: logslope rows: {e}"))?;
{
let mut x_view = x_chunk.slice_mut(s![..rows, ..]);
for local in 0..rows {
let i = chunk_start + local;
for j in 0..p_loc {
x_view[[local, j]] = chain_q[i] * loc_rows[[local, j]];
}
for j in 0..p_g {
x_view[[local, p_loc + j]] = chain_g[i] * g_rows[[local, j]];
}
}
}
let h_chunk = hess_eta1.slice(s![chunk_start..chunk_end]).to_owned();
let mut neg_g_chunk = Array1::<f64>::zeros(rows);
for local in 0..rows {
neg_g_chunk[local] = -grad_eta1[chunk_start + local];
}
if rows == x_chunk.nrows() {
gram += &fast_xt_diag_x(&x_chunk, &h_chunk);
rhs += &fast_atv(&x_chunk, &neg_g_chunk);
} else {
let x_tail = x_chunk.slice(s![..rows, ..]).to_owned();
gram += &fast_xt_diag_x(&x_tail, &h_chunk);
rhs += &fast_atv(&x_tail, &neg_g_chunk);
}
chunk_start = chunk_end;
}
let avg_diag = if p_joint > 0 {
(0..p_joint).map(|j| gram[[j, j]]).sum::<f64>() / (p_joint as f64)
} else {
0.0
};
let ridge_eff = (1.0e-6 * avg_diag).max(1.0e-12);
for j in 0..p_joint {
gram[[j, j]] += ridge_eff;
}
let factor = gram
.cholesky(faer::Side::Lower)
.map_err(|e| format!("survival_nonrigid_pilot_eta: Cholesky failed: {e:?}"))?;
let beta_step = factor.solvevec(&rhs);
let mut beta_loc = Array1::<f64>::zeros(p_loc);
let mut beta_g = Array1::<f64>::zeros(p_g);
for j in 0..p_loc {
beta_loc[j] = beta_step[j];
}
for j in 0..p_g {
beta_g[j] = beta_step[p_loc + j];
}
let q_delta = location_anchor_design.apply(&beta_loc);
let g_delta = logslope_design.apply(&beta_g);
let mut step_cap: f64 = 4.0;
{
let mean: f64 = eta1.iter().sum::<f64>() / (n as f64).max(1.0);
let mut var: f64 = 0.0;
for i in 0..n {
let d = eta1[i] - mean;
var += d * d;
}
let sd = (var / (n as f64).max(1.0)).sqrt();
if sd.is_finite() && sd > 0.0 {
step_cap = (4.0_f64).max(4.0 * sd);
}
}
let mut pilot_eta = Array1::<f64>::zeros(n);
for i in 0..n {
let q_new = q_exit[i] + q_delta[i];
let g_new = slope[i] + g_delta[i];
let proposed = rigid_observed_eta(q_new, g_new, z_primary[i], probit_scale);
let delta = proposed - eta1[i];
let capped = if delta.abs() > step_cap {
eta1[i] + step_cap.copysign(delta)
} else {
proposed
};
pilot_eta[i] = if capped.is_finite() { capped } else { eta1[i] };
}
let beta_logslope = if beta_g.iter().all(|v| v.is_finite()) {
beta_g
} else {
Array1::<f64>::zeros(p_g)
};
Ok((pilot_eta, beta_logslope))
}
pub fn survival_marginal_slope_vector_scale(
slopes: &[f64],
covariance: &MarginalSlopeCovariance,
probit_scale: f64,
) -> Result<f64, String> {
marginal_slope_preserving_scale(slopes, covariance, probit_scale)
}
pub fn survival_marginal_slope_vector_eta(
q: f64,
z: &[f64],
slopes: &[f64],
covariance: &MarginalSlopeCovariance,
probit_scale: f64,
) -> Result<f64, String> {
if z.len() != covariance.dim() {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival marginal-slope vector eta: score/covariance dimension mismatch: z={}, covariance={}",
z.len(),
covariance.dim()
),
}
.into());
}
marginal_slope_probit_eta(q, z, slopes, covariance, probit_scale)
.map_err(|err| format!("survival marginal-slope vector eta: {err}"))
}
pub(crate) fn with_row_context(err: String, row: usize) -> String {
if let Some(colon) = err.find(':') {
let (head, tail) = err.split_at(colon);
format!("{head} at row {row}{tail}")
} else {
format!("{err} at row {row}")
}
}
pub fn survival_marginal_slope_vector_neglog(
q0: f64,
q1: f64,
qd1: f64,
slopes: &[f64],
z: &[f64],
covariance: &MarginalSlopeCovariance,
weight: f64,
event: f64,
derivative_guard: f64,
probit_scale: f64,
) -> Result<f64, String> {
if survival_derivative_guard_violated(qd1, derivative_guard) {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival marginal-slope monotonicity violated: qd1={qd1:.3e} < guard={derivative_guard:.3e}"
),
}
.into());
}
let c = survival_marginal_slope_vector_scale(slopes, covariance, probit_scale)?;
let eta0 = survival_marginal_slope_vector_eta(q0, z, slopes, covariance, probit_scale)?;
let eta1 = survival_marginal_slope_vector_eta(q1, z, slopes, covariance, probit_scale)?;
let ad1 = qd1 * c;
if !(ad1.is_finite() && ad1 > 0.0) {
return Err(SurvivalMarginalSlopeError::NumericalFailure {
reason: format!(
"survival marginal-slope transformed derivative must be positive, got {ad1}"
),
}
.into());
}
let (logcdf_neg_eta0, _) = signed_probit_logcdf_and_mills_ratio(-eta0);
let (logcdf_neg_eta1, _) = signed_probit_logcdf_and_mills_ratio(-eta1);
let log_phi_eta1 = -0.5 * (eta1 * eta1 + std::f64::consts::TAU.ln());
Ok(weight
* ((1.0 - event) * (-logcdf_neg_eta1) + logcdf_neg_eta0
- event * log_phi_eta1
- event * ad1.ln()))
}
pub(crate) fn marginal_slope_covariance_matvec(
covariance: &MarginalSlopeCovariance,
vector: &[f64],
) -> Result<Vec<f64>, String> {
covariance.validate("survival marginal-slope covariance matvec")?;
if vector.len() != covariance.dim() {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival marginal-slope covariance matvec dimension mismatch: vector={}, covariance={}",
vector.len(),
covariance.dim()
),
}
.into());
}
Ok(match covariance {
MarginalSlopeCovariance::Diagonal(diag) => vector
.iter()
.zip(diag.iter())
.map(|(&v, &sigma)| sigma * v)
.collect(),
MarginalSlopeCovariance::Full(cov) => {
let mut out = vec![0.0; cov.nrows()];
for i in 0..cov.nrows() {
for j in 0..cov.ncols() {
out[i] += cov[[i, j]] * vector[j];
}
}
out
}
MarginalSlopeCovariance::LowRank(factor) => {
let mut projected = vec![0.0; factor.ncols()];
for r in 0..factor.ncols() {
for k in 0..factor.nrows() {
projected[r] += factor[[k, r]] * vector[k];
}
}
let mut out = vec![0.0; factor.nrows()];
for k in 0..factor.nrows() {
for r in 0..factor.ncols() {
out[k] += factor[[k, r]] * projected[r];
}
}
out
}
})
}
pub(crate) fn row_primary_closed_form_vector(
q0: f64,
q1: f64,
qd1: f64,
slopes: &[f64],
z: &[f64],
covariance: &MarginalSlopeCovariance,
w: f64,
d: f64,
derivative_guard: f64,
probit_scale: f64,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let k = slopes.len();
if z.len() != k || covariance.dim() != k {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival marginal-slope vector row dimension mismatch: slopes={}, z={}, covariance={}",
k,
z.len(),
covariance.dim()
),
}
.into());
}
let c = survival_marginal_slope_vector_scale(slopes, covariance, probit_scale)?;
let sigma_g = marginal_slope_covariance_matvec(covariance, slopes)?;
let s2 = probit_scale * probit_scale;
let mut c1 = vec![0.0; k];
for a in 0..k {
c1[a] = s2 * sigma_g[a] / c;
}
let mut c2 = Array2::<f64>::zeros((k, k));
for a in 0..k {
for b in 0..k {
let sigma_ab = match covariance {
MarginalSlopeCovariance::Diagonal(diag) => {
if a == b {
diag[a]
} else {
0.0
}
}
MarginalSlopeCovariance::Full(cov) => cov[[a, b]],
MarginalSlopeCovariance::LowRank(factor) => {
let mut value = 0.0;
for r in 0..factor.ncols() {
value += factor[[a, r]] * factor[[b, r]];
}
value
}
};
c2[[a, b]] = s2 * sigma_ab / c - (s2 * sigma_g[a]) * (s2 * sigma_g[b]) / (c * c * c);
}
}
let linear = probit_scale
* slopes
.iter()
.zip(z.iter())
.map(|(&g, &zi)| g * zi)
.sum::<f64>();
let eta0 = q0 * c + linear;
let eta1 = q1 * c + linear;
let ad1 = qd1 * c;
if survival_derivative_guard_violated(qd1, derivative_guard) {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival marginal-slope monotonicity violated: qd1={qd1:.3e} < guard={derivative_guard:.3e}"
),
}
.into());
}
if !(ad1.is_finite() && ad1 > 0.0) {
return Err(SurvivalMarginalSlopeError::NumericalFailure {
reason: format!(
"survival marginal-slope transformed derivative must be positive, got {ad1}"
),
}
.into());
}
let (logcdf_neg_eta0, _) = signed_probit_logcdf_and_mills_ratio(-eta0);
let (logcdf_neg_eta1, _) = signed_probit_logcdf_and_mills_ratio(-eta1);
let log_phi_eta1 = -0.5 * (eta1 * eta1 + std::f64::consts::TAU.ln());
let nll =
w * ((1.0 - d) * (-logcdf_neg_eta1) + logcdf_neg_eta0 - d * log_phi_eta1 - d * ad1.ln());
let (e0_k1, e0_k2, _, _) = signed_probit_neglog_derivatives_up_to_fourth(-eta0, -w)?;
let (e1_k1, e1_k2, _, _) = signed_probit_neglog_derivatives_up_to_fourth(-eta1, w * (1.0 - d))?;
let phi_u1 = w * d * eta1;
let phi_u2 = w * d;
let (nl_u1, nl_u2, _, _) = neglog_derivatives(ad1);
let td_u1 = w * d * nl_u1;
let td_u2 = w * d * nl_u2;
let u1_eta0 = -e0_k1;
let u1_eta1 = -e1_k1 + phi_u1;
let u1_ad1 = td_u1;
let u2_eta0 = e0_k2;
let u2_eta1 = e1_k2 + phi_u2;
let u2_ad1 = td_u2;
let dim = 3 + k;
let mut grad = Array1::<f64>::zeros(dim);
let mut hess = Array2::<f64>::zeros((dim, dim));
grad[0] = u1_eta0 * c;
grad[1] = u1_eta1 * c;
grad[2] = u1_ad1 * c;
hess[[0, 0]] = u2_eta0 * c * c;
hess[[1, 1]] = u2_eta1 * c * c;
hess[[2, 2]] = u2_ad1 * c * c;
for a in 0..k {
let idx = 3 + a;
let dlin = probit_scale * z[a];
let deta0 = q0 * c1[a] + dlin;
let deta1 = q1 * c1[a] + dlin;
let dad1 = qd1 * c1[a];
grad[idx] = u1_eta0 * deta0 + u1_eta1 * deta1 + u1_ad1 * dad1;
hess[[0, idx]] = u2_eta0 * c * deta0 + u1_eta0 * c1[a];
hess[[idx, 0]] = hess[[0, idx]];
hess[[1, idx]] = u2_eta1 * c * deta1 + u1_eta1 * c1[a];
hess[[idx, 1]] = hess[[1, idx]];
hess[[2, idx]] = u2_ad1 * c * dad1 + u1_ad1 * c1[a];
hess[[idx, 2]] = hess[[2, idx]];
for b in 0..k {
let jdx = 3 + b;
let dlin_b = probit_scale * z[b];
let deta0_b = q0 * c1[b] + dlin_b;
let deta1_b = q1 * c1[b] + dlin_b;
let dad1_b = qd1 * c1[b];
hess[[idx, jdx]] = u2_eta0 * deta0 * deta0_b
+ u1_eta0 * q0 * c2[[a, b]]
+ u2_eta1 * deta1 * deta1_b
+ u1_eta1 * q1 * c2[[a, b]]
+ u2_ad1 * dad1 * dad1_b
+ u1_ad1 * qd1 * c2[[a, b]];
}
}
Ok((nll, grad, hess))
}
pub(crate) fn standardize_latent_z_matrix_with_policy(
z: &Array2<f64>,
weights: &Array1<f64>,
context: &str,
policy: &LatentZPolicy,
) -> Result<(Array2<f64>, LatentZNormalization), String> {
if z.ncols() == 0 {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: format!("{context} requires at least one z column"),
}
.into());
}
let mut out = Array2::<f64>::zeros(z.raw_dim());
let mut first_norm = LatentZNormalization { mean: 0.0, sd: 1.0 };
for col in 0..z.ncols() {
let input = z.column(col).to_owned();
let (standardized, normalization) =
standardize_latent_z_with_policy(&input, weights, context, policy)?;
if col == 0 {
first_norm = normalization;
}
out.column_mut(col).assign(&standardized);
}
Ok((out, first_norm))
}
#[inline]
pub(crate) fn c_derivatives(g: f64, probit_scale: f64) -> (f64, f64, f64, f64, f64) {
let observed_g = rigid_observed_logslope(g, probit_scale);
let g2 = observed_g * observed_g;
let s2 = probit_scale * probit_scale;
let s4 = s2 * s2;
let c = (1.0 + g2).sqrt();
let c2 = c * c;
let c3 = c2 * c;
let c5 = c3 * c2;
let c7 = c5 * c2;
let c1 = s2 * g / c;
let c2d = s2 / c3;
let c3d = -3.0 * s4 * g / c5;
let c4d = s4 * (12.0 * g2 - 3.0) / c7;
(c, c1, c2d, c3d, c4d)
}
#[inline]
pub(crate) fn neglog_derivatives(x: f64) -> (f64, f64, f64, f64) {
let x1 = x.max(1e-300);
let inv = 1.0 / x1;
let inv2 = inv * inv;
(-inv, inv2, -2.0 * inv2 * inv, 6.0 * inv2 * inv2)
}
#[inline]
pub(crate) fn row_primary_closed_form(
q0: f64,
q1: f64,
qd1: f64,
g: f64,
z: f64,
w: f64,
d: f64,
derivative_guard: f64,
probit_scale: f64,
) -> Result<(f64, [f64; N_PRIMARY], [[f64; N_PRIMARY]; N_PRIMARY]), String> {
let (c, c1, c2, ..) = c_derivatives(g, probit_scale);
let observed_g = rigid_observed_logslope(g, probit_scale);
let eta0 = q0 * c + observed_g * z;
let eta1 = q1 * c + observed_g * z;
let ad1 = qd1 * c;
if survival_derivative_guard_violated(qd1, derivative_guard) {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival marginal-slope monotonicity violated: qd1={qd1:.3e} < guard={derivative_guard:.3e}"
),
}
.into());
}
let (logcdf_neg_eta0, _) = signed_probit_logcdf_and_mills_ratio(-eta0);
let (logcdf_neg_eta1, _) = signed_probit_logcdf_and_mills_ratio(-eta1);
let log_phi_eta1 = -0.5 * (eta1 * eta1 + std::f64::consts::TAU.ln());
let log_ad1 = ad1.max(1e-300).ln();
let nll =
w * ((1.0 - d) * (-logcdf_neg_eta1) + logcdf_neg_eta0 - d * log_phi_eta1 - d * log_ad1);
let (e0_k1, e0_k2, _, _) = signed_probit_neglog_derivatives_up_to_fourth(-eta0, -w)?;
let (e1_k1, e1_k2, _, _) = signed_probit_neglog_derivatives_up_to_fourth(-eta1, w * (1.0 - d))?;
let phi_u1 = w * d * eta1;
let phi_u2 = w * d;
let (nl_u1, nl_u2, _, _) = neglog_derivatives(ad1);
let td_u1 = w * d * nl_u1;
let td_u2 = w * d * nl_u2;
let deta0_dq0 = c;
let deta0_dg = q0 * c1 + probit_scale * z;
let deta1_dq1 = c;
let deta1_dg = q1 * c1 + probit_scale * z;
let dad1_dqd1 = c;
let dad1_dg = qd1 * c1;
let u1_eta0 = -e0_k1;
let u1_eta1 = -e1_k1 + phi_u1;
let u1_ad1 = td_u1;
let mut grad = [0.0_f64; N_PRIMARY];
grad[0] = u1_eta0 * deta0_dq0; grad[1] = u1_eta1 * deta1_dq1; grad[2] = u1_ad1 * dad1_dqd1; grad[3] = u1_eta0 * deta0_dg + u1_eta1 * deta1_dg + u1_ad1 * dad1_dg;
let u2_eta0 = e0_k2;
let u2_eta1 = e1_k2 + phi_u2;
let u2_ad1 = td_u2;
let d2eta0_dq0dg = c1;
let d2eta1_dq1dg = c1;
let d2ad1_dqd1dg = c1;
let d2eta0_dg2 = q0 * c2;
let d2eta1_dg2 = q1 * c2;
let d2ad1_dg2 = qd1 * c2;
let mut hess = [[0.0_f64; N_PRIMARY]; N_PRIMARY];
hess[0][0] = u2_eta0 * deta0_dq0 * deta0_dq0;
hess[1][1] = u2_eta1 * deta1_dq1 * deta1_dq1;
hess[2][2] = u2_ad1 * dad1_dqd1 * dad1_dqd1;
hess[0][1] = 0.0;
hess[1][0] = 0.0;
hess[0][2] = 0.0;
hess[2][0] = 0.0;
hess[1][2] = 0.0;
hess[2][1] = 0.0;
hess[0][3] = u2_eta0 * deta0_dq0 * deta0_dg + u1_eta0 * d2eta0_dq0dg;
hess[3][0] = hess[0][3];
hess[1][3] = u2_eta1 * deta1_dq1 * deta1_dg + u1_eta1 * d2eta1_dq1dg;
hess[3][1] = hess[1][3];
hess[2][3] = u2_ad1 * dad1_dqd1 * dad1_dg + u1_ad1 * d2ad1_dqd1dg;
hess[3][2] = hess[2][3];
hess[3][3] = u2_eta0 * deta0_dg * deta0_dg
+ u1_eta0 * d2eta0_dg2
+ u2_eta1 * deta1_dg * deta1_dg
+ u1_eta1 * d2eta1_dg2
+ u2_ad1 * dad1_dg * dad1_dg
+ u1_ad1 * d2ad1_dg2;
Ok((nll, grad, hess))
}
pub(crate) fn row_primary_for_compiler(
q0: f64,
q1: f64,
qd1: f64,
g: f64,
z: f64,
w: f64,
d: f64,
derivative_guard: f64,
probit_scale: f64,
) -> Result<(f64, [f64; N_PRIMARY], [[f64; N_PRIMARY]; N_PRIMARY]), String> {
row_primary_closed_form(q0, q1, qd1, g, z, w, d, derivative_guard, probit_scale)
}
#[inline]
pub(crate) fn row_primary_closed_form_shared_score(
q0: f64,
q1: f64,
qd1: f64,
g: f64,
z_sum: f64,
covariance_ones: f64,
w: f64,
d: f64,
derivative_guard: f64,
probit_scale: f64,
) -> Result<(f64, [f64; N_PRIMARY], [[f64; N_PRIMARY]; N_PRIMARY]), String> {
if !(covariance_ones.is_finite() && covariance_ones >= 0.0) {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: format!(
"survival marginal-slope shared-score covariance scale must be finite and non-negative, got {covariance_ones}"
),
}
.into());
}
let effective_scale = probit_scale * covariance_ones.sqrt();
let (c, c1, c2, ..) = c_derivatives(g, effective_scale);
let linear = rigid_observed_logslope(g, probit_scale) * z_sum;
let linear_dg = probit_scale * z_sum;
let eta0 = q0 * c + linear;
let eta1 = q1 * c + linear;
let ad1 = qd1 * c;
if survival_derivative_guard_violated(qd1, derivative_guard) {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival marginal-slope monotonicity violated: qd1={qd1:.3e} < guard={derivative_guard:.3e}"
),
}
.into());
}
let (logcdf_neg_eta0, _) = signed_probit_logcdf_and_mills_ratio(-eta0);
let (logcdf_neg_eta1, _) = signed_probit_logcdf_and_mills_ratio(-eta1);
let log_phi_eta1 = -0.5 * (eta1 * eta1 + std::f64::consts::TAU.ln());
let log_ad1 = ad1.max(1e-300).ln();
let nll =
w * ((1.0 - d) * (-logcdf_neg_eta1) + logcdf_neg_eta0 - d * log_phi_eta1 - d * log_ad1);
let (e0_k1, e0_k2, _, _) = signed_probit_neglog_derivatives_up_to_fourth(-eta0, -w)?;
let (e1_k1, e1_k2, _, _) = signed_probit_neglog_derivatives_up_to_fourth(-eta1, w * (1.0 - d))?;
let phi_u1 = w * d * eta1;
let phi_u2 = w * d;
let (nl_u1, nl_u2, _, _) = neglog_derivatives(ad1);
let td_u1 = w * d * nl_u1;
let td_u2 = w * d * nl_u2;
let deta0_dq0 = c;
let deta0_dg = q0 * c1 + linear_dg;
let deta1_dq1 = c;
let deta1_dg = q1 * c1 + linear_dg;
let dad1_dqd1 = c;
let dad1_dg = qd1 * c1;
let u1_eta0 = -e0_k1;
let u1_eta1 = -e1_k1 + phi_u1;
let u1_ad1 = td_u1;
let mut grad = [0.0_f64; N_PRIMARY];
grad[0] = u1_eta0 * deta0_dq0;
grad[1] = u1_eta1 * deta1_dq1;
grad[2] = u1_ad1 * dad1_dqd1;
grad[3] = u1_eta0 * deta0_dg + u1_eta1 * deta1_dg + u1_ad1 * dad1_dg;
let u2_eta0 = e0_k2;
let u2_eta1 = e1_k2 + phi_u2;
let u2_ad1 = td_u2;
let d2eta0_dq0dg = c1;
let d2eta1_dq1dg = c1;
let d2ad1_dqd1dg = c1;
let d2eta0_dg2 = q0 * c2;
let d2eta1_dg2 = q1 * c2;
let d2ad1_dg2 = qd1 * c2;
let mut hess = [[0.0_f64; N_PRIMARY]; N_PRIMARY];
hess[0][0] = u2_eta0 * deta0_dq0 * deta0_dq0;
hess[1][1] = u2_eta1 * deta1_dq1 * deta1_dq1;
hess[2][2] = u2_ad1 * dad1_dqd1 * dad1_dqd1;
hess[0][3] = u2_eta0 * deta0_dq0 * deta0_dg + u1_eta0 * d2eta0_dq0dg;
hess[3][0] = hess[0][3];
hess[1][3] = u2_eta1 * deta1_dq1 * deta1_dg + u1_eta1 * d2eta1_dq1dg;
hess[3][1] = hess[1][3];
hess[2][3] = u2_ad1 * dad1_dqd1 * dad1_dg + u1_ad1 * d2ad1_dqd1dg;
hess[3][2] = hess[2][3];
hess[3][3] = u2_eta0 * deta0_dg * deta0_dg
+ u1_eta0 * d2eta0_dg2
+ u2_eta1 * deta1_dg * deta1_dg
+ u1_eta1 * d2eta1_dg2
+ u2_ad1 * dad1_dg * dad1_dg
+ u1_ad1 * d2ad1_dg2;
Ok((nll, grad, hess))
}
#[derive(Clone)]
pub(crate) struct RowPrimaryBase {
pub(crate) gradient: Array1<f64>,
pub(crate) hessian: Array2<f64>,
}
pub(crate) struct EvalCache {
pub(crate) row_bases: Vec<RowPrimaryBase>,
}