use super::*;
impl SurvivalMarginalSlopeFamily {
pub(crate) fn compute_row_flex_primary_gradient_hessian_exact(
&self,
row: usize,
block_states: &[ParameterBlockState],
q_geom: &SurvivalMarginalSlopeDynamicRow,
primary: &FlexPrimarySlices,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
self.ensure_scalar_flex_exact_score_geometry(
"compute_row_flex_primary_gradient_hessian_exact",
)?;
let g = block_states[2].eta[row];
let beta_h = self.flex_score_beta(block_states)?;
let beta_w = self.flex_link_beta(block_states)?;
let o_infl = self.influence_index_offset(row, block_states)?;
self.compute_row_flex_primary_gradient_hessian_from_parts(
row, q_geom.q0, q_geom.q1, q_geom.qd1, g, beta_h, beta_w, o_infl, primary,
)
}
pub(crate) fn compute_row_flex_primary_gradient_exact(
&self,
row: usize,
block_states: &[ParameterBlockState],
q_geom: &SurvivalMarginalSlopeDynamicRowGradient,
primary: &FlexPrimarySlices,
) -> Result<(f64, Array1<f64>), String> {
self.ensure_scalar_flex_exact_score_geometry("compute_row_flex_primary_gradient_exact")?;
let g = block_states[2].eta[row];
let beta_h = self.flex_score_beta(block_states)?;
let beta_w = self.flex_link_beta(block_states)?;
let o_infl = self.influence_index_offset(row, block_states)?;
self.compute_row_flex_primary_gradient_from_parts(
row, q_geom.q0, q_geom.q1, q_geom.qd1, g, beta_h, beta_w, o_infl, primary,
)
}
pub(crate) fn compute_row_flex_primary_gradient_from_parts(
&self,
row: usize,
q0: f64,
q1: f64,
qd1: f64,
g: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
o_infl: f64,
primary: &FlexPrimarySlices,
) -> Result<(f64, Array1<f64>), String> {
if survival_derivative_guard_violated(qd1, self.derivative_guard) {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival marginal-slope monotonicity violated at row {row}: qd1={qd1:.3e} < guard={:.3e}",
self.derivative_guard
),
}
.into());
}
let (a0, d0) = self.solve_row_survival_intercept_with_slot(
q0,
g,
beta_h,
beta_w,
Some((row, SurvivalInterceptSlotKind::Entry)),
)?;
let (a1, d1) = self.solve_row_survival_intercept_with_slot(
q1,
g,
beta_h,
beta_w,
Some((row, SurvivalInterceptSlotKind::Exit)),
)?;
let entry = self.compute_survival_timepoint_first_order_exact(
row, primary, q0, primary.q0, a0, g, d0, beta_h, beta_w, o_infl,
)?;
let exit = self.compute_survival_timepoint_first_order_exact(
row, primary, q1, primary.q1, a1, g, d1, beta_h, beta_w, o_infl,
)?;
if !exit.chi.is_finite() || exit.chi <= 0.0 {
return Err(SurvivalMarginalSlopeError::NumericalFailure {
reason: format!(
"survival marginal-slope row {row} produced non-positive observed chi1={:.3e}",
exit.chi
),
}
.into());
}
let (row_nll, grad, _) = self.flex_row_nll_value_grad_hess(
row,
primary,
q1,
qd1,
crate::survival::marginal_slope::timepoint_exact::flex_jet::FlexRowJet2Channels {
eta0_v: entry.eta,
eta0_g: entry.eta_u.view(),
eta0_h: None,
eta1_v: exit.eta,
eta1_g: exit.eta_u.view(),
eta1_h: None,
chi1_v: exit.chi,
chi1_g: exit.chi_u.view(),
chi1_h: None,
d1_v: exit.d,
d1_g: exit.d_u.view(),
d1_h: None,
},
)?;
Ok((row_nll, grad))
}
pub(crate) fn compute_row_flex_primary_gradient_hessian_from_parts(
&self,
row: usize,
q0: f64,
q1: f64,
qd1: f64,
g: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
o_infl: f64,
primary: &FlexPrimarySlices,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
if survival_derivative_guard_violated(qd1, self.derivative_guard) {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival marginal-slope monotonicity violated at row {row}: qd1={qd1:.3e} < guard={:.3e}",
self.derivative_guard
),
}
.into());
}
let (a0, _d0) = self.solve_row_survival_intercept_with_slot(
q0,
g,
beta_h,
beta_w,
Some((row, SurvivalInterceptSlotKind::Entry)),
)?;
let (a1, _d1) = self.solve_row_survival_intercept_with_slot(
q1,
g,
beta_h,
beta_w,
Some((row, SurvivalInterceptSlotKind::Exit)),
)?;
let entry = self.compute_survival_timepoint_exact_jet(
row, primary, q0, primary.q0, a0, g, beta_h, beta_w, o_infl,
)?;
let exit = self.compute_survival_timepoint_exact_jet(
row, primary, q1, primary.q1, a1, g, beta_h, beta_w, o_infl,
)?;
if !exit.chi.is_finite() || exit.chi <= 0.0 {
return Err(SurvivalMarginalSlopeError::NumericalFailure {
reason: format!(
"survival marginal-slope row {row} produced non-positive observed chi1={:.3e}",
exit.chi
),
}
.into());
}
let (row_nll, grad, hess) = self.flex_row_nll_value_grad_hess(
row,
primary,
q1,
qd1,
crate::survival::marginal_slope::timepoint_exact::flex_jet::FlexRowJet2Channels {
eta0_v: entry.eta,
eta0_g: entry.eta_u.view(),
eta0_h: Some(entry.eta_uv.view()),
eta1_v: exit.eta,
eta1_g: exit.eta_u.view(),
eta1_h: Some(exit.eta_uv.view()),
chi1_v: exit.chi,
chi1_g: exit.chi_u.view(),
chi1_h: Some(exit.chi_uv.view()),
d1_v: exit.d,
d1_g: exit.d_u.view(),
d1_h: Some(exit.d_uv.view()),
},
)?;
Ok((row_nll, grad, hess))
}
}